TUN-7934: Update quic-go to a version that queues datagrams for better throughput and drops large datagram
Remove TestUnregisterUdpSession
This commit is contained in:
parent
00cd7c333c
commit
8e69f41833
2
go.mod
2
go.mod
|
@ -24,7 +24,7 @@ require (
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/prometheus/client_golang v1.13.0
|
github.com/prometheus/client_golang v1.13.0
|
||||||
github.com/prometheus/client_model v0.2.0
|
github.com/prometheus/client_model v0.2.0
|
||||||
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55
|
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6
|
||||||
github.com/rs/zerolog v1.20.0
|
github.com/rs/zerolog v1.20.0
|
||||||
github.com/stretchr/testify v1.8.4
|
github.com/stretchr/testify v1.8.4
|
||||||
github.com/urfave/cli/v2 v2.3.0
|
github.com/urfave/cli/v2 v2.3.0
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -324,6 +324,8 @@ github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5
|
||||||
github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
|
||||||
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55 h1:I4N3ZRnkZPbDN935Tg8QDf8fRpHp3bZ0U0/L42jBgNE=
|
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55 h1:I4N3ZRnkZPbDN935Tg8QDf8fRpHp3bZ0U0/L42jBgNE=
|
||||||
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55/go.mod h1:PeN7kuVJ4xZbxSv/4OX6S1USOX8MJvydwpTx31vx60c=
|
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55/go.mod h1:PeN7kuVJ4xZbxSv/4OX6S1USOX8MJvydwpTx31vx60c=
|
||||||
|
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6 h1:OI4WiysowCcxLtcZMGBZildo12di3ljcMN4vWdUQpoU=
|
||||||
|
github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA=
|
||||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||||
|
|
|
@ -109,63 +109,6 @@ func TestConnectResponseMeta(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnregisterUdpSession(t *testing.T) {
|
|
||||||
unregisterMessage := "closed by eyeball"
|
|
||||||
|
|
||||||
var tests = []struct {
|
|
||||||
name string
|
|
||||||
sessionRPCServer mockSessionRPCServer
|
|
||||||
timeout time.Duration
|
|
||||||
}{
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "UnregisterUdpSessionTimesout if the RPC server does not respond",
|
|
||||||
sessionRPCServer: mockSessionRPCServer{
|
|
||||||
sessionID: uuid.New(),
|
|
||||||
dstIP: net.IP{172, 16, 0, 1},
|
|
||||||
dstPort: 8000,
|
|
||||||
closeIdleAfter: testCloseIdleAfterHint,
|
|
||||||
unregisterMessage: unregisterMessage,
|
|
||||||
traceContext: "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1",
|
|
||||||
},
|
|
||||||
// very very low value so we trigger the timeout every time.
|
|
||||||
timeout: time.Nanosecond * 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
|
||||||
logger := zerolog.Nop()
|
|
||||||
clientStream, serverStream := newMockRPCStreams()
|
|
||||||
sessionRegisteredChan := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
protocol, err := DetermineProtocol(serverStream)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = rpcServerStream.Serve(test.sessionRPCServer, nil, &logger)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
serverStream.Close()
|
|
||||||
close(sessionRegisteredChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, test.timeout, &logger)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NoError(t, reg.Err)
|
|
||||||
|
|
||||||
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
|
|
||||||
|
|
||||||
rpcClientStream.Close()
|
|
||||||
<-sessionRegisteredChan
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegisterUdpSession(t *testing.T) {
|
func TestRegisterUdpSession(t *testing.T) {
|
||||||
unregisterMessage := "closed by eyeball"
|
unregisterMessage := "closed by eyeball"
|
||||||
|
|
||||||
|
|
|
@ -1,27 +0,0 @@
|
||||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
|
||||||
modification, are permitted provided that the following conditions are
|
|
||||||
met:
|
|
||||||
|
|
||||||
* Redistributions of source code must retain the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer.
|
|
||||||
* Redistributions in binary form must reproduce the above
|
|
||||||
copyright notice, this list of conditions and the following disclaimer
|
|
||||||
in the documentation and/or other materials provided with the
|
|
||||||
distribution.
|
|
||||||
* Neither the name of Google Inc. nor the names of its
|
|
||||||
contributors may be used to endorse or promote products derived from
|
|
||||||
this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
||||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
||||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
||||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
||||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
||||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
||||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
||||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
@ -1,6 +0,0 @@
|
||||||
# qtls
|
|
||||||
|
|
||||||
[![Go Reference](https://pkg.go.dev/badge/github.com/quic-go/qtls-go1-20.svg)](https://pkg.go.dev/github.com/quic-go/qtls-go1-20)
|
|
||||||
[![.github/workflows/go-test.yml](https://github.com/quic-go/qtls-go1-20/actions/workflows/go-test.yml/badge.svg)](https://github.com/quic-go/qtls-go1-20/actions/workflows/go-test.yml)
|
|
||||||
|
|
||||||
This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/quic-go/quic-go).
|
|
|
@ -1,109 +0,0 @@
|
||||||
// Copyright 2009 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import "strconv"
|
|
||||||
|
|
||||||
// An AlertError is a TLS alert.
|
|
||||||
//
|
|
||||||
// When using a QUIC transport, QUICConn methods will return an error
|
|
||||||
// which wraps AlertError rather than sending a TLS alert.
|
|
||||||
type AlertError uint8
|
|
||||||
|
|
||||||
func (e AlertError) Error() string {
|
|
||||||
return alert(e).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
type alert uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
// alert level
|
|
||||||
alertLevelWarning = 1
|
|
||||||
alertLevelError = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
alertCloseNotify alert = 0
|
|
||||||
alertUnexpectedMessage alert = 10
|
|
||||||
alertBadRecordMAC alert = 20
|
|
||||||
alertDecryptionFailed alert = 21
|
|
||||||
alertRecordOverflow alert = 22
|
|
||||||
alertDecompressionFailure alert = 30
|
|
||||||
alertHandshakeFailure alert = 40
|
|
||||||
alertBadCertificate alert = 42
|
|
||||||
alertUnsupportedCertificate alert = 43
|
|
||||||
alertCertificateRevoked alert = 44
|
|
||||||
alertCertificateExpired alert = 45
|
|
||||||
alertCertificateUnknown alert = 46
|
|
||||||
alertIllegalParameter alert = 47
|
|
||||||
alertUnknownCA alert = 48
|
|
||||||
alertAccessDenied alert = 49
|
|
||||||
alertDecodeError alert = 50
|
|
||||||
alertDecryptError alert = 51
|
|
||||||
alertExportRestriction alert = 60
|
|
||||||
alertProtocolVersion alert = 70
|
|
||||||
alertInsufficientSecurity alert = 71
|
|
||||||
alertInternalError alert = 80
|
|
||||||
alertInappropriateFallback alert = 86
|
|
||||||
alertUserCanceled alert = 90
|
|
||||||
alertNoRenegotiation alert = 100
|
|
||||||
alertMissingExtension alert = 109
|
|
||||||
alertUnsupportedExtension alert = 110
|
|
||||||
alertCertificateUnobtainable alert = 111
|
|
||||||
alertUnrecognizedName alert = 112
|
|
||||||
alertBadCertificateStatusResponse alert = 113
|
|
||||||
alertBadCertificateHashValue alert = 114
|
|
||||||
alertUnknownPSKIdentity alert = 115
|
|
||||||
alertCertificateRequired alert = 116
|
|
||||||
alertNoApplicationProtocol alert = 120
|
|
||||||
)
|
|
||||||
|
|
||||||
var alertText = map[alert]string{
|
|
||||||
alertCloseNotify: "close notify",
|
|
||||||
alertUnexpectedMessage: "unexpected message",
|
|
||||||
alertBadRecordMAC: "bad record MAC",
|
|
||||||
alertDecryptionFailed: "decryption failed",
|
|
||||||
alertRecordOverflow: "record overflow",
|
|
||||||
alertDecompressionFailure: "decompression failure",
|
|
||||||
alertHandshakeFailure: "handshake failure",
|
|
||||||
alertBadCertificate: "bad certificate",
|
|
||||||
alertUnsupportedCertificate: "unsupported certificate",
|
|
||||||
alertCertificateRevoked: "revoked certificate",
|
|
||||||
alertCertificateExpired: "expired certificate",
|
|
||||||
alertCertificateUnknown: "unknown certificate",
|
|
||||||
alertIllegalParameter: "illegal parameter",
|
|
||||||
alertUnknownCA: "unknown certificate authority",
|
|
||||||
alertAccessDenied: "access denied",
|
|
||||||
alertDecodeError: "error decoding message",
|
|
||||||
alertDecryptError: "error decrypting message",
|
|
||||||
alertExportRestriction: "export restriction",
|
|
||||||
alertProtocolVersion: "protocol version not supported",
|
|
||||||
alertInsufficientSecurity: "insufficient security level",
|
|
||||||
alertInternalError: "internal error",
|
|
||||||
alertInappropriateFallback: "inappropriate fallback",
|
|
||||||
alertUserCanceled: "user canceled",
|
|
||||||
alertNoRenegotiation: "no renegotiation",
|
|
||||||
alertMissingExtension: "missing extension",
|
|
||||||
alertUnsupportedExtension: "unsupported extension",
|
|
||||||
alertCertificateUnobtainable: "certificate unobtainable",
|
|
||||||
alertUnrecognizedName: "unrecognized name",
|
|
||||||
alertBadCertificateStatusResponse: "bad certificate status response",
|
|
||||||
alertBadCertificateHashValue: "bad certificate hash value",
|
|
||||||
alertUnknownPSKIdentity: "unknown PSK identity",
|
|
||||||
alertCertificateRequired: "certificate required",
|
|
||||||
alertNoApplicationProtocol: "no application protocol",
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e alert) String() string {
|
|
||||||
s, ok := alertText[e]
|
|
||||||
if ok {
|
|
||||||
return "tls: " + s
|
|
||||||
}
|
|
||||||
return "tls: alert(" + strconv.Itoa(int(e)) + ")"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e alert) Error() string {
|
|
||||||
return e.String()
|
|
||||||
}
|
|
|
@ -1,293 +0,0 @@
|
||||||
// Copyright 2017 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rsa"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"hash"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// verifyHandshakeSignature verifies a signature against pre-hashed
|
|
||||||
// (if required) handshake contents.
|
|
||||||
func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error {
|
|
||||||
switch sigType {
|
|
||||||
case signatureECDSA:
|
|
||||||
pubKey, ok := pubkey.(*ecdsa.PublicKey)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("expected an ECDSA public key, got %T", pubkey)
|
|
||||||
}
|
|
||||||
if !ecdsa.VerifyASN1(pubKey, signed, sig) {
|
|
||||||
return errors.New("ECDSA verification failure")
|
|
||||||
}
|
|
||||||
case signatureEd25519:
|
|
||||||
pubKey, ok := pubkey.(ed25519.PublicKey)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey)
|
|
||||||
}
|
|
||||||
if !ed25519.Verify(pubKey, signed, sig) {
|
|
||||||
return errors.New("Ed25519 verification failure")
|
|
||||||
}
|
|
||||||
case signaturePKCS1v15:
|
|
||||||
pubKey, ok := pubkey.(*rsa.PublicKey)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
|
|
||||||
}
|
|
||||||
if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case signatureRSAPSS:
|
|
||||||
pubKey, ok := pubkey.(*rsa.PublicKey)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
|
|
||||||
}
|
|
||||||
signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}
|
|
||||||
if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return errors.New("internal error: unknown signature type")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
serverSignatureContext = "TLS 1.3, server CertificateVerify\x00"
|
|
||||||
clientSignatureContext = "TLS 1.3, client CertificateVerify\x00"
|
|
||||||
)
|
|
||||||
|
|
||||||
var signaturePadding = []byte{
|
|
||||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
|
||||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
|
||||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
|
||||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
|
||||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
|
||||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
|
||||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
|
||||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
|
||||||
}
|
|
||||||
|
|
||||||
// signedMessage returns the pre-hashed (if necessary) message to be signed by
|
|
||||||
// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3.
|
|
||||||
func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte {
|
|
||||||
if sigHash == directSigning {
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
b.Write(signaturePadding)
|
|
||||||
io.WriteString(b, context)
|
|
||||||
b.Write(transcript.Sum(nil))
|
|
||||||
return b.Bytes()
|
|
||||||
}
|
|
||||||
h := sigHash.New()
|
|
||||||
h.Write(signaturePadding)
|
|
||||||
io.WriteString(h, context)
|
|
||||||
h.Write(transcript.Sum(nil))
|
|
||||||
return h.Sum(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// typeAndHashFromSignatureScheme returns the corresponding signature type and
|
|
||||||
// crypto.Hash for a given TLS SignatureScheme.
|
|
||||||
func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) {
|
|
||||||
switch signatureAlgorithm {
|
|
||||||
case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512:
|
|
||||||
sigType = signaturePKCS1v15
|
|
||||||
case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512:
|
|
||||||
sigType = signatureRSAPSS
|
|
||||||
case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512:
|
|
||||||
sigType = signatureECDSA
|
|
||||||
case Ed25519:
|
|
||||||
sigType = signatureEd25519
|
|
||||||
default:
|
|
||||||
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
|
|
||||||
}
|
|
||||||
switch signatureAlgorithm {
|
|
||||||
case PKCS1WithSHA1, ECDSAWithSHA1:
|
|
||||||
hash = crypto.SHA1
|
|
||||||
case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256:
|
|
||||||
hash = crypto.SHA256
|
|
||||||
case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384:
|
|
||||||
hash = crypto.SHA384
|
|
||||||
case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512:
|
|
||||||
hash = crypto.SHA512
|
|
||||||
case Ed25519:
|
|
||||||
hash = directSigning
|
|
||||||
default:
|
|
||||||
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
|
|
||||||
}
|
|
||||||
return sigType, hash, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for
|
|
||||||
// a given public key used with TLS 1.0 and 1.1, before the introduction of
|
|
||||||
// signature algorithm negotiation.
|
|
||||||
func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) {
|
|
||||||
switch pub.(type) {
|
|
||||||
case *rsa.PublicKey:
|
|
||||||
return signaturePKCS1v15, crypto.MD5SHA1, nil
|
|
||||||
case *ecdsa.PublicKey:
|
|
||||||
return signatureECDSA, crypto.SHA1, nil
|
|
||||||
case ed25519.PublicKey:
|
|
||||||
// RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1,
|
|
||||||
// but it requires holding on to a handshake transcript to do a
|
|
||||||
// full signature, and not even OpenSSL bothers with the
|
|
||||||
// complexity, so we can't even test it properly.
|
|
||||||
return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2")
|
|
||||||
default:
|
|
||||||
return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var rsaSignatureSchemes = []struct {
|
|
||||||
scheme SignatureScheme
|
|
||||||
minModulusBytes int
|
|
||||||
maxVersion uint16
|
|
||||||
}{
|
|
||||||
// RSA-PSS is used with PSSSaltLengthEqualsHash, and requires
|
|
||||||
// emLen >= hLen + sLen + 2
|
|
||||||
{PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13},
|
|
||||||
{PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13},
|
|
||||||
{PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13},
|
|
||||||
// PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires
|
|
||||||
// emLen >= len(prefix) + hLen + 11
|
|
||||||
// TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS.
|
|
||||||
{PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12},
|
|
||||||
{PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12},
|
|
||||||
{PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12},
|
|
||||||
{PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12},
|
|
||||||
}
|
|
||||||
|
|
||||||
// signatureSchemesForCertificate returns the list of supported SignatureSchemes
|
|
||||||
// for a given certificate, based on the public key and the protocol version,
|
|
||||||
// and optionally filtered by its explicit SupportedSignatureAlgorithms.
|
|
||||||
//
|
|
||||||
// This function must be kept in sync with supportedSignatureAlgorithms.
|
|
||||||
// FIPS filtering is applied in the caller, selectSignatureScheme.
|
|
||||||
func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme {
|
|
||||||
priv, ok := cert.PrivateKey.(crypto.Signer)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var sigAlgs []SignatureScheme
|
|
||||||
switch pub := priv.Public().(type) {
|
|
||||||
case *ecdsa.PublicKey:
|
|
||||||
if version != VersionTLS13 {
|
|
||||||
// In TLS 1.2 and earlier, ECDSA algorithms are not
|
|
||||||
// constrained to a single curve.
|
|
||||||
sigAlgs = []SignatureScheme{
|
|
||||||
ECDSAWithP256AndSHA256,
|
|
||||||
ECDSAWithP384AndSHA384,
|
|
||||||
ECDSAWithP521AndSHA512,
|
|
||||||
ECDSAWithSHA1,
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
switch pub.Curve {
|
|
||||||
case elliptic.P256():
|
|
||||||
sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256}
|
|
||||||
case elliptic.P384():
|
|
||||||
sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384}
|
|
||||||
case elliptic.P521():
|
|
||||||
sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512}
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case *rsa.PublicKey:
|
|
||||||
size := pub.Size()
|
|
||||||
sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes))
|
|
||||||
for _, candidate := range rsaSignatureSchemes {
|
|
||||||
if size >= candidate.minModulusBytes && version <= candidate.maxVersion {
|
|
||||||
sigAlgs = append(sigAlgs, candidate.scheme)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case ed25519.PublicKey:
|
|
||||||
sigAlgs = []SignatureScheme{Ed25519}
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if cert.SupportedSignatureAlgorithms != nil {
|
|
||||||
var filteredSigAlgs []SignatureScheme
|
|
||||||
for _, sigAlg := range sigAlgs {
|
|
||||||
if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) {
|
|
||||||
filteredSigAlgs = append(filteredSigAlgs, sigAlg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return filteredSigAlgs
|
|
||||||
}
|
|
||||||
return sigAlgs
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectSignatureScheme picks a SignatureScheme from the peer's preference list
|
|
||||||
// that works with the selected certificate. It's only called for protocol
|
|
||||||
// versions that support signature algorithms, so TLS 1.2 and 1.3.
|
|
||||||
func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) {
|
|
||||||
supportedAlgs := signatureSchemesForCertificate(vers, c)
|
|
||||||
if len(supportedAlgs) == 0 {
|
|
||||||
return 0, unsupportedCertificateError(c)
|
|
||||||
}
|
|
||||||
if len(peerAlgs) == 0 && vers == VersionTLS12 {
|
|
||||||
// For TLS 1.2, if the client didn't send signature_algorithms then we
|
|
||||||
// can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1.
|
|
||||||
peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1}
|
|
||||||
}
|
|
||||||
// Pick signature scheme in the peer's preference order, as our
|
|
||||||
// preference order is not configurable.
|
|
||||||
for _, preferredAlg := range peerAlgs {
|
|
||||||
if needFIPS() && !isSupportedSignatureAlgorithm(preferredAlg, fipsSupportedSignatureAlgorithms) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) {
|
|
||||||
return preferredAlg, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms")
|
|
||||||
}
|
|
||||||
|
|
||||||
// unsupportedCertificateError returns a helpful error for certificates with
|
|
||||||
// an unsupported private key.
|
|
||||||
func unsupportedCertificateError(cert *Certificate) error {
|
|
||||||
switch cert.PrivateKey.(type) {
|
|
||||||
case rsa.PrivateKey, ecdsa.PrivateKey:
|
|
||||||
return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T",
|
|
||||||
cert.PrivateKey, cert.PrivateKey)
|
|
||||||
case *ed25519.PrivateKey:
|
|
||||||
return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey")
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, ok := cert.PrivateKey.(crypto.Signer)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer",
|
|
||||||
cert.PrivateKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch pub := signer.Public().(type) {
|
|
||||||
case *ecdsa.PublicKey:
|
|
||||||
switch pub.Curve {
|
|
||||||
case elliptic.P256():
|
|
||||||
case elliptic.P384():
|
|
||||||
case elliptic.P521():
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name)
|
|
||||||
}
|
|
||||||
case *rsa.PublicKey:
|
|
||||||
return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms")
|
|
||||||
case ed25519.PublicKey:
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("tls: unsupported certificate key (%T)", pub)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cert.SupportedSignatureAlgorithms != nil {
|
|
||||||
return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms")
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey)
|
|
||||||
}
|
|
|
@ -1,95 +0,0 @@
|
||||||
// Copyright 2022 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/x509"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
type cacheEntry struct {
|
|
||||||
refs atomic.Int64
|
|
||||||
cert *x509.Certificate
|
|
||||||
}
|
|
||||||
|
|
||||||
// certCache implements an intern table for reference counted x509.Certificates,
|
|
||||||
// implemented in a similar fashion to BoringSSL's CRYPTO_BUFFER_POOL. This
|
|
||||||
// allows for a single x509.Certificate to be kept in memory and referenced from
|
|
||||||
// multiple Conns. Returned references should not be mutated by callers. Certificates
|
|
||||||
// are still safe to use after they are removed from the cache.
|
|
||||||
//
|
|
||||||
// Certificates are returned wrapped in a activeCert struct that should be held by
|
|
||||||
// the caller. When references to the activeCert are freed, the number of references
|
|
||||||
// to the certificate in the cache is decremented. Once the number of references
|
|
||||||
// reaches zero, the entry is evicted from the cache.
|
|
||||||
//
|
|
||||||
// The main difference between this implementation and CRYPTO_BUFFER_POOL is that
|
|
||||||
// CRYPTO_BUFFER_POOL is a more generic structure which supports blobs of data,
|
|
||||||
// rather than specific structures. Since we only care about x509.Certificates,
|
|
||||||
// certCache is implemented as a specific cache, rather than a generic one.
|
|
||||||
//
|
|
||||||
// See https://boringssl.googlesource.com/boringssl/+/master/include/openssl/pool.h
|
|
||||||
// and https://boringssl.googlesource.com/boringssl/+/master/crypto/pool/pool.c
|
|
||||||
// for the BoringSSL reference.
|
|
||||||
type certCache struct {
|
|
||||||
sync.Map
|
|
||||||
}
|
|
||||||
|
|
||||||
var clientCertCache = new(certCache)
|
|
||||||
|
|
||||||
// activeCert is a handle to a certificate held in the cache. Once there are
|
|
||||||
// no alive activeCerts for a given certificate, the certificate is removed
|
|
||||||
// from the cache by a finalizer.
|
|
||||||
type activeCert struct {
|
|
||||||
cert *x509.Certificate
|
|
||||||
}
|
|
||||||
|
|
||||||
// active increments the number of references to the entry, wraps the
|
|
||||||
// certificate in the entry in a activeCert, and sets the finalizer.
|
|
||||||
//
|
|
||||||
// Note that there is a race between active and the finalizer set on the
|
|
||||||
// returned activeCert, triggered if active is called after the ref count is
|
|
||||||
// decremented such that refs may be > 0 when evict is called. We consider this
|
|
||||||
// safe, since the caller holding an activeCert for an entry that is no longer
|
|
||||||
// in the cache is fine, with the only side effect being the memory overhead of
|
|
||||||
// there being more than one distinct reference to a certificate alive at once.
|
|
||||||
func (cc *certCache) active(e *cacheEntry) *activeCert {
|
|
||||||
e.refs.Add(1)
|
|
||||||
a := &activeCert{e.cert}
|
|
||||||
runtime.SetFinalizer(a, func(_ *activeCert) {
|
|
||||||
if e.refs.Add(-1) == 0 {
|
|
||||||
cc.evict(e)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
// evict removes a cacheEntry from the cache.
|
|
||||||
func (cc *certCache) evict(e *cacheEntry) {
|
|
||||||
cc.Delete(string(e.cert.Raw))
|
|
||||||
}
|
|
||||||
|
|
||||||
// newCert returns a x509.Certificate parsed from der. If there is already a copy
|
|
||||||
// of the certificate in the cache, a reference to the existing certificate will
|
|
||||||
// be returned. Otherwise, a fresh certificate will be added to the cache, and
|
|
||||||
// the reference returned. The returned reference should not be mutated.
|
|
||||||
func (cc *certCache) newCert(der []byte) (*activeCert, error) {
|
|
||||||
if entry, ok := cc.Load(string(der)); ok {
|
|
||||||
return cc.active(entry.(*cacheEntry)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := x509.ParseCertificate(der)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := &cacheEntry{cert: cert}
|
|
||||||
if entry, loaded := cc.LoadOrStore(string(der), entry); loaded {
|
|
||||||
return cc.active(entry.(*cacheEntry)), nil
|
|
||||||
}
|
|
||||||
return cc.active(entry), nil
|
|
||||||
}
|
|
|
@ -1,691 +0,0 @@
|
||||||
// Copyright 2010 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/aes"
|
|
||||||
"crypto/cipher"
|
|
||||||
"crypto/des"
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/rc4"
|
|
||||||
"crypto/sha1"
|
|
||||||
"crypto/sha256"
|
|
||||||
"fmt"
|
|
||||||
"hash"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
|
||||||
"golang.org/x/sys/cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CipherSuite is a TLS cipher suite. Note that most functions in this package
|
|
||||||
// accept and expose cipher suite IDs instead of this type.
|
|
||||||
type CipherSuite struct {
|
|
||||||
ID uint16
|
|
||||||
Name string
|
|
||||||
|
|
||||||
// Supported versions is the list of TLS protocol versions that can
|
|
||||||
// negotiate this cipher suite.
|
|
||||||
SupportedVersions []uint16
|
|
||||||
|
|
||||||
// Insecure is true if the cipher suite has known security issues
|
|
||||||
// due to its primitives, design, or implementation.
|
|
||||||
Insecure bool
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12}
|
|
||||||
supportedOnlyTLS12 = []uint16{VersionTLS12}
|
|
||||||
supportedOnlyTLS13 = []uint16{VersionTLS13}
|
|
||||||
)
|
|
||||||
|
|
||||||
// CipherSuites returns a list of cipher suites currently implemented by this
|
|
||||||
// package, excluding those with security issues, which are returned by
|
|
||||||
// InsecureCipherSuites.
|
|
||||||
//
|
|
||||||
// The list is sorted by ID. Note that the default cipher suites selected by
|
|
||||||
// this package might depend on logic that can't be captured by a static list,
|
|
||||||
// and might not match those returned by this function.
|
|
||||||
func CipherSuites() []*CipherSuite {
|
|
||||||
return []*CipherSuite{
|
|
||||||
{TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
|
|
||||||
{TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
|
|
||||||
{TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
|
|
||||||
{TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
|
|
||||||
|
|
||||||
{TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false},
|
|
||||||
{TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false},
|
|
||||||
{TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false},
|
|
||||||
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
|
|
||||||
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
|
|
||||||
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
|
|
||||||
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
|
|
||||||
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
|
|
||||||
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsecureCipherSuites returns a list of cipher suites currently implemented by
|
|
||||||
// this package and which have security issues.
|
|
||||||
//
|
|
||||||
// Most applications should not use the cipher suites in this list, and should
|
|
||||||
// only use those returned by CipherSuites.
|
|
||||||
func InsecureCipherSuites() []*CipherSuite {
|
|
||||||
// This list includes RC4, CBC_SHA256, and 3DES cipher suites. See
|
|
||||||
// cipherSuitesPreferenceOrder for details.
|
|
||||||
return []*CipherSuite{
|
|
||||||
{TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
|
|
||||||
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
|
|
||||||
{TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
|
|
||||||
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
|
|
||||||
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
|
|
||||||
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CipherSuiteName returns the standard name for the passed cipher suite ID
|
|
||||||
// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation
|
|
||||||
// of the ID value if the cipher suite is not implemented by this package.
|
|
||||||
func CipherSuiteName(id uint16) string {
|
|
||||||
for _, c := range CipherSuites() {
|
|
||||||
if c.ID == id {
|
|
||||||
return c.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, c := range InsecureCipherSuites() {
|
|
||||||
if c.ID == id {
|
|
||||||
return c.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("0x%04X", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// suiteECDHE indicates that the cipher suite involves elliptic curve
|
|
||||||
// Diffie-Hellman. This means that it should only be selected when the
|
|
||||||
// client indicates that it supports ECC with a curve and point format
|
|
||||||
// that we're happy with.
|
|
||||||
suiteECDHE = 1 << iota
|
|
||||||
// suiteECSign indicates that the cipher suite involves an ECDSA or
|
|
||||||
// EdDSA signature and therefore may only be selected when the server's
|
|
||||||
// certificate is ECDSA or EdDSA. If this is not set then the cipher suite
|
|
||||||
// is RSA based.
|
|
||||||
suiteECSign
|
|
||||||
// suiteTLS12 indicates that the cipher suite should only be advertised
|
|
||||||
// and accepted when using TLS 1.2.
|
|
||||||
suiteTLS12
|
|
||||||
// suiteSHA384 indicates that the cipher suite uses SHA384 as the
|
|
||||||
// handshake hash.
|
|
||||||
suiteSHA384
|
|
||||||
)
|
|
||||||
|
|
||||||
// A cipherSuite is a TLS 1.0–1.2 cipher suite, and defines the key exchange
|
|
||||||
// mechanism, as well as the cipher+MAC pair or the AEAD.
|
|
||||||
type cipherSuite struct {
|
|
||||||
id uint16
|
|
||||||
// the lengths, in bytes, of the key material needed for each component.
|
|
||||||
keyLen int
|
|
||||||
macLen int
|
|
||||||
ivLen int
|
|
||||||
ka func(version uint16) keyAgreement
|
|
||||||
// flags is a bitmask of the suite* values, above.
|
|
||||||
flags int
|
|
||||||
cipher func(key, iv []byte, isRead bool) any
|
|
||||||
mac func(key []byte) hash.Hash
|
|
||||||
aead func(key, fixedNonce []byte) aead
|
|
||||||
}
|
|
||||||
|
|
||||||
var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter.
|
|
||||||
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
|
|
||||||
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM},
|
|
||||||
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
|
|
||||||
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil},
|
|
||||||
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
|
|
||||||
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
|
|
||||||
{TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM},
|
|
||||||
{TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
|
|
||||||
{TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil},
|
|
||||||
{TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
|
|
||||||
{TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
|
|
||||||
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil},
|
|
||||||
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil},
|
|
||||||
{TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil},
|
|
||||||
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil},
|
|
||||||
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil},
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectCipherSuite returns the first TLS 1.0–1.2 cipher suite from ids which
|
|
||||||
// is also in supportedIDs and passes the ok filter.
|
|
||||||
func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite {
|
|
||||||
for _, id := range ids {
|
|
||||||
candidate := cipherSuiteByID(id)
|
|
||||||
if candidate == nil || !ok(candidate) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, suppID := range supportedIDs {
|
|
||||||
if id == suppID {
|
|
||||||
return candidate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash
|
|
||||||
// algorithm to be used with HKDF. See RFC 8446, Appendix B.4.
|
|
||||||
type cipherSuiteTLS13 struct {
|
|
||||||
id uint16
|
|
||||||
keyLen int
|
|
||||||
aead func(key, fixedNonce []byte) aead
|
|
||||||
hash crypto.Hash
|
|
||||||
}
|
|
||||||
|
|
||||||
var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map.
|
|
||||||
{TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256},
|
|
||||||
{TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256},
|
|
||||||
{TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384},
|
|
||||||
}
|
|
||||||
|
|
||||||
// cipherSuitesPreferenceOrder is the order in which we'll select (on the
|
|
||||||
// server) or advertise (on the client) TLS 1.0–1.2 cipher suites.
|
|
||||||
//
|
|
||||||
// Cipher suites are filtered but not reordered based on the application and
|
|
||||||
// peer's preferences, meaning we'll never select a suite lower in this list if
|
|
||||||
// any higher one is available. This makes it more defensible to keep weaker
|
|
||||||
// cipher suites enabled, especially on the server side where we get the last
|
|
||||||
// word, since there are no known downgrade attacks on cipher suites selection.
|
|
||||||
//
|
|
||||||
// The list is sorted by applying the following priority rules, stopping at the
|
|
||||||
// first (most important) applicable one:
|
|
||||||
//
|
|
||||||
// - Anything else comes before RC4
|
|
||||||
//
|
|
||||||
// RC4 has practically exploitable biases. See https://www.rc4nomore.com.
|
|
||||||
//
|
|
||||||
// - Anything else comes before CBC_SHA256
|
|
||||||
//
|
|
||||||
// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13
|
|
||||||
// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and
|
|
||||||
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
|
|
||||||
//
|
|
||||||
// - Anything else comes before 3DES
|
|
||||||
//
|
|
||||||
// 3DES has 64-bit blocks, which makes it fundamentally susceptible to
|
|
||||||
// birthday attacks. See https://sweet32.info.
|
|
||||||
//
|
|
||||||
// - ECDHE comes before anything else
|
|
||||||
//
|
|
||||||
// Once we got the broken stuff out of the way, the most important
|
|
||||||
// property a cipher suite can have is forward secrecy. We don't
|
|
||||||
// implement FFDHE, so that means ECDHE.
|
|
||||||
//
|
|
||||||
// - AEADs come before CBC ciphers
|
|
||||||
//
|
|
||||||
// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites
|
|
||||||
// are fundamentally fragile, and suffered from an endless sequence of
|
|
||||||
// padding oracle attacks. See https://eprint.iacr.org/2015/1129,
|
|
||||||
// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and
|
|
||||||
// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/.
|
|
||||||
//
|
|
||||||
// - AES comes before ChaCha20
|
|
||||||
//
|
|
||||||
// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster
|
|
||||||
// than ChaCha20Poly1305.
|
|
||||||
//
|
|
||||||
// When AES hardware is not available, AES-128-GCM is one or more of: much
|
|
||||||
// slower, way more complex, and less safe (because not constant time)
|
|
||||||
// than ChaCha20Poly1305.
|
|
||||||
//
|
|
||||||
// We use this list if we think both peers have AES hardware, and
|
|
||||||
// cipherSuitesPreferenceOrderNoAES otherwise.
|
|
||||||
//
|
|
||||||
// - AES-128 comes before AES-256
|
|
||||||
//
|
|
||||||
// The only potential advantages of AES-256 are better multi-target
|
|
||||||
// margins, and hypothetical post-quantum properties. Neither apply to
|
|
||||||
// TLS, and AES-256 is slower due to its four extra rounds (which don't
|
|
||||||
// contribute to the advantages above).
|
|
||||||
//
|
|
||||||
// - ECDSA comes before RSA
|
|
||||||
//
|
|
||||||
// The relative order of ECDSA and RSA cipher suites doesn't matter,
|
|
||||||
// as they depend on the certificate. Pick one to get a stable order.
|
|
||||||
var cipherSuitesPreferenceOrder = []uint16{
|
|
||||||
// AEADs w/ ECDHE
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
||||||
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
|
||||||
|
|
||||||
// CBC w/ ECDHE
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
|
||||||
|
|
||||||
// AEADs w/o ECDHE
|
|
||||||
TLS_RSA_WITH_AES_128_GCM_SHA256,
|
|
||||||
TLS_RSA_WITH_AES_256_GCM_SHA384,
|
|
||||||
|
|
||||||
// CBC w/o ECDHE
|
|
||||||
TLS_RSA_WITH_AES_128_CBC_SHA,
|
|
||||||
TLS_RSA_WITH_AES_256_CBC_SHA,
|
|
||||||
|
|
||||||
// 3DES
|
|
||||||
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
|
||||||
TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
|
||||||
|
|
||||||
// CBC_SHA256
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
|
||||||
TLS_RSA_WITH_AES_128_CBC_SHA256,
|
|
||||||
|
|
||||||
// RC4
|
|
||||||
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
|
||||||
TLS_RSA_WITH_RC4_128_SHA,
|
|
||||||
}
|
|
||||||
|
|
||||||
var cipherSuitesPreferenceOrderNoAES = []uint16{
|
|
||||||
// ChaCha20Poly1305
|
|
||||||
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
|
||||||
|
|
||||||
// AES-GCM w/ ECDHE
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
||||||
|
|
||||||
// The rest of cipherSuitesPreferenceOrder.
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
|
||||||
TLS_RSA_WITH_AES_128_GCM_SHA256,
|
|
||||||
TLS_RSA_WITH_AES_256_GCM_SHA384,
|
|
||||||
TLS_RSA_WITH_AES_128_CBC_SHA,
|
|
||||||
TLS_RSA_WITH_AES_256_CBC_SHA,
|
|
||||||
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
|
||||||
TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
|
||||||
TLS_RSA_WITH_AES_128_CBC_SHA256,
|
|
||||||
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
|
||||||
TLS_RSA_WITH_RC4_128_SHA,
|
|
||||||
}
|
|
||||||
|
|
||||||
// disabledCipherSuites are not used unless explicitly listed in
|
|
||||||
// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder.
|
|
||||||
var disabledCipherSuites = []uint16{
|
|
||||||
// CBC_SHA256
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
|
||||||
TLS_RSA_WITH_AES_128_CBC_SHA256,
|
|
||||||
|
|
||||||
// RC4
|
|
||||||
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
|
||||||
TLS_RSA_WITH_RC4_128_SHA,
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites)
|
|
||||||
defaultCipherSuites = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen]
|
|
||||||
)
|
|
||||||
|
|
||||||
// defaultCipherSuitesTLS13 is also the preference order, since there are no
|
|
||||||
// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as
|
|
||||||
// cipherSuitesPreferenceOrder applies.
|
|
||||||
var defaultCipherSuitesTLS13 = []uint16{
|
|
||||||
TLS_AES_128_GCM_SHA256,
|
|
||||||
TLS_AES_256_GCM_SHA384,
|
|
||||||
TLS_CHACHA20_POLY1305_SHA256,
|
|
||||||
}
|
|
||||||
|
|
||||||
var defaultCipherSuitesTLS13NoAES = []uint16{
|
|
||||||
TLS_CHACHA20_POLY1305_SHA256,
|
|
||||||
TLS_AES_128_GCM_SHA256,
|
|
||||||
TLS_AES_256_GCM_SHA384,
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
|
|
||||||
hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
|
|
||||||
// Keep in sync with crypto/aes/cipher_s390x.go.
|
|
||||||
hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR &&
|
|
||||||
(cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
|
|
||||||
|
|
||||||
hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 ||
|
|
||||||
runtime.GOARCH == "arm64" && hasGCMAsmARM64 ||
|
|
||||||
runtime.GOARCH == "s390x" && hasGCMAsmS390X
|
|
||||||
)
|
|
||||||
|
|
||||||
var aesgcmCiphers = map[uint16]bool{
|
|
||||||
// TLS 1.2
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true,
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true,
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true,
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true,
|
|
||||||
// TLS 1.3
|
|
||||||
TLS_AES_128_GCM_SHA256: true,
|
|
||||||
TLS_AES_256_GCM_SHA384: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
var nonAESGCMAEADCiphers = map[uint16]bool{
|
|
||||||
// TLS 1.2
|
|
||||||
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true,
|
|
||||||
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true,
|
|
||||||
// TLS 1.3
|
|
||||||
TLS_CHACHA20_POLY1305_SHA256: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// aesgcmPreferred returns whether the first known cipher in the preference list
|
|
||||||
// is an AES-GCM cipher, implying the peer has hardware support for it.
|
|
||||||
func aesgcmPreferred(ciphers []uint16) bool {
|
|
||||||
for _, cID := range ciphers {
|
|
||||||
if c := cipherSuiteByID(cID); c != nil {
|
|
||||||
return aesgcmCiphers[cID]
|
|
||||||
}
|
|
||||||
if c := cipherSuiteTLS13ByID(cID); c != nil {
|
|
||||||
return aesgcmCiphers[cID]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func cipherRC4(key, iv []byte, isRead bool) any {
|
|
||||||
cipher, _ := rc4.NewCipher(key)
|
|
||||||
return cipher
|
|
||||||
}
|
|
||||||
|
|
||||||
func cipher3DES(key, iv []byte, isRead bool) any {
|
|
||||||
block, _ := des.NewTripleDESCipher(key)
|
|
||||||
if isRead {
|
|
||||||
return cipher.NewCBCDecrypter(block, iv)
|
|
||||||
}
|
|
||||||
return cipher.NewCBCEncrypter(block, iv)
|
|
||||||
}
|
|
||||||
|
|
||||||
func cipherAES(key, iv []byte, isRead bool) any {
|
|
||||||
block, _ := aes.NewCipher(key)
|
|
||||||
if isRead {
|
|
||||||
return cipher.NewCBCDecrypter(block, iv)
|
|
||||||
}
|
|
||||||
return cipher.NewCBCEncrypter(block, iv)
|
|
||||||
}
|
|
||||||
|
|
||||||
// macSHA1 returns a SHA-1 based constant time MAC.
|
|
||||||
func macSHA1(key []byte) hash.Hash {
|
|
||||||
h := sha1.New
|
|
||||||
h = newConstantTimeHash(h)
|
|
||||||
return hmac.New(h, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and
|
|
||||||
// is currently only used in disabled-by-default cipher suites.
|
|
||||||
func macSHA256(key []byte) hash.Hash {
|
|
||||||
return hmac.New(sha256.New, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
type aead interface {
|
|
||||||
cipher.AEAD
|
|
||||||
|
|
||||||
// explicitNonceLen returns the number of bytes of explicit nonce
|
|
||||||
// included in each record. This is eight for older AEADs and
|
|
||||||
// zero for modern ones.
|
|
||||||
explicitNonceLen() int
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
aeadNonceLength = 12
|
|
||||||
noncePrefixLength = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to
|
|
||||||
// each call.
|
|
||||||
type prefixNonceAEAD struct {
|
|
||||||
// nonce contains the fixed part of the nonce in the first four bytes.
|
|
||||||
nonce [aeadNonceLength]byte
|
|
||||||
aead cipher.AEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength }
|
|
||||||
func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() }
|
|
||||||
func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() }
|
|
||||||
|
|
||||||
func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
|
|
||||||
copy(f.nonce[4:], nonce)
|
|
||||||
return f.aead.Seal(out, f.nonce[:], plaintext, additionalData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
|
|
||||||
copy(f.nonce[4:], nonce)
|
|
||||||
return f.aead.Open(out, f.nonce[:], ciphertext, additionalData)
|
|
||||||
}
|
|
||||||
|
|
||||||
// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
|
|
||||||
// before each call.
|
|
||||||
type xorNonceAEAD struct {
|
|
||||||
nonceMask [aeadNonceLength]byte
|
|
||||||
aead cipher.AEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
|
|
||||||
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
|
|
||||||
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
|
|
||||||
|
|
||||||
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
|
|
||||||
for i, b := range nonce {
|
|
||||||
f.nonceMask[4+i] ^= b
|
|
||||||
}
|
|
||||||
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
|
|
||||||
for i, b := range nonce {
|
|
||||||
f.nonceMask[4+i] ^= b
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
|
|
||||||
for i, b := range nonce {
|
|
||||||
f.nonceMask[4+i] ^= b
|
|
||||||
}
|
|
||||||
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
|
|
||||||
for i, b := range nonce {
|
|
||||||
f.nonceMask[4+i] ^= b
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func aeadAESGCM(key, noncePrefix []byte) aead {
|
|
||||||
if len(noncePrefix) != noncePrefixLength {
|
|
||||||
panic("tls: internal error: wrong nonce length")
|
|
||||||
}
|
|
||||||
aes, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
var aead cipher.AEAD
|
|
||||||
aead, err = cipher.NewGCM(aes)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ret := &prefixNonceAEAD{aead: aead}
|
|
||||||
copy(ret.nonce[:], noncePrefix)
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func aeadAESGCMTLS13(key, nonceMask []byte) aead {
|
|
||||||
if len(nonceMask) != aeadNonceLength {
|
|
||||||
panic("tls: internal error: wrong nonce length")
|
|
||||||
}
|
|
||||||
aes, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
aead, err := cipher.NewGCM(aes)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ret := &xorNonceAEAD{aead: aead}
|
|
||||||
copy(ret.nonceMask[:], nonceMask)
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func aeadChaCha20Poly1305(key, nonceMask []byte) aead {
|
|
||||||
if len(nonceMask) != aeadNonceLength {
|
|
||||||
panic("tls: internal error: wrong nonce length")
|
|
||||||
}
|
|
||||||
aead, err := chacha20poly1305.New(key)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ret := &xorNonceAEAD{aead: aead}
|
|
||||||
copy(ret.nonceMask[:], nonceMask)
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
type constantTimeHash interface {
|
|
||||||
hash.Hash
|
|
||||||
ConstantTimeSum(b []byte) []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces
|
|
||||||
// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC.
|
|
||||||
type cthWrapper struct {
|
|
||||||
h constantTimeHash
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cthWrapper) Size() int { return c.h.Size() }
|
|
||||||
func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() }
|
|
||||||
func (c *cthWrapper) Reset() { c.h.Reset() }
|
|
||||||
func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) }
|
|
||||||
func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) }
|
|
||||||
|
|
||||||
func newConstantTimeHash(h func() hash.Hash) func() hash.Hash {
|
|
||||||
return func() hash.Hash {
|
|
||||||
return &cthWrapper{h().(constantTimeHash)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3.
|
|
||||||
func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte {
|
|
||||||
h.Reset()
|
|
||||||
h.Write(seq)
|
|
||||||
h.Write(header)
|
|
||||||
h.Write(data)
|
|
||||||
res := h.Sum(out)
|
|
||||||
if extra != nil {
|
|
||||||
h.Write(extra)
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func rsaKA(version uint16) keyAgreement {
|
|
||||||
return rsaKeyAgreement{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ecdheECDSAKA(version uint16) keyAgreement {
|
|
||||||
return &ecdheKeyAgreement{
|
|
||||||
isRSA: false,
|
|
||||||
version: version,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ecdheRSAKA(version uint16) keyAgreement {
|
|
||||||
return &ecdheKeyAgreement{
|
|
||||||
isRSA: true,
|
|
||||||
version: version,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// mutualCipherSuite returns a cipherSuite given a list of supported
|
|
||||||
// ciphersuites and the id requested by the peer.
|
|
||||||
func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
|
|
||||||
for _, id := range have {
|
|
||||||
if id == want {
|
|
||||||
return cipherSuiteByID(id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func cipherSuiteByID(id uint16) *cipherSuite {
|
|
||||||
for _, cipherSuite := range cipherSuites {
|
|
||||||
if cipherSuite.id == id {
|
|
||||||
return cipherSuite
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 {
|
|
||||||
for _, id := range have {
|
|
||||||
if id == want {
|
|
||||||
return cipherSuiteTLS13ByID(id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 {
|
|
||||||
for _, cipherSuite := range cipherSuitesTLS13 {
|
|
||||||
if cipherSuite.id == id {
|
|
||||||
return cipherSuite
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// A list of cipher suite IDs that are, or have been, implemented by this
|
|
||||||
// package.
|
|
||||||
//
|
|
||||||
// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml
|
|
||||||
const (
|
|
||||||
// TLS 1.0 - 1.2 cipher suites.
|
|
||||||
TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005
|
|
||||||
TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a
|
|
||||||
TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f
|
|
||||||
TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035
|
|
||||||
TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c
|
|
||||||
TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c
|
|
||||||
TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d
|
|
||||||
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a
|
|
||||||
TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011
|
|
||||||
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b
|
|
||||||
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030
|
|
||||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c
|
|
||||||
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8
|
|
||||||
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9
|
|
||||||
|
|
||||||
// TLS 1.3 cipher suites.
|
|
||||||
TLS_AES_128_GCM_SHA256 uint16 = 0x1301
|
|
||||||
TLS_AES_256_GCM_SHA384 uint16 = 0x1302
|
|
||||||
TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303
|
|
||||||
|
|
||||||
// TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator
|
|
||||||
// that the client is doing version fallback. See RFC 7507.
|
|
||||||
TLS_FALLBACK_SCSV uint16 = 0x5600
|
|
||||||
|
|
||||||
// Legacy names for the corresponding cipher suites with the correct _SHA256
|
|
||||||
// suffix, retained for backward compatibility.
|
|
||||||
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
|
|
||||||
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
|
|
||||||
)
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,782 +0,0 @@
|
||||||
// Copyright 2018 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdh"
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"hash"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/cryptobyte"
|
|
||||||
)
|
|
||||||
|
|
||||||
type clientHandshakeStateTLS13 struct {
|
|
||||||
c *Conn
|
|
||||||
ctx context.Context
|
|
||||||
serverHello *serverHelloMsg
|
|
||||||
hello *clientHelloMsg
|
|
||||||
ecdheKey *ecdh.PrivateKey
|
|
||||||
|
|
||||||
session *clientSessionState
|
|
||||||
earlySecret []byte
|
|
||||||
binderKey []byte
|
|
||||||
|
|
||||||
certReq *certificateRequestMsgTLS13
|
|
||||||
usingPSK bool
|
|
||||||
sentDummyCCS bool
|
|
||||||
suite *cipherSuiteTLS13
|
|
||||||
transcript hash.Hash
|
|
||||||
masterSecret []byte
|
|
||||||
trafficSecret []byte // client_application_traffic_secret_0
|
|
||||||
}
|
|
||||||
|
|
||||||
// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheKey, and,
|
|
||||||
// optionally, hs.session, hs.earlySecret and hs.binderKey to be set.
|
|
||||||
func (hs *clientHandshakeStateTLS13) handshake() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if needFIPS() {
|
|
||||||
return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
// The server must not select TLS 1.3 in a renegotiation. See RFC 8446,
|
|
||||||
// sections 4.1.2 and 4.1.3.
|
|
||||||
if c.handshakes > 0 {
|
|
||||||
c.sendAlert(alertProtocolVersion)
|
|
||||||
return errors.New("tls: server selected TLS 1.3 in a renegotiation")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Consistency check on the presence of a keyShare and its parameters.
|
|
||||||
if hs.ecdheKey == nil || len(hs.hello.keyShares) != 1 {
|
|
||||||
return c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := hs.checkServerHelloOrHRR(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.transcript = hs.suite.hash.New()
|
|
||||||
|
|
||||||
if err := transcriptMsg(hs.hello, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
|
|
||||||
if err := hs.sendDummyChangeCipherSpec(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.processHelloRetryRequest(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.buffering = true
|
|
||||||
if err := hs.processServerHello(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.sendDummyChangeCipherSpec(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.establishHandshakeKeys(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.readServerParameters(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.readServerCertificate(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.readServerFinished(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.sendClientCertificate(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.sendClientFinished(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := c.flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.isHandshakeComplete.Store(true)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkServerHelloOrHRR does validity checks that apply to both ServerHello and
|
|
||||||
// HelloRetryRequest messages. It sets hs.suite.
|
|
||||||
func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if hs.serverHello.supportedVersion == 0 {
|
|
||||||
c.sendAlert(alertMissingExtension)
|
|
||||||
return errors.New("tls: server selected TLS 1.3 using the legacy version field")
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.serverHello.supportedVersion != VersionTLS13 {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server selected an invalid version after a HelloRetryRequest")
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.serverHello.vers != VersionTLS12 {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server sent an incorrect legacy version")
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.serverHello.ocspStapling ||
|
|
||||||
hs.serverHello.ticketSupported ||
|
|
||||||
hs.serverHello.secureRenegotiationSupported ||
|
|
||||||
len(hs.serverHello.secureRenegotiation) != 0 ||
|
|
||||||
len(hs.serverHello.alpnProtocol) != 0 ||
|
|
||||||
len(hs.serverHello.scts) != 0 {
|
|
||||||
c.sendAlert(alertUnsupportedExtension)
|
|
||||||
return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server did not echo the legacy session ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.serverHello.compressionMethod != compressionNone {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server selected unsupported compression format")
|
|
||||||
}
|
|
||||||
|
|
||||||
selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite)
|
|
||||||
if hs.suite != nil && selectedSuite != hs.suite {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server changed cipher suite after a HelloRetryRequest")
|
|
||||||
}
|
|
||||||
if selectedSuite == nil {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server chose an unconfigured cipher suite")
|
|
||||||
}
|
|
||||||
hs.suite = selectedSuite
|
|
||||||
c.cipherSuite = hs.suite.id
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
|
|
||||||
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
|
|
||||||
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
|
|
||||||
if hs.c.quic != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if hs.sentDummyCCS {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
hs.sentDummyCCS = true
|
|
||||||
|
|
||||||
return hs.c.writeChangeCipherRecord()
|
|
||||||
}
|
|
||||||
|
|
||||||
// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and
|
|
||||||
// resends hs.hello, and reads the new ServerHello into hs.serverHello.
|
|
||||||
func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
// The first ClientHello gets double-hashed into the transcript upon a
|
|
||||||
// HelloRetryRequest. (The idea is that the server might offload transcript
|
|
||||||
// storage to the client in the cookie.) See RFC 8446, Section 4.4.1.
|
|
||||||
chHash := hs.transcript.Sum(nil)
|
|
||||||
hs.transcript.Reset()
|
|
||||||
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
|
|
||||||
hs.transcript.Write(chHash)
|
|
||||||
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// The only HelloRetryRequest extensions we support are key_share and
|
|
||||||
// cookie, and clients must abort the handshake if the HRR would not result
|
|
||||||
// in any change in the ClientHello.
|
|
||||||
if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server sent an unnecessary HelloRetryRequest message")
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.serverHello.cookie != nil {
|
|
||||||
hs.hello.cookie = hs.serverHello.cookie
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.serverHello.serverShare.group != 0 {
|
|
||||||
c.sendAlert(alertDecodeError)
|
|
||||||
return errors.New("tls: received malformed key_share extension")
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the server sent a key_share extension selecting a group, ensure it's
|
|
||||||
// a group we advertised but did not send a key share for, and send a key
|
|
||||||
// share for it this time.
|
|
||||||
if curveID := hs.serverHello.selectedGroup; curveID != 0 {
|
|
||||||
curveOK := false
|
|
||||||
for _, id := range hs.hello.supportedCurves {
|
|
||||||
if id == curveID {
|
|
||||||
curveOK = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !curveOK {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server selected unsupported group")
|
|
||||||
}
|
|
||||||
if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); sentID == curveID {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
|
|
||||||
}
|
|
||||||
if _, ok := curveForCurveID(curveID); !ok {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return errors.New("tls: CurvePreferences includes unsupported curve")
|
|
||||||
}
|
|
||||||
key, err := generateECDHEKey(c.config.rand(), curveID)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
hs.ecdheKey = key
|
|
||||||
hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.hello.raw = nil
|
|
||||||
if len(hs.hello.pskIdentities) > 0 {
|
|
||||||
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
|
|
||||||
if pskSuite == nil {
|
|
||||||
return c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
if pskSuite.hash == hs.suite.hash {
|
|
||||||
// Update binders and obfuscated_ticket_age.
|
|
||||||
ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond)
|
|
||||||
hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd
|
|
||||||
|
|
||||||
transcript := hs.suite.hash.New()
|
|
||||||
transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
|
|
||||||
transcript.Write(chHash)
|
|
||||||
if err := transcriptMsg(hs.serverHello, transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
helloBytes, err := hs.hello.marshalWithoutBinders()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
transcript.Write(helloBytes)
|
|
||||||
pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)}
|
|
||||||
if err := hs.hello.updateBinders(pskBinders); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Server selected a cipher suite incompatible with the PSK.
|
|
||||||
hs.hello.pskIdentities = nil
|
|
||||||
hs.hello.pskBinders = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.hello.earlyData {
|
|
||||||
hs.hello.earlyData = false
|
|
||||||
c.quicRejectedEarlyData()
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// serverHelloMsg is not included in the transcript
|
|
||||||
msg, err := c.readHandshake(nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
serverHello, ok := msg.(*serverHelloMsg)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(serverHello, msg)
|
|
||||||
}
|
|
||||||
hs.serverHello = serverHello
|
|
||||||
|
|
||||||
if err := hs.checkServerHelloOrHRR(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *clientHandshakeStateTLS13) processServerHello() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return errors.New("tls: server sent two HelloRetryRequest messages")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(hs.serverHello.cookie) != 0 {
|
|
||||||
c.sendAlert(alertUnsupportedExtension)
|
|
||||||
return errors.New("tls: server sent a cookie in a normal ServerHello")
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.serverHello.selectedGroup != 0 {
|
|
||||||
c.sendAlert(alertDecodeError)
|
|
||||||
return errors.New("tls: malformed key_share extension")
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.serverHello.serverShare.group == 0 {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server did not send a key share")
|
|
||||||
}
|
|
||||||
if sentID, _ := curveIDForCurve(hs.ecdheKey.Curve()); hs.serverHello.serverShare.group != sentID {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server selected unsupported group")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hs.serverHello.selectedIdentityPresent {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server selected an invalid PSK")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(hs.hello.pskIdentities) != 1 || hs.session == nil {
|
|
||||||
return c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
|
|
||||||
if pskSuite == nil {
|
|
||||||
return c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
if pskSuite.hash != hs.suite.hash {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: server selected an invalid PSK and cipher suite pair")
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.usingPSK = true
|
|
||||||
c.didResume = true
|
|
||||||
c.peerCertificates = hs.session.serverCertificates
|
|
||||||
c.verifiedChains = hs.session.verifiedChains
|
|
||||||
c.ocspResponse = hs.session.ocspResponse
|
|
||||||
c.scts = hs.session.scts
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
peerKey, err := hs.ecdheKey.Curve().NewPublicKey(hs.serverHello.serverShare.data)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: invalid server key share")
|
|
||||||
}
|
|
||||||
sharedKey, err := hs.ecdheKey.ECDH(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: invalid server key share")
|
|
||||||
}
|
|
||||||
|
|
||||||
earlySecret := hs.earlySecret
|
|
||||||
if !hs.usingPSK {
|
|
||||||
earlySecret = hs.suite.extract(nil, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
handshakeSecret := hs.suite.extract(sharedKey,
|
|
||||||
hs.suite.deriveSecret(earlySecret, "derived", nil))
|
|
||||||
|
|
||||||
clientSecret := hs.suite.deriveSecret(handshakeSecret,
|
|
||||||
clientHandshakeTrafficLabel, hs.transcript)
|
|
||||||
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
|
|
||||||
serverSecret := hs.suite.deriveSecret(handshakeSecret,
|
|
||||||
serverHandshakeTrafficLabel, hs.transcript)
|
|
||||||
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
|
|
||||||
|
|
||||||
if c.quic != nil {
|
|
||||||
if c.hand.Len() != 0 {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
}
|
|
||||||
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
|
|
||||||
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.masterSecret = hs.suite.extract(nil,
|
|
||||||
hs.suite.deriveSecret(handshakeSecret, "derived", nil))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *clientHandshakeStateTLS13) readServerParameters() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
msg, err := c.readHandshake(hs.transcript)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(encryptedExtensions, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol, c.quic != nil); err != nil {
|
|
||||||
// RFC 8446 specifies that no_application_protocol is sent by servers, but
|
|
||||||
// does not specify how clients handle the selection of an incompatible protocol.
|
|
||||||
// RFC 9001 Section 8.1 specifies that QUIC clients send no_application_protocol
|
|
||||||
// in this case. Always sending no_application_protocol seems reasonable.
|
|
||||||
c.sendAlert(alertNoApplicationProtocol)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.clientProtocol = encryptedExtensions.alpnProtocol
|
|
||||||
|
|
||||||
if c.quic != nil {
|
|
||||||
if encryptedExtensions.quicTransportParameters == nil {
|
|
||||||
// RFC 9001 Section 8.2.
|
|
||||||
c.sendAlert(alertMissingExtension)
|
|
||||||
return errors.New("tls: server did not send a quic_transport_parameters extension")
|
|
||||||
}
|
|
||||||
c.quicSetTransportParameters(encryptedExtensions.quicTransportParameters)
|
|
||||||
} else {
|
|
||||||
if encryptedExtensions.quicTransportParameters != nil {
|
|
||||||
c.sendAlert(alertUnsupportedExtension)
|
|
||||||
return errors.New("tls: server sent an unexpected quic_transport_parameters extension")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.hello.earlyData && !encryptedExtensions.earlyData {
|
|
||||||
c.quicRejectedEarlyData()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
// Either a PSK or a certificate is always used, but not both.
|
|
||||||
// See RFC 8446, Section 4.1.1.
|
|
||||||
if hs.usingPSK {
|
|
||||||
// Make sure the connection is still being verified whether or not this
|
|
||||||
// is a resumption. Resumptions currently don't reverify certificates so
|
|
||||||
// they don't call verifyServerCertificate. See Issue 31641.
|
|
||||||
if c.config.VerifyConnection != nil {
|
|
||||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
|
||||||
c.sendAlert(alertBadCertificate)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err := c.readHandshake(hs.transcript)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
certReq, ok := msg.(*certificateRequestMsgTLS13)
|
|
||||||
if ok {
|
|
||||||
hs.certReq = certReq
|
|
||||||
|
|
||||||
msg, err = c.readHandshake(hs.transcript)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
certMsg, ok := msg.(*certificateMsgTLS13)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(certMsg, msg)
|
|
||||||
}
|
|
||||||
if len(certMsg.certificate.Certificate) == 0 {
|
|
||||||
c.sendAlert(alertDecodeError)
|
|
||||||
return errors.New("tls: received empty certificates message")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.scts = certMsg.certificate.SignedCertificateTimestamps
|
|
||||||
c.ocspResponse = certMsg.certificate.OCSPStaple
|
|
||||||
|
|
||||||
if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// certificateVerifyMsg is included in the transcript, but not until
|
|
||||||
// after we verify the handshake signature, since the state before
|
|
||||||
// this message was sent is used.
|
|
||||||
msg, err = c.readHandshake(nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
certVerify, ok := msg.(*certificateVerifyMsg)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(certVerify, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 8446, Section 4.4.3.
|
|
||||||
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: certificate used with invalid signature algorithm")
|
|
||||||
}
|
|
||||||
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
|
|
||||||
if err != nil {
|
|
||||||
return c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: certificate used with invalid signature algorithm")
|
|
||||||
}
|
|
||||||
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
|
|
||||||
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
|
|
||||||
sigHash, signed, certVerify.signature); err != nil {
|
|
||||||
c.sendAlert(alertDecryptError)
|
|
||||||
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := transcriptMsg(certVerify, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *clientHandshakeStateTLS13) readServerFinished() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
// finishedMsg is included in the transcript, but not until after we
|
|
||||||
// check the client version, since the state before this message was
|
|
||||||
// sent is used during verification.
|
|
||||||
msg, err := c.readHandshake(nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
finished, ok := msg.(*finishedMsg)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(finished, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
|
|
||||||
if !hmac.Equal(expectedMAC, finished.verifyData) {
|
|
||||||
c.sendAlert(alertDecryptError)
|
|
||||||
return errors.New("tls: invalid server finished hash")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := transcriptMsg(finished, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Derive secrets that take context through the server Finished.
|
|
||||||
|
|
||||||
hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
|
|
||||||
clientApplicationTrafficLabel, hs.transcript)
|
|
||||||
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
|
|
||||||
serverApplicationTrafficLabel, hs.transcript)
|
|
||||||
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
|
|
||||||
|
|
||||||
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if hs.certReq == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := c.getClientCertificate(toCertificateRequestInfo(&certificateRequestInfo{
|
|
||||||
AcceptableCAs: hs.certReq.certificateAuthorities,
|
|
||||||
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
|
|
||||||
Version: c.vers,
|
|
||||||
ctx: hs.ctx,
|
|
||||||
}))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
certMsg := new(certificateMsgTLS13)
|
|
||||||
|
|
||||||
certMsg.certificate = *cert
|
|
||||||
certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0
|
|
||||||
certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0
|
|
||||||
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we sent an empty certificate message, skip the CertificateVerify.
|
|
||||||
if len(cert.Certificate) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
certVerifyMsg := new(certificateVerifyMsg)
|
|
||||||
certVerifyMsg.hasSignatureAlgorithm = true
|
|
||||||
|
|
||||||
certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms)
|
|
||||||
if err != nil {
|
|
||||||
// getClientCertificate returned a certificate incompatible with the
|
|
||||||
// CertificateRequestInfo supported signature algorithms.
|
|
||||||
c.sendAlert(alertHandshakeFailure)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm)
|
|
||||||
if err != nil {
|
|
||||||
return c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
|
|
||||||
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
|
|
||||||
signOpts := crypto.SignerOpts(sigHash)
|
|
||||||
if sigType == signatureRSAPSS {
|
|
||||||
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
|
|
||||||
}
|
|
||||||
sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return errors.New("tls: failed to sign handshake: " + err.Error())
|
|
||||||
}
|
|
||||||
certVerifyMsg.signature = sig
|
|
||||||
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
finished := &finishedMsg{
|
|
||||||
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
|
|
||||||
|
|
||||||
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
|
|
||||||
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
|
|
||||||
resumptionLabel, hs.transcript)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.quic != nil {
|
|
||||||
if c.hand.Len() != 0 {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
}
|
|
||||||
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, hs.trafficSecret)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
|
|
||||||
if !c.isClient {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return errors.New("tls: received new session ticket from a client")
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 8446, Section 4.6.1.
|
|
||||||
if msg.lifetime == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
lifetime := time.Duration(msg.lifetime) * time.Second
|
|
||||||
if lifetime > maxSessionTicketLifetime {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: received a session ticket with invalid lifetime")
|
|
||||||
}
|
|
||||||
|
|
||||||
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
|
|
||||||
if cipherSuite == nil || c.resumptionSecret == nil {
|
|
||||||
return c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need to save the max_early_data_size that the server sent us, in order
|
|
||||||
// to decide if we're going to try 0-RTT with this ticket.
|
|
||||||
// However, at the same time, the qtls.ClientSessionTicket needs to be equal to
|
|
||||||
// the tls.ClientSessionTicket, so we can't just add a new field to the struct.
|
|
||||||
// We therefore abuse the nonce field (which is a byte slice)
|
|
||||||
nonceWithEarlyData := make([]byte, len(msg.nonce)+4)
|
|
||||||
binary.BigEndian.PutUint32(nonceWithEarlyData, msg.maxEarlyData)
|
|
||||||
copy(nonceWithEarlyData[4:], msg.nonce)
|
|
||||||
|
|
||||||
var appData []byte
|
|
||||||
if c.extraConfig != nil && c.extraConfig.GetAppDataForSessionState != nil {
|
|
||||||
appData = c.extraConfig.GetAppDataForSessionState()
|
|
||||||
}
|
|
||||||
var b cryptobyte.Builder
|
|
||||||
b.AddUint16(clientSessionStateVersion) // revision
|
|
||||||
b.AddUint32(msg.maxEarlyData)
|
|
||||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
||||||
b.AddBytes(appData)
|
|
||||||
})
|
|
||||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
||||||
b.AddBytes(msg.nonce)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Save the resumption_master_secret and nonce instead of deriving the PSK
|
|
||||||
// to do the least amount of work on NewSessionTicket messages before we
|
|
||||||
// know if the ticket will be used. Forward secrecy of resumed connections
|
|
||||||
// is guaranteed by the requirement for pskModeDHE.
|
|
||||||
session := &clientSessionState{
|
|
||||||
sessionTicket: msg.label,
|
|
||||||
vers: c.vers,
|
|
||||||
cipherSuite: c.cipherSuite,
|
|
||||||
masterSecret: c.resumptionSecret,
|
|
||||||
serverCertificates: c.peerCertificates,
|
|
||||||
verifiedChains: c.verifiedChains,
|
|
||||||
receivedAt: c.config.time(),
|
|
||||||
nonce: b.BytesOrPanic(),
|
|
||||||
useBy: c.config.time().Add(lifetime),
|
|
||||||
ageAdd: msg.ageAdd,
|
|
||||||
ocspResponse: c.ocspResponse,
|
|
||||||
scts: c.scts,
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheKey := c.clientSessionCacheKey()
|
|
||||||
if cacheKey != "" {
|
|
||||||
c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,899 +0,0 @@
|
||||||
// Copyright 2009 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/subtle"
|
|
||||||
"crypto/x509"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"hash"
|
|
||||||
"io"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// serverHandshakeState contains details of a server handshake in progress.
|
|
||||||
// It's discarded once the handshake has completed.
|
|
||||||
type serverHandshakeState struct {
|
|
||||||
c *Conn
|
|
||||||
ctx context.Context
|
|
||||||
clientHello *clientHelloMsg
|
|
||||||
hello *serverHelloMsg
|
|
||||||
suite *cipherSuite
|
|
||||||
ecdheOk bool
|
|
||||||
ecSignOk bool
|
|
||||||
rsaDecryptOk bool
|
|
||||||
rsaSignOk bool
|
|
||||||
sessionState *sessionState
|
|
||||||
finishedHash finishedHash
|
|
||||||
masterSecret []byte
|
|
||||||
cert *Certificate
|
|
||||||
}
|
|
||||||
|
|
||||||
// serverHandshake performs a TLS handshake as a server.
|
|
||||||
func (c *Conn) serverHandshake(ctx context.Context) error {
|
|
||||||
clientHello, err := c.readClientHello(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.vers == VersionTLS13 {
|
|
||||||
hs := serverHandshakeStateTLS13{
|
|
||||||
c: c,
|
|
||||||
ctx: ctx,
|
|
||||||
clientHello: clientHello,
|
|
||||||
}
|
|
||||||
return hs.handshake()
|
|
||||||
}
|
|
||||||
|
|
||||||
hs := serverHandshakeState{
|
|
||||||
c: c,
|
|
||||||
ctx: ctx,
|
|
||||||
clientHello: clientHello,
|
|
||||||
}
|
|
||||||
return hs.handshake()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeState) handshake() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if err := hs.processClientHello(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// For an overview of TLS handshaking, see RFC 5246, Section 7.3.
|
|
||||||
c.buffering = true
|
|
||||||
if hs.checkForResumption() {
|
|
||||||
// The client has included a session ticket and so we do an abbreviated handshake.
|
|
||||||
c.didResume = true
|
|
||||||
if err := hs.doResumeHandshake(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.establishKeys(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.sendSessionTicket(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.sendFinished(c.serverFinished[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := c.flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.clientFinishedIsFirst = false
|
|
||||||
if err := hs.readFinished(nil); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// The client didn't include a session ticket, or it wasn't
|
|
||||||
// valid so we do a full handshake.
|
|
||||||
if err := hs.pickCipherSuite(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.doFullHandshake(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.establishKeys(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.readFinished(c.clientFinished[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.clientFinishedIsFirst = true
|
|
||||||
c.buffering = true
|
|
||||||
if err := hs.sendSessionTicket(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.sendFinished(nil); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := c.flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
|
|
||||||
c.isHandshakeComplete.Store(true)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// readClientHello reads a ClientHello message and selects the protocol version.
|
|
||||||
func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
|
|
||||||
// clientHelloMsg is included in the transcript, but we haven't initialized
|
|
||||||
// it yet. The respective handshake functions will record it themselves.
|
|
||||||
msg, err := c.readHandshake(nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
clientHello, ok := msg.(*clientHelloMsg)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return nil, unexpectedMessageError(clientHello, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
var configForClient *config
|
|
||||||
originalConfig := c.config
|
|
||||||
if c.config.GetConfigForClient != nil {
|
|
||||||
chi := newClientHelloInfo(ctx, c, clientHello)
|
|
||||||
if cfc, err := c.config.GetConfigForClient(chi); err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return nil, err
|
|
||||||
} else if cfc != nil {
|
|
||||||
configForClient = fromConfig(cfc)
|
|
||||||
c.config = configForClient
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.ticketKeys = originalConfig.ticketKeys(configForClient)
|
|
||||||
|
|
||||||
clientVersions := clientHello.supportedVersions
|
|
||||||
if len(clientHello.supportedVersions) == 0 {
|
|
||||||
clientVersions = supportedVersionsFromMax(clientHello.vers)
|
|
||||||
}
|
|
||||||
c.vers, ok = c.config.mutualVersion(roleServer, clientVersions)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertProtocolVersion)
|
|
||||||
return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions)
|
|
||||||
}
|
|
||||||
c.haveVers = true
|
|
||||||
c.in.version = c.vers
|
|
||||||
c.out.version = c.vers
|
|
||||||
|
|
||||||
return clientHello, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeState) processClientHello() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
hs.hello = new(serverHelloMsg)
|
|
||||||
hs.hello.vers = c.vers
|
|
||||||
|
|
||||||
foundCompression := false
|
|
||||||
// We only support null compression, so check that the client offered it.
|
|
||||||
for _, compression := range hs.clientHello.compressionMethods {
|
|
||||||
if compression == compressionNone {
|
|
||||||
foundCompression = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !foundCompression {
|
|
||||||
c.sendAlert(alertHandshakeFailure)
|
|
||||||
return errors.New("tls: client does not support uncompressed connections")
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.hello.random = make([]byte, 32)
|
|
||||||
serverRandom := hs.hello.random
|
|
||||||
// Downgrade protection canaries. See RFC 8446, Section 4.1.3.
|
|
||||||
maxVers := c.config.maxSupportedVersion(roleServer)
|
|
||||||
if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary {
|
|
||||||
if c.vers == VersionTLS12 {
|
|
||||||
copy(serverRandom[24:], downgradeCanaryTLS12)
|
|
||||||
} else {
|
|
||||||
copy(serverRandom[24:], downgradeCanaryTLS11)
|
|
||||||
}
|
|
||||||
serverRandom = serverRandom[:24]
|
|
||||||
}
|
|
||||||
_, err := io.ReadFull(c.config.rand(), serverRandom)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(hs.clientHello.secureRenegotiation) != 0 {
|
|
||||||
c.sendAlert(alertHandshakeFailure)
|
|
||||||
return errors.New("tls: initial handshake had non-empty renegotiation extension")
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
|
|
||||||
hs.hello.compressionMethod = compressionNone
|
|
||||||
if len(hs.clientHello.serverName) > 0 {
|
|
||||||
c.serverName = hs.clientHello.serverName
|
|
||||||
}
|
|
||||||
|
|
||||||
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, false)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertNoApplicationProtocol)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
hs.hello.alpnProtocol = selectedProto
|
|
||||||
c.clientProtocol = selectedProto
|
|
||||||
|
|
||||||
hs.cert, err = c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello))
|
|
||||||
if err != nil {
|
|
||||||
if err == errNoCertificates {
|
|
||||||
c.sendAlert(alertUnrecognizedName)
|
|
||||||
} else {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if hs.clientHello.scts {
|
|
||||||
hs.hello.scts = hs.cert.SignedCertificateTimestamps
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints)
|
|
||||||
|
|
||||||
if hs.ecdheOk && len(hs.clientHello.supportedPoints) > 0 {
|
|
||||||
// Although omitting the ec_point_formats extension is permitted, some
|
|
||||||
// old OpenSSL version will refuse to handshake if not present.
|
|
||||||
//
|
|
||||||
// Per RFC 4492, section 5.1.2, implementations MUST support the
|
|
||||||
// uncompressed point format. See golang.org/issue/31943.
|
|
||||||
hs.hello.supportedPoints = []uint8{pointFormatUncompressed}
|
|
||||||
}
|
|
||||||
|
|
||||||
if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok {
|
|
||||||
switch priv.Public().(type) {
|
|
||||||
case *ecdsa.PublicKey:
|
|
||||||
hs.ecSignOk = true
|
|
||||||
case ed25519.PublicKey:
|
|
||||||
hs.ecSignOk = true
|
|
||||||
case *rsa.PublicKey:
|
|
||||||
hs.rsaSignOk = true
|
|
||||||
default:
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok {
|
|
||||||
switch priv.Public().(type) {
|
|
||||||
case *rsa.PublicKey:
|
|
||||||
hs.rsaDecryptOk = true
|
|
||||||
default:
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// negotiateALPN picks a shared ALPN protocol that both sides support in server
|
|
||||||
// preference order. If ALPN is not configured or the peer doesn't support it,
|
|
||||||
// it returns "" and no error.
|
|
||||||
func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, error) {
|
|
||||||
if len(serverProtos) == 0 || len(clientProtos) == 0 {
|
|
||||||
if quic && len(serverProtos) != 0 {
|
|
||||||
// RFC 9001, Section 8.1
|
|
||||||
return "", fmt.Errorf("tls: client did not request an application protocol")
|
|
||||||
}
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
var http11fallback bool
|
|
||||||
for _, s := range serverProtos {
|
|
||||||
for _, c := range clientProtos {
|
|
||||||
if s == c {
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
if s == "h2" && c == "http/1.1" {
|
|
||||||
http11fallback = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// As a special case, let http/1.1 clients connect to h2 servers as if they
|
|
||||||
// didn't support ALPN. We used not to enforce protocol overlap, so over
|
|
||||||
// time a number of HTTP servers were configured with only "h2", but
|
|
||||||
// expected to accept connections from "http/1.1" clients. See Issue 46310.
|
|
||||||
if http11fallback {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos)
|
|
||||||
}
|
|
||||||
|
|
||||||
// supportsECDHE returns whether ECDHE key exchanges can be used with this
|
|
||||||
// pre-TLS 1.3 client.
|
|
||||||
func supportsECDHE(c *config, supportedCurves []CurveID, supportedPoints []uint8) bool {
|
|
||||||
supportsCurve := false
|
|
||||||
for _, curve := range supportedCurves {
|
|
||||||
if c.supportsCurve(curve) {
|
|
||||||
supportsCurve = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
supportsPointFormat := false
|
|
||||||
for _, pointFormat := range supportedPoints {
|
|
||||||
if pointFormat == pointFormatUncompressed {
|
|
||||||
supportsPointFormat = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Per RFC 8422, Section 5.1.2, if the Supported Point Formats extension is
|
|
||||||
// missing, uncompressed points are supported. If supportedPoints is empty,
|
|
||||||
// the extension must be missing, as an empty extension body is rejected by
|
|
||||||
// the parser. See https://go.dev/issue/49126.
|
|
||||||
if len(supportedPoints) == 0 {
|
|
||||||
supportsPointFormat = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return supportsCurve && supportsPointFormat
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeState) pickCipherSuite() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
preferenceOrder := cipherSuitesPreferenceOrder
|
|
||||||
if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) {
|
|
||||||
preferenceOrder = cipherSuitesPreferenceOrderNoAES
|
|
||||||
}
|
|
||||||
|
|
||||||
configCipherSuites := c.config.cipherSuites()
|
|
||||||
preferenceList := make([]uint16, 0, len(configCipherSuites))
|
|
||||||
for _, suiteID := range preferenceOrder {
|
|
||||||
for _, id := range configCipherSuites {
|
|
||||||
if id == suiteID {
|
|
||||||
preferenceList = append(preferenceList, id)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk)
|
|
||||||
if hs.suite == nil {
|
|
||||||
c.sendAlert(alertHandshakeFailure)
|
|
||||||
return errors.New("tls: no cipher suite supported by both client and server")
|
|
||||||
}
|
|
||||||
c.cipherSuite = hs.suite.id
|
|
||||||
|
|
||||||
for _, id := range hs.clientHello.cipherSuites {
|
|
||||||
if id == TLS_FALLBACK_SCSV {
|
|
||||||
// The client is doing a fallback connection. See RFC 7507.
|
|
||||||
if hs.clientHello.vers < c.config.maxSupportedVersion(roleServer) {
|
|
||||||
c.sendAlert(alertInappropriateFallback)
|
|
||||||
return errors.New("tls: client using inappropriate protocol fallback")
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool {
|
|
||||||
if c.flags&suiteECDHE != 0 {
|
|
||||||
if !hs.ecdheOk {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if c.flags&suiteECSign != 0 {
|
|
||||||
if !hs.ecSignOk {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
} else if !hs.rsaSignOk {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
} else if !hs.rsaDecryptOk {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkForResumption reports whether we should perform resumption on this connection.
|
|
||||||
func (hs *serverHandshakeState) checkForResumption() bool {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if c.config.SessionTicketsDisabled {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket)
|
|
||||||
if plaintext == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
hs.sessionState = &sessionState{usedOldKey: usedOldKey}
|
|
||||||
ok := hs.sessionState.unmarshal(plaintext)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
createdAt := time.Unix(int64(hs.sessionState.createdAt), 0)
|
|
||||||
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Never resume a session for a different TLS version.
|
|
||||||
if c.vers != hs.sessionState.vers {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
cipherSuiteOk := false
|
|
||||||
// Check that the client is still offering the ciphersuite in the session.
|
|
||||||
for _, id := range hs.clientHello.cipherSuites {
|
|
||||||
if id == hs.sessionState.cipherSuite {
|
|
||||||
cipherSuiteOk = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !cipherSuiteOk {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that we also support the ciphersuite from the session.
|
|
||||||
hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite},
|
|
||||||
c.config.cipherSuites(), hs.cipherSuiteOk)
|
|
||||||
if hs.suite == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionHasClientCerts := len(hs.sessionState.certificates) != 0
|
|
||||||
needClientCerts := requiresClientCert(c.config.ClientAuth)
|
|
||||||
if needClientCerts && !sessionHasClientCerts {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeState) doResumeHandshake() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
hs.hello.cipherSuite = hs.suite.id
|
|
||||||
c.cipherSuite = hs.suite.id
|
|
||||||
// We echo the client's session ID in the ServerHello to let it know
|
|
||||||
// that we're doing a resumption.
|
|
||||||
hs.hello.sessionId = hs.clientHello.sessionId
|
|
||||||
hs.hello.ticketSupported = hs.sessionState.usedOldKey
|
|
||||||
hs.finishedHash = newFinishedHash(c.vers, hs.suite)
|
|
||||||
hs.finishedHash.discardHandshakeBuffer()
|
|
||||||
if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.processCertsFromClient(Certificate{
|
|
||||||
Certificate: hs.sessionState.certificates,
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.VerifyConnection != nil {
|
|
||||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
|
||||||
c.sendAlert(alertBadCertificate)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.masterSecret = hs.sessionState.masterSecret
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeState) doFullHandshake() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
|
|
||||||
hs.hello.ocspStapling = true
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
|
|
||||||
hs.hello.cipherSuite = hs.suite.id
|
|
||||||
|
|
||||||
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
|
|
||||||
if c.config.ClientAuth == NoClientCert {
|
|
||||||
// No need to keep a full record of the handshake if client
|
|
||||||
// certificates won't be used.
|
|
||||||
hs.finishedHash.discardHandshakeBuffer()
|
|
||||||
}
|
|
||||||
if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
certMsg := new(certificateMsg)
|
|
||||||
certMsg.certificates = hs.cert.Certificate
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.hello.ocspStapling {
|
|
||||||
certStatus := new(certificateStatusMsg)
|
|
||||||
certStatus.response = hs.cert.OCSPStaple
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
keyAgreement := hs.suite.ka(c.vers)
|
|
||||||
skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertHandshakeFailure)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if skx != nil {
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var certReq *certificateRequestMsg
|
|
||||||
if c.config.ClientAuth >= RequestClientCert {
|
|
||||||
// Request a client certificate
|
|
||||||
certReq = new(certificateRequestMsg)
|
|
||||||
certReq.certificateTypes = []byte{
|
|
||||||
byte(certTypeRSASign),
|
|
||||||
byte(certTypeECDSASign),
|
|
||||||
}
|
|
||||||
if c.vers >= VersionTLS12 {
|
|
||||||
certReq.hasSignatureAlgorithm = true
|
|
||||||
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
|
|
||||||
}
|
|
||||||
|
|
||||||
// An empty list of certificateAuthorities signals to
|
|
||||||
// the client that it may send any certificate in response
|
|
||||||
// to our request. When we know the CAs we trust, then
|
|
||||||
// we can send them down, so that the client can choose
|
|
||||||
// an appropriate certificate to give to us.
|
|
||||||
if c.config.ClientCAs != nil {
|
|
||||||
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
|
|
||||||
}
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
helloDone := new(serverHelloDoneMsg)
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := c.flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var pub crypto.PublicKey // public key for client auth, if any
|
|
||||||
|
|
||||||
msg, err := c.readHandshake(&hs.finishedHash)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we requested a client certificate, then the client must send a
|
|
||||||
// certificate message, even if it's empty.
|
|
||||||
if c.config.ClientAuth >= RequestClientCert {
|
|
||||||
certMsg, ok := msg.(*certificateMsg)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(certMsg, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.processCertsFromClient(Certificate{
|
|
||||||
Certificate: certMsg.certificates,
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if len(certMsg.certificates) != 0 {
|
|
||||||
pub = c.peerCertificates[0].PublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, err = c.readHandshake(&hs.finishedHash)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.config.VerifyConnection != nil {
|
|
||||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
|
||||||
c.sendAlert(alertBadCertificate)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get client key exchange
|
|
||||||
ckx, ok := msg.(*clientKeyExchangeMsg)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(ckx, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertHandshakeFailure)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
|
|
||||||
if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we received a client cert in response to our certificate request message,
|
|
||||||
// the client will send us a certificateVerifyMsg immediately after the
|
|
||||||
// clientKeyExchangeMsg. This message is a digest of all preceding
|
|
||||||
// handshake-layer messages that is signed using the private key corresponding
|
|
||||||
// to the client's certificate. This allows us to verify that the client is in
|
|
||||||
// possession of the private key of the certificate.
|
|
||||||
if len(c.peerCertificates) > 0 {
|
|
||||||
// certificateVerifyMsg is included in the transcript, but not until
|
|
||||||
// after we verify the handshake signature, since the state before
|
|
||||||
// this message was sent is used.
|
|
||||||
msg, err = c.readHandshake(nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
certVerify, ok := msg.(*certificateVerifyMsg)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(certVerify, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
var sigType uint8
|
|
||||||
var sigHash crypto.Hash
|
|
||||||
if c.vers >= VersionTLS12 {
|
|
||||||
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: client certificate used with invalid signature algorithm")
|
|
||||||
}
|
|
||||||
sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
|
|
||||||
if err != nil {
|
|
||||||
return c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash)
|
|
||||||
if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil {
|
|
||||||
c.sendAlert(alertDecryptError)
|
|
||||||
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.finishedHash.discardHandshakeBuffer()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeState) establishKeys() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
|
|
||||||
keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
|
|
||||||
|
|
||||||
var clientCipher, serverCipher any
|
|
||||||
var clientHash, serverHash hash.Hash
|
|
||||||
|
|
||||||
if hs.suite.aead == nil {
|
|
||||||
clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */)
|
|
||||||
clientHash = hs.suite.mac(clientMAC)
|
|
||||||
serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */)
|
|
||||||
serverHash = hs.suite.mac(serverMAC)
|
|
||||||
} else {
|
|
||||||
clientCipher = hs.suite.aead(clientKey, clientIV)
|
|
||||||
serverCipher = hs.suite.aead(serverKey, serverIV)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.in.prepareCipherSpec(c.vers, clientCipher, clientHash)
|
|
||||||
c.out.prepareCipherSpec(c.vers, serverCipher, serverHash)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeState) readFinished(out []byte) error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if err := c.readChangeCipherSpec(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// finishedMsg is included in the transcript, but not until after we
|
|
||||||
// check the client version, since the state before this message was
|
|
||||||
// sent is used during verification.
|
|
||||||
msg, err := c.readHandshake(nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
clientFinished, ok := msg.(*finishedMsg)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(clientFinished, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
verify := hs.finishedHash.clientSum(hs.masterSecret)
|
|
||||||
if len(verify) != len(clientFinished.verifyData) ||
|
|
||||||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
|
|
||||||
c.sendAlert(alertHandshakeFailure)
|
|
||||||
return errors.New("tls: client's Finished message is incorrect")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(out, verify)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeState) sendSessionTicket() error {
|
|
||||||
// ticketSupported is set in a resumption handshake if the
|
|
||||||
// ticket from the client was encrypted with an old session
|
|
||||||
// ticket key and thus a refreshed ticket should be sent.
|
|
||||||
if !hs.hello.ticketSupported {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c := hs.c
|
|
||||||
m := new(newSessionTicketMsg)
|
|
||||||
|
|
||||||
createdAt := uint64(c.config.time().Unix())
|
|
||||||
if hs.sessionState != nil {
|
|
||||||
// If this is re-wrapping an old key, then keep
|
|
||||||
// the original time it was created.
|
|
||||||
createdAt = hs.sessionState.createdAt
|
|
||||||
}
|
|
||||||
|
|
||||||
var certsFromClient [][]byte
|
|
||||||
for _, cert := range c.peerCertificates {
|
|
||||||
certsFromClient = append(certsFromClient, cert.Raw)
|
|
||||||
}
|
|
||||||
state := sessionState{
|
|
||||||
vers: c.vers,
|
|
||||||
cipherSuite: hs.suite.id,
|
|
||||||
createdAt: createdAt,
|
|
||||||
masterSecret: hs.masterSecret,
|
|
||||||
certificates: certsFromClient,
|
|
||||||
}
|
|
||||||
stateBytes, err := state.marshal()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.ticket, err = c.encryptTicket(stateBytes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeState) sendFinished(out []byte) error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if err := c.writeChangeCipherRecord(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
finished := new(finishedMsg)
|
|
||||||
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(out, finished.verifyData)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// processCertsFromClient takes a chain of client certificates either from a
|
|
||||||
// Certificates message or from a sessionState and verifies them. It returns
|
|
||||||
// the public key of the leaf certificate.
|
|
||||||
func (c *Conn) processCertsFromClient(certificate Certificate) error {
|
|
||||||
certificates := certificate.Certificate
|
|
||||||
certs := make([]*x509.Certificate, len(certificates))
|
|
||||||
var err error
|
|
||||||
for i, asn1Data := range certificates {
|
|
||||||
if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
|
|
||||||
c.sendAlert(alertBadCertificate)
|
|
||||||
return errors.New("tls: failed to parse client certificate: " + err.Error())
|
|
||||||
}
|
|
||||||
if certs[i].PublicKeyAlgorithm == x509.RSA && certs[i].PublicKey.(*rsa.PublicKey).N.BitLen() > maxRSAKeySize {
|
|
||||||
c.sendAlert(alertBadCertificate)
|
|
||||||
return fmt.Errorf("tls: client sent certificate containing RSA key larger than %d bits", maxRSAKeySize)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) {
|
|
||||||
c.sendAlert(alertBadCertificate)
|
|
||||||
return errors.New("tls: client didn't provide a certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
|
|
||||||
opts := x509.VerifyOptions{
|
|
||||||
Roots: c.config.ClientCAs,
|
|
||||||
CurrentTime: c.config.time(),
|
|
||||||
Intermediates: x509.NewCertPool(),
|
|
||||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, cert := range certs[1:] {
|
|
||||||
opts.Intermediates.AddCert(cert)
|
|
||||||
}
|
|
||||||
|
|
||||||
chains, err := certs[0].Verify(opts)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertBadCertificate)
|
|
||||||
return &CertificateVerificationError{UnverifiedCertificates: certs, Err: err}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.verifiedChains = chains
|
|
||||||
}
|
|
||||||
|
|
||||||
c.peerCertificates = certs
|
|
||||||
c.ocspResponse = certificate.OCSPStaple
|
|
||||||
c.scts = certificate.SignedCertificateTimestamps
|
|
||||||
|
|
||||||
if len(certs) > 0 {
|
|
||||||
switch certs[0].PublicKey.(type) {
|
|
||||||
case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
|
|
||||||
default:
|
|
||||||
c.sendAlert(alertUnsupportedCertificate)
|
|
||||||
return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.VerifyPeerCertificate != nil {
|
|
||||||
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
|
|
||||||
c.sendAlert(alertBadCertificate)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newClientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
|
|
||||||
supportedVersions := clientHello.supportedVersions
|
|
||||||
if len(clientHello.supportedVersions) == 0 {
|
|
||||||
supportedVersions = supportedVersionsFromMax(clientHello.vers)
|
|
||||||
}
|
|
||||||
|
|
||||||
return toClientHelloInfo(&clientHelloInfo{
|
|
||||||
CipherSuites: clientHello.cipherSuites,
|
|
||||||
ServerName: clientHello.serverName,
|
|
||||||
SupportedCurves: clientHello.supportedCurves,
|
|
||||||
SupportedPoints: clientHello.supportedPoints,
|
|
||||||
SignatureSchemes: clientHello.supportedSignatureAlgorithms,
|
|
||||||
SupportedProtos: clientHello.alpnProtocols,
|
|
||||||
SupportedVersions: supportedVersions,
|
|
||||||
Conn: c.conn,
|
|
||||||
config: toConfig(c.config),
|
|
||||||
ctx: ctx,
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -1,986 +0,0 @@
|
||||||
// Copyright 2018 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/rsa"
|
|
||||||
"errors"
|
|
||||||
"hash"
|
|
||||||
"io"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// maxClientPSKIdentities is the number of client PSK identities the server will
|
|
||||||
// attempt to validate. It will ignore the rest not to let cheap ClientHello
|
|
||||||
// messages cause too much work in session ticket decryption attempts.
|
|
||||||
const maxClientPSKIdentities = 5
|
|
||||||
|
|
||||||
type serverHandshakeStateTLS13 struct {
|
|
||||||
c *Conn
|
|
||||||
ctx context.Context
|
|
||||||
clientHello *clientHelloMsg
|
|
||||||
hello *serverHelloMsg
|
|
||||||
alpnNegotiationErr error
|
|
||||||
encryptedExtensions *encryptedExtensionsMsg
|
|
||||||
sentDummyCCS bool
|
|
||||||
usingPSK bool
|
|
||||||
suite *cipherSuiteTLS13
|
|
||||||
cert *Certificate
|
|
||||||
sigAlg SignatureScheme
|
|
||||||
earlySecret []byte
|
|
||||||
sharedKey []byte
|
|
||||||
handshakeSecret []byte
|
|
||||||
masterSecret []byte
|
|
||||||
trafficSecret []byte // client_application_traffic_secret_0
|
|
||||||
transcript hash.Hash
|
|
||||||
clientFinished []byte
|
|
||||||
earlyData bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) handshake() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if needFIPS() {
|
|
||||||
return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
// For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2.
|
|
||||||
if err := hs.processClientHello(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.checkForResumption(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.pickCertificate(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.buffering = true
|
|
||||||
if err := hs.sendServerParameters(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.sendServerCertificate(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.sendServerFinished(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Note that at this point we could start sending application data without
|
|
||||||
// waiting for the client's second flight, but the application might not
|
|
||||||
// expect the lack of replay protection of the ClientHello parameters.
|
|
||||||
if _, err := c.flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.readClientCertificate(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := hs.readClientFinished(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.isHandshakeComplete.Store(true)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) processClientHello() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
hs.hello = new(serverHelloMsg)
|
|
||||||
hs.encryptedExtensions = new(encryptedExtensionsMsg)
|
|
||||||
|
|
||||||
// TLS 1.3 froze the ServerHello.legacy_version field, and uses
|
|
||||||
// supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1.
|
|
||||||
hs.hello.vers = VersionTLS12
|
|
||||||
hs.hello.supportedVersion = c.vers
|
|
||||||
|
|
||||||
if len(hs.clientHello.supportedVersions) == 0 {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: client used the legacy version field to negotiate TLS 1.3")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Abort if the client is doing a fallback and landing lower than what we
|
|
||||||
// support. See RFC 7507, which however does not specify the interaction
|
|
||||||
// with supported_versions. The only difference is that with
|
|
||||||
// supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4]
|
|
||||||
// handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case,
|
|
||||||
// it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to
|
|
||||||
// TLS 1.2, because a TLS 1.3 server would abort here. The situation before
|
|
||||||
// supported_versions was not better because there was just no way to do a
|
|
||||||
// TLS 1.4 handshake without risking the server selecting TLS 1.3.
|
|
||||||
for _, id := range hs.clientHello.cipherSuites {
|
|
||||||
if id == TLS_FALLBACK_SCSV {
|
|
||||||
// Use c.vers instead of max(supported_versions) because an attacker
|
|
||||||
// could defeat this by adding an arbitrary high version otherwise.
|
|
||||||
if c.vers < c.config.maxSupportedVersion(roleServer) {
|
|
||||||
c.sendAlert(alertInappropriateFallback)
|
|
||||||
return errors.New("tls: client using inappropriate protocol fallback")
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(hs.clientHello.compressionMethods) != 1 ||
|
|
||||||
hs.clientHello.compressionMethods[0] != compressionNone {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: TLS 1.3 client supports illegal compression methods")
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.hello.random = make([]byte, 32)
|
|
||||||
if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(hs.clientHello.secureRenegotiation) != 0 {
|
|
||||||
c.sendAlert(alertHandshakeFailure)
|
|
||||||
return errors.New("tls: initial handshake had non-empty renegotiation extension")
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.hello.sessionId = hs.clientHello.sessionId
|
|
||||||
hs.hello.compressionMethod = compressionNone
|
|
||||||
|
|
||||||
preferenceList := defaultCipherSuitesTLS13
|
|
||||||
if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) {
|
|
||||||
preferenceList = defaultCipherSuitesTLS13NoAES
|
|
||||||
}
|
|
||||||
for _, suiteID := range preferenceList {
|
|
||||||
hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID)
|
|
||||||
if hs.suite != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if hs.suite == nil {
|
|
||||||
c.sendAlert(alertHandshakeFailure)
|
|
||||||
return errors.New("tls: no cipher suite supported by both client and server")
|
|
||||||
}
|
|
||||||
c.cipherSuite = hs.suite.id
|
|
||||||
hs.hello.cipherSuite = hs.suite.id
|
|
||||||
hs.transcript = hs.suite.hash.New()
|
|
||||||
|
|
||||||
// Pick the ECDHE group in server preference order, but give priority to
|
|
||||||
// groups with a key share, to avoid a HelloRetryRequest round-trip.
|
|
||||||
var selectedGroup CurveID
|
|
||||||
var clientKeyShare *keyShare
|
|
||||||
GroupSelection:
|
|
||||||
for _, preferredGroup := range c.config.curvePreferences() {
|
|
||||||
for _, ks := range hs.clientHello.keyShares {
|
|
||||||
if ks.group == preferredGroup {
|
|
||||||
selectedGroup = ks.group
|
|
||||||
clientKeyShare = &ks
|
|
||||||
break GroupSelection
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if selectedGroup != 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, group := range hs.clientHello.supportedCurves {
|
|
||||||
if group == preferredGroup {
|
|
||||||
selectedGroup = group
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if selectedGroup == 0 {
|
|
||||||
c.sendAlert(alertHandshakeFailure)
|
|
||||||
return errors.New("tls: no ECDHE curve supported by both client and server")
|
|
||||||
}
|
|
||||||
if clientKeyShare == nil {
|
|
||||||
if err := hs.doHelloRetryRequest(selectedGroup); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
clientKeyShare = &hs.clientHello.keyShares[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := curveForCurveID(selectedGroup); !ok {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return errors.New("tls: CurvePreferences includes unsupported curve")
|
|
||||||
}
|
|
||||||
key, err := generateECDHEKey(c.config.rand(), selectedGroup)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()}
|
|
||||||
peerKey, err := key.Curve().NewPublicKey(clientKeyShare.data)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: invalid client key share")
|
|
||||||
}
|
|
||||||
hs.sharedKey, err = key.ECDH(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: invalid client key share")
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.quic != nil {
|
|
||||||
// RFC 9001 Section 4.2: Clients MUST NOT offer TLS versions older than 1.3.
|
|
||||||
for _, v := range hs.clientHello.supportedVersions {
|
|
||||||
if v < VersionTLS13 {
|
|
||||||
c.sendAlert(alertProtocolVersion)
|
|
||||||
return errors.New("tls: client offered TLS version older than TLS 1.3")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// RFC 9001 Section 8.2.
|
|
||||||
if hs.clientHello.quicTransportParameters == nil {
|
|
||||||
c.sendAlert(alertMissingExtension)
|
|
||||||
return errors.New("tls: client did not send a quic_transport_parameters extension")
|
|
||||||
}
|
|
||||||
c.quicSetTransportParameters(hs.clientHello.quicTransportParameters)
|
|
||||||
} else {
|
|
||||||
if hs.clientHello.quicTransportParameters != nil {
|
|
||||||
c.sendAlert(alertUnsupportedExtension)
|
|
||||||
return errors.New("tls: client sent an unexpected quic_transport_parameters extension")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.serverName = hs.clientHello.serverName
|
|
||||||
|
|
||||||
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
|
|
||||||
if err != nil {
|
|
||||||
hs.alpnNegotiationErr = err
|
|
||||||
}
|
|
||||||
hs.encryptedExtensions.alpnProtocol = selectedProto
|
|
||||||
c.clientProtocol = selectedProto
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) checkForResumption() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if c.config.SessionTicketsDisabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
modeOK := false
|
|
||||||
for _, mode := range hs.clientHello.pskModes {
|
|
||||||
if mode == pskModeDHE {
|
|
||||||
modeOK = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !modeOK {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: invalid or missing PSK binders")
|
|
||||||
}
|
|
||||||
if len(hs.clientHello.pskIdentities) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, identity := range hs.clientHello.pskIdentities {
|
|
||||||
if i >= maxClientPSKIdentities {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
plaintext, _ := c.decryptTicket(identity.label)
|
|
||||||
if plaintext == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
sessionState := new(sessionStateTLS13)
|
|
||||||
if ok := sessionState.unmarshal(plaintext); !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.clientHello.earlyData {
|
|
||||||
if sessionState.maxEarlyData == 0 {
|
|
||||||
c.sendAlert(alertUnsupportedExtension)
|
|
||||||
return errors.New("tls: client sent unexpected early data")
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.alpnNegotiationErr == nil && sessionState.alpn == c.clientProtocol &&
|
|
||||||
c.extraConfig != nil && c.extraConfig.Enable0RTT &&
|
|
||||||
c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) {
|
|
||||||
hs.encryptedExtensions.earlyData = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
createdAt := time.Unix(int64(sessionState.createdAt), 0)
|
|
||||||
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't check the obfuscated ticket age because it's affected by
|
|
||||||
// clock skew and it's only a freshness signal useful for shrinking the
|
|
||||||
// window for replay attacks, which don't affect us as we don't do 0-RTT.
|
|
||||||
|
|
||||||
pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite)
|
|
||||||
if pskSuite == nil || pskSuite.hash != hs.suite.hash {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// PSK connections don't re-establish client certificates, but carry
|
|
||||||
// them over in the session ticket. Ensure the presence of client certs
|
|
||||||
// in the ticket is consistent with the configured requirements.
|
|
||||||
sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0
|
|
||||||
needClientCerts := requiresClientCert(c.config.ClientAuth)
|
|
||||||
if needClientCerts && !sessionHasClientCerts {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption",
|
|
||||||
nil, hs.suite.hash.Size())
|
|
||||||
hs.earlySecret = hs.suite.extract(psk, nil)
|
|
||||||
binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil)
|
|
||||||
// Clone the transcript in case a HelloRetryRequest was recorded.
|
|
||||||
transcript := cloneHash(hs.transcript, hs.suite.hash)
|
|
||||||
if transcript == nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return errors.New("tls: internal error: failed to clone hash")
|
|
||||||
}
|
|
||||||
clientHelloBytes, err := hs.clientHello.marshalWithoutBinders()
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
transcript.Write(clientHelloBytes)
|
|
||||||
pskBinder := hs.suite.finishedHash(binderKey, transcript)
|
|
||||||
if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) {
|
|
||||||
c.sendAlert(alertDecryptError)
|
|
||||||
return errors.New("tls: invalid PSK binder")
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.quic != nil && hs.clientHello.earlyData && hs.encryptedExtensions.earlyData && i == 0 &&
|
|
||||||
sessionState.maxEarlyData > 0 && sessionState.cipherSuite == hs.suite.id {
|
|
||||||
hs.earlyData = true
|
|
||||||
|
|
||||||
transcript := hs.suite.hash.New()
|
|
||||||
if err := transcriptMsg(hs.clientHello, transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
earlyTrafficSecret := hs.suite.deriveSecret(hs.earlySecret, clientEarlyTrafficLabel, transcript)
|
|
||||||
c.quicSetReadSecret(QUICEncryptionLevelEarly, hs.suite.id, earlyTrafficSecret)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.didResume = true
|
|
||||||
if err := c.processCertsFromClient(sessionState.certificate); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.hello.selectedIdentityPresent = true
|
|
||||||
hs.hello.selectedIdentity = uint16(i)
|
|
||||||
hs.usingPSK = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
|
|
||||||
// interfaces implemented by standard library hashes to clone the state of in
|
|
||||||
// to a new instance of h. It returns nil if the operation fails.
|
|
||||||
func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash {
|
|
||||||
// Recreate the interface to avoid importing encoding.
|
|
||||||
type binaryMarshaler interface {
|
|
||||||
MarshalBinary() (data []byte, err error)
|
|
||||||
UnmarshalBinary(data []byte) error
|
|
||||||
}
|
|
||||||
marshaler, ok := in.(binaryMarshaler)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
state, err := marshaler.MarshalBinary()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
out := h.New()
|
|
||||||
unmarshaler, ok := out.(binaryMarshaler)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := unmarshaler.UnmarshalBinary(state); err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) pickCertificate() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
// Only one of PSK and certificates are used at a time.
|
|
||||||
if hs.usingPSK {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3.
|
|
||||||
if len(hs.clientHello.supportedSignatureAlgorithms) == 0 {
|
|
||||||
return c.sendAlert(alertMissingExtension)
|
|
||||||
}
|
|
||||||
|
|
||||||
certificate, err := c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello))
|
|
||||||
if err != nil {
|
|
||||||
if err == errNoCertificates {
|
|
||||||
c.sendAlert(alertUnrecognizedName)
|
|
||||||
} else {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms)
|
|
||||||
if err != nil {
|
|
||||||
// getCertificate returned a certificate that is unsupported or
|
|
||||||
// incompatible with the client's signature algorithms.
|
|
||||||
c.sendAlert(alertHandshakeFailure)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
hs.cert = certificate
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
|
|
||||||
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
|
|
||||||
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
|
|
||||||
if hs.c.quic != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if hs.sentDummyCCS {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
hs.sentDummyCCS = true
|
|
||||||
|
|
||||||
return hs.c.writeChangeCipherRecord()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
// The first ClientHello gets double-hashed into the transcript upon a
|
|
||||||
// HelloRetryRequest. See RFC 8446, Section 4.4.1.
|
|
||||||
if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
chHash := hs.transcript.Sum(nil)
|
|
||||||
hs.transcript.Reset()
|
|
||||||
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
|
|
||||||
hs.transcript.Write(chHash)
|
|
||||||
|
|
||||||
helloRetryRequest := &serverHelloMsg{
|
|
||||||
vers: hs.hello.vers,
|
|
||||||
random: helloRetryRequestRandom,
|
|
||||||
sessionId: hs.hello.sessionId,
|
|
||||||
cipherSuite: hs.hello.cipherSuite,
|
|
||||||
compressionMethod: hs.hello.compressionMethod,
|
|
||||||
supportedVersion: hs.hello.supportedVersion,
|
|
||||||
selectedGroup: selectedGroup,
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := hs.sendDummyChangeCipherSpec(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientHelloMsg is not included in the transcript.
|
|
||||||
msg, err := c.readHandshake(nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
clientHello, ok := msg.(*clientHelloMsg)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(clientHello, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: client sent invalid key share in second ClientHello")
|
|
||||||
}
|
|
||||||
|
|
||||||
if clientHello.earlyData {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: client indicated early data in second ClientHello")
|
|
||||||
}
|
|
||||||
|
|
||||||
if illegalClientHelloChange(clientHello, hs.clientHello) {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: client illegally modified second ClientHello")
|
|
||||||
}
|
|
||||||
|
|
||||||
if illegalClientHelloChange(clientHello, hs.clientHello) {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: client illegally modified second ClientHello")
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.clientHello = clientHello
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// illegalClientHelloChange reports whether the two ClientHello messages are
|
|
||||||
// different, with the exception of the changes allowed before and after a
|
|
||||||
// HelloRetryRequest. See RFC 8446, Section 4.1.2.
|
|
||||||
func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool {
|
|
||||||
if len(ch.supportedVersions) != len(ch1.supportedVersions) ||
|
|
||||||
len(ch.cipherSuites) != len(ch1.cipherSuites) ||
|
|
||||||
len(ch.supportedCurves) != len(ch1.supportedCurves) ||
|
|
||||||
len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) ||
|
|
||||||
len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) ||
|
|
||||||
len(ch.alpnProtocols) != len(ch1.alpnProtocols) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
for i := range ch.supportedVersions {
|
|
||||||
if ch.supportedVersions[i] != ch1.supportedVersions[i] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i := range ch.cipherSuites {
|
|
||||||
if ch.cipherSuites[i] != ch1.cipherSuites[i] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i := range ch.supportedCurves {
|
|
||||||
if ch.supportedCurves[i] != ch1.supportedCurves[i] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i := range ch.supportedSignatureAlgorithms {
|
|
||||||
if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i := range ch.supportedSignatureAlgorithmsCert {
|
|
||||||
if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i := range ch.alpnProtocols {
|
|
||||||
if ch.alpnProtocols[i] != ch1.alpnProtocols[i] {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ch.vers != ch1.vers ||
|
|
||||||
!bytes.Equal(ch.random, ch1.random) ||
|
|
||||||
!bytes.Equal(ch.sessionId, ch1.sessionId) ||
|
|
||||||
!bytes.Equal(ch.compressionMethods, ch1.compressionMethods) ||
|
|
||||||
ch.serverName != ch1.serverName ||
|
|
||||||
ch.ocspStapling != ch1.ocspStapling ||
|
|
||||||
!bytes.Equal(ch.supportedPoints, ch1.supportedPoints) ||
|
|
||||||
ch.ticketSupported != ch1.ticketSupported ||
|
|
||||||
!bytes.Equal(ch.sessionTicket, ch1.sessionTicket) ||
|
|
||||||
ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported ||
|
|
||||||
!bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) ||
|
|
||||||
ch.scts != ch1.scts ||
|
|
||||||
!bytes.Equal(ch.cookie, ch1.cookie) ||
|
|
||||||
!bytes.Equal(ch.pskModes, ch1.pskModes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := hs.sendDummyChangeCipherSpec(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
earlySecret := hs.earlySecret
|
|
||||||
if earlySecret == nil {
|
|
||||||
earlySecret = hs.suite.extract(nil, nil)
|
|
||||||
}
|
|
||||||
hs.handshakeSecret = hs.suite.extract(hs.sharedKey,
|
|
||||||
hs.suite.deriveSecret(earlySecret, "derived", nil))
|
|
||||||
|
|
||||||
clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
|
|
||||||
clientHandshakeTrafficLabel, hs.transcript)
|
|
||||||
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
|
|
||||||
serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
|
|
||||||
serverHandshakeTrafficLabel, hs.transcript)
|
|
||||||
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
|
|
||||||
|
|
||||||
if c.quic != nil {
|
|
||||||
if c.hand.Len() != 0 {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
}
|
|
||||||
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
|
|
||||||
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertNoApplicationProtocol)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
hs.encryptedExtensions.alpnProtocol = selectedProto
|
|
||||||
c.clientProtocol = selectedProto
|
|
||||||
|
|
||||||
if c.quic != nil {
|
|
||||||
p, err := c.quicGetTransportParameters()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
hs.encryptedExtensions.quicTransportParameters = p
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(hs.encryptedExtensions, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) requestClientCert() bool {
|
|
||||||
return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) sendServerCertificate() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
// Only one of PSK and certificates are used at a time.
|
|
||||||
if hs.usingPSK {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.requestClientCert() {
|
|
||||||
// Request a client certificate
|
|
||||||
certReq := new(certificateRequestMsgTLS13)
|
|
||||||
certReq.ocspStapling = true
|
|
||||||
certReq.scts = true
|
|
||||||
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
|
|
||||||
if c.config.ClientCAs != nil {
|
|
||||||
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
certMsg := new(certificateMsgTLS13)
|
|
||||||
|
|
||||||
certMsg.certificate = *hs.cert
|
|
||||||
certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0
|
|
||||||
certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0
|
|
||||||
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
certVerifyMsg := new(certificateVerifyMsg)
|
|
||||||
certVerifyMsg.hasSignatureAlgorithm = true
|
|
||||||
certVerifyMsg.signatureAlgorithm = hs.sigAlg
|
|
||||||
|
|
||||||
sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg)
|
|
||||||
if err != nil {
|
|
||||||
return c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
|
|
||||||
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
|
|
||||||
signOpts := crypto.SignerOpts(sigHash)
|
|
||||||
if sigType == signatureRSAPSS {
|
|
||||||
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
|
|
||||||
}
|
|
||||||
sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
|
|
||||||
if err != nil {
|
|
||||||
public := hs.cert.PrivateKey.(crypto.Signer).Public()
|
|
||||||
if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS &&
|
|
||||||
rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS
|
|
||||||
c.sendAlert(alertHandshakeFailure)
|
|
||||||
} else {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
return errors.New("tls: failed to sign handshake: " + err.Error())
|
|
||||||
}
|
|
||||||
certVerifyMsg.signature = sig
|
|
||||||
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
finished := &finishedMsg{
|
|
||||||
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Derive secrets that take context through the server Finished.
|
|
||||||
|
|
||||||
hs.masterSecret = hs.suite.extract(nil,
|
|
||||||
hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil))
|
|
||||||
|
|
||||||
hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
|
|
||||||
clientApplicationTrafficLabel, hs.transcript)
|
|
||||||
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
|
|
||||||
serverApplicationTrafficLabel, hs.transcript)
|
|
||||||
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
|
|
||||||
|
|
||||||
if c.quic != nil {
|
|
||||||
if c.hand.Len() != 0 {
|
|
||||||
// TODO: Handle this in setTrafficSecret?
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
}
|
|
||||||
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, serverSecret)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret)
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
|
|
||||||
|
|
||||||
// If we did not request client certificates, at this point we can
|
|
||||||
// precompute the client finished and roll the transcript forward to send
|
|
||||||
// session tickets in our first flight.
|
|
||||||
if !hs.requestClientCert() {
|
|
||||||
if err := hs.sendSessionTickets(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool {
|
|
||||||
if hs.c.config.SessionTicketsDisabled {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// QUIC tickets are sent by QUICConn.SendSessionTicket, not automatically.
|
|
||||||
if hs.c.quic != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9.
|
|
||||||
for _, pskMode := range hs.clientHello.pskModes {
|
|
||||||
if pskMode == pskModeDHE {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
|
|
||||||
finishedMsg := &finishedMsg{
|
|
||||||
verifyData: hs.clientFinished,
|
|
||||||
}
|
|
||||||
if err := transcriptMsg(finishedMsg, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
|
|
||||||
resumptionLabel, hs.transcript)
|
|
||||||
|
|
||||||
if !hs.shouldSendSessionTickets() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return c.sendSessionTicket(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) sendSessionTicket(earlyData bool) error {
|
|
||||||
suite := cipherSuiteTLS13ByID(c.cipherSuite)
|
|
||||||
if suite == nil {
|
|
||||||
return errors.New("tls: internal error: unknown cipher suite")
|
|
||||||
}
|
|
||||||
|
|
||||||
m := new(newSessionTicketMsgTLS13)
|
|
||||||
|
|
||||||
var certsFromClient [][]byte
|
|
||||||
for _, cert := range c.peerCertificates {
|
|
||||||
certsFromClient = append(certsFromClient, cert.Raw)
|
|
||||||
}
|
|
||||||
state := sessionStateTLS13{
|
|
||||||
cipherSuite: suite.id,
|
|
||||||
createdAt: uint64(c.config.time().Unix()),
|
|
||||||
resumptionSecret: c.resumptionSecret,
|
|
||||||
certificate: Certificate{
|
|
||||||
Certificate: certsFromClient,
|
|
||||||
OCSPStaple: c.ocspResponse,
|
|
||||||
SignedCertificateTimestamps: c.scts,
|
|
||||||
},
|
|
||||||
alpn: c.clientProtocol,
|
|
||||||
}
|
|
||||||
if earlyData {
|
|
||||||
state.maxEarlyData = 0xffffffff
|
|
||||||
state.appData = c.extraConfig.GetAppDataForSessionTicket()
|
|
||||||
}
|
|
||||||
stateBytes, err := state.marshal()
|
|
||||||
if err != nil {
|
|
||||||
c.sendAlert(alertInternalError)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.label, err = c.encryptTicket(stateBytes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
|
|
||||||
|
|
||||||
// ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1
|
|
||||||
// The value is not stored anywhere; we never need to check the ticket age
|
|
||||||
// because 0-RTT is not supported.
|
|
||||||
ageAdd := make([]byte, 4)
|
|
||||||
_, err = c.config.rand().Read(ageAdd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if earlyData {
|
|
||||||
// RFC 9001, Section 4.6.1
|
|
||||||
m.maxEarlyData = 0xffffffff
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := c.writeHandshakeRecord(m, nil); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
if !hs.requestClientCert() {
|
|
||||||
// Make sure the connection is still being verified whether or not
|
|
||||||
// the server requested a client certificate.
|
|
||||||
if c.config.VerifyConnection != nil {
|
|
||||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
|
||||||
c.sendAlert(alertBadCertificate)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we requested a client certificate, then the client must send a
|
|
||||||
// certificate message. If it's empty, no CertificateVerify is sent.
|
|
||||||
|
|
||||||
msg, err := c.readHandshake(hs.transcript)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
certMsg, ok := msg.(*certificateMsgTLS13)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(certMsg, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.processCertsFromClient(certMsg.certificate); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.VerifyConnection != nil {
|
|
||||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
|
||||||
c.sendAlert(alertBadCertificate)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(certMsg.certificate.Certificate) != 0 {
|
|
||||||
// certificateVerifyMsg is included in the transcript, but not until
|
|
||||||
// after we verify the handshake signature, since the state before
|
|
||||||
// this message was sent is used.
|
|
||||||
msg, err = c.readHandshake(nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
certVerify, ok := msg.(*certificateVerifyMsg)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(certVerify, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 8446, Section 4.4.3.
|
|
||||||
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: client certificate used with invalid signature algorithm")
|
|
||||||
}
|
|
||||||
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
|
|
||||||
if err != nil {
|
|
||||||
return c.sendAlert(alertInternalError)
|
|
||||||
}
|
|
||||||
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
|
|
||||||
c.sendAlert(alertIllegalParameter)
|
|
||||||
return errors.New("tls: client certificate used with invalid signature algorithm")
|
|
||||||
}
|
|
||||||
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
|
|
||||||
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
|
|
||||||
sigHash, signed, certVerify.signature); err != nil {
|
|
||||||
c.sendAlert(alertDecryptError)
|
|
||||||
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := transcriptMsg(certVerify, hs.transcript); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we waited until the client certificates to send session tickets, we
|
|
||||||
// are ready to do it now.
|
|
||||||
if err := hs.sendSessionTickets(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *serverHandshakeStateTLS13) readClientFinished() error {
|
|
||||||
c := hs.c
|
|
||||||
|
|
||||||
// finishedMsg is not included in the transcript.
|
|
||||||
msg, err := c.readHandshake(nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
finished, ok := msg.(*finishedMsg)
|
|
||||||
if !ok {
|
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
|
||||||
return unexpectedMessageError(finished, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hmac.Equal(hs.clientFinished, finished.verifyData) {
|
|
||||||
c.sendAlert(alertDecryptError)
|
|
||||||
return errors.New("tls: invalid client finished hash")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,366 +0,0 @@
|
||||||
// Copyright 2010 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdh"
|
|
||||||
"crypto/md5"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/sha1"
|
|
||||||
"crypto/x509"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// a keyAgreement implements the client and server side of a TLS key agreement
|
|
||||||
// protocol by generating and processing key exchange messages.
|
|
||||||
type keyAgreement interface {
|
|
||||||
// On the server side, the first two methods are called in order.
|
|
||||||
|
|
||||||
// In the case that the key agreement protocol doesn't use a
|
|
||||||
// ServerKeyExchange message, generateServerKeyExchange can return nil,
|
|
||||||
// nil.
|
|
||||||
generateServerKeyExchange(*config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error)
|
|
||||||
processClientKeyExchange(*config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error)
|
|
||||||
|
|
||||||
// On the client side, the next two methods are called in order.
|
|
||||||
|
|
||||||
// This method may not be called if the server doesn't send a
|
|
||||||
// ServerKeyExchange message.
|
|
||||||
processServerKeyExchange(*config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error
|
|
||||||
generateClientKeyExchange(*config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
|
|
||||||
var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
|
|
||||||
|
|
||||||
// rsaKeyAgreement implements the standard TLS key agreement where the client
|
|
||||||
// encrypts the pre-master secret to the server's public key.
|
|
||||||
type rsaKeyAgreement struct{}
|
|
||||||
|
|
||||||
func (ka rsaKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ka rsaKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
|
|
||||||
if len(ckx.ciphertext) < 2 {
|
|
||||||
return nil, errClientKeyExchange
|
|
||||||
}
|
|
||||||
ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
|
|
||||||
if ciphertextLen != len(ckx.ciphertext)-2 {
|
|
||||||
return nil, errClientKeyExchange
|
|
||||||
}
|
|
||||||
ciphertext := ckx.ciphertext[2:]
|
|
||||||
|
|
||||||
priv, ok := cert.PrivateKey.(crypto.Decrypter)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
|
|
||||||
}
|
|
||||||
// Perform constant time RSA PKCS #1 v1.5 decryption
|
|
||||||
preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// We don't check the version number in the premaster secret. For one,
|
|
||||||
// by checking it, we would leak information about the validity of the
|
|
||||||
// encrypted pre-master secret. Secondly, it provides only a small
|
|
||||||
// benefit against a downgrade attack and some implementations send the
|
|
||||||
// wrong version anyway. See the discussion at the end of section
|
|
||||||
// 7.4.7.1 of RFC 4346.
|
|
||||||
return preMasterSecret, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ka rsaKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
|
|
||||||
return errors.New("tls: unexpected ServerKeyExchange")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ka rsaKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
|
|
||||||
preMasterSecret := make([]byte, 48)
|
|
||||||
preMasterSecret[0] = byte(clientHello.vers >> 8)
|
|
||||||
preMasterSecret[1] = byte(clientHello.vers)
|
|
||||||
_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rsaKey, ok := cert.PublicKey.(*rsa.PublicKey)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite")
|
|
||||||
}
|
|
||||||
encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
ckx := new(clientKeyExchangeMsg)
|
|
||||||
ckx.ciphertext = make([]byte, len(encrypted)+2)
|
|
||||||
ckx.ciphertext[0] = byte(len(encrypted) >> 8)
|
|
||||||
ckx.ciphertext[1] = byte(len(encrypted))
|
|
||||||
copy(ckx.ciphertext[2:], encrypted)
|
|
||||||
return preMasterSecret, ckx, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sha1Hash calculates a SHA1 hash over the given byte slices.
|
|
||||||
func sha1Hash(slices [][]byte) []byte {
|
|
||||||
hsha1 := sha1.New()
|
|
||||||
for _, slice := range slices {
|
|
||||||
hsha1.Write(slice)
|
|
||||||
}
|
|
||||||
return hsha1.Sum(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
|
|
||||||
// concatenation of an MD5 and SHA1 hash.
|
|
||||||
func md5SHA1Hash(slices [][]byte) []byte {
|
|
||||||
md5sha1 := make([]byte, md5.Size+sha1.Size)
|
|
||||||
hmd5 := md5.New()
|
|
||||||
for _, slice := range slices {
|
|
||||||
hmd5.Write(slice)
|
|
||||||
}
|
|
||||||
copy(md5sha1, hmd5.Sum(nil))
|
|
||||||
copy(md5sha1[md5.Size:], sha1Hash(slices))
|
|
||||||
return md5sha1
|
|
||||||
}
|
|
||||||
|
|
||||||
// hashForServerKeyExchange hashes the given slices and returns their digest
|
|
||||||
// using the given hash function (for >= TLS 1.2) or using a default based on
|
|
||||||
// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't
|
|
||||||
// do pre-hashing, it returns the concatenation of the slices.
|
|
||||||
func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte {
|
|
||||||
if sigType == signatureEd25519 {
|
|
||||||
var signed []byte
|
|
||||||
for _, slice := range slices {
|
|
||||||
signed = append(signed, slice...)
|
|
||||||
}
|
|
||||||
return signed
|
|
||||||
}
|
|
||||||
if version >= VersionTLS12 {
|
|
||||||
h := hashFunc.New()
|
|
||||||
for _, slice := range slices {
|
|
||||||
h.Write(slice)
|
|
||||||
}
|
|
||||||
digest := h.Sum(nil)
|
|
||||||
return digest
|
|
||||||
}
|
|
||||||
if sigType == signatureECDSA {
|
|
||||||
return sha1Hash(slices)
|
|
||||||
}
|
|
||||||
return md5SHA1Hash(slices)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ecdheKeyAgreement implements a TLS key agreement where the server
|
|
||||||
// generates an ephemeral EC public/private key pair and signs it. The
|
|
||||||
// pre-master secret is then calculated using ECDH. The signature may
|
|
||||||
// be ECDSA, Ed25519 or RSA.
|
|
||||||
type ecdheKeyAgreement struct {
|
|
||||||
version uint16
|
|
||||||
isRSA bool
|
|
||||||
key *ecdh.PrivateKey
|
|
||||||
|
|
||||||
// ckx and preMasterSecret are generated in processServerKeyExchange
|
|
||||||
// and returned in generateClientKeyExchange.
|
|
||||||
ckx *clientKeyExchangeMsg
|
|
||||||
preMasterSecret []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
|
|
||||||
var curveID CurveID
|
|
||||||
for _, c := range clientHello.supportedCurves {
|
|
||||||
if config.supportsCurve(c) {
|
|
||||||
curveID = c
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if curveID == 0 {
|
|
||||||
return nil, errors.New("tls: no supported elliptic curves offered")
|
|
||||||
}
|
|
||||||
if _, ok := curveForCurveID(curveID); !ok {
|
|
||||||
return nil, errors.New("tls: CurvePreferences includes unsupported curve")
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := generateECDHEKey(config.rand(), curveID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ka.key = key
|
|
||||||
|
|
||||||
// See RFC 4492, Section 5.4.
|
|
||||||
ecdhePublic := key.PublicKey().Bytes()
|
|
||||||
serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic))
|
|
||||||
serverECDHEParams[0] = 3 // named curve
|
|
||||||
serverECDHEParams[1] = byte(curveID >> 8)
|
|
||||||
serverECDHEParams[2] = byte(curveID)
|
|
||||||
serverECDHEParams[3] = byte(len(ecdhePublic))
|
|
||||||
copy(serverECDHEParams[4:], ecdhePublic)
|
|
||||||
|
|
||||||
priv, ok := cert.PrivateKey.(crypto.Signer)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
var signatureAlgorithm SignatureScheme
|
|
||||||
var sigType uint8
|
|
||||||
var sigHash crypto.Hash
|
|
||||||
if ka.version >= VersionTLS12 {
|
|
||||||
signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
|
|
||||||
return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
|
|
||||||
}
|
|
||||||
|
|
||||||
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams)
|
|
||||||
|
|
||||||
signOpts := crypto.SignerOpts(sigHash)
|
|
||||||
if sigType == signatureRSAPSS {
|
|
||||||
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
|
|
||||||
}
|
|
||||||
sig, err := priv.Sign(config.rand(), signed, signOpts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
skx := new(serverKeyExchangeMsg)
|
|
||||||
sigAndHashLen := 0
|
|
||||||
if ka.version >= VersionTLS12 {
|
|
||||||
sigAndHashLen = 2
|
|
||||||
}
|
|
||||||
skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig))
|
|
||||||
copy(skx.key, serverECDHEParams)
|
|
||||||
k := skx.key[len(serverECDHEParams):]
|
|
||||||
if ka.version >= VersionTLS12 {
|
|
||||||
k[0] = byte(signatureAlgorithm >> 8)
|
|
||||||
k[1] = byte(signatureAlgorithm)
|
|
||||||
k = k[2:]
|
|
||||||
}
|
|
||||||
k[0] = byte(len(sig) >> 8)
|
|
||||||
k[1] = byte(len(sig))
|
|
||||||
copy(k[2:], sig)
|
|
||||||
|
|
||||||
return skx, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ka *ecdheKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
|
|
||||||
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
|
|
||||||
return nil, errClientKeyExchange
|
|
||||||
}
|
|
||||||
|
|
||||||
peerKey, err := ka.key.Curve().NewPublicKey(ckx.ciphertext[1:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errClientKeyExchange
|
|
||||||
}
|
|
||||||
preMasterSecret, err := ka.key.ECDH(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errClientKeyExchange
|
|
||||||
}
|
|
||||||
|
|
||||||
return preMasterSecret, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ka *ecdheKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
|
|
||||||
if len(skx.key) < 4 {
|
|
||||||
return errServerKeyExchange
|
|
||||||
}
|
|
||||||
if skx.key[0] != 3 { // named curve
|
|
||||||
return errors.New("tls: server selected unsupported curve")
|
|
||||||
}
|
|
||||||
curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
|
|
||||||
|
|
||||||
publicLen := int(skx.key[3])
|
|
||||||
if publicLen+4 > len(skx.key) {
|
|
||||||
return errServerKeyExchange
|
|
||||||
}
|
|
||||||
serverECDHEParams := skx.key[:4+publicLen]
|
|
||||||
publicKey := serverECDHEParams[4:]
|
|
||||||
|
|
||||||
sig := skx.key[4+publicLen:]
|
|
||||||
if len(sig) < 2 {
|
|
||||||
return errServerKeyExchange
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := curveForCurveID(curveID); !ok {
|
|
||||||
return errors.New("tls: server selected unsupported curve")
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := generateECDHEKey(config.rand(), curveID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ka.key = key
|
|
||||||
|
|
||||||
peerKey, err := key.Curve().NewPublicKey(publicKey)
|
|
||||||
if err != nil {
|
|
||||||
return errServerKeyExchange
|
|
||||||
}
|
|
||||||
ka.preMasterSecret, err = key.ECDH(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
return errServerKeyExchange
|
|
||||||
}
|
|
||||||
|
|
||||||
ourPublicKey := key.PublicKey().Bytes()
|
|
||||||
ka.ckx = new(clientKeyExchangeMsg)
|
|
||||||
ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey))
|
|
||||||
ka.ckx.ciphertext[0] = byte(len(ourPublicKey))
|
|
||||||
copy(ka.ckx.ciphertext[1:], ourPublicKey)
|
|
||||||
|
|
||||||
var sigType uint8
|
|
||||||
var sigHash crypto.Hash
|
|
||||||
if ka.version >= VersionTLS12 {
|
|
||||||
signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
|
|
||||||
sig = sig[2:]
|
|
||||||
if len(sig) < 2 {
|
|
||||||
return errServerKeyExchange
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) {
|
|
||||||
return errors.New("tls: certificate used with invalid signature algorithm")
|
|
||||||
}
|
|
||||||
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
|
|
||||||
return errServerKeyExchange
|
|
||||||
}
|
|
||||||
|
|
||||||
sigLen := int(sig[0])<<8 | int(sig[1])
|
|
||||||
if sigLen+2 != len(sig) {
|
|
||||||
return errServerKeyExchange
|
|
||||||
}
|
|
||||||
sig = sig[2:]
|
|
||||||
|
|
||||||
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams)
|
|
||||||
if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil {
|
|
||||||
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
|
|
||||||
if ka.ckx == nil {
|
|
||||||
return nil, nil, errors.New("tls: missing ServerKeyExchange message")
|
|
||||||
}
|
|
||||||
|
|
||||||
return ka.preMasterSecret, ka.ckx, nil
|
|
||||||
}
|
|
|
@ -1,159 +0,0 @@
|
||||||
// Copyright 2018 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/ecdh"
|
|
||||||
"crypto/hmac"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"hash"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/cryptobyte"
|
|
||||||
"golang.org/x/crypto/hkdf"
|
|
||||||
)
|
|
||||||
|
|
||||||
// This file contains the functions necessary to compute the TLS 1.3 key
|
|
||||||
// schedule. See RFC 8446, Section 7.
|
|
||||||
|
|
||||||
const (
|
|
||||||
resumptionBinderLabel = "res binder"
|
|
||||||
clientEarlyTrafficLabel = "c e traffic"
|
|
||||||
clientHandshakeTrafficLabel = "c hs traffic"
|
|
||||||
serverHandshakeTrafficLabel = "s hs traffic"
|
|
||||||
clientApplicationTrafficLabel = "c ap traffic"
|
|
||||||
serverApplicationTrafficLabel = "s ap traffic"
|
|
||||||
exporterLabel = "exp master"
|
|
||||||
resumptionLabel = "res master"
|
|
||||||
trafficUpdateLabel = "traffic upd"
|
|
||||||
)
|
|
||||||
|
|
||||||
// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
|
|
||||||
func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte {
|
|
||||||
var hkdfLabel cryptobyte.Builder
|
|
||||||
hkdfLabel.AddUint16(uint16(length))
|
|
||||||
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
||||||
b.AddBytes([]byte("tls13 "))
|
|
||||||
b.AddBytes([]byte(label))
|
|
||||||
})
|
|
||||||
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
||||||
b.AddBytes(context)
|
|
||||||
})
|
|
||||||
hkdfLabelBytes, err := hkdfLabel.Bytes()
|
|
||||||
if err != nil {
|
|
||||||
// Rather than calling BytesOrPanic, we explicitly handle this error, in
|
|
||||||
// order to provide a reasonable error message. It should be basically
|
|
||||||
// impossible for this to panic, and routing errors back through the
|
|
||||||
// tree rooted in this function is quite painful. The labels are fixed
|
|
||||||
// size, and the context is either a fixed-length computed hash, or
|
|
||||||
// parsed from a field which has the same length limitation. As such, an
|
|
||||||
// error here is likely to only be caused during development.
|
|
||||||
//
|
|
||||||
// NOTE: another reasonable approach here might be to return a
|
|
||||||
// randomized slice if we encounter an error, which would break the
|
|
||||||
// connection, but avoid panicking. This would perhaps be safer but
|
|
||||||
// significantly more confusing to users.
|
|
||||||
panic(fmt.Errorf("failed to construct HKDF label: %s", err))
|
|
||||||
}
|
|
||||||
out := make([]byte, length)
|
|
||||||
n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out)
|
|
||||||
if err != nil || n != length {
|
|
||||||
panic("tls: HKDF-Expand-Label invocation failed unexpectedly")
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1.
|
|
||||||
func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte {
|
|
||||||
if transcript == nil {
|
|
||||||
transcript = c.hash.New()
|
|
||||||
}
|
|
||||||
return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size())
|
|
||||||
}
|
|
||||||
|
|
||||||
// extract implements HKDF-Extract with the cipher suite hash.
|
|
||||||
func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte {
|
|
||||||
if newSecret == nil {
|
|
||||||
newSecret = make([]byte, c.hash.Size())
|
|
||||||
}
|
|
||||||
return hkdf.Extract(c.hash.New, newSecret, currentSecret)
|
|
||||||
}
|
|
||||||
|
|
||||||
// nextTrafficSecret generates the next traffic secret, given the current one,
|
|
||||||
// according to RFC 8446, Section 7.2.
|
|
||||||
func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte {
|
|
||||||
return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size())
|
|
||||||
}
|
|
||||||
|
|
||||||
// trafficKey generates traffic keys according to RFC 8446, Section 7.3.
|
|
||||||
func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) {
|
|
||||||
key = c.expandLabel(trafficSecret, "key", nil, c.keyLen)
|
|
||||||
iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// finishedHash generates the Finished verify_data or PskBinderEntry according
|
|
||||||
// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey
|
|
||||||
// selection.
|
|
||||||
func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte {
|
|
||||||
finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size())
|
|
||||||
verifyData := hmac.New(c.hash.New, finishedKey)
|
|
||||||
verifyData.Write(transcript.Sum(nil))
|
|
||||||
return verifyData.Sum(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to
|
|
||||||
// RFC 8446, Section 7.5.
|
|
||||||
func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) {
|
|
||||||
expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript)
|
|
||||||
return func(label string, context []byte, length int) ([]byte, error) {
|
|
||||||
secret := c.deriveSecret(expMasterSecret, label, nil)
|
|
||||||
h := c.hash.New()
|
|
||||||
h.Write(context)
|
|
||||||
return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateECDHEKey returns a PrivateKey that implements Diffie-Hellman
|
|
||||||
// according to RFC 8446, Section 4.2.8.2.
|
|
||||||
func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) {
|
|
||||||
curve, ok := curveForCurveID(curveID)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("tls: internal error: unsupported curve")
|
|
||||||
}
|
|
||||||
|
|
||||||
return curve.GenerateKey(rand)
|
|
||||||
}
|
|
||||||
|
|
||||||
func curveForCurveID(id CurveID) (ecdh.Curve, bool) {
|
|
||||||
switch id {
|
|
||||||
case X25519:
|
|
||||||
return ecdh.X25519(), true
|
|
||||||
case CurveP256:
|
|
||||||
return ecdh.P256(), true
|
|
||||||
case CurveP384:
|
|
||||||
return ecdh.P384(), true
|
|
||||||
case CurveP521:
|
|
||||||
return ecdh.P521(), true
|
|
||||||
default:
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func curveIDForCurve(curve ecdh.Curve) (CurveID, bool) {
|
|
||||||
switch curve {
|
|
||||||
case ecdh.X25519():
|
|
||||||
return X25519, true
|
|
||||||
case ecdh.P256():
|
|
||||||
return CurveP256, true
|
|
||||||
case ecdh.P384():
|
|
||||||
return CurveP384, true
|
|
||||||
case ecdh.P521():
|
|
||||||
return CurveP521, true
|
|
||||||
default:
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,18 +0,0 @@
|
||||||
// Copyright 2022 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
func needFIPS() bool { return false }
|
|
||||||
|
|
||||||
func supportedSignatureAlgorithms() []SignatureScheme {
|
|
||||||
return defaultSupportedSignatureAlgorithms
|
|
||||||
}
|
|
||||||
|
|
||||||
func fipsMinVersion(c *config) uint16 { panic("fipsMinVersion") }
|
|
||||||
func fipsMaxVersion(c *config) uint16 { panic("fipsMaxVersion") }
|
|
||||||
func fipsCurvePreferences(c *config) []CurveID { panic("fipsCurvePreferences") }
|
|
||||||
func fipsCipherSuites(c *config) []uint16 { panic("fipsCipherSuites") }
|
|
||||||
|
|
||||||
var fipsSupportedSignatureAlgorithms []SignatureScheme
|
|
|
@ -1,283 +0,0 @@
|
||||||
// Copyright 2009 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/md5"
|
|
||||||
"crypto/sha1"
|
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/sha512"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"hash"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Split a premaster secret in two as specified in RFC 4346, Section 5.
|
|
||||||
func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
|
|
||||||
s1 = secret[0 : (len(secret)+1)/2]
|
|
||||||
s2 = secret[len(secret)/2:]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// pHash implements the P_hash function, as defined in RFC 4346, Section 5.
|
|
||||||
func pHash(result, secret, seed []byte, hash func() hash.Hash) {
|
|
||||||
h := hmac.New(hash, secret)
|
|
||||||
h.Write(seed)
|
|
||||||
a := h.Sum(nil)
|
|
||||||
|
|
||||||
j := 0
|
|
||||||
for j < len(result) {
|
|
||||||
h.Reset()
|
|
||||||
h.Write(a)
|
|
||||||
h.Write(seed)
|
|
||||||
b := h.Sum(nil)
|
|
||||||
copy(result[j:], b)
|
|
||||||
j += len(b)
|
|
||||||
|
|
||||||
h.Reset()
|
|
||||||
h.Write(a)
|
|
||||||
a = h.Sum(nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5.
|
|
||||||
func prf10(result, secret, label, seed []byte) {
|
|
||||||
hashSHA1 := sha1.New
|
|
||||||
hashMD5 := md5.New
|
|
||||||
|
|
||||||
labelAndSeed := make([]byte, len(label)+len(seed))
|
|
||||||
copy(labelAndSeed, label)
|
|
||||||
copy(labelAndSeed[len(label):], seed)
|
|
||||||
|
|
||||||
s1, s2 := splitPreMasterSecret(secret)
|
|
||||||
pHash(result, s1, labelAndSeed, hashMD5)
|
|
||||||
result2 := make([]byte, len(result))
|
|
||||||
pHash(result2, s2, labelAndSeed, hashSHA1)
|
|
||||||
|
|
||||||
for i, b := range result2 {
|
|
||||||
result[i] ^= b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5.
|
|
||||||
func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) {
|
|
||||||
return func(result, secret, label, seed []byte) {
|
|
||||||
labelAndSeed := make([]byte, len(label)+len(seed))
|
|
||||||
copy(labelAndSeed, label)
|
|
||||||
copy(labelAndSeed[len(label):], seed)
|
|
||||||
|
|
||||||
pHash(result, secret, labelAndSeed, hashFunc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
masterSecretLength = 48 // Length of a master secret in TLS 1.1.
|
|
||||||
finishedVerifyLength = 12 // Length of verify_data in a Finished message.
|
|
||||||
)
|
|
||||||
|
|
||||||
var masterSecretLabel = []byte("master secret")
|
|
||||||
var keyExpansionLabel = []byte("key expansion")
|
|
||||||
var clientFinishedLabel = []byte("client finished")
|
|
||||||
var serverFinishedLabel = []byte("server finished")
|
|
||||||
|
|
||||||
func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) {
|
|
||||||
switch version {
|
|
||||||
case VersionTLS10, VersionTLS11:
|
|
||||||
return prf10, crypto.Hash(0)
|
|
||||||
case VersionTLS12:
|
|
||||||
if suite.flags&suiteSHA384 != 0 {
|
|
||||||
return prf12(sha512.New384), crypto.SHA384
|
|
||||||
}
|
|
||||||
return prf12(sha256.New), crypto.SHA256
|
|
||||||
default:
|
|
||||||
panic("unknown version")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) {
|
|
||||||
prf, _ := prfAndHashForVersion(version, suite)
|
|
||||||
return prf
|
|
||||||
}
|
|
||||||
|
|
||||||
// masterFromPreMasterSecret generates the master secret from the pre-master
|
|
||||||
// secret. See RFC 5246, Section 8.1.
|
|
||||||
func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte {
|
|
||||||
seed := make([]byte, 0, len(clientRandom)+len(serverRandom))
|
|
||||||
seed = append(seed, clientRandom...)
|
|
||||||
seed = append(seed, serverRandom...)
|
|
||||||
|
|
||||||
masterSecret := make([]byte, masterSecretLength)
|
|
||||||
prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed)
|
|
||||||
return masterSecret
|
|
||||||
}
|
|
||||||
|
|
||||||
// keysFromMasterSecret generates the connection keys from the master
|
|
||||||
// secret, given the lengths of the MAC key, cipher key and IV, as defined in
|
|
||||||
// RFC 2246, Section 6.3.
|
|
||||||
func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
|
|
||||||
seed := make([]byte, 0, len(serverRandom)+len(clientRandom))
|
|
||||||
seed = append(seed, serverRandom...)
|
|
||||||
seed = append(seed, clientRandom...)
|
|
||||||
|
|
||||||
n := 2*macLen + 2*keyLen + 2*ivLen
|
|
||||||
keyMaterial := make([]byte, n)
|
|
||||||
prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed)
|
|
||||||
clientMAC = keyMaterial[:macLen]
|
|
||||||
keyMaterial = keyMaterial[macLen:]
|
|
||||||
serverMAC = keyMaterial[:macLen]
|
|
||||||
keyMaterial = keyMaterial[macLen:]
|
|
||||||
clientKey = keyMaterial[:keyLen]
|
|
||||||
keyMaterial = keyMaterial[keyLen:]
|
|
||||||
serverKey = keyMaterial[:keyLen]
|
|
||||||
keyMaterial = keyMaterial[keyLen:]
|
|
||||||
clientIV = keyMaterial[:ivLen]
|
|
||||||
keyMaterial = keyMaterial[ivLen:]
|
|
||||||
serverIV = keyMaterial[:ivLen]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash {
|
|
||||||
var buffer []byte
|
|
||||||
if version >= VersionTLS12 {
|
|
||||||
buffer = []byte{}
|
|
||||||
}
|
|
||||||
|
|
||||||
prf, hash := prfAndHashForVersion(version, cipherSuite)
|
|
||||||
if hash != 0 {
|
|
||||||
return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf}
|
|
||||||
}
|
|
||||||
|
|
||||||
return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A finishedHash calculates the hash of a set of handshake messages suitable
|
|
||||||
// for including in a Finished message.
|
|
||||||
type finishedHash struct {
|
|
||||||
client hash.Hash
|
|
||||||
server hash.Hash
|
|
||||||
|
|
||||||
// Prior to TLS 1.2, an additional MD5 hash is required.
|
|
||||||
clientMD5 hash.Hash
|
|
||||||
serverMD5 hash.Hash
|
|
||||||
|
|
||||||
// In TLS 1.2, a full buffer is sadly required.
|
|
||||||
buffer []byte
|
|
||||||
|
|
||||||
version uint16
|
|
||||||
prf func(result, secret, label, seed []byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *finishedHash) Write(msg []byte) (n int, err error) {
|
|
||||||
h.client.Write(msg)
|
|
||||||
h.server.Write(msg)
|
|
||||||
|
|
||||||
if h.version < VersionTLS12 {
|
|
||||||
h.clientMD5.Write(msg)
|
|
||||||
h.serverMD5.Write(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.buffer != nil {
|
|
||||||
h.buffer = append(h.buffer, msg...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(msg), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h finishedHash) Sum() []byte {
|
|
||||||
if h.version >= VersionTLS12 {
|
|
||||||
return h.client.Sum(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
out := make([]byte, 0, md5.Size+sha1.Size)
|
|
||||||
out = h.clientMD5.Sum(out)
|
|
||||||
return h.client.Sum(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientSum returns the contents of the verify_data member of a client's
|
|
||||||
// Finished message.
|
|
||||||
func (h finishedHash) clientSum(masterSecret []byte) []byte {
|
|
||||||
out := make([]byte, finishedVerifyLength)
|
|
||||||
h.prf(out, masterSecret, clientFinishedLabel, h.Sum())
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// serverSum returns the contents of the verify_data member of a server's
|
|
||||||
// Finished message.
|
|
||||||
func (h finishedHash) serverSum(masterSecret []byte) []byte {
|
|
||||||
out := make([]byte, finishedVerifyLength)
|
|
||||||
h.prf(out, masterSecret, serverFinishedLabel, h.Sum())
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// hashForClientCertificate returns the handshake messages so far, pre-hashed if
|
|
||||||
// necessary, suitable for signing by a TLS client certificate.
|
|
||||||
func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash) []byte {
|
|
||||||
if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil {
|
|
||||||
panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer")
|
|
||||||
}
|
|
||||||
|
|
||||||
if sigType == signatureEd25519 {
|
|
||||||
return h.buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.version >= VersionTLS12 {
|
|
||||||
hash := hashAlg.New()
|
|
||||||
hash.Write(h.buffer)
|
|
||||||
return hash.Sum(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sigType == signatureECDSA {
|
|
||||||
return h.server.Sum(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
return h.Sum()
|
|
||||||
}
|
|
||||||
|
|
||||||
// discardHandshakeBuffer is called when there is no more need to
|
|
||||||
// buffer the entirety of the handshake messages.
|
|
||||||
func (h *finishedHash) discardHandshakeBuffer() {
|
|
||||||
h.buffer = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// noExportedKeyingMaterial is used as a value of
|
|
||||||
// ConnectionState.ekm when renegotiation is enabled and thus
|
|
||||||
// we wish to fail all key-material export requests.
|
|
||||||
func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
|
|
||||||
return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ekmFromMasterSecret generates exported keying material as defined in RFC 5705.
|
|
||||||
func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) {
|
|
||||||
return func(label string, context []byte, length int) ([]byte, error) {
|
|
||||||
switch label {
|
|
||||||
case "client finished", "server finished", "master secret", "key expansion":
|
|
||||||
// These values are reserved and may not be used.
|
|
||||||
return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label)
|
|
||||||
}
|
|
||||||
|
|
||||||
seedLen := len(serverRandom) + len(clientRandom)
|
|
||||||
if context != nil {
|
|
||||||
seedLen += 2 + len(context)
|
|
||||||
}
|
|
||||||
seed := make([]byte, 0, seedLen)
|
|
||||||
|
|
||||||
seed = append(seed, clientRandom...)
|
|
||||||
seed = append(seed, serverRandom...)
|
|
||||||
|
|
||||||
if context != nil {
|
|
||||||
if len(context) >= 1<<16 {
|
|
||||||
return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long")
|
|
||||||
}
|
|
||||||
seed = append(seed, byte(len(context)>>8), byte(len(context)))
|
|
||||||
seed = append(seed, context...)
|
|
||||||
}
|
|
||||||
|
|
||||||
keyMaterial := make([]byte, length)
|
|
||||||
prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed)
|
|
||||||
return keyMaterial, nil
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,418 +0,0 @@
|
||||||
// Copyright 2023 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// QUICEncryptionLevel represents a QUIC encryption level used to transmit
|
|
||||||
// handshake messages.
|
|
||||||
type QUICEncryptionLevel int
|
|
||||||
|
|
||||||
const (
|
|
||||||
QUICEncryptionLevelInitial = QUICEncryptionLevel(iota)
|
|
||||||
QUICEncryptionLevelEarly
|
|
||||||
QUICEncryptionLevelHandshake
|
|
||||||
QUICEncryptionLevelApplication
|
|
||||||
)
|
|
||||||
|
|
||||||
func (l QUICEncryptionLevel) String() string {
|
|
||||||
switch l {
|
|
||||||
case QUICEncryptionLevelInitial:
|
|
||||||
return "Initial"
|
|
||||||
case QUICEncryptionLevelEarly:
|
|
||||||
return "Early"
|
|
||||||
case QUICEncryptionLevelHandshake:
|
|
||||||
return "Handshake"
|
|
||||||
case QUICEncryptionLevelApplication:
|
|
||||||
return "Application"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A QUICConn represents a connection which uses a QUIC implementation as the underlying
|
|
||||||
// transport as described in RFC 9001.
|
|
||||||
//
|
|
||||||
// Methods of QUICConn are not safe for concurrent use.
|
|
||||||
type QUICConn struct {
|
|
||||||
conn *Conn
|
|
||||||
|
|
||||||
sessionTicketSent bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// A QUICConfig configures a QUICConn.
|
|
||||||
type QUICConfig struct {
|
|
||||||
TLSConfig *Config
|
|
||||||
ExtraConfig *ExtraConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// A QUICEventKind is a type of operation on a QUIC connection.
|
|
||||||
type QUICEventKind int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// QUICNoEvent indicates that there are no events available.
|
|
||||||
QUICNoEvent QUICEventKind = iota
|
|
||||||
|
|
||||||
// QUICSetReadSecret and QUICSetWriteSecret provide the read and write
|
|
||||||
// secrets for a given encryption level.
|
|
||||||
// QUICEvent.Level, QUICEvent.Data, and QUICEvent.Suite are set.
|
|
||||||
//
|
|
||||||
// Secrets for the Initial encryption level are derived from the initial
|
|
||||||
// destination connection ID, and are not provided by the QUICConn.
|
|
||||||
QUICSetReadSecret
|
|
||||||
QUICSetWriteSecret
|
|
||||||
|
|
||||||
// QUICWriteData provides data to send to the peer in CRYPTO frames.
|
|
||||||
// QUICEvent.Data is set.
|
|
||||||
QUICWriteData
|
|
||||||
|
|
||||||
// QUICTransportParameters provides the peer's QUIC transport parameters.
|
|
||||||
// QUICEvent.Data is set.
|
|
||||||
QUICTransportParameters
|
|
||||||
|
|
||||||
// QUICTransportParametersRequired indicates that the caller must provide
|
|
||||||
// QUIC transport parameters to send to the peer. The caller should set
|
|
||||||
// the transport parameters with QUICConn.SetTransportParameters and call
|
|
||||||
// QUICConn.NextEvent again.
|
|
||||||
//
|
|
||||||
// If transport parameters are set before calling QUICConn.Start, the
|
|
||||||
// connection will never generate a QUICTransportParametersRequired event.
|
|
||||||
QUICTransportParametersRequired
|
|
||||||
|
|
||||||
// QUICRejectedEarlyData indicates that the server rejected 0-RTT data even
|
|
||||||
// if we offered it. It's returned before QUICEncryptionLevelApplication
|
|
||||||
// keys are returned.
|
|
||||||
QUICRejectedEarlyData
|
|
||||||
|
|
||||||
// QUICHandshakeDone indicates that the TLS handshake has completed.
|
|
||||||
QUICHandshakeDone
|
|
||||||
)
|
|
||||||
|
|
||||||
// A QUICEvent is an event occurring on a QUIC connection.
|
|
||||||
//
|
|
||||||
// The type of event is specified by the Kind field.
|
|
||||||
// The contents of the other fields are kind-specific.
|
|
||||||
type QUICEvent struct {
|
|
||||||
Kind QUICEventKind
|
|
||||||
|
|
||||||
// Set for QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
|
|
||||||
Level QUICEncryptionLevel
|
|
||||||
|
|
||||||
// Set for QUICTransportParameters, QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
|
|
||||||
// The contents are owned by crypto/tls, and are valid until the next NextEvent call.
|
|
||||||
Data []byte
|
|
||||||
|
|
||||||
// Set for QUICSetReadSecret and QUICSetWriteSecret.
|
|
||||||
Suite uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
type quicState struct {
|
|
||||||
events []QUICEvent
|
|
||||||
nextEvent int
|
|
||||||
|
|
||||||
// eventArr is a statically allocated event array, large enough to handle
|
|
||||||
// the usual maximum number of events resulting from a single call: transport
|
|
||||||
// parameters, Initial data, Early read secret, Handshake write and read
|
|
||||||
// secrets, Handshake data, Application write secret, Application data.
|
|
||||||
eventArr [8]QUICEvent
|
|
||||||
|
|
||||||
started bool
|
|
||||||
signalc chan struct{} // handshake data is available to be read
|
|
||||||
blockedc chan struct{} // handshake is waiting for data, closed when done
|
|
||||||
cancelc <-chan struct{} // handshake has been canceled
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
// readbuf is shared between HandleData and the handshake goroutine.
|
|
||||||
// HandshakeCryptoData passes ownership to the handshake goroutine by
|
|
||||||
// reading from signalc, and reclaims ownership by reading from blockedc.
|
|
||||||
readbuf []byte
|
|
||||||
|
|
||||||
transportParams []byte // to send to the peer
|
|
||||||
}
|
|
||||||
|
|
||||||
// QUICClient returns a new TLS client side connection using QUICTransport as the
|
|
||||||
// underlying transport. The config cannot be nil.
|
|
||||||
//
|
|
||||||
// The config's MinVersion must be at least TLS 1.3.
|
|
||||||
func QUICClient(config *QUICConfig) *QUICConn {
|
|
||||||
return newQUICConn(Client(nil, config.TLSConfig), config.ExtraConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QUICServer returns a new TLS server side connection using QUICTransport as the
|
|
||||||
// underlying transport. The config cannot be nil.
|
|
||||||
//
|
|
||||||
// The config's MinVersion must be at least TLS 1.3.
|
|
||||||
func QUICServer(config *QUICConfig) *QUICConn {
|
|
||||||
return newQUICConn(Server(nil, config.TLSConfig), config.ExtraConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newQUICConn(conn *Conn, extraConfig *ExtraConfig) *QUICConn {
|
|
||||||
conn.quic = &quicState{
|
|
||||||
signalc: make(chan struct{}),
|
|
||||||
blockedc: make(chan struct{}),
|
|
||||||
}
|
|
||||||
conn.quic.events = conn.quic.eventArr[:0]
|
|
||||||
conn.extraConfig = extraConfig
|
|
||||||
return &QUICConn{
|
|
||||||
conn: conn,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start starts the client or server handshake protocol.
|
|
||||||
// It may produce connection events, which may be read with NextEvent.
|
|
||||||
//
|
|
||||||
// Start must be called at most once.
|
|
||||||
func (q *QUICConn) Start(ctx context.Context) error {
|
|
||||||
if q.conn.quic.started {
|
|
||||||
return quicError(errors.New("tls: Start called more than once"))
|
|
||||||
}
|
|
||||||
q.conn.quic.started = true
|
|
||||||
if q.conn.config.MinVersion < VersionTLS13 {
|
|
||||||
return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.13"))
|
|
||||||
}
|
|
||||||
go q.conn.HandshakeContext(ctx)
|
|
||||||
if _, ok := <-q.conn.quic.blockedc; !ok {
|
|
||||||
return q.conn.handshakeErr
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NextEvent returns the next event occurring on the connection.
|
|
||||||
// It returns an event with a Kind of QUICNoEvent when no events are available.
|
|
||||||
func (q *QUICConn) NextEvent() QUICEvent {
|
|
||||||
qs := q.conn.quic
|
|
||||||
if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
|
|
||||||
// Write over some of the previous event's data,
|
|
||||||
// to catch callers erroniously retaining it.
|
|
||||||
qs.events[last].Data[0] = 0
|
|
||||||
}
|
|
||||||
if qs.nextEvent >= len(qs.events) {
|
|
||||||
qs.events = qs.events[:0]
|
|
||||||
qs.nextEvent = 0
|
|
||||||
return QUICEvent{Kind: QUICNoEvent}
|
|
||||||
}
|
|
||||||
e := qs.events[qs.nextEvent]
|
|
||||||
qs.events[qs.nextEvent] = QUICEvent{} // zero out references to data
|
|
||||||
qs.nextEvent++
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the connection and stops any in-progress handshake.
|
|
||||||
func (q *QUICConn) Close() error {
|
|
||||||
if q.conn.quic.cancel == nil {
|
|
||||||
return nil // never started
|
|
||||||
}
|
|
||||||
q.conn.quic.cancel()
|
|
||||||
for range q.conn.quic.blockedc {
|
|
||||||
// Wait for the handshake goroutine to return.
|
|
||||||
}
|
|
||||||
return q.conn.handshakeErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleData handles handshake bytes received from the peer.
|
|
||||||
// It may produce connection events, which may be read with NextEvent.
|
|
||||||
func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
|
|
||||||
c := q.conn
|
|
||||||
if c.in.level != level {
|
|
||||||
return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level")))
|
|
||||||
}
|
|
||||||
c.quic.readbuf = data
|
|
||||||
<-c.quic.signalc
|
|
||||||
_, ok := <-c.quic.blockedc
|
|
||||||
if ok {
|
|
||||||
// The handshake goroutine is waiting for more data.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// The handshake goroutine has exited.
|
|
||||||
c.handshakeMutex.Lock()
|
|
||||||
defer c.handshakeMutex.Unlock()
|
|
||||||
c.hand.Write(c.quic.readbuf)
|
|
||||||
c.quic.readbuf = nil
|
|
||||||
for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
|
|
||||||
b := q.conn.hand.Bytes()
|
|
||||||
n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
|
|
||||||
if n > maxHandshake {
|
|
||||||
q.conn.handshakeErr = fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if len(b) < 4+n {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := q.conn.handlePostHandshakeMessage(); err != nil {
|
|
||||||
q.conn.handshakeErr = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if q.conn.handshakeErr != nil {
|
|
||||||
return quicError(q.conn.handshakeErr)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendSessionTicket sends a session ticket to the client.
|
|
||||||
// It produces connection events, which may be read with NextEvent.
|
|
||||||
// Currently, it can only be called once.
|
|
||||||
func (q *QUICConn) SendSessionTicket(earlyData bool) error {
|
|
||||||
c := q.conn
|
|
||||||
if !c.isHandshakeComplete.Load() {
|
|
||||||
return quicError(errors.New("tls: SendSessionTicket called before handshake completed"))
|
|
||||||
}
|
|
||||||
if c.isClient {
|
|
||||||
return quicError(errors.New("tls: SendSessionTicket called on the client"))
|
|
||||||
}
|
|
||||||
if q.sessionTicketSent {
|
|
||||||
return quicError(errors.New("tls: SendSessionTicket called multiple times"))
|
|
||||||
}
|
|
||||||
q.sessionTicketSent = true
|
|
||||||
return quicError(c.sendSessionTicket(earlyData))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectionState returns basic TLS details about the connection.
|
|
||||||
func (q *QUICConn) ConnectionState() ConnectionState {
|
|
||||||
return q.conn.ConnectionState()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTransportParameters sets the transport parameters to send to the peer.
|
|
||||||
//
|
|
||||||
// Server connections may delay setting the transport parameters until after
|
|
||||||
// receiving the client's transport parameters. See QUICTransportParametersRequired.
|
|
||||||
func (q *QUICConn) SetTransportParameters(params []byte) {
|
|
||||||
if params == nil {
|
|
||||||
params = []byte{}
|
|
||||||
}
|
|
||||||
q.conn.quic.transportParams = params
|
|
||||||
if q.conn.quic.started {
|
|
||||||
<-q.conn.quic.signalc
|
|
||||||
<-q.conn.quic.blockedc
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// quicError ensures err is an AlertError.
|
|
||||||
// If err is not already, quicError wraps it with alertInternalError.
|
|
||||||
func quicError(err error) error {
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var ae AlertError
|
|
||||||
if errors.As(err, &ae) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var a alert
|
|
||||||
if !errors.As(err, &a) {
|
|
||||||
a = alertInternalError
|
|
||||||
}
|
|
||||||
// Return an error wrapping the original error and an AlertError.
|
|
||||||
// Truncate the text of the alert to 0 characters.
|
|
||||||
return fmt.Errorf("%w%.0w", err, AlertError(a))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) quicReadHandshakeBytes(n int) error {
|
|
||||||
for c.hand.Len() < n {
|
|
||||||
if err := c.quicWaitForSignal(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
|
|
||||||
c.quic.events = append(c.quic.events, QUICEvent{
|
|
||||||
Kind: QUICSetReadSecret,
|
|
||||||
Level: level,
|
|
||||||
Suite: suite,
|
|
||||||
Data: secret,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
|
|
||||||
c.quic.events = append(c.quic.events, QUICEvent{
|
|
||||||
Kind: QUICSetWriteSecret,
|
|
||||||
Level: level,
|
|
||||||
Suite: suite,
|
|
||||||
Data: secret,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
|
|
||||||
var last *QUICEvent
|
|
||||||
if len(c.quic.events) > 0 {
|
|
||||||
last = &c.quic.events[len(c.quic.events)-1]
|
|
||||||
}
|
|
||||||
if last == nil || last.Kind != QUICWriteData || last.Level != level {
|
|
||||||
c.quic.events = append(c.quic.events, QUICEvent{
|
|
||||||
Kind: QUICWriteData,
|
|
||||||
Level: level,
|
|
||||||
})
|
|
||||||
last = &c.quic.events[len(c.quic.events)-1]
|
|
||||||
}
|
|
||||||
last.Data = append(last.Data, data...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) quicSetTransportParameters(params []byte) {
|
|
||||||
c.quic.events = append(c.quic.events, QUICEvent{
|
|
||||||
Kind: QUICTransportParameters,
|
|
||||||
Data: params,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) quicGetTransportParameters() ([]byte, error) {
|
|
||||||
if c.quic.transportParams == nil {
|
|
||||||
c.quic.events = append(c.quic.events, QUICEvent{
|
|
||||||
Kind: QUICTransportParametersRequired,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
for c.quic.transportParams == nil {
|
|
||||||
if err := c.quicWaitForSignal(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return c.quic.transportParams, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) quicHandshakeComplete() {
|
|
||||||
c.quic.events = append(c.quic.events, QUICEvent{
|
|
||||||
Kind: QUICHandshakeDone,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) quicRejectedEarlyData() {
|
|
||||||
c.quic.events = append(c.quic.events, QUICEvent{
|
|
||||||
Kind: QUICRejectedEarlyData,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// quicWaitForSignal notifies the QUICConn that handshake progress is blocked,
|
|
||||||
// and waits for a signal that the handshake should proceed.
|
|
||||||
//
|
|
||||||
// The handshake may become blocked waiting for handshake bytes
|
|
||||||
// or for the user to provide transport parameters.
|
|
||||||
func (c *Conn) quicWaitForSignal() error {
|
|
||||||
// Drop the handshake mutex while blocked to allow the user
|
|
||||||
// to call ConnectionState before the handshake completes.
|
|
||||||
c.handshakeMutex.Unlock()
|
|
||||||
defer c.handshakeMutex.Lock()
|
|
||||||
// Send on blockedc to notify the QUICConn that the handshake is blocked.
|
|
||||||
// Exported methods of QUICConn wait for the handshake to become blocked
|
|
||||||
// before returning to the user.
|
|
||||||
select {
|
|
||||||
case c.quic.blockedc <- struct{}{}:
|
|
||||||
case <-c.quic.cancelc:
|
|
||||||
return c.sendAlertLocked(alertCloseNotify)
|
|
||||||
}
|
|
||||||
// The QUICConn reads from signalc to notify us that the handshake may
|
|
||||||
// be able to proceed. (The QUICConn reads, because we close signalc to
|
|
||||||
// indicate that the handshake has completed.)
|
|
||||||
select {
|
|
||||||
case c.quic.signalc <- struct{}{}:
|
|
||||||
c.hand.Write(c.quic.readbuf)
|
|
||||||
c.quic.readbuf = nil
|
|
||||||
case <-c.quic.cancelc:
|
|
||||||
return c.sendAlertLocked(alertCloseNotify)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,203 +0,0 @@
|
||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/aes"
|
|
||||||
"crypto/cipher"
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/subtle"
|
|
||||||
"errors"
|
|
||||||
"golang.org/x/crypto/cryptobyte"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// sessionState contains the information that is serialized into a session
|
|
||||||
// ticket in order to later resume a connection.
|
|
||||||
type sessionState struct {
|
|
||||||
vers uint16
|
|
||||||
cipherSuite uint16
|
|
||||||
createdAt uint64
|
|
||||||
masterSecret []byte // opaque master_secret<1..2^16-1>;
|
|
||||||
// struct { opaque certificate<1..2^24-1> } Certificate;
|
|
||||||
certificates [][]byte // Certificate certificate_list<0..2^24-1>;
|
|
||||||
|
|
||||||
// usedOldKey is true if the ticket from which this session came from
|
|
||||||
// was encrypted with an older key and thus should be refreshed.
|
|
||||||
usedOldKey bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *sessionState) marshal() ([]byte, error) {
|
|
||||||
var b cryptobyte.Builder
|
|
||||||
b.AddUint16(m.vers)
|
|
||||||
b.AddUint16(m.cipherSuite)
|
|
||||||
addUint64(&b, m.createdAt)
|
|
||||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
||||||
b.AddBytes(m.masterSecret)
|
|
||||||
})
|
|
||||||
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
||||||
for _, cert := range m.certificates {
|
|
||||||
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
||||||
b.AddBytes(cert)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return b.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *sessionState) unmarshal(data []byte) bool {
|
|
||||||
*m = sessionState{usedOldKey: m.usedOldKey}
|
|
||||||
s := cryptobyte.String(data)
|
|
||||||
if ok := s.ReadUint16(&m.vers) &&
|
|
||||||
s.ReadUint16(&m.cipherSuite) &&
|
|
||||||
readUint64(&s, &m.createdAt) &&
|
|
||||||
readUint16LengthPrefixed(&s, &m.masterSecret) &&
|
|
||||||
len(m.masterSecret) != 0; !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
var certList cryptobyte.String
|
|
||||||
if !s.ReadUint24LengthPrefixed(&certList) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for !certList.Empty() {
|
|
||||||
var cert []byte
|
|
||||||
if !readUint24LengthPrefixed(&certList, &cert) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
m.certificates = append(m.certificates, cert)
|
|
||||||
}
|
|
||||||
return s.Empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first
|
|
||||||
// version (revision = 0) doesn't carry any of the information needed for 0-RTT
|
|
||||||
// validation and the nonce is always empty.
|
|
||||||
// version (revision = 1) carries the max_early_data_size sent in the ticket.
|
|
||||||
// version (revision = 2) carries the ALPN sent in the ticket.
|
|
||||||
type sessionStateTLS13 struct {
|
|
||||||
// uint8 version = 0x0304;
|
|
||||||
// uint8 revision = 2;
|
|
||||||
cipherSuite uint16
|
|
||||||
createdAt uint64
|
|
||||||
resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>;
|
|
||||||
certificate Certificate // CertificateEntry certificate_list<0..2^24-1>;
|
|
||||||
maxEarlyData uint32
|
|
||||||
alpn string
|
|
||||||
|
|
||||||
appData []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *sessionStateTLS13) marshal() ([]byte, error) {
|
|
||||||
var b cryptobyte.Builder
|
|
||||||
b.AddUint16(VersionTLS13)
|
|
||||||
b.AddUint8(2) // revision
|
|
||||||
b.AddUint16(m.cipherSuite)
|
|
||||||
addUint64(&b, m.createdAt)
|
|
||||||
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
||||||
b.AddBytes(m.resumptionSecret)
|
|
||||||
})
|
|
||||||
marshalCertificate(&b, m.certificate)
|
|
||||||
b.AddUint32(m.maxEarlyData)
|
|
||||||
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
||||||
b.AddBytes([]byte(m.alpn))
|
|
||||||
})
|
|
||||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
|
||||||
b.AddBytes(m.appData)
|
|
||||||
})
|
|
||||||
return b.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *sessionStateTLS13) unmarshal(data []byte) bool {
|
|
||||||
*m = sessionStateTLS13{}
|
|
||||||
s := cryptobyte.String(data)
|
|
||||||
var version uint16
|
|
||||||
var revision uint8
|
|
||||||
var alpn []byte
|
|
||||||
ret := s.ReadUint16(&version) &&
|
|
||||||
version == VersionTLS13 &&
|
|
||||||
s.ReadUint8(&revision) &&
|
|
||||||
revision == 2 &&
|
|
||||||
s.ReadUint16(&m.cipherSuite) &&
|
|
||||||
readUint64(&s, &m.createdAt) &&
|
|
||||||
readUint8LengthPrefixed(&s, &m.resumptionSecret) &&
|
|
||||||
len(m.resumptionSecret) != 0 &&
|
|
||||||
unmarshalCertificate(&s, &m.certificate) &&
|
|
||||||
s.ReadUint32(&m.maxEarlyData) &&
|
|
||||||
readUint8LengthPrefixed(&s, &alpn) &&
|
|
||||||
readUint16LengthPrefixed(&s, &m.appData) &&
|
|
||||||
s.Empty()
|
|
||||||
m.alpn = string(alpn)
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
|
|
||||||
if len(c.ticketKeys) == 0 {
|
|
||||||
return nil, errors.New("tls: internal error: session ticket keys unavailable")
|
|
||||||
}
|
|
||||||
|
|
||||||
encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size)
|
|
||||||
keyName := encrypted[:ticketKeyNameLen]
|
|
||||||
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
|
|
||||||
macBytes := encrypted[len(encrypted)-sha256.Size:]
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
key := c.ticketKeys[0]
|
|
||||||
copy(keyName, key.keyName[:])
|
|
||||||
block, err := aes.NewCipher(key.aesKey[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
|
|
||||||
}
|
|
||||||
cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state)
|
|
||||||
|
|
||||||
mac := hmac.New(sha256.New, key.hmacKey[:])
|
|
||||||
mac.Write(encrypted[:len(encrypted)-sha256.Size])
|
|
||||||
mac.Sum(macBytes[:0])
|
|
||||||
|
|
||||||
return encrypted, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) {
|
|
||||||
if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
keyName := encrypted[:ticketKeyNameLen]
|
|
||||||
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
|
|
||||||
macBytes := encrypted[len(encrypted)-sha256.Size:]
|
|
||||||
ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
|
|
||||||
|
|
||||||
keyIndex := -1
|
|
||||||
for i, candidateKey := range c.ticketKeys {
|
|
||||||
if bytes.Equal(keyName, candidateKey.keyName[:]) {
|
|
||||||
keyIndex = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if keyIndex == -1 {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
key := &c.ticketKeys[keyIndex]
|
|
||||||
|
|
||||||
mac := hmac.New(sha256.New, key.hmacKey[:])
|
|
||||||
mac.Write(encrypted[:len(encrypted)-sha256.Size])
|
|
||||||
expected := mac.Sum(nil)
|
|
||||||
|
|
||||||
if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
block, err := aes.NewCipher(key.aesKey[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
plaintext = make([]byte, len(ciphertext))
|
|
||||||
cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
|
|
||||||
|
|
||||||
return plaintext, keyIndex > 0
|
|
||||||
}
|
|
|
@ -1,356 +0,0 @@
|
||||||
// Copyright 2009 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// package qtls partially implements TLS 1.2, as specified in RFC 5246,
|
|
||||||
// and TLS 1.3, as specified in RFC 8446.
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
// BUG(agl): The crypto/tls package only implements some countermeasures
|
|
||||||
// against Lucky13 attacks on CBC-mode encryption, and only on SHA1
|
|
||||||
// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and
|
|
||||||
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server returns a new TLS server side connection
|
|
||||||
// using conn as the underlying transport.
|
|
||||||
// The configuration config must be non-nil and must include
|
|
||||||
// at least one certificate or else set GetCertificate.
|
|
||||||
func Server(conn net.Conn, config *Config) *Conn {
|
|
||||||
c := &Conn{
|
|
||||||
conn: conn,
|
|
||||||
config: fromConfig(config),
|
|
||||||
}
|
|
||||||
c.handshakeFn = c.serverHandshake
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client returns a new TLS client side connection
|
|
||||||
// using conn as the underlying transport.
|
|
||||||
// The config cannot be nil: users must set either ServerName or
|
|
||||||
// InsecureSkipVerify in the config.
|
|
||||||
func Client(conn net.Conn, config *Config) *Conn {
|
|
||||||
c := &Conn{
|
|
||||||
conn: conn,
|
|
||||||
config: fromConfig(config),
|
|
||||||
isClient: true,
|
|
||||||
}
|
|
||||||
c.handshakeFn = c.clientHandshake
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// A listener implements a network listener (net.Listener) for TLS connections.
|
|
||||||
type listener struct {
|
|
||||||
net.Listener
|
|
||||||
config *Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accept waits for and returns the next incoming TLS connection.
|
|
||||||
// The returned connection is of type *Conn.
|
|
||||||
func (l *listener) Accept() (net.Conn, error) {
|
|
||||||
c, err := l.Listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return Server(c, l.config), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewListener creates a Listener which accepts connections from an inner
|
|
||||||
// Listener and wraps each connection with Server.
|
|
||||||
// The configuration config must be non-nil and must include
|
|
||||||
// at least one certificate or else set GetCertificate.
|
|
||||||
func NewListener(inner net.Listener, config *Config) net.Listener {
|
|
||||||
l := new(listener)
|
|
||||||
l.Listener = inner
|
|
||||||
l.config = config
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
// Listen creates a TLS listener accepting connections on the
|
|
||||||
// given network address using net.Listen.
|
|
||||||
// The configuration config must be non-nil and must include
|
|
||||||
// at least one certificate or else set GetCertificate.
|
|
||||||
func Listen(network, laddr string, config *Config) (net.Listener, error) {
|
|
||||||
if config == nil || len(config.Certificates) == 0 &&
|
|
||||||
config.GetCertificate == nil && config.GetConfigForClient == nil {
|
|
||||||
return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
|
|
||||||
}
|
|
||||||
l, err := net.Listen(network, laddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return NewListener(l, config), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type timeoutError struct{}
|
|
||||||
|
|
||||||
func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
|
|
||||||
func (timeoutError) Timeout() bool { return true }
|
|
||||||
func (timeoutError) Temporary() bool { return true }
|
|
||||||
|
|
||||||
// DialWithDialer connects to the given network address using dialer.Dial and
|
|
||||||
// then initiates a TLS handshake, returning the resulting TLS connection. Any
|
|
||||||
// timeout or deadline given in the dialer apply to connection and TLS
|
|
||||||
// handshake as a whole.
|
|
||||||
//
|
|
||||||
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
|
||||||
// configuration; see the documentation of Config for the defaults.
|
|
||||||
//
|
|
||||||
// DialWithDialer uses context.Background internally; to specify the context,
|
|
||||||
// use Dialer.DialContext with NetDialer set to the desired dialer.
|
|
||||||
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
|
||||||
return dial(context.Background(), dialer, network, addr, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
|
||||||
if netDialer.Timeout != 0 {
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
|
|
||||||
defer cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
if !netDialer.Deadline.IsZero() {
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
|
|
||||||
defer cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
rawConn, err := netDialer.DialContext(ctx, network, addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
colonPos := strings.LastIndex(addr, ":")
|
|
||||||
if colonPos == -1 {
|
|
||||||
colonPos = len(addr)
|
|
||||||
}
|
|
||||||
hostname := addr[:colonPos]
|
|
||||||
|
|
||||||
if config == nil {
|
|
||||||
config = defaultConfig()
|
|
||||||
}
|
|
||||||
// If no ServerName is set, infer the ServerName
|
|
||||||
// from the hostname we're connecting to.
|
|
||||||
if config.ServerName == "" {
|
|
||||||
// Make a copy to avoid polluting argument or default.
|
|
||||||
c := config.Clone()
|
|
||||||
c.ServerName = hostname
|
|
||||||
config = c
|
|
||||||
}
|
|
||||||
|
|
||||||
conn := Client(rawConn, config)
|
|
||||||
if err := conn.HandshakeContext(ctx); err != nil {
|
|
||||||
rawConn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial connects to the given network address using net.Dial
|
|
||||||
// and then initiates a TLS handshake, returning the resulting
|
|
||||||
// TLS connection.
|
|
||||||
// Dial interprets a nil configuration as equivalent to
|
|
||||||
// the zero configuration; see the documentation of Config
|
|
||||||
// for the defaults.
|
|
||||||
func Dial(network, addr string, config *Config) (*Conn, error) {
|
|
||||||
return DialWithDialer(new(net.Dialer), network, addr, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dialer dials TLS connections given a configuration and a Dialer for the
|
|
||||||
// underlying connection.
|
|
||||||
type Dialer struct {
|
|
||||||
// NetDialer is the optional dialer to use for the TLS connections'
|
|
||||||
// underlying TCP connections.
|
|
||||||
// A nil NetDialer is equivalent to the net.Dialer zero value.
|
|
||||||
NetDialer *net.Dialer
|
|
||||||
|
|
||||||
// Config is the TLS configuration to use for new connections.
|
|
||||||
// A nil configuration is equivalent to the zero
|
|
||||||
// configuration; see the documentation of Config for the
|
|
||||||
// defaults.
|
|
||||||
Config *Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial connects to the given network address and initiates a TLS
|
|
||||||
// handshake, returning the resulting TLS connection.
|
|
||||||
//
|
|
||||||
// The returned Conn, if any, will always be of type *Conn.
|
|
||||||
//
|
|
||||||
// Dial uses context.Background internally; to specify the context,
|
|
||||||
// use DialContext.
|
|
||||||
func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
|
|
||||||
return d.DialContext(context.Background(), network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Dialer) netDialer() *net.Dialer {
|
|
||||||
if d.NetDialer != nil {
|
|
||||||
return d.NetDialer
|
|
||||||
}
|
|
||||||
return new(net.Dialer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialContext connects to the given network address and initiates a TLS
|
|
||||||
// handshake, returning the resulting TLS connection.
|
|
||||||
//
|
|
||||||
// The provided Context must be non-nil. If the context expires before
|
|
||||||
// the connection is complete, an error is returned. Once successfully
|
|
||||||
// connected, any expiration of the context will not affect the
|
|
||||||
// connection.
|
|
||||||
//
|
|
||||||
// The returned Conn, if any, will always be of type *Conn.
|
|
||||||
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
|
|
||||||
if err != nil {
|
|
||||||
// Don't return c (a typed nil) in an interface.
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadX509KeyPair reads and parses a public/private key pair from a pair
|
|
||||||
// of files. The files must contain PEM encoded data. The certificate file
|
|
||||||
// may contain intermediate certificates following the leaf certificate to
|
|
||||||
// form a certificate chain. On successful return, Certificate.Leaf will
|
|
||||||
// be nil because the parsed form of the certificate is not retained.
|
|
||||||
func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
|
|
||||||
certPEMBlock, err := os.ReadFile(certFile)
|
|
||||||
if err != nil {
|
|
||||||
return Certificate{}, err
|
|
||||||
}
|
|
||||||
keyPEMBlock, err := os.ReadFile(keyFile)
|
|
||||||
if err != nil {
|
|
||||||
return Certificate{}, err
|
|
||||||
}
|
|
||||||
return X509KeyPair(certPEMBlock, keyPEMBlock)
|
|
||||||
}
|
|
||||||
|
|
||||||
// X509KeyPair parses a public/private key pair from a pair of
|
|
||||||
// PEM encoded data. On successful return, Certificate.Leaf will be nil because
|
|
||||||
// the parsed form of the certificate is not retained.
|
|
||||||
func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
|
|
||||||
fail := func(err error) (Certificate, error) { return Certificate{}, err }
|
|
||||||
|
|
||||||
var cert Certificate
|
|
||||||
var skippedBlockTypes []string
|
|
||||||
for {
|
|
||||||
var certDERBlock *pem.Block
|
|
||||||
certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
|
|
||||||
if certDERBlock == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if certDERBlock.Type == "CERTIFICATE" {
|
|
||||||
cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
|
|
||||||
} else {
|
|
||||||
skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(cert.Certificate) == 0 {
|
|
||||||
if len(skippedBlockTypes) == 0 {
|
|
||||||
return fail(errors.New("tls: failed to find any PEM data in certificate input"))
|
|
||||||
}
|
|
||||||
if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
|
|
||||||
return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
|
|
||||||
}
|
|
||||||
return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
|
|
||||||
}
|
|
||||||
|
|
||||||
skippedBlockTypes = skippedBlockTypes[:0]
|
|
||||||
var keyDERBlock *pem.Block
|
|
||||||
for {
|
|
||||||
keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
|
|
||||||
if keyDERBlock == nil {
|
|
||||||
if len(skippedBlockTypes) == 0 {
|
|
||||||
return fail(errors.New("tls: failed to find any PEM data in key input"))
|
|
||||||
}
|
|
||||||
if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
|
|
||||||
return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
|
|
||||||
}
|
|
||||||
return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
|
|
||||||
}
|
|
||||||
if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't need to parse the public key for TLS, but we so do anyway
|
|
||||||
// to check that it looks sane and matches the private key.
|
|
||||||
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
|
|
||||||
if err != nil {
|
|
||||||
return fail(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
return fail(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch pub := x509Cert.PublicKey.(type) {
|
|
||||||
case *rsa.PublicKey:
|
|
||||||
priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
|
|
||||||
if !ok {
|
|
||||||
return fail(errors.New("tls: private key type does not match public key type"))
|
|
||||||
}
|
|
||||||
if pub.N.Cmp(priv.N) != 0 {
|
|
||||||
return fail(errors.New("tls: private key does not match public key"))
|
|
||||||
}
|
|
||||||
case *ecdsa.PublicKey:
|
|
||||||
priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
|
|
||||||
if !ok {
|
|
||||||
return fail(errors.New("tls: private key type does not match public key type"))
|
|
||||||
}
|
|
||||||
if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
|
|
||||||
return fail(errors.New("tls: private key does not match public key"))
|
|
||||||
}
|
|
||||||
case ed25519.PublicKey:
|
|
||||||
priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
|
|
||||||
if !ok {
|
|
||||||
return fail(errors.New("tls: private key type does not match public key type"))
|
|
||||||
}
|
|
||||||
if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
|
|
||||||
return fail(errors.New("tls: private key does not match public key"))
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fail(errors.New("tls: unknown public key algorithm"))
|
|
||||||
}
|
|
||||||
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
|
|
||||||
// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys.
|
|
||||||
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
|
|
||||||
func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
|
|
||||||
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
|
|
||||||
switch key := key.(type) {
|
|
||||||
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
|
|
||||||
return key, nil
|
|
||||||
default:
|
|
||||||
return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if key, err := x509.ParseECPrivateKey(der); err == nil {
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.New("tls: failed to parse private key")
|
|
||||||
}
|
|
|
@ -1,101 +0,0 @@
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"reflect"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
if !structsEqual(&tls.ConnectionState{}, &connectionState{}) {
|
|
||||||
panic("qtls.ConnectionState doesn't match")
|
|
||||||
}
|
|
||||||
if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) {
|
|
||||||
panic("qtls.ClientSessionState doesn't match")
|
|
||||||
}
|
|
||||||
if !structsEqual(&tls.CertificateRequestInfo{}, &certificateRequestInfo{}) {
|
|
||||||
panic("qtls.CertificateRequestInfo doesn't match")
|
|
||||||
}
|
|
||||||
if !structsEqual(&tls.Config{}, &config{}) {
|
|
||||||
panic("qtls.Config doesn't match")
|
|
||||||
}
|
|
||||||
if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) {
|
|
||||||
panic("qtls.ClientHelloInfo doesn't match")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func toConnectionState(c connectionState) ConnectionState {
|
|
||||||
return *(*ConnectionState)(unsafe.Pointer(&c))
|
|
||||||
}
|
|
||||||
|
|
||||||
func toClientSessionState(s *clientSessionState) *ClientSessionState {
|
|
||||||
return (*ClientSessionState)(unsafe.Pointer(s))
|
|
||||||
}
|
|
||||||
|
|
||||||
func fromClientSessionState(s *ClientSessionState) *clientSessionState {
|
|
||||||
return (*clientSessionState)(unsafe.Pointer(s))
|
|
||||||
}
|
|
||||||
|
|
||||||
func toCertificateRequestInfo(i *certificateRequestInfo) *CertificateRequestInfo {
|
|
||||||
return (*CertificateRequestInfo)(unsafe.Pointer(i))
|
|
||||||
}
|
|
||||||
|
|
||||||
func toConfig(c *config) *Config {
|
|
||||||
return (*Config)(unsafe.Pointer(c))
|
|
||||||
}
|
|
||||||
|
|
||||||
func fromConfig(c *Config) *config {
|
|
||||||
return (*config)(unsafe.Pointer(c))
|
|
||||||
}
|
|
||||||
|
|
||||||
func toClientHelloInfo(chi *clientHelloInfo) *ClientHelloInfo {
|
|
||||||
return (*ClientHelloInfo)(unsafe.Pointer(chi))
|
|
||||||
}
|
|
||||||
|
|
||||||
func structsEqual(a, b interface{}) bool {
|
|
||||||
return compare(reflect.ValueOf(a), reflect.ValueOf(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
func compare(a, b reflect.Value) bool {
|
|
||||||
sa := a.Elem()
|
|
||||||
sb := b.Elem()
|
|
||||||
if sa.NumField() != sb.NumField() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := 0; i < sa.NumField(); i++ {
|
|
||||||
fa := sa.Type().Field(i)
|
|
||||||
fb := sb.Type().Field(i)
|
|
||||||
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
|
|
||||||
if fa.Type.Kind() != fb.Type.Kind() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if fa.Type.Kind() == reflect.Slice {
|
|
||||||
if !compareStruct(fa.Type.Elem(), fb.Type.Elem()) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func compareStruct(a, b reflect.Type) bool {
|
|
||||||
if a.NumField() != b.NumField() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := 0; i < a.NumField(); i++ {
|
|
||||||
fa := a.Field(i)
|
|
||||||
fb := b.Field(i)
|
|
||||||
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitSessionTicketKeys triggers the initialization of session ticket keys.
|
|
||||||
func InitSessionTicketKeys(conf *Config) {
|
|
||||||
fromConfig(conf).ticketKeys(nil)
|
|
||||||
}
|
|
|
@ -3,15 +3,15 @@ run:
|
||||||
- internal/handshake/cipher_suite.go
|
- internal/handshake/cipher_suite.go
|
||||||
linters-settings:
|
linters-settings:
|
||||||
depguard:
|
depguard:
|
||||||
type: blacklist
|
rules:
|
||||||
packages:
|
qtls:
|
||||||
- github.com/marten-seemann/qtls
|
list-mode: lax
|
||||||
- github.com/quic-go/qtls-go1-19
|
files:
|
||||||
- github.com/quic-go/qtls-go1-20
|
- "!internal/qtls/**"
|
||||||
packages-with-error-message:
|
- "$all"
|
||||||
- github.com/marten-seemann/qtls: "importing qtls only allowed in internal/qtls"
|
deny:
|
||||||
- github.com/quic-go/qtls-go1-19: "importing qtls only allowed in internal/qtls"
|
- pkg: github.com/quic-go/qtls-go1-20
|
||||||
- github.com/quic-go/qtls-go1-20: "importing qtls only allowed in internal/qtls"
|
desc: "importing qtls only allowed in internal/qtls"
|
||||||
misspell:
|
misspell:
|
||||||
ignore-words:
|
ignore-words:
|
||||||
- ect
|
- ect
|
||||||
|
|
|
@ -227,12 +227,13 @@ http.Client{
|
||||||
## Projects using quic-go
|
## Projects using quic-go
|
||||||
|
|
||||||
| Project | Description | Stars |
|
| Project | Description | Stars |
|
||||||
| --------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
|
| ---------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
|
||||||
| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) |
|
| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) |
|
||||||
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) |
|
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) |
|
||||||
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) |
|
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) |
|
||||||
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) |
|
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) |
|
||||||
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) |
|
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) |
|
||||||
|
| [gost](https://github.com/go-gost/gost) | A simple security tunnel written in Go | ![GitHub Repo stars](https://img.shields.io/github/stars/go-gost/gost?style=flat-square) |
|
||||||
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) |
|
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) |
|
||||||
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
|
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
|
||||||
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
|
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
|
||||||
|
@ -247,11 +248,6 @@ If you'd like to see your project added to this list, please send us a PR.
|
||||||
|
|
||||||
quic-go always aims to support the latest two Go releases.
|
quic-go always aims to support the latest two Go releases.
|
||||||
|
|
||||||
### Dependency on forked crypto/tls
|
|
||||||
|
|
||||||
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20).
|
|
||||||
This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.
|
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/qerr"
|
"github.com/quic-go/quic-go/internal/qerr"
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
|
||||||
"github.com/quic-go/quic-go/internal/wire"
|
"github.com/quic-go/quic-go/internal/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -60,7 +59,7 @@ func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
|
||||||
// transport parameter.
|
// transport parameter.
|
||||||
// We currently don't send the preferred_address transport parameter,
|
// We currently don't send the preferred_address transport parameter,
|
||||||
// so we can issue (limit - 1) connection IDs.
|
// so we can issue (limit - 1) connection IDs.
|
||||||
for i := uint64(len(m.activeSrcConnIDs)); i < utils.Min(limit, protocol.MaxIssuedConnectionIDs); i++ {
|
for i := uint64(len(m.activeSrcConnIDs)); i < min(limit, protocol.MaxIssuedConnectionIDs); i++ {
|
||||||
if err := m.issueNewConnID(); err != nil {
|
if err := m.issueNewConnID(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -145,7 +145,7 @@ func (h *connIDManager) updateConnectionID() {
|
||||||
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
h.queueControlFrame(&wire.RetireConnectionIDFrame{
|
||||||
SequenceNumber: h.activeSequenceNumber,
|
SequenceNumber: h.activeSequenceNumber,
|
||||||
})
|
})
|
||||||
h.highestRetired = utils.Max(h.highestRetired, h.activeSequenceNumber)
|
h.highestRetired = max(h.highestRetired, h.activeSequenceNumber)
|
||||||
if h.activeStatelessResetToken != nil {
|
if h.activeStatelessResetToken != nil {
|
||||||
h.removeStatelessResetToken(*h.activeStatelessResetToken)
|
h.removeStatelessResetToken(*h.activeStatelessResetToken)
|
||||||
}
|
}
|
||||||
|
|
|
@ -629,7 +629,7 @@ runLoop:
|
||||||
sendQueueAvailable = s.sendQueue.Available()
|
sendQueueAvailable = s.sendQueue.Available()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := s.triggerSending(); err != nil {
|
if err := s.triggerSending(now); err != nil {
|
||||||
s.closeLocal(err)
|
s.closeLocal(err)
|
||||||
}
|
}
|
||||||
if s.sendQueue.WouldBlock() {
|
if s.sendQueue.WouldBlock() {
|
||||||
|
@ -681,7 +681,7 @@ func (s *connection) ConnectionState() ConnectionState {
|
||||||
|
|
||||||
// Time when the connection should time out
|
// Time when the connection should time out
|
||||||
func (s *connection) nextIdleTimeoutTime() time.Time {
|
func (s *connection) nextIdleTimeoutTime() time.Time {
|
||||||
idleTimeout := utils.Max(s.idleTimeout, s.rttStats.PTO(true)*3)
|
idleTimeout := max(s.idleTimeout, s.rttStats.PTO(true)*3)
|
||||||
return s.idleTimeoutStartTime().Add(idleTimeout)
|
return s.idleTimeoutStartTime().Add(idleTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -691,7 +691,7 @@ func (s *connection) nextKeepAliveTime() time.Time {
|
||||||
if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() {
|
if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent || !s.firstAckElicitingPacketAfterIdleSentTime.IsZero() {
|
||||||
return time.Time{}
|
return time.Time{}
|
||||||
}
|
}
|
||||||
keepAliveInterval := utils.Max(s.keepAliveInterval, s.rttStats.PTO(true)*3/2)
|
keepAliveInterval := max(s.keepAliveInterval, s.rttStats.PTO(true)*3/2)
|
||||||
return s.lastPacketReceivedTime.Add(keepAliveInterval)
|
return s.lastPacketReceivedTime.Add(keepAliveInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -731,6 +731,10 @@ func (s *connection) handleHandshakeComplete() error {
|
||||||
s.connIDManager.SetHandshakeComplete()
|
s.connIDManager.SetHandshakeComplete()
|
||||||
s.connIDGenerator.SetHandshakeComplete()
|
s.connIDGenerator.SetHandshakeComplete()
|
||||||
|
|
||||||
|
if s.tracer != nil && s.tracer.ChoseALPN != nil {
|
||||||
|
s.tracer.ChoseALPN(s.cryptoStreamHandler.ConnectionState().NegotiatedProtocol)
|
||||||
|
}
|
||||||
|
|
||||||
// The server applies transport parameters right away, but the client side has to wait for handshake completion.
|
// The server applies transport parameters right away, but the client side has to wait for handshake completion.
|
||||||
// During a 0-RTT connection, the client is only allowed to use the new transport parameters for 1-RTT packets.
|
// During a 0-RTT connection, the client is only allowed to use the new transport parameters for 1-RTT packets.
|
||||||
if s.perspective == protocol.PerspectiveClient {
|
if s.perspective == protocol.PerspectiveClient {
|
||||||
|
@ -776,7 +780,7 @@ func (s *connection) handleHandshakeConfirmed() error {
|
||||||
if maxPacketSize == 0 {
|
if maxPacketSize == 0 {
|
||||||
maxPacketSize = protocol.MaxByteCount
|
maxPacketSize = protocol.MaxByteCount
|
||||||
}
|
}
|
||||||
s.mtuDiscoverer.Start(utils.Min(maxPacketSize, protocol.MaxPacketBufferSize))
|
s.mtuDiscoverer.Start(min(maxPacketSize, protocol.MaxPacketBufferSize))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1751,7 +1755,7 @@ func (s *connection) applyTransportParameters() {
|
||||||
params := s.peerParams
|
params := s.peerParams
|
||||||
// Our local idle timeout will always be > 0.
|
// Our local idle timeout will always be > 0.
|
||||||
s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout)
|
s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout)
|
||||||
s.keepAliveInterval = utils.Min(s.config.KeepAlivePeriod, utils.Min(s.idleTimeout/2, protocol.MaxKeepAliveInterval))
|
s.keepAliveInterval = min(s.config.KeepAlivePeriod, min(s.idleTimeout/2, protocol.MaxKeepAliveInterval))
|
||||||
s.streamsMap.UpdateLimits(params)
|
s.streamsMap.UpdateLimits(params)
|
||||||
s.frameParser.SetAckDelayExponent(params.AckDelayExponent)
|
s.frameParser.SetAckDelayExponent(params.AckDelayExponent)
|
||||||
s.connFlowController.UpdateSendWindow(params.InitialMaxData)
|
s.connFlowController.UpdateSendWindow(params.InitialMaxData)
|
||||||
|
@ -1767,9 +1771,8 @@ func (s *connection) applyTransportParameters() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *connection) triggerSending() error {
|
func (s *connection) triggerSending(now time.Time) error {
|
||||||
s.pacingDeadline = time.Time{}
|
s.pacingDeadline = time.Time{}
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
sendMode := s.sentPacketHandler.SendMode(now)
|
sendMode := s.sentPacketHandler.SendMode(now)
|
||||||
//nolint:exhaustive // No need to handle pacing limited here.
|
//nolint:exhaustive // No need to handle pacing limited here.
|
||||||
|
@ -1801,7 +1804,7 @@ func (s *connection) triggerSending() error {
|
||||||
s.scheduleSending()
|
s.scheduleSending()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.triggerSending()
|
return s.triggerSending(now)
|
||||||
case ackhandler.SendPTOHandshake:
|
case ackhandler.SendPTOHandshake:
|
||||||
if err := s.sendProbePacket(protocol.EncryptionHandshake, now); err != nil {
|
if err := s.sendProbePacket(protocol.EncryptionHandshake, now); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -1810,7 +1813,7 @@ func (s *connection) triggerSending() error {
|
||||||
s.scheduleSending()
|
s.scheduleSending()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.triggerSending()
|
return s.triggerSending(now)
|
||||||
case ackhandler.SendPTOAppData:
|
case ackhandler.SendPTOAppData:
|
||||||
if err := s.sendProbePacket(protocol.Encryption1RTT, now); err != nil {
|
if err := s.sendProbePacket(protocol.Encryption1RTT, now); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -1819,7 +1822,7 @@ func (s *connection) triggerSending() error {
|
||||||
s.scheduleSending()
|
s.scheduleSending()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.triggerSending()
|
return s.triggerSending(now)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("BUG: invalid send mode %d", sendMode)
|
return fmt.Errorf("BUG: invalid send mode %d", sendMode)
|
||||||
}
|
}
|
||||||
|
@ -1988,7 +1991,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
|
||||||
if packet == nil {
|
if packet == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.sendPackedCoalescedPacket(packet, ecn, time.Now())
|
return s.sendPackedCoalescedPacket(packet, ecn, now)
|
||||||
}
|
}
|
||||||
|
|
||||||
ecn := s.sentPacketHandler.ECNMode(true)
|
ecn := s.sentPacketHandler.ECNMode(true)
|
||||||
|
@ -2356,7 +2359,7 @@ func (s *connection) SendDatagram(p []byte) error {
|
||||||
}
|
}
|
||||||
f.Data = make([]byte, len(p))
|
f.Data = make([]byte, len(p))
|
||||||
copy(f.Data, p)
|
copy(f.Data, p)
|
||||||
return s.datagramQueue.AddAndWait(f)
|
return s.datagramQueue.Add(f)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *connection) ReceiveDatagram(ctx context.Context) ([]byte, error) {
|
func (s *connection) ReceiveDatagram(ctx context.Context) ([]byte, error) {
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/qerr"
|
"github.com/quic-go/quic-go/internal/qerr"
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
|
||||||
"github.com/quic-go/quic-go/internal/wire"
|
"github.com/quic-go/quic-go/internal/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -56,7 +55,7 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
|
||||||
// could e.g. be a retransmission
|
// could e.g. be a retransmission
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
s.highestOffset = utils.Max(s.highestOffset, highestOffset)
|
s.highestOffset = max(s.highestOffset, highestOffset)
|
||||||
if err := s.queue.Push(f.Data, f.Offset, nil); err != nil {
|
if err := s.queue.Push(f.Data, f.Offset, nil); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -99,7 +98,7 @@ func (s *cryptoStreamImpl) HasData() bool {
|
||||||
|
|
||||||
func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
|
func (s *cryptoStreamImpl) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
|
||||||
f := &wire.CryptoFrame{Offset: s.writeOffset}
|
f := &wire.CryptoFrame{Offset: s.writeOffset}
|
||||||
n := utils.Min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
|
n := min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
|
||||||
f.Data = s.writeBuf[:n]
|
f.Data = s.writeBuf[:n]
|
||||||
s.writeBuf = s.writeBuf[n:]
|
s.writeBuf = s.writeBuf[n:]
|
||||||
s.writeOffset += n
|
s.writeOffset += n
|
||||||
|
|
|
@ -4,14 +4,20 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
"github.com/quic-go/quic-go/internal/utils"
|
||||||
|
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
|
||||||
"github.com/quic-go/quic-go/internal/wire"
|
"github.com/quic-go/quic-go/internal/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxDatagramSendQueueLen = 32
|
||||||
|
maxDatagramRcvQueueLen = 128
|
||||||
|
)
|
||||||
|
|
||||||
type datagramQueue struct {
|
type datagramQueue struct {
|
||||||
sendQueue chan *wire.DatagramFrame
|
sendMx sync.Mutex
|
||||||
nextFrame *wire.DatagramFrame
|
sendQueue ringbuffer.RingBuffer[*wire.DatagramFrame]
|
||||||
|
sent chan struct{} // used to notify Add that a datagram was dequeued
|
||||||
|
|
||||||
rcvMx sync.Mutex
|
rcvMx sync.Mutex
|
||||||
rcvQueue [][]byte
|
rcvQueue [][]byte
|
||||||
|
@ -22,60 +28,65 @@ type datagramQueue struct {
|
||||||
|
|
||||||
hasData func()
|
hasData func()
|
||||||
|
|
||||||
dequeued chan struct{}
|
|
||||||
|
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue {
|
func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue {
|
||||||
return &datagramQueue{
|
return &datagramQueue{
|
||||||
hasData: hasData,
|
hasData: hasData,
|
||||||
sendQueue: make(chan *wire.DatagramFrame, 1),
|
rcvd: make(chan struct{}, 1),
|
||||||
rcvd: make(chan struct{}, 1),
|
sent: make(chan struct{}, 1),
|
||||||
dequeued: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
closed: make(chan struct{}),
|
logger: logger,
|
||||||
logger: logger,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddAndWait queues a new DATAGRAM frame for sending.
|
// Add queues a new DATAGRAM frame for sending.
|
||||||
// It blocks until the frame has been dequeued.
|
// Up to 32 DATAGRAM frames will be queued.
|
||||||
func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error {
|
// Once that limit is reached, Add blocks until the queue size has reduced.
|
||||||
select {
|
func (h *datagramQueue) Add(f *wire.DatagramFrame) error {
|
||||||
case h.sendQueue <- f:
|
h.sendMx.Lock()
|
||||||
h.hasData()
|
|
||||||
case <-h.closed:
|
|
||||||
return h.closeErr
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
for {
|
||||||
case <-h.dequeued:
|
if h.sendQueue.Len() < maxDatagramSendQueueLen {
|
||||||
return nil
|
h.sendQueue.PushBack(f)
|
||||||
case <-h.closed:
|
h.sendMx.Unlock()
|
||||||
return h.closeErr
|
h.hasData()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-h.sent: // drain the queue so we don't loop immediately
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
h.sendMx.Unlock()
|
||||||
|
select {
|
||||||
|
case <-h.closed:
|
||||||
|
return h.closeErr
|
||||||
|
case <-h.sent:
|
||||||
|
}
|
||||||
|
h.sendMx.Lock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peek gets the next DATAGRAM frame for sending.
|
// Peek gets the next DATAGRAM frame for sending.
|
||||||
// If actually sent out, Pop needs to be called before the next call to Peek.
|
// If actually sent out, Pop needs to be called before the next call to Peek.
|
||||||
func (h *datagramQueue) Peek() *wire.DatagramFrame {
|
func (h *datagramQueue) Peek() *wire.DatagramFrame {
|
||||||
if h.nextFrame != nil {
|
h.sendMx.Lock()
|
||||||
return h.nextFrame
|
defer h.sendMx.Unlock()
|
||||||
}
|
if h.sendQueue.Empty() {
|
||||||
select {
|
|
||||||
case h.nextFrame = <-h.sendQueue:
|
|
||||||
h.dequeued <- struct{}{}
|
|
||||||
default:
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return h.nextFrame
|
return h.sendQueue.PeekFront()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *datagramQueue) Pop() {
|
func (h *datagramQueue) Pop() {
|
||||||
if h.nextFrame == nil {
|
h.sendMx.Lock()
|
||||||
panic("datagramQueue BUG: Pop called for nil frame")
|
defer h.sendMx.Unlock()
|
||||||
|
_ = h.sendQueue.PopFront()
|
||||||
|
select {
|
||||||
|
case h.sent <- struct{}{}:
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
h.nextFrame = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleDatagramFrame handles a received DATAGRAM frame.
|
// HandleDatagramFrame handles a received DATAGRAM frame.
|
||||||
|
@ -84,7 +95,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
|
||||||
copy(data, f.Data)
|
copy(data, f.Data)
|
||||||
var queued bool
|
var queued bool
|
||||||
h.rcvMx.Lock()
|
h.rcvMx.Lock()
|
||||||
if len(h.rcvQueue) < protocol.DatagramRcvQueueLen {
|
if len(h.rcvQueue) < maxDatagramRcvQueueLen {
|
||||||
h.rcvQueue = append(h.rcvQueue, data)
|
h.rcvQueue = append(h.rcvQueue, data)
|
||||||
queued = true
|
queued = true
|
||||||
select {
|
select {
|
||||||
|
@ -94,7 +105,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
|
||||||
}
|
}
|
||||||
h.rcvMx.Unlock()
|
h.rcvMx.Unlock()
|
||||||
if !queued && h.logger.Debug() {
|
if !queued && h.logger.Debug() {
|
||||||
h.logger.Debugf("Discarding DATAGRAM frame (%d bytes payload)", len(f.Data))
|
h.logger.Debugf("Discarding received DATAGRAM frame (%d bytes payload)", len(f.Data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,8 @@ type framer interface {
|
||||||
Handle0RTTRejection() error
|
Handle0RTTRejection() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const maxPathResponses = 256
|
||||||
|
|
||||||
type framerI struct {
|
type framerI struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
|
@ -33,6 +35,7 @@ type framerI struct {
|
||||||
|
|
||||||
controlFrameMutex sync.Mutex
|
controlFrameMutex sync.Mutex
|
||||||
controlFrames []wire.Frame
|
controlFrames []wire.Frame
|
||||||
|
pathResponses []*wire.PathResponseFrame
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ framer = &framerI{}
|
var _ framer = &framerI{}
|
||||||
|
@ -52,20 +55,43 @@ func (f *framerI) HasData() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
f.controlFrameMutex.Lock()
|
f.controlFrameMutex.Lock()
|
||||||
hasData = len(f.controlFrames) > 0
|
defer f.controlFrameMutex.Unlock()
|
||||||
f.controlFrameMutex.Unlock()
|
return len(f.controlFrames) > 0 || len(f.pathResponses) > 0
|
||||||
return hasData
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *framerI) QueueControlFrame(frame wire.Frame) {
|
func (f *framerI) QueueControlFrame(frame wire.Frame) {
|
||||||
f.controlFrameMutex.Lock()
|
f.controlFrameMutex.Lock()
|
||||||
|
defer f.controlFrameMutex.Unlock()
|
||||||
|
|
||||||
|
if pr, ok := frame.(*wire.PathResponseFrame); ok {
|
||||||
|
// Only queue up to maxPathResponses PATH_RESPONSE frames.
|
||||||
|
// This limit should be high enough to never be hit in practice,
|
||||||
|
// unless the peer is doing something malicious.
|
||||||
|
if len(f.pathResponses) >= maxPathResponses {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.pathResponses = append(f.pathResponses, pr)
|
||||||
|
return
|
||||||
|
}
|
||||||
f.controlFrames = append(f.controlFrames, frame)
|
f.controlFrames = append(f.controlFrames, frame)
|
||||||
f.controlFrameMutex.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) {
|
func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol.ByteCount, v protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount) {
|
||||||
var length protocol.ByteCount
|
|
||||||
f.controlFrameMutex.Lock()
|
f.controlFrameMutex.Lock()
|
||||||
|
defer f.controlFrameMutex.Unlock()
|
||||||
|
|
||||||
|
var length protocol.ByteCount
|
||||||
|
// add a PATH_RESPONSE first, but only pack a single PATH_RESPONSE per packet
|
||||||
|
if len(f.pathResponses) > 0 {
|
||||||
|
frame := f.pathResponses[0]
|
||||||
|
frameLen := frame.Length(v)
|
||||||
|
if frameLen <= maxLen {
|
||||||
|
frames = append(frames, ackhandler.Frame{Frame: frame})
|
||||||
|
length += frameLen
|
||||||
|
f.pathResponses = f.pathResponses[1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for len(f.controlFrames) > 0 {
|
for len(f.controlFrames) > 0 {
|
||||||
frame := f.controlFrames[len(f.controlFrames)-1]
|
frame := f.controlFrames[len(f.controlFrames)-1]
|
||||||
frameLen := frame.Length(v)
|
frameLen := frame.Length(v)
|
||||||
|
@ -76,7 +102,6 @@ func (f *framerI) AppendControlFrames(frames []ackhandler.Frame, maxLen protocol
|
||||||
length += frameLen
|
length += frameLen
|
||||||
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
|
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
|
||||||
}
|
}
|
||||||
f.controlFrameMutex.Unlock()
|
|
||||||
return frames, length
|
return frames, length
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -187,8 +187,12 @@ type Connection interface {
|
||||||
// Warning: This API should not be considered stable and might change soon.
|
// Warning: This API should not be considered stable and might change soon.
|
||||||
ConnectionState() ConnectionState
|
ConnectionState() ConnectionState
|
||||||
|
|
||||||
// SendDatagram sends a message as a datagram, as specified in RFC 9221.
|
// SendDatagram sends a message using a QUIC datagram, as specified in RFC 9221.
|
||||||
SendDatagram([]byte) error
|
// There is no delivery guarantee for DATAGRAM frames, they are not retransmitted if lost.
|
||||||
|
// The payload of the datagram needs to fit into a single QUIC packet.
|
||||||
|
// In addition, a datagram may be dropped before being sent out if the available packet size suddenly decreases.
|
||||||
|
// If the payload is too large to be sent at the current time, a DatagramTooLargeError is returned.
|
||||||
|
SendDatagram(payload []byte) error
|
||||||
// ReceiveDatagram gets a message received in a datagram, as specified in RFC 9221.
|
// ReceiveDatagram gets a message received in a datagram, as specified in RFC 9221.
|
||||||
ReceiveDatagram(context.Context) ([]byte, error)
|
ReceiveDatagram(context.Context) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
2
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go
generated
vendored
2
vendor/github.com/quic-go/quic-go/internal/ackhandler/packet_number_generator.go
generated
vendored
|
@ -80,5 +80,5 @@ func (p *skippingPacketNumberGenerator) Pop() (bool, protocol.PacketNumber) {
|
||||||
func (p *skippingPacketNumberGenerator) generateNewSkip() {
|
func (p *skippingPacketNumberGenerator) generateNewSkip() {
|
||||||
// make sure that there are never two consecutive packet numbers that are skipped
|
// make sure that there are never two consecutive packet numbers that are skipped
|
||||||
p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period)))
|
p.nextToSkip = p.next + 3 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period)))
|
||||||
p.period = utils.Min(2*p.period, p.maxPeriod)
|
p.period = min(2*p.period, p.maxPeriod)
|
||||||
}
|
}
|
||||||
|
|
2
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go
generated
vendored
2
vendor/github.com/quic-go/quic-go/internal/ackhandler/received_packet_tracker.go
generated
vendored
|
@ -179,7 +179,7 @@ func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
|
||||||
ack = &wire.AckFrame{}
|
ack = &wire.AckFrame{}
|
||||||
}
|
}
|
||||||
ack.Reset()
|
ack.Reset()
|
||||||
ack.DelayTime = utils.Max(0, now.Sub(h.largestObservedRcvdTime))
|
ack.DelayTime = max(0, now.Sub(h.largestObservedRcvdTime))
|
||||||
ack.ECT0 = h.ect0
|
ack.ECT0 = h.ect0
|
||||||
ack.ECT1 = h.ect1
|
ack.ECT1 = h.ect1
|
||||||
ack.ECNCE = h.ecnce
|
ack.ECNCE = h.ecnce
|
||||||
|
|
|
@ -245,7 +245,7 @@ func (h *sentPacketHandler) SentPacket(
|
||||||
|
|
||||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||||
if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
|
if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
|
||||||
for p := utils.Max(0, pnSpace.largestSent+1); p < pn; p++ {
|
for p := max(0, pnSpace.largestSent+1); p < pn; p++ {
|
||||||
h.logger.Debugf("Skipping packet number %d", p)
|
h.logger.Debugf("Skipping packet number %d", p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -336,7 +336,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
|
||||||
// don't use the ack delay for Initial and Handshake packets
|
// don't use the ack delay for Initial and Handshake packets
|
||||||
var ackDelay time.Duration
|
var ackDelay time.Duration
|
||||||
if encLevel == protocol.Encryption1RTT {
|
if encLevel == protocol.Encryption1RTT {
|
||||||
ackDelay = utils.Min(ack.DelayTime, h.rttStats.MaxAckDelay())
|
ackDelay = min(ack.DelayTime, h.rttStats.MaxAckDelay())
|
||||||
}
|
}
|
||||||
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
|
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
|
||||||
if h.logger.Debug() {
|
if h.logger.Debug() {
|
||||||
|
@ -354,7 +354,7 @@ func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.En
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pnSpace.largestAcked = utils.Max(pnSpace.largestAcked, largestAcked)
|
pnSpace.largestAcked = max(pnSpace.largestAcked, largestAcked)
|
||||||
|
|
||||||
if err := h.detectLostPackets(rcvTime, encLevel); err != nil {
|
if err := h.detectLostPackets(rcvTime, encLevel); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
@ -446,7 +446,7 @@ func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encL
|
||||||
|
|
||||||
for _, p := range h.ackedPackets {
|
for _, p := range h.ackedPackets {
|
||||||
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
|
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
|
||||||
h.lowestNotConfirmedAcked = utils.Max(h.lowestNotConfirmedAcked, p.LargestAcked+1)
|
h.lowestNotConfirmedAcked = max(h.lowestNotConfirmedAcked, p.LargestAcked+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, f := range p.Frames {
|
for _, f := range p.Frames {
|
||||||
|
@ -607,11 +607,11 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E
|
||||||
pnSpace := h.getPacketNumberSpace(encLevel)
|
pnSpace := h.getPacketNumberSpace(encLevel)
|
||||||
pnSpace.lossTime = time.Time{}
|
pnSpace.lossTime = time.Time{}
|
||||||
|
|
||||||
maxRTT := float64(utils.Max(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
maxRTT := float64(max(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
|
||||||
lossDelay := time.Duration(timeThreshold * maxRTT)
|
lossDelay := time.Duration(timeThreshold * maxRTT)
|
||||||
|
|
||||||
// Minimum time of granularity before packets are deemed lost.
|
// Minimum time of granularity before packets are deemed lost.
|
||||||
lossDelay = utils.Max(lossDelay, protocol.TimerGranularity)
|
lossDelay = max(lossDelay, protocol.TimerGranularity)
|
||||||
|
|
||||||
// Packets sent before this time are deemed lost.
|
// Packets sent before this time are deemed lost.
|
||||||
lostSendTime := now.Add(-lossDelay)
|
lostSendTime := now.Add(-lossDelay)
|
||||||
|
@ -890,7 +890,7 @@ func (h *sentPacketHandler) ResetForRetry(now time.Time) error {
|
||||||
// Otherwise, we don't know which Initial the Retry was sent in response to.
|
// Otherwise, we don't know which Initial the Retry was sent in response to.
|
||||||
if h.ptoCount == 0 {
|
if h.ptoCount == 0 {
|
||||||
// Don't set the RTT to a value lower than 5ms here.
|
// Don't set the RTT to a value lower than 5ms here.
|
||||||
h.rttStats.UpdateRTT(utils.Max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now)
|
h.rttStats.UpdateRTT(max(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now)
|
||||||
if h.logger.Debug() {
|
if h.logger.Debug() {
|
||||||
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
|
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// This cubic implementation is based on the one found in Chromiums's QUIC
|
// This cubic implementation is based on the one found in Chromiums's QUIC
|
||||||
|
@ -187,7 +186,7 @@ func (c *Cubic) CongestionWindowAfterAck(
|
||||||
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
|
||||||
}
|
}
|
||||||
// Limit the CWND increase to half the acked bytes.
|
// Limit the CWND increase to half the acked bytes.
|
||||||
targetCongestionWindow = utils.Min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
|
targetCongestionWindow = min(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
|
||||||
|
|
||||||
// Increase the window by approximately Alpha * 1 MSS of bytes every
|
// Increase the window by approximately Alpha * 1 MSS of bytes every
|
||||||
// time we ack an estimated tcp window of bytes. For small
|
// time we ack an estimated tcp window of bytes. For small
|
||||||
|
|
|
@ -178,7 +178,7 @@ func (c *cubicSender) OnPacketAcked(
|
||||||
priorInFlight protocol.ByteCount,
|
priorInFlight protocol.ByteCount,
|
||||||
eventTime time.Time,
|
eventTime time.Time,
|
||||||
) {
|
) {
|
||||||
c.largestAckedPacketNumber = utils.Max(ackedPacketNumber, c.largestAckedPacketNumber)
|
c.largestAckedPacketNumber = max(ackedPacketNumber, c.largestAckedPacketNumber)
|
||||||
if c.InRecovery() {
|
if c.InRecovery() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -246,7 +246,7 @@ func (c *cubicSender) maybeIncreaseCwnd(
|
||||||
c.numAckedPackets = 0
|
c.numAckedPackets = 0
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
c.congestionWindow = utils.Min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
|
c.congestionWindow = min(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Note(pwestin): the magic clamping numbers come from the original code in
|
// Note(pwestin): the magic clamping numbers come from the original code in
|
||||||
|
@ -75,8 +74,8 @@ func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT ti
|
||||||
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
|
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
|
||||||
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
|
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
|
||||||
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
|
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
|
||||||
minRTTincreaseThresholdUs = utils.Min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
|
minRTTincreaseThresholdUs = min(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
|
||||||
minRTTincreaseThreshold := time.Duration(utils.Max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
|
minRTTincreaseThreshold := time.Duration(max(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
|
||||||
|
|
||||||
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
|
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
|
||||||
s.hystartFound = true
|
s.hystartFound = true
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const maxBurstSizePackets = 10
|
const maxBurstSizePackets = 10
|
||||||
|
@ -52,11 +51,11 @@ func (p *pacer) Budget(now time.Time) protocol.ByteCount {
|
||||||
if budget < 0 { // protect against overflows
|
if budget < 0 { // protect against overflows
|
||||||
budget = protocol.MaxByteCount
|
budget = protocol.MaxByteCount
|
||||||
}
|
}
|
||||||
return utils.Min(p.maxBurstSize(), budget)
|
return min(p.maxBurstSize(), budget)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *pacer) maxBurstSize() protocol.ByteCount {
|
func (p *pacer) maxBurstSize() protocol.ByteCount {
|
||||||
return utils.Max(
|
return max(
|
||||||
protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.adjustedBandwidth())/1e9,
|
protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.adjustedBandwidth())/1e9,
|
||||||
maxBurstSizePackets*p.maxDatagramSize,
|
maxBurstSizePackets*p.maxDatagramSize,
|
||||||
)
|
)
|
||||||
|
@ -77,7 +76,7 @@ func (p *pacer) TimeUntilSend() time.Time {
|
||||||
if diff%bw > 0 {
|
if diff%bw > 0 {
|
||||||
d++
|
d++
|
||||||
}
|
}
|
||||||
return p.lastSentTime.Add(utils.Max(protocol.MinPacingDelay, time.Duration(d)*time.Nanosecond))
|
return p.lastSentTime.Add(max(protocol.MinPacingDelay, time.Duration(d)*time.Nanosecond))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) {
|
func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) {
|
||||||
|
|
|
@ -107,7 +107,7 @@ func (c *baseFlowController) maybeAdjustWindowSize() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
|
if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
|
||||||
// window is consumed too fast, try to increase the window size
|
// window is consumed too fast, try to increase the window size
|
||||||
newSize := utils.Min(2*c.receiveWindowSize, c.maxReceiveWindowSize)
|
newSize := min(2*c.receiveWindowSize, c.maxReceiveWindowSize)
|
||||||
if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) {
|
if newSize > c.receiveWindowSize && (c.allowWindowIncrease == nil || c.allowWindowIncrease(newSize-c.receiveWindowSize)) {
|
||||||
c.receiveWindowSize = newSize
|
c.receiveWindowSize = newSize
|
||||||
}
|
}
|
||||||
|
|
2
vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go
generated
vendored
2
vendor/github.com/quic-go/quic-go/internal/flowcontrol/connection_flow_controller.go
generated
vendored
|
@ -87,7 +87,7 @@ func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCoun
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
if inc > c.receiveWindowSize {
|
if inc > c.receiveWindowSize {
|
||||||
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
|
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
|
||||||
newSize := utils.Min(inc, c.maxReceiveWindowSize)
|
newSize := min(inc, c.maxReceiveWindowSize)
|
||||||
if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) {
|
if delta := newSize - c.receiveWindowSize; delta > 0 && c.allowWindowIncrease(delta) {
|
||||||
c.receiveWindowSize = newSize
|
c.receiveWindowSize = newSize
|
||||||
}
|
}
|
||||||
|
|
2
vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
2
vendor/github.com/quic-go/quic-go/internal/flowcontrol/stream_flow_controller.go
generated
vendored
|
@ -123,7 +123,7 @@ func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
|
||||||
return utils.Min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
|
return min(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamFlowController) shouldQueueWindowUpdate() bool {
|
func (c *streamFlowController) shouldQueueWindowUpdate() bool {
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
|
func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD {
|
||||||
|
@ -82,7 +81,7 @@ func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []
|
||||||
// It uses the nonce provided here and XOR it with the IV.
|
// It uses the nonce provided here and XOR it with the IV.
|
||||||
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
|
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
o.highestRcvdPN = utils.Max(o.highestRcvdPN, pn)
|
o.highestRcvdPN = max(o.highestRcvdPN, pn)
|
||||||
} else {
|
} else {
|
||||||
err = ErrDecryptionFailed
|
err = ErrDecryptionFailed
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,11 +25,11 @@ type quicVersionContextKey struct{}
|
||||||
|
|
||||||
var QUICVersionContextKey = &quicVersionContextKey{}
|
var QUICVersionContextKey = &quicVersionContextKey{}
|
||||||
|
|
||||||
const clientSessionStateRevision = 3
|
const clientSessionStateRevision = 4
|
||||||
|
|
||||||
type cryptoSetup struct {
|
type cryptoSetup struct {
|
||||||
tlsConf *tls.Config
|
tlsConf *tls.Config
|
||||||
conn *qtls.QUICConn
|
conn *tls.QUICConn
|
||||||
|
|
||||||
events []Event
|
events []Event
|
||||||
|
|
||||||
|
@ -93,12 +93,12 @@ func NewCryptoSetupClient(
|
||||||
|
|
||||||
tlsConf = tlsConf.Clone()
|
tlsConf = tlsConf.Clone()
|
||||||
tlsConf.MinVersion = tls.VersionTLS13
|
tlsConf.MinVersion = tls.VersionTLS13
|
||||||
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
|
quicConf := &tls.QUICConfig{TLSConfig: tlsConf}
|
||||||
qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState)
|
qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState)
|
||||||
cs.tlsConf = tlsConf
|
cs.tlsConf = tlsConf
|
||||||
cs.allow0RTT = enable0RTT
|
cs.allow0RTT = enable0RTT
|
||||||
|
|
||||||
cs.conn = qtls.QUICClient(quicConf)
|
cs.conn = tls.QUICClient(quicConf)
|
||||||
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
|
cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient))
|
||||||
|
|
||||||
return cs
|
return cs
|
||||||
|
@ -127,12 +127,12 @@ func NewCryptoSetupServer(
|
||||||
)
|
)
|
||||||
cs.allow0RTT = allow0RTT
|
cs.allow0RTT = allow0RTT
|
||||||
|
|
||||||
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
|
quicConf := &tls.QUICConfig{TLSConfig: tlsConf}
|
||||||
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket)
|
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket)
|
||||||
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
|
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
|
||||||
|
|
||||||
cs.tlsConf = quicConf.TLSConfig
|
cs.tlsConf = quicConf.TLSConfig
|
||||||
cs.conn = qtls.QUICServer(quicConf)
|
cs.conn = tls.QUICServer(quicConf)
|
||||||
|
|
||||||
return cs
|
return cs
|
||||||
}
|
}
|
||||||
|
@ -264,28 +264,28 @@ func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLev
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
|
func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) (done bool, err error) {
|
||||||
switch ev.Kind {
|
switch ev.Kind {
|
||||||
case qtls.QUICNoEvent:
|
case tls.QUICNoEvent:
|
||||||
return true, nil
|
return true, nil
|
||||||
case qtls.QUICSetReadSecret:
|
case tls.QUICSetReadSecret:
|
||||||
h.SetReadKey(ev.Level, ev.Suite, ev.Data)
|
h.SetReadKey(ev.Level, ev.Suite, ev.Data)
|
||||||
return false, nil
|
return false, nil
|
||||||
case qtls.QUICSetWriteSecret:
|
case tls.QUICSetWriteSecret:
|
||||||
h.SetWriteKey(ev.Level, ev.Suite, ev.Data)
|
h.SetWriteKey(ev.Level, ev.Suite, ev.Data)
|
||||||
return false, nil
|
return false, nil
|
||||||
case qtls.QUICTransportParameters:
|
case tls.QUICTransportParameters:
|
||||||
return false, h.handleTransportParameters(ev.Data)
|
return false, h.handleTransportParameters(ev.Data)
|
||||||
case qtls.QUICTransportParametersRequired:
|
case tls.QUICTransportParametersRequired:
|
||||||
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
|
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
|
||||||
return false, nil
|
return false, nil
|
||||||
case qtls.QUICRejectedEarlyData:
|
case tls.QUICRejectedEarlyData:
|
||||||
h.rejected0RTT()
|
h.rejected0RTT()
|
||||||
return false, nil
|
return false, nil
|
||||||
case qtls.QUICWriteData:
|
case tls.QUICWriteData:
|
||||||
h.WriteRecord(ev.Level, ev.Data)
|
h.writeRecord(ev.Level, ev.Data)
|
||||||
return false, nil
|
return false, nil
|
||||||
case qtls.QUICHandshakeDone:
|
case tls.QUICHandshakeDone:
|
||||||
h.handshakeComplete()
|
h.handshakeComplete()
|
||||||
return false, nil
|
return false, nil
|
||||||
default:
|
default:
|
||||||
|
@ -313,19 +313,24 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// must be called after receiving the transport parameters
|
// must be called after receiving the transport parameters
|
||||||
func (h *cryptoSetup) marshalDataForSessionState() []byte {
|
func (h *cryptoSetup) marshalDataForSessionState(earlyData bool) []byte {
|
||||||
b := make([]byte, 0, 256)
|
b := make([]byte, 0, 256)
|
||||||
b = quicvarint.Append(b, clientSessionStateRevision)
|
b = quicvarint.Append(b, clientSessionStateRevision)
|
||||||
b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds()))
|
b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds()))
|
||||||
return h.peerParams.MarshalForSessionTicket(b)
|
if earlyData {
|
||||||
|
// only save the transport parameters for 0-RTT enabled session tickets
|
||||||
|
return h.peerParams.MarshalForSessionTicket(b)
|
||||||
|
}
|
||||||
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) handleDataFromSessionState(data []byte) (allowEarlyData bool) {
|
func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (allowEarlyData bool) {
|
||||||
tp, err := h.handleDataFromSessionStateImpl(data)
|
rtt, tp, err := decodeDataFromSessionState(data, earlyData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
|
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
h.rttStats.SetInitialRTT(rtt)
|
||||||
// The session ticket might have been saved from a connection that allowed 0-RTT,
|
// The session ticket might have been saved from a connection that allowed 0-RTT,
|
||||||
// and therefore contain transport parameters.
|
// and therefore contain transport parameters.
|
||||||
// Only use them if 0-RTT is actually used on the new connection.
|
// Only use them if 0-RTT is actually used on the new connection.
|
||||||
|
@ -336,25 +341,28 @@ func (h *cryptoSetup) handleDataFromSessionState(data []byte) (allowEarlyData bo
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) {
|
func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
|
||||||
r := bytes.NewReader(data)
|
r := bytes.NewReader(data)
|
||||||
ver, err := quicvarint.Read(r)
|
ver, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
if ver != clientSessionStateRevision {
|
if ver != clientSessionStateRevision {
|
||||||
return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
|
return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
|
||||||
}
|
}
|
||||||
rtt, err := quicvarint.Read(r)
|
rttEncoded, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
rtt := time.Duration(rttEncoded) * time.Microsecond
|
||||||
|
if !earlyData {
|
||||||
|
return rtt, nil, nil
|
||||||
}
|
}
|
||||||
h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
|
|
||||||
var tp wire.TransportParameters
|
var tp wire.TransportParameters
|
||||||
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
|
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
|
||||||
return nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
return &tp, nil
|
return rtt, &tp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) getDataForSessionTicket() []byte {
|
func (h *cryptoSetup) getDataForSessionTicket() []byte {
|
||||||
|
@ -371,7 +379,9 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte {
|
||||||
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
|
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
|
||||||
// It is only valid for the server.
|
// It is only valid for the server.
|
||||||
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||||
if err := qtls.SendSessionTicket(h.conn, h.allow0RTT); err != nil {
|
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{
|
||||||
|
EarlyData: h.allow0RTT,
|
||||||
|
}); err != nil {
|
||||||
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
|
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
|
||||||
// We can't check h.tlsConfig here, since the actual config might have been obtained from
|
// We can't check h.tlsConfig here, since the actual config might have been obtained from
|
||||||
// the GetConfigForClient callback.
|
// the GetConfigForClient callback.
|
||||||
|
@ -383,11 +393,11 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ev := h.conn.NextEvent()
|
ev := h.conn.NextEvent()
|
||||||
if ev.Kind != qtls.QUICWriteData || ev.Level != qtls.QUICEncryptionLevelApplication {
|
if ev.Kind != tls.QUICWriteData || ev.Level != tls.QUICEncryptionLevelApplication {
|
||||||
panic("crypto/tls bug: where's my session ticket?")
|
panic("crypto/tls bug: where's my session ticket?")
|
||||||
}
|
}
|
||||||
ticket := ev.Data
|
ticket := ev.Data
|
||||||
if ev := h.conn.NextEvent(); ev.Kind != qtls.QUICNoEvent {
|
if ev := h.conn.NextEvent(); ev.Kind != tls.QUICNoEvent {
|
||||||
panic("crypto/tls bug: why more than one ticket?")
|
panic("crypto/tls bug: why more than one ticket?")
|
||||||
}
|
}
|
||||||
return ticket, nil
|
return ticket, nil
|
||||||
|
@ -434,12 +444,12 @@ func (h *cryptoSetup) rejected0RTT() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||||
suite := getCipherSuite(suiteID)
|
suite := getCipherSuite(suiteID)
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
||||||
switch el {
|
switch el {
|
||||||
case qtls.QUICEncryptionLevelEarly:
|
case tls.QUICEncryptionLevelEarly:
|
||||||
if h.perspective == protocol.PerspectiveClient {
|
if h.perspective == protocol.PerspectiveClient {
|
||||||
panic("Received 0-RTT read key for the client")
|
panic("Received 0-RTT read key for the client")
|
||||||
}
|
}
|
||||||
|
@ -451,7 +461,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr
|
||||||
if h.logger.Debug() {
|
if h.logger.Debug() {
|
||||||
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||||
}
|
}
|
||||||
case qtls.QUICEncryptionLevelHandshake:
|
case tls.QUICEncryptionLevelHandshake:
|
||||||
h.handshakeOpener = newLongHeaderOpener(
|
h.handshakeOpener = newLongHeaderOpener(
|
||||||
createAEAD(suite, trafficSecret, h.version),
|
createAEAD(suite, trafficSecret, h.version),
|
||||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||||
|
@ -459,7 +469,7 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr
|
||||||
if h.logger.Debug() {
|
if h.logger.Debug() {
|
||||||
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||||
}
|
}
|
||||||
case qtls.QUICEncryptionLevelApplication:
|
case tls.QUICEncryptionLevelApplication:
|
||||||
h.aead.SetReadKey(suite, trafficSecret)
|
h.aead.SetReadKey(suite, trafficSecret)
|
||||||
h.has1RTTOpener = true
|
h.has1RTTOpener = true
|
||||||
if h.logger.Debug() {
|
if h.logger.Debug() {
|
||||||
|
@ -475,12 +485,12 @@ func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, tr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) {
|
||||||
suite := getCipherSuite(suiteID)
|
suite := getCipherSuite(suiteID)
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
//nolint:exhaustive // The TLS stack doesn't export Initial keys.
|
||||||
switch el {
|
switch el {
|
||||||
case qtls.QUICEncryptionLevelEarly:
|
case tls.QUICEncryptionLevelEarly:
|
||||||
if h.perspective == protocol.PerspectiveServer {
|
if h.perspective == protocol.PerspectiveServer {
|
||||||
panic("Received 0-RTT write key for the server")
|
panic("Received 0-RTT write key for the server")
|
||||||
}
|
}
|
||||||
|
@ -497,7 +507,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t
|
||||||
}
|
}
|
||||||
// don't set used0RTT here. 0-RTT might still get rejected.
|
// don't set used0RTT here. 0-RTT might still get rejected.
|
||||||
return
|
return
|
||||||
case qtls.QUICEncryptionLevelHandshake:
|
case tls.QUICEncryptionLevelHandshake:
|
||||||
h.handshakeSealer = newLongHeaderSealer(
|
h.handshakeSealer = newLongHeaderSealer(
|
||||||
createAEAD(suite, trafficSecret, h.version),
|
createAEAD(suite, trafficSecret, h.version),
|
||||||
newHeaderProtector(suite, trafficSecret, true, h.version),
|
newHeaderProtector(suite, trafficSecret, true, h.version),
|
||||||
|
@ -505,7 +515,7 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t
|
||||||
if h.logger.Debug() {
|
if h.logger.Debug() {
|
||||||
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
|
||||||
}
|
}
|
||||||
case qtls.QUICEncryptionLevelApplication:
|
case tls.QUICEncryptionLevelApplication:
|
||||||
h.aead.SetWriteKey(suite, trafficSecret)
|
h.aead.SetWriteKey(suite, trafficSecret)
|
||||||
h.has1RTTSealer = true
|
h.has1RTTSealer = true
|
||||||
if h.logger.Debug() {
|
if h.logger.Debug() {
|
||||||
|
@ -529,15 +539,15 @@ func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteRecord is called when TLS writes data
|
// writeRecord is called when TLS writes data
|
||||||
func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) {
|
func (h *cryptoSetup) writeRecord(encLevel tls.QUICEncryptionLevel, p []byte) {
|
||||||
//nolint:exhaustive // handshake records can only be written for Initial and Handshake.
|
//nolint:exhaustive // handshake records can only be written for Initial and Handshake.
|
||||||
switch encLevel {
|
switch encLevel {
|
||||||
case qtls.QUICEncryptionLevelInitial:
|
case tls.QUICEncryptionLevelInitial:
|
||||||
h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p})
|
h.events = append(h.events, Event{Kind: EventWriteInitialData, Data: p})
|
||||||
case qtls.QUICEncryptionLevelHandshake:
|
case tls.QUICEncryptionLevelHandshake:
|
||||||
h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p})
|
h.events = append(h.events, Event{Kind: EventWriteHandshakeData, Data: p})
|
||||||
case qtls.QUICEncryptionLevelApplication:
|
case tls.QUICEncryptionLevelApplication:
|
||||||
panic("unexpected write")
|
panic("unexpected write")
|
||||||
default:
|
default:
|
||||||
panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel))
|
panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel))
|
||||||
|
@ -684,7 +694,7 @@ func (h *cryptoSetup) ConnectionState() ConnectionState {
|
||||||
|
|
||||||
func wrapError(err error) error {
|
func wrapError(err error) error {
|
||||||
// alert 80 is an internal error
|
// alert 80 is an internal error
|
||||||
if alertErr := qtls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 {
|
if alertErr := tls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 {
|
||||||
return qerr.NewLocalCryptoError(uint8(alertErr), err)
|
return qerr.NewLocalCryptoError(uint8(alertErr), err)
|
||||||
}
|
}
|
||||||
return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()}
|
return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()}
|
||||||
|
|
|
@ -172,7 +172,7 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
a.highestRcvdPN = utils.Max(a.highestRcvdPN, pn)
|
a.highestRcvdPN = max(a.highestRcvdPN, pn)
|
||||||
}
|
}
|
||||||
return dec, err
|
return dec, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -129,9 +129,6 @@ const MaxPostHandshakeCryptoFrameSize = 1000
|
||||||
// but must ensure that a maximum size ACK frame fits into one packet.
|
// but must ensure that a maximum size ACK frame fits into one packet.
|
||||||
const MaxAckFrameSize ByteCount = 1000
|
const MaxAckFrameSize ByteCount = 1000
|
||||||
|
|
||||||
// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames (RFC 9221)
|
|
||||||
const DatagramRcvQueueLen = 128
|
|
||||||
|
|
||||||
// MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame.
|
// MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame.
|
||||||
// It also serves as a limit for the packet history.
|
// It also serves as a limit for the packet history.
|
||||||
// If at any point we keep track of more ranges, old ranges are discarded.
|
// If at any point we keep track of more ranges, old ranges are discarded.
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
package qerr
|
package qerr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/qtls"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TransportErrorCode is a QUIC transport error.
|
// TransportErrorCode is a QUIC transport error.
|
||||||
|
@ -40,7 +39,7 @@ func (e TransportErrorCode) Message() string {
|
||||||
if !e.IsCryptoError() {
|
if !e.IsCryptoError() {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return qtls.AlertError(e - 0x100).Error()
|
return tls.AlertError(e - 0x100).Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e TransportErrorCode) String() string {
|
func (e TransportErrorCode) String() string {
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
//go:build go1.21
|
|
||||||
|
|
||||||
package qtls
|
package qtls
|
||||||
|
|
||||||
import (
|
import (
|
|
@ -7,8 +7,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type clientSessionCache struct {
|
type clientSessionCache struct {
|
||||||
getData func() []byte
|
getData func(earlyData bool) []byte
|
||||||
setData func([]byte) (allowEarlyData bool)
|
setData func(data []byte, earlyData bool) (allowEarlyData bool)
|
||||||
wrapped tls.ClientSessionCache
|
wrapped tls.ClientSessionCache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) {
|
||||||
c.wrapped.Put(key, cs)
|
c.wrapped.Put(key, cs)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
state.Extra = append(state.Extra, addExtraPrefix(c.getData()))
|
state.Extra = append(state.Extra, addExtraPrefix(c.getData(state.EarlyData)))
|
||||||
newCS, err := tls.NewResumptionState(ticket, state)
|
newCS, err := tls.NewResumptionState(ticket, state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// It's not clear why this would error. Just save the original state.
|
// It's not clear why this would error. Just save the original state.
|
||||||
|
@ -46,12 +46,13 @@ func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) {
|
||||||
c.wrapped.Put(key, nil)
|
c.wrapped.Put(key, nil)
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
var earlyData bool
|
|
||||||
// restore QUIC transport parameters and RTT stored in state.Extra
|
// restore QUIC transport parameters and RTT stored in state.Extra
|
||||||
if extra := findExtraData(state.Extra); extra != nil {
|
if extra := findExtraData(state.Extra); extra != nil {
|
||||||
earlyData = c.setData(extra)
|
earlyData := c.setData(extra, state.EarlyData)
|
||||||
|
if state.EarlyData {
|
||||||
|
state.EarlyData = earlyData
|
||||||
|
}
|
||||||
}
|
}
|
||||||
state.EarlyData = earlyData
|
|
||||||
session, err := tls.NewResumptionState(ticket, state)
|
session, err := tls.NewResumptionState(ticket, state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// It's not clear why this would error.
|
// It's not clear why this would error.
|
||||||
|
|
|
@ -1,147 +0,0 @@
|
||||||
//go:build go1.20 && !go1.21
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
|
||||||
|
|
||||||
"github.com/quic-go/qtls-go1-20"
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
|
||||||
QUICConn = qtls.QUICConn
|
|
||||||
QUICConfig = qtls.QUICConfig
|
|
||||||
QUICEvent = qtls.QUICEvent
|
|
||||||
QUICEventKind = qtls.QUICEventKind
|
|
||||||
QUICEncryptionLevel = qtls.QUICEncryptionLevel
|
|
||||||
AlertError = qtls.AlertError
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
QUICEncryptionLevelInitial = qtls.QUICEncryptionLevelInitial
|
|
||||||
QUICEncryptionLevelEarly = qtls.QUICEncryptionLevelEarly
|
|
||||||
QUICEncryptionLevelHandshake = qtls.QUICEncryptionLevelHandshake
|
|
||||||
QUICEncryptionLevelApplication = qtls.QUICEncryptionLevelApplication
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
QUICNoEvent = qtls.QUICNoEvent
|
|
||||||
QUICSetReadSecret = qtls.QUICSetReadSecret
|
|
||||||
QUICSetWriteSecret = qtls.QUICSetWriteSecret
|
|
||||||
QUICWriteData = qtls.QUICWriteData
|
|
||||||
QUICTransportParameters = qtls.QUICTransportParameters
|
|
||||||
QUICTransportParametersRequired = qtls.QUICTransportParametersRequired
|
|
||||||
QUICRejectedEarlyData = qtls.QUICRejectedEarlyData
|
|
||||||
QUICHandshakeDone = qtls.QUICHandshakeDone
|
|
||||||
)
|
|
||||||
|
|
||||||
func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, handleSessionTicket func([]byte, bool) bool) {
|
|
||||||
qtls.InitSessionTicketKeys(conf.TLSConfig)
|
|
||||||
conf.TLSConfig = conf.TLSConfig.Clone()
|
|
||||||
conf.TLSConfig.MinVersion = tls.VersionTLS13
|
|
||||||
conf.ExtraConfig = &qtls.ExtraConfig{
|
|
||||||
Enable0RTT: enable0RTT,
|
|
||||||
Accept0RTT: func(data []byte) bool {
|
|
||||||
return handleSessionTicket(data, true)
|
|
||||||
},
|
|
||||||
GetAppDataForSessionTicket: getDataForSessionTicket,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetupConfigForClient(conf *QUICConfig, getDataForSessionState func() []byte, setDataFromSessionState func([]byte) bool) {
|
|
||||||
conf.ExtraConfig = &qtls.ExtraConfig{
|
|
||||||
GetAppDataForSessionState: getDataForSessionState,
|
|
||||||
SetAppDataFromSessionState: setDataFromSessionState,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func QUICServer(config *QUICConfig) *QUICConn {
|
|
||||||
return qtls.QUICServer(config)
|
|
||||||
}
|
|
||||||
|
|
||||||
func QUICClient(config *QUICConfig) *QUICConn {
|
|
||||||
return qtls.QUICClient(config)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ToTLSEncryptionLevel(e protocol.EncryptionLevel) qtls.QUICEncryptionLevel {
|
|
||||||
switch e {
|
|
||||||
case protocol.EncryptionInitial:
|
|
||||||
return qtls.QUICEncryptionLevelInitial
|
|
||||||
case protocol.EncryptionHandshake:
|
|
||||||
return qtls.QUICEncryptionLevelHandshake
|
|
||||||
case protocol.Encryption1RTT:
|
|
||||||
return qtls.QUICEncryptionLevelApplication
|
|
||||||
case protocol.Encryption0RTT:
|
|
||||||
return qtls.QUICEncryptionLevelEarly
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("unexpected encryption level: %s", e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func FromTLSEncryptionLevel(e qtls.QUICEncryptionLevel) protocol.EncryptionLevel {
|
|
||||||
switch e {
|
|
||||||
case qtls.QUICEncryptionLevelInitial:
|
|
||||||
return protocol.EncryptionInitial
|
|
||||||
case qtls.QUICEncryptionLevelHandshake:
|
|
||||||
return protocol.EncryptionHandshake
|
|
||||||
case qtls.QUICEncryptionLevelApplication:
|
|
||||||
return protocol.Encryption1RTT
|
|
||||||
case qtls.QUICEncryptionLevelEarly:
|
|
||||||
return protocol.Encryption0RTT
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("unexpect encryption level: %s", e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:linkname cipherSuitesTLS13 github.com/quic-go/qtls-go1-20.cipherSuitesTLS13
|
|
||||||
var cipherSuitesTLS13 []unsafe.Pointer
|
|
||||||
|
|
||||||
//go:linkname defaultCipherSuitesTLS13 github.com/quic-go/qtls-go1-20.defaultCipherSuitesTLS13
|
|
||||||
var defaultCipherSuitesTLS13 []uint16
|
|
||||||
|
|
||||||
//go:linkname defaultCipherSuitesTLS13NoAES github.com/quic-go/qtls-go1-20.defaultCipherSuitesTLS13NoAES
|
|
||||||
var defaultCipherSuitesTLS13NoAES []uint16
|
|
||||||
|
|
||||||
var cipherSuitesModified bool
|
|
||||||
|
|
||||||
// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls
|
|
||||||
// such that it only contains the cipher suite with the chosen id.
|
|
||||||
// The reset function returned resets them back to the original value.
|
|
||||||
func SetCipherSuite(id uint16) (reset func()) {
|
|
||||||
if cipherSuitesModified {
|
|
||||||
panic("cipher suites modified multiple times without resetting")
|
|
||||||
}
|
|
||||||
cipherSuitesModified = true
|
|
||||||
|
|
||||||
origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...)
|
|
||||||
origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...)
|
|
||||||
origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...)
|
|
||||||
// The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls.
|
|
||||||
switch id {
|
|
||||||
case tls.TLS_AES_128_GCM_SHA256:
|
|
||||||
cipherSuitesTLS13 = cipherSuitesTLS13[:1]
|
|
||||||
case tls.TLS_CHACHA20_POLY1305_SHA256:
|
|
||||||
cipherSuitesTLS13 = cipherSuitesTLS13[1:2]
|
|
||||||
case tls.TLS_AES_256_GCM_SHA384:
|
|
||||||
cipherSuitesTLS13 = cipherSuitesTLS13[2:]
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("unexpected cipher suite: %d", id))
|
|
||||||
}
|
|
||||||
defaultCipherSuitesTLS13 = []uint16{id}
|
|
||||||
defaultCipherSuitesTLS13NoAES = []uint16{id}
|
|
||||||
|
|
||||||
return func() {
|
|
||||||
cipherSuitesTLS13 = origCipherSuitesTLS13
|
|
||||||
defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13
|
|
||||||
defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES
|
|
||||||
cipherSuitesModified = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func SendSessionTicket(c *QUICConn, allow0RTT bool) error {
|
|
||||||
return c.SendSessionTicket(allow0RTT)
|
|
||||||
}
|
|
|
@ -1,5 +0,0 @@
|
||||||
//go:build !go1.20
|
|
||||||
|
|
||||||
package qtls
|
|
||||||
|
|
||||||
var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/quic-go/quic-go/wiki/quic-go-and-Go-versions."
|
|
|
@ -1,5 +1,3 @@
|
||||||
//go:build go1.21
|
|
||||||
|
|
||||||
package qtls
|
package qtls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -10,38 +8,7 @@ import (
|
||||||
"github.com/quic-go/quic-go/internal/protocol"
|
"github.com/quic-go/quic-go/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
func SetupConfigForServer(qconf *tls.QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
|
||||||
QUICConn = tls.QUICConn
|
|
||||||
QUICConfig = tls.QUICConfig
|
|
||||||
QUICEvent = tls.QUICEvent
|
|
||||||
QUICEventKind = tls.QUICEventKind
|
|
||||||
QUICEncryptionLevel = tls.QUICEncryptionLevel
|
|
||||||
QUICSessionTicketOptions = tls.QUICSessionTicketOptions
|
|
||||||
AlertError = tls.AlertError
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
QUICEncryptionLevelInitial = tls.QUICEncryptionLevelInitial
|
|
||||||
QUICEncryptionLevelEarly = tls.QUICEncryptionLevelEarly
|
|
||||||
QUICEncryptionLevelHandshake = tls.QUICEncryptionLevelHandshake
|
|
||||||
QUICEncryptionLevelApplication = tls.QUICEncryptionLevelApplication
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
QUICNoEvent = tls.QUICNoEvent
|
|
||||||
QUICSetReadSecret = tls.QUICSetReadSecret
|
|
||||||
QUICSetWriteSecret = tls.QUICSetWriteSecret
|
|
||||||
QUICWriteData = tls.QUICWriteData
|
|
||||||
QUICTransportParameters = tls.QUICTransportParameters
|
|
||||||
QUICTransportParametersRequired = tls.QUICTransportParametersRequired
|
|
||||||
QUICRejectedEarlyData = tls.QUICRejectedEarlyData
|
|
||||||
QUICHandshakeDone = tls.QUICHandshakeDone
|
|
||||||
)
|
|
||||||
|
|
||||||
func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) }
|
|
||||||
func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) }
|
|
||||||
|
|
||||||
func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
|
|
||||||
conf := qconf.TLSConfig
|
conf := qconf.TLSConfig
|
||||||
|
|
||||||
// Workaround for https://github.com/golang/go/issues/60506.
|
// Workaround for https://github.com/golang/go/issues/60506.
|
||||||
|
@ -93,7 +60,11 @@ func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, hand
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetupConfigForClient(qconf *QUICConfig, getData func() []byte, setData func([]byte) bool) {
|
func SetupConfigForClient(
|
||||||
|
qconf *tls.QUICConfig,
|
||||||
|
getData func(earlyData bool) []byte,
|
||||||
|
setData func(data []byte, earlyData bool) (allowEarlyData bool),
|
||||||
|
) {
|
||||||
conf := qconf.TLSConfig
|
conf := qconf.TLSConfig
|
||||||
if conf.ClientSessionCache != nil {
|
if conf.ClientSessionCache != nil {
|
||||||
origCache := conf.ClientSessionCache
|
origCache := conf.ClientSessionCache
|
||||||
|
@ -151,9 +122,3 @@ func findExtraData(extras [][]byte) []byte {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendSessionTicket(c *QUICConn, allow0RTT bool) error {
|
|
||||||
return c.SendSessionTicket(tls.QUICSessionTicketOptions{
|
|
||||||
EarlyData: allow0RTT,
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -3,27 +3,11 @@ package utils
|
||||||
import (
|
import (
|
||||||
"math"
|
"math"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/exp/constraints"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// InfDuration is a duration of infinite length
|
// InfDuration is a duration of infinite length
|
||||||
const InfDuration = time.Duration(math.MaxInt64)
|
const InfDuration = time.Duration(math.MaxInt64)
|
||||||
|
|
||||||
func Max[T constraints.Ordered](a, b T) T {
|
|
||||||
if a < b {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
func Min[T constraints.Ordered](a, b T) T {
|
|
||||||
if a < b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinNonZeroDuration return the minimum duration that's not zero.
|
// MinNonZeroDuration return the minimum duration that's not zero.
|
||||||
func MinNonZeroDuration(a, b time.Duration) time.Duration {
|
func MinNonZeroDuration(a, b time.Duration) time.Duration {
|
||||||
if a == 0 {
|
if a == 0 {
|
||||||
|
@ -32,15 +16,7 @@ func MinNonZeroDuration(a, b time.Duration) time.Duration {
|
||||||
if b == 0 {
|
if b == 0 {
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
return Min(a, b)
|
return min(a, b)
|
||||||
}
|
|
||||||
|
|
||||||
// AbsDuration returns the absolute value of a time duration
|
|
||||||
func AbsDuration(d time.Duration) time.Duration {
|
|
||||||
if d >= 0 {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
return -d
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MinTime returns the earlier time
|
// MinTime returns the earlier time
|
||||||
|
|
|
@ -8,7 +8,7 @@ type RingBuffer[T any] struct {
|
||||||
full bool
|
full bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init preallocs a buffer with a certain size.
|
// Init preallocates a buffer with a certain size.
|
||||||
func (r *RingBuffer[T]) Init(size int) {
|
func (r *RingBuffer[T]) Init(size int) {
|
||||||
r.ring = make([]T, size)
|
r.ring = make([]T, size)
|
||||||
}
|
}
|
||||||
|
@ -62,6 +62,16 @@ func (r *RingBuffer[T]) PopFront() T {
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeekFront returns the next element.
|
||||||
|
// It must not be called when the buffer is empty, that means that
|
||||||
|
// callers might need to check if there are elements in the buffer first.
|
||||||
|
func (r *RingBuffer[T]) PeekFront() T {
|
||||||
|
if r.Empty() {
|
||||||
|
panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: peek from an empty queue")
|
||||||
|
}
|
||||||
|
return r.ring[r.headPos]
|
||||||
|
}
|
||||||
|
|
||||||
// Grow the maximum size of the queue.
|
// Grow the maximum size of the queue.
|
||||||
// This method assume the queue is full.
|
// This method assume the queue is full.
|
||||||
func (r *RingBuffer[T]) grow() {
|
func (r *RingBuffer[T]) grow() {
|
||||||
|
|
|
@ -55,7 +55,7 @@ func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration {
|
||||||
if r.SmoothedRTT() == 0 {
|
if r.SmoothedRTT() == 0 {
|
||||||
return 2 * defaultInitialRTT
|
return 2 * defaultInitialRTT
|
||||||
}
|
}
|
||||||
pto := r.SmoothedRTT() + Max(4*r.MeanDeviation(), protocol.TimerGranularity)
|
pto := r.SmoothedRTT() + max(4*r.MeanDeviation(), protocol.TimerGranularity)
|
||||||
if includeMaxAckDelay {
|
if includeMaxAckDelay {
|
||||||
pto += r.MaxAckDelay()
|
pto += r.MaxAckDelay()
|
||||||
}
|
}
|
||||||
|
@ -90,7 +90,7 @@ func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
|
||||||
r.smoothedRTT = sample
|
r.smoothedRTT = sample
|
||||||
r.meanDeviation = sample / 2
|
r.meanDeviation = sample / 2
|
||||||
} else {
|
} else {
|
||||||
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
|
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32((r.smoothedRTT-sample).Abs()/time.Microsecond)) * time.Microsecond
|
||||||
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
|
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -126,6 +126,6 @@ func (r *RTTStats) OnConnectionMigration() {
|
||||||
// is larger. The mean deviation is increased to the most recent deviation if
|
// is larger. The mean deviation is increased to the most recent deviation if
|
||||||
// it's larger.
|
// it's larger.
|
||||||
func (r *RTTStats) ExpireSmoothedMetrics() {
|
func (r *RTTStats) ExpireSmoothedMetrics() {
|
||||||
r.meanDeviation = Max(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT))
|
r.meanDeviation = max(r.meanDeviation, (r.smoothedRTT - r.latestRTT).Abs())
|
||||||
r.smoothedRTT = Max(r.smoothedRTT, r.latestRTT)
|
r.smoothedRTT = max(r.smoothedRTT, r.latestRTT)
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen
|
||||||
|
|
||||||
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
|
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
|
||||||
if delayTime < 0 {
|
if delayTime < 0 {
|
||||||
// If the delay time overflows, set it to the maximum encodable value.
|
// If the delay time overflows, set it to the maximum encode-able value.
|
||||||
delayTime = utils.InfDuration
|
delayTime = utils.InfDuration
|
||||||
}
|
}
|
||||||
frame.DelayTime = delayTime
|
frame.DelayTime = delayTime
|
||||||
|
@ -57,9 +57,9 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen
|
||||||
return errors.New("invalid first ACK range")
|
return errors.New("invalid first ACK range")
|
||||||
}
|
}
|
||||||
smallest := largestAcked - ackBlock
|
smallest := largestAcked - ackBlock
|
||||||
|
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
|
||||||
|
|
||||||
// read all the other ACK ranges
|
// read all the other ACK ranges
|
||||||
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
|
|
||||||
for i := uint64(0); i < numBlocks; i++ {
|
for i := uint64(0); i < numBlocks; i++ {
|
||||||
g, err := quicvarint.Read(r)
|
g, err := quicvarint.Read(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -294,7 +294,7 @@ func (p *TransportParameters) readNumericTransportParameter(
|
||||||
return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount)
|
return fmt.Errorf("initial_max_streams_uni too large: %d (maximum %d)", p.MaxUniStreamNum, protocol.MaxStreamCount)
|
||||||
}
|
}
|
||||||
case maxIdleTimeoutParameterID:
|
case maxIdleTimeoutParameterID:
|
||||||
p.MaxIdleTimeout = utils.Max(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond)
|
p.MaxIdleTimeout = max(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond)
|
||||||
case maxUDPPayloadSizeParameterID:
|
case maxUDPPayloadSizeParameterID:
|
||||||
if val < 1200 {
|
if val < 1200 {
|
||||||
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val)
|
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val)
|
||||||
|
|
|
@ -34,6 +34,7 @@ type ConnectionTracer struct {
|
||||||
LossTimerExpired func(TimerType, EncryptionLevel)
|
LossTimerExpired func(TimerType, EncryptionLevel)
|
||||||
LossTimerCanceled func()
|
LossTimerCanceled func()
|
||||||
ECNStateUpdated func(state ECNState, trigger ECNStateTrigger)
|
ECNStateUpdated func(state ECNState, trigger ECNStateTrigger)
|
||||||
|
ChoseALPN func(protocol string)
|
||||||
// Close is called when the connection is closed.
|
// Close is called when the connection is closed.
|
||||||
Close func()
|
Close func()
|
||||||
Debug func(name, msg string)
|
Debug func(name, msg string)
|
||||||
|
@ -237,6 +238,13 @@ func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTra
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
ChoseALPN: func(protocol string) {
|
||||||
|
for _, t := range tracers {
|
||||||
|
if t.ChoseALPN != nil {
|
||||||
|
t.ChoseALPN(protocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
Close: func() {
|
Close: func() {
|
||||||
for _, t := range tracers {
|
for _, t := range tracers {
|
||||||
if t.Close != nil {
|
if t.Close != nil {
|
||||||
|
|
|
@ -3,12 +3,12 @@
|
||||||
# Install Go manually, since oss-fuzz ships with an outdated Go version.
|
# Install Go manually, since oss-fuzz ships with an outdated Go version.
|
||||||
# See https://github.com/google/oss-fuzz/pull/10643.
|
# See https://github.com/google/oss-fuzz/pull/10643.
|
||||||
export CXX="${CXX} -lresolv" # required by Go 1.20
|
export CXX="${CXX} -lresolv" # required by Go 1.20
|
||||||
wget https://go.dev/dl/go1.20.5.linux-amd64.tar.gz \
|
wget https://go.dev/dl/go1.21.5.linux-amd64.tar.gz \
|
||||||
&& mkdir temp-go \
|
&& mkdir temp-go \
|
||||||
&& rm -rf /root/.go/* \
|
&& rm -rf /root/.go/* \
|
||||||
&& tar -C temp-go/ -xzf go1.20.5.linux-amd64.tar.gz \
|
&& tar -C temp-go/ -xzf go1.21.5.linux-amd64.tar.gz \
|
||||||
&& mv temp-go/go/* /root/.go/ \
|
&& mv temp-go/go/* /root/.go/ \
|
||||||
&& rm -rf temp-go go1.20.5.linux-amd64.tar.gz
|
&& rm -rf temp-go go1.21.5.linux-amd64.tar.gz
|
||||||
|
|
||||||
(
|
(
|
||||||
# fuzz qpack
|
# fuzz qpack
|
||||||
|
|
|
@ -606,11 +606,17 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
|
||||||
if p.datagramQueue != nil {
|
if p.datagramQueue != nil {
|
||||||
if f := p.datagramQueue.Peek(); f != nil {
|
if f := p.datagramQueue.Peek(); f != nil {
|
||||||
size := f.Length(v)
|
size := f.Length(v)
|
||||||
if size <= maxFrameSize-pl.length {
|
if size <= maxFrameSize-pl.length { // DATAGRAM frame fits
|
||||||
pl.frames = append(pl.frames, ackhandler.Frame{Frame: f})
|
pl.frames = append(pl.frames, ackhandler.Frame{Frame: f})
|
||||||
pl.length += size
|
pl.length += size
|
||||||
p.datagramQueue.Pop()
|
p.datagramQueue.Pop()
|
||||||
|
} else if !hasAck {
|
||||||
|
// The DATAGRAM frame doesn't fit, and the packet doesn't contain an ACK.
|
||||||
|
// Discard this frame. There's no point in retrying this in the next packet,
|
||||||
|
// as it's unlikely that the available packet size will increase.
|
||||||
|
p.datagramQueue.Pop()
|
||||||
}
|
}
|
||||||
|
// If the DATAGRAM frame was too large and the packet contained an ACK, we'll try to send it out later.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -640,7 +646,13 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
|
||||||
pl.length += lengthAdded
|
pl.length += lengthAdded
|
||||||
// add handlers for the control frames that were added
|
// add handlers for the control frames that were added
|
||||||
for i := startLen; i < len(pl.frames); i++ {
|
for i := startLen; i < len(pl.frames); i++ {
|
||||||
pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler()
|
switch pl.frames[i].Frame.(type) {
|
||||||
|
case *wire.PathChallengeFrame, *wire.PathResponseFrame:
|
||||||
|
// Path probing is currently not supported, therefore we don't need to set the OnAcked callback yet.
|
||||||
|
// PATH_CHALLENGE and PATH_RESPONSE are never retransmitted.
|
||||||
|
default:
|
||||||
|
pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pl.streamFrames, lengthAdded = p.framer.AppendStreamFrames(pl.streamFrames, maxFrameSize-pl.length, v)
|
pl.streamFrames, lengthAdded = p.framer.AppendStreamFrames(pl.streamFrames, maxFrameSize-pl.length, v)
|
||||||
|
|
|
@ -274,7 +274,7 @@ func (s *sendStream) popNewStreamFrame(maxBytes, sendWindow protocol.ByteCount,
|
||||||
nextFrame := s.nextFrame
|
nextFrame := s.nextFrame
|
||||||
s.nextFrame = nil
|
s.nextFrame = nil
|
||||||
|
|
||||||
maxDataLen := utils.Min(sendWindow, nextFrame.MaxDataLen(maxBytes, v))
|
maxDataLen := min(sendWindow, nextFrame.MaxDataLen(maxBytes, v))
|
||||||
if nextFrame.DataLen() > maxDataLen {
|
if nextFrame.DataLen() > maxDataLen {
|
||||||
s.nextFrame = wire.GetStreamFrame()
|
s.nextFrame = wire.GetStreamFrame()
|
||||||
s.nextFrame.StreamID = s.streamID
|
s.nextFrame.StreamID = s.streamID
|
||||||
|
@ -309,7 +309,7 @@ func (s *sendStream) popNewStreamFrameWithoutBuffer(f *wire.StreamFrame, maxByte
|
||||||
if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
|
if maxDataLen == 0 { // a STREAM frame must have at least one byte of data
|
||||||
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
|
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
|
||||||
}
|
}
|
||||||
s.getDataForWriting(f, utils.Min(maxDataLen, sendWindow))
|
s.getDataForWriting(f, min(maxDataLen, sendWindow))
|
||||||
|
|
||||||
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
|
return s.dataForWriting != nil || s.nextFrame != nil || s.finishedWriting
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package quic
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/quic-go/quic-go/internal/utils"
|
|
||||||
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
|
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,14 +19,14 @@ func newSingleOriginTokenStore(size int) *singleOriginTokenStore {
|
||||||
func (s *singleOriginTokenStore) Add(token *ClientToken) {
|
func (s *singleOriginTokenStore) Add(token *ClientToken) {
|
||||||
s.tokens[s.p] = token
|
s.tokens[s.p] = token
|
||||||
s.p = s.index(s.p + 1)
|
s.p = s.index(s.p + 1)
|
||||||
s.len = utils.Min(s.len+1, len(s.tokens))
|
s.len = min(s.len+1, len(s.tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *singleOriginTokenStore) Pop() *ClientToken {
|
func (s *singleOriginTokenStore) Pop() *ClientToken {
|
||||||
s.p = s.index(s.p - 1)
|
s.p = s.index(s.p - 1)
|
||||||
token := s.tokens[s.p]
|
token := s.tokens[s.p]
|
||||||
s.tokens[s.p] = nil
|
s.tokens[s.p] = nil
|
||||||
s.len = utils.Max(s.len-1, 0)
|
s.len = max(s.len-1, 0)
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,825 +0,0 @@
|
||||||
// Copyright 2017 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package cryptobyte
|
|
||||||
|
|
||||||
import (
|
|
||||||
encoding_asn1 "encoding/asn1"
|
|
||||||
"fmt"
|
|
||||||
"math/big"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/cryptobyte/asn1"
|
|
||||||
)
|
|
||||||
|
|
||||||
// This file contains ASN.1-related methods for String and Builder.
|
|
||||||
|
|
||||||
// Builder
|
|
||||||
|
|
||||||
// AddASN1Int64 appends a DER-encoded ASN.1 INTEGER.
|
|
||||||
func (b *Builder) AddASN1Int64(v int64) {
|
|
||||||
b.addASN1Signed(asn1.INTEGER, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddASN1Int64WithTag appends a DER-encoded ASN.1 INTEGER with the
|
|
||||||
// given tag.
|
|
||||||
func (b *Builder) AddASN1Int64WithTag(v int64, tag asn1.Tag) {
|
|
||||||
b.addASN1Signed(tag, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddASN1Enum appends a DER-encoded ASN.1 ENUMERATION.
|
|
||||||
func (b *Builder) AddASN1Enum(v int64) {
|
|
||||||
b.addASN1Signed(asn1.ENUM, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Builder) addASN1Signed(tag asn1.Tag, v int64) {
|
|
||||||
b.AddASN1(tag, func(c *Builder) {
|
|
||||||
length := 1
|
|
||||||
for i := v; i >= 0x80 || i < -0x80; i >>= 8 {
|
|
||||||
length++
|
|
||||||
}
|
|
||||||
|
|
||||||
for ; length > 0; length-- {
|
|
||||||
i := v >> uint((length-1)*8) & 0xff
|
|
||||||
c.AddUint8(uint8(i))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddASN1Uint64 appends a DER-encoded ASN.1 INTEGER.
|
|
||||||
func (b *Builder) AddASN1Uint64(v uint64) {
|
|
||||||
b.AddASN1(asn1.INTEGER, func(c *Builder) {
|
|
||||||
length := 1
|
|
||||||
for i := v; i >= 0x80; i >>= 8 {
|
|
||||||
length++
|
|
||||||
}
|
|
||||||
|
|
||||||
for ; length > 0; length-- {
|
|
||||||
i := v >> uint((length-1)*8) & 0xff
|
|
||||||
c.AddUint8(uint8(i))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddASN1BigInt appends a DER-encoded ASN.1 INTEGER.
|
|
||||||
func (b *Builder) AddASN1BigInt(n *big.Int) {
|
|
||||||
if b.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
b.AddASN1(asn1.INTEGER, func(c *Builder) {
|
|
||||||
if n.Sign() < 0 {
|
|
||||||
// A negative number has to be converted to two's-complement form. So we
|
|
||||||
// invert and subtract 1. If the most-significant-bit isn't set then
|
|
||||||
// we'll need to pad the beginning with 0xff in order to keep the number
|
|
||||||
// negative.
|
|
||||||
nMinus1 := new(big.Int).Neg(n)
|
|
||||||
nMinus1.Sub(nMinus1, bigOne)
|
|
||||||
bytes := nMinus1.Bytes()
|
|
||||||
for i := range bytes {
|
|
||||||
bytes[i] ^= 0xff
|
|
||||||
}
|
|
||||||
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
|
|
||||||
c.add(0xff)
|
|
||||||
}
|
|
||||||
c.add(bytes...)
|
|
||||||
} else if n.Sign() == 0 {
|
|
||||||
c.add(0)
|
|
||||||
} else {
|
|
||||||
bytes := n.Bytes()
|
|
||||||
if bytes[0]&0x80 != 0 {
|
|
||||||
c.add(0)
|
|
||||||
}
|
|
||||||
c.add(bytes...)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddASN1OctetString appends a DER-encoded ASN.1 OCTET STRING.
|
|
||||||
func (b *Builder) AddASN1OctetString(bytes []byte) {
|
|
||||||
b.AddASN1(asn1.OCTET_STRING, func(c *Builder) {
|
|
||||||
c.AddBytes(bytes)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const generalizedTimeFormatStr = "20060102150405Z0700"
|
|
||||||
|
|
||||||
// AddASN1GeneralizedTime appends a DER-encoded ASN.1 GENERALIZEDTIME.
|
|
||||||
func (b *Builder) AddASN1GeneralizedTime(t time.Time) {
|
|
||||||
if t.Year() < 0 || t.Year() > 9999 {
|
|
||||||
b.err = fmt.Errorf("cryptobyte: cannot represent %v as a GeneralizedTime", t)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.AddASN1(asn1.GeneralizedTime, func(c *Builder) {
|
|
||||||
c.AddBytes([]byte(t.Format(generalizedTimeFormatStr)))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddASN1UTCTime appends a DER-encoded ASN.1 UTCTime.
|
|
||||||
func (b *Builder) AddASN1UTCTime(t time.Time) {
|
|
||||||
b.AddASN1(asn1.UTCTime, func(c *Builder) {
|
|
||||||
// As utilized by the X.509 profile, UTCTime can only
|
|
||||||
// represent the years 1950 through 2049.
|
|
||||||
if t.Year() < 1950 || t.Year() >= 2050 {
|
|
||||||
b.err = fmt.Errorf("cryptobyte: cannot represent %v as a UTCTime", t)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.AddBytes([]byte(t.Format(defaultUTCTimeFormatStr)))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddASN1BitString appends a DER-encoded ASN.1 BIT STRING. This does not
|
|
||||||
// support BIT STRINGs that are not a whole number of bytes.
|
|
||||||
func (b *Builder) AddASN1BitString(data []byte) {
|
|
||||||
b.AddASN1(asn1.BIT_STRING, func(b *Builder) {
|
|
||||||
b.AddUint8(0)
|
|
||||||
b.AddBytes(data)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Builder) addBase128Int(n int64) {
|
|
||||||
var length int
|
|
||||||
if n == 0 {
|
|
||||||
length = 1
|
|
||||||
} else {
|
|
||||||
for i := n; i > 0; i >>= 7 {
|
|
||||||
length++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := length - 1; i >= 0; i-- {
|
|
||||||
o := byte(n >> uint(i*7))
|
|
||||||
o &= 0x7f
|
|
||||||
if i != 0 {
|
|
||||||
o |= 0x80
|
|
||||||
}
|
|
||||||
|
|
||||||
b.add(o)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isValidOID(oid encoding_asn1.ObjectIdentifier) bool {
|
|
||||||
if len(oid) < 2 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if oid[0] > 2 || (oid[0] <= 1 && oid[1] >= 40) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, v := range oid {
|
|
||||||
if v < 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Builder) AddASN1ObjectIdentifier(oid encoding_asn1.ObjectIdentifier) {
|
|
||||||
b.AddASN1(asn1.OBJECT_IDENTIFIER, func(b *Builder) {
|
|
||||||
if !isValidOID(oid) {
|
|
||||||
b.err = fmt.Errorf("cryptobyte: invalid OID: %v", oid)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
b.addBase128Int(int64(oid[0])*40 + int64(oid[1]))
|
|
||||||
for _, v := range oid[2:] {
|
|
||||||
b.addBase128Int(int64(v))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Builder) AddASN1Boolean(v bool) {
|
|
||||||
b.AddASN1(asn1.BOOLEAN, func(b *Builder) {
|
|
||||||
if v {
|
|
||||||
b.AddUint8(0xff)
|
|
||||||
} else {
|
|
||||||
b.AddUint8(0)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Builder) AddASN1NULL() {
|
|
||||||
b.add(uint8(asn1.NULL), 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarshalASN1 calls encoding_asn1.Marshal on its input and appends the result if
|
|
||||||
// successful or records an error if one occurred.
|
|
||||||
func (b *Builder) MarshalASN1(v interface{}) {
|
|
||||||
// NOTE(martinkr): This is somewhat of a hack to allow propagation of
|
|
||||||
// encoding_asn1.Marshal errors into Builder.err. N.B. if you call MarshalASN1 with a
|
|
||||||
// value embedded into a struct, its tag information is lost.
|
|
||||||
if b.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
bytes, err := encoding_asn1.Marshal(v)
|
|
||||||
if err != nil {
|
|
||||||
b.err = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.AddBytes(bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddASN1 appends an ASN.1 object. The object is prefixed with the given tag.
|
|
||||||
// Tags greater than 30 are not supported and result in an error (i.e.
|
|
||||||
// low-tag-number form only). The child builder passed to the
|
|
||||||
// BuilderContinuation can be used to build the content of the ASN.1 object.
|
|
||||||
func (b *Builder) AddASN1(tag asn1.Tag, f BuilderContinuation) {
|
|
||||||
if b.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Identifiers with the low five bits set indicate high-tag-number format
|
|
||||||
// (two or more octets), which we don't support.
|
|
||||||
if tag&0x1f == 0x1f {
|
|
||||||
b.err = fmt.Errorf("cryptobyte: high-tag number identifier octects not supported: 0x%x", tag)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.AddUint8(uint8(tag))
|
|
||||||
b.addLengthPrefixed(1, true, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
// String
|
|
||||||
|
|
||||||
// ReadASN1Boolean decodes an ASN.1 BOOLEAN and converts it to a boolean
|
|
||||||
// representation into out and advances. It reports whether the read
|
|
||||||
// was successful.
|
|
||||||
func (s *String) ReadASN1Boolean(out *bool) bool {
|
|
||||||
var bytes String
|
|
||||||
if !s.ReadASN1(&bytes, asn1.BOOLEAN) || len(bytes) != 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
switch bytes[0] {
|
|
||||||
case 0:
|
|
||||||
*out = false
|
|
||||||
case 0xff:
|
|
||||||
*out = true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadASN1Integer decodes an ASN.1 INTEGER into out and advances. If out does
|
|
||||||
// not point to an integer, to a big.Int, or to a []byte it panics. Only
|
|
||||||
// positive and zero values can be decoded into []byte, and they are returned as
|
|
||||||
// big-endian binary values that share memory with s. Positive values will have
|
|
||||||
// no leading zeroes, and zero will be returned as a single zero byte.
|
|
||||||
// ReadASN1Integer reports whether the read was successful.
|
|
||||||
func (s *String) ReadASN1Integer(out interface{}) bool {
|
|
||||||
switch out := out.(type) {
|
|
||||||
case *int, *int8, *int16, *int32, *int64:
|
|
||||||
var i int64
|
|
||||||
if !s.readASN1Int64(&i) || reflect.ValueOf(out).Elem().OverflowInt(i) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
reflect.ValueOf(out).Elem().SetInt(i)
|
|
||||||
return true
|
|
||||||
case *uint, *uint8, *uint16, *uint32, *uint64:
|
|
||||||
var u uint64
|
|
||||||
if !s.readASN1Uint64(&u) || reflect.ValueOf(out).Elem().OverflowUint(u) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
reflect.ValueOf(out).Elem().SetUint(u)
|
|
||||||
return true
|
|
||||||
case *big.Int:
|
|
||||||
return s.readASN1BigInt(out)
|
|
||||||
case *[]byte:
|
|
||||||
return s.readASN1Bytes(out)
|
|
||||||
default:
|
|
||||||
panic("out does not point to an integer type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkASN1Integer(bytes []byte) bool {
|
|
||||||
if len(bytes) == 0 {
|
|
||||||
// An INTEGER is encoded with at least one octet.
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if len(bytes) == 1 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if bytes[0] == 0 && bytes[1]&0x80 == 0 || bytes[0] == 0xff && bytes[1]&0x80 == 0x80 {
|
|
||||||
// Value is not minimally encoded.
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
var bigOne = big.NewInt(1)
|
|
||||||
|
|
||||||
func (s *String) readASN1BigInt(out *big.Int) bool {
|
|
||||||
var bytes String
|
|
||||||
if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if bytes[0]&0x80 == 0x80 {
|
|
||||||
// Negative number.
|
|
||||||
neg := make([]byte, len(bytes))
|
|
||||||
for i, b := range bytes {
|
|
||||||
neg[i] = ^b
|
|
||||||
}
|
|
||||||
out.SetBytes(neg)
|
|
||||||
out.Add(out, bigOne)
|
|
||||||
out.Neg(out)
|
|
||||||
} else {
|
|
||||||
out.SetBytes(bytes)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *String) readASN1Bytes(out *[]byte) bool {
|
|
||||||
var bytes String
|
|
||||||
if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if bytes[0]&0x80 == 0x80 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for len(bytes) > 1 && bytes[0] == 0 {
|
|
||||||
bytes = bytes[1:]
|
|
||||||
}
|
|
||||||
*out = bytes
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *String) readASN1Int64(out *int64) bool {
|
|
||||||
var bytes String
|
|
||||||
if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) || !asn1Signed(out, bytes) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func asn1Signed(out *int64, n []byte) bool {
|
|
||||||
length := len(n)
|
|
||||||
if length > 8 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := 0; i < length; i++ {
|
|
||||||
*out <<= 8
|
|
||||||
*out |= int64(n[i])
|
|
||||||
}
|
|
||||||
// Shift up and down in order to sign extend the result.
|
|
||||||
*out <<= 64 - uint8(length)*8
|
|
||||||
*out >>= 64 - uint8(length)*8
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *String) readASN1Uint64(out *uint64) bool {
|
|
||||||
var bytes String
|
|
||||||
if !s.ReadASN1(&bytes, asn1.INTEGER) || !checkASN1Integer(bytes) || !asn1Unsigned(out, bytes) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func asn1Unsigned(out *uint64, n []byte) bool {
|
|
||||||
length := len(n)
|
|
||||||
if length > 9 || length == 9 && n[0] != 0 {
|
|
||||||
// Too large for uint64.
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if n[0]&0x80 != 0 {
|
|
||||||
// Negative number.
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := 0; i < length; i++ {
|
|
||||||
*out <<= 8
|
|
||||||
*out |= uint64(n[i])
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadASN1Int64WithTag decodes an ASN.1 INTEGER with the given tag into out
|
|
||||||
// and advances. It reports whether the read was successful and resulted in a
|
|
||||||
// value that can be represented in an int64.
|
|
||||||
func (s *String) ReadASN1Int64WithTag(out *int64, tag asn1.Tag) bool {
|
|
||||||
var bytes String
|
|
||||||
return s.ReadASN1(&bytes, tag) && checkASN1Integer(bytes) && asn1Signed(out, bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadASN1Enum decodes an ASN.1 ENUMERATION into out and advances. It reports
|
|
||||||
// whether the read was successful.
|
|
||||||
func (s *String) ReadASN1Enum(out *int) bool {
|
|
||||||
var bytes String
|
|
||||||
var i int64
|
|
||||||
if !s.ReadASN1(&bytes, asn1.ENUM) || !checkASN1Integer(bytes) || !asn1Signed(&i, bytes) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if int64(int(i)) != i {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*out = int(i)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *String) readBase128Int(out *int) bool {
|
|
||||||
ret := 0
|
|
||||||
for i := 0; len(*s) > 0; i++ {
|
|
||||||
if i == 5 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// Avoid overflowing int on a 32-bit platform.
|
|
||||||
// We don't want different behavior based on the architecture.
|
|
||||||
if ret >= 1<<(31-7) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
ret <<= 7
|
|
||||||
b := s.read(1)[0]
|
|
||||||
|
|
||||||
// ITU-T X.690, section 8.19.2:
|
|
||||||
// The subidentifier shall be encoded in the fewest possible octets,
|
|
||||||
// that is, the leading octet of the subidentifier shall not have the value 0x80.
|
|
||||||
if i == 0 && b == 0x80 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
ret |= int(b & 0x7f)
|
|
||||||
if b&0x80 == 0 {
|
|
||||||
*out = ret
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false // truncated
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadASN1ObjectIdentifier decodes an ASN.1 OBJECT IDENTIFIER into out and
|
|
||||||
// advances. It reports whether the read was successful.
|
|
||||||
func (s *String) ReadASN1ObjectIdentifier(out *encoding_asn1.ObjectIdentifier) bool {
|
|
||||||
var bytes String
|
|
||||||
if !s.ReadASN1(&bytes, asn1.OBJECT_IDENTIFIER) || len(bytes) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// In the worst case, we get two elements from the first byte (which is
|
|
||||||
// encoded differently) and then every varint is a single byte long.
|
|
||||||
components := make([]int, len(bytes)+1)
|
|
||||||
|
|
||||||
// The first varint is 40*value1 + value2:
|
|
||||||
// According to this packing, value1 can take the values 0, 1 and 2 only.
|
|
||||||
// When value1 = 0 or value1 = 1, then value2 is <= 39. When value1 = 2,
|
|
||||||
// then there are no restrictions on value2.
|
|
||||||
var v int
|
|
||||||
if !bytes.readBase128Int(&v) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if v < 80 {
|
|
||||||
components[0] = v / 40
|
|
||||||
components[1] = v % 40
|
|
||||||
} else {
|
|
||||||
components[0] = 2
|
|
||||||
components[1] = v - 80
|
|
||||||
}
|
|
||||||
|
|
||||||
i := 2
|
|
||||||
for ; len(bytes) > 0; i++ {
|
|
||||||
if !bytes.readBase128Int(&v) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
components[i] = v
|
|
||||||
}
|
|
||||||
*out = components[:i]
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadASN1GeneralizedTime decodes an ASN.1 GENERALIZEDTIME into out and
|
|
||||||
// advances. It reports whether the read was successful.
|
|
||||||
func (s *String) ReadASN1GeneralizedTime(out *time.Time) bool {
|
|
||||||
var bytes String
|
|
||||||
if !s.ReadASN1(&bytes, asn1.GeneralizedTime) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
t := string(bytes)
|
|
||||||
res, err := time.Parse(generalizedTimeFormatStr, t)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if serialized := res.Format(generalizedTimeFormatStr); serialized != t {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*out = res
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
const defaultUTCTimeFormatStr = "060102150405Z0700"
|
|
||||||
|
|
||||||
// ReadASN1UTCTime decodes an ASN.1 UTCTime into out and advances.
|
|
||||||
// It reports whether the read was successful.
|
|
||||||
func (s *String) ReadASN1UTCTime(out *time.Time) bool {
|
|
||||||
var bytes String
|
|
||||||
if !s.ReadASN1(&bytes, asn1.UTCTime) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
t := string(bytes)
|
|
||||||
|
|
||||||
formatStr := defaultUTCTimeFormatStr
|
|
||||||
var err error
|
|
||||||
res, err := time.Parse(formatStr, t)
|
|
||||||
if err != nil {
|
|
||||||
// Fallback to minute precision if we can't parse second
|
|
||||||
// precision. If we are following X.509 or X.690 we shouldn't
|
|
||||||
// support this, but we do.
|
|
||||||
formatStr = "0601021504Z0700"
|
|
||||||
res, err = time.Parse(formatStr, t)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if serialized := res.Format(formatStr); serialized != t {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.Year() >= 2050 {
|
|
||||||
// UTCTime interprets the low order digits 50-99 as 1950-99.
|
|
||||||
// This only applies to its use in the X.509 profile.
|
|
||||||
// See https://tools.ietf.org/html/rfc5280#section-4.1.2.5.1
|
|
||||||
res = res.AddDate(-100, 0, 0)
|
|
||||||
}
|
|
||||||
*out = res
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadASN1BitString decodes an ASN.1 BIT STRING into out and advances.
|
|
||||||
// It reports whether the read was successful.
|
|
||||||
func (s *String) ReadASN1BitString(out *encoding_asn1.BitString) bool {
|
|
||||||
var bytes String
|
|
||||||
if !s.ReadASN1(&bytes, asn1.BIT_STRING) || len(bytes) == 0 ||
|
|
||||||
len(bytes)*8/8 != len(bytes) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
paddingBits := bytes[0]
|
|
||||||
bytes = bytes[1:]
|
|
||||||
if paddingBits > 7 ||
|
|
||||||
len(bytes) == 0 && paddingBits != 0 ||
|
|
||||||
len(bytes) > 0 && bytes[len(bytes)-1]&(1<<paddingBits-1) != 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
out.BitLength = len(bytes)*8 - int(paddingBits)
|
|
||||||
out.Bytes = bytes
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadASN1BitStringAsBytes decodes an ASN.1 BIT STRING into out and advances. It is
|
|
||||||
// an error if the BIT STRING is not a whole number of bytes. It reports
|
|
||||||
// whether the read was successful.
|
|
||||||
func (s *String) ReadASN1BitStringAsBytes(out *[]byte) bool {
|
|
||||||
var bytes String
|
|
||||||
if !s.ReadASN1(&bytes, asn1.BIT_STRING) || len(bytes) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
paddingBits := bytes[0]
|
|
||||||
if paddingBits != 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*out = bytes[1:]
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadASN1Bytes reads the contents of a DER-encoded ASN.1 element (not including
|
|
||||||
// tag and length bytes) into out, and advances. The element must match the
|
|
||||||
// given tag. It reports whether the read was successful.
|
|
||||||
func (s *String) ReadASN1Bytes(out *[]byte, tag asn1.Tag) bool {
|
|
||||||
return s.ReadASN1((*String)(out), tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadASN1 reads the contents of a DER-encoded ASN.1 element (not including
|
|
||||||
// tag and length bytes) into out, and advances. The element must match the
|
|
||||||
// given tag. It reports whether the read was successful.
|
|
||||||
//
|
|
||||||
// Tags greater than 30 are not supported (i.e. low-tag-number format only).
|
|
||||||
func (s *String) ReadASN1(out *String, tag asn1.Tag) bool {
|
|
||||||
var t asn1.Tag
|
|
||||||
if !s.ReadAnyASN1(out, &t) || t != tag {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadASN1Element reads the contents of a DER-encoded ASN.1 element (including
|
|
||||||
// tag and length bytes) into out, and advances. The element must match the
|
|
||||||
// given tag. It reports whether the read was successful.
|
|
||||||
//
|
|
||||||
// Tags greater than 30 are not supported (i.e. low-tag-number format only).
|
|
||||||
func (s *String) ReadASN1Element(out *String, tag asn1.Tag) bool {
|
|
||||||
var t asn1.Tag
|
|
||||||
if !s.ReadAnyASN1Element(out, &t) || t != tag {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadAnyASN1 reads the contents of a DER-encoded ASN.1 element (not including
|
|
||||||
// tag and length bytes) into out, sets outTag to its tag, and advances.
|
|
||||||
// It reports whether the read was successful.
|
|
||||||
//
|
|
||||||
// Tags greater than 30 are not supported (i.e. low-tag-number format only).
|
|
||||||
func (s *String) ReadAnyASN1(out *String, outTag *asn1.Tag) bool {
|
|
||||||
return s.readASN1(out, outTag, true /* skip header */)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadAnyASN1Element reads the contents of a DER-encoded ASN.1 element
|
|
||||||
// (including tag and length bytes) into out, sets outTag to is tag, and
|
|
||||||
// advances. It reports whether the read was successful.
|
|
||||||
//
|
|
||||||
// Tags greater than 30 are not supported (i.e. low-tag-number format only).
|
|
||||||
func (s *String) ReadAnyASN1Element(out *String, outTag *asn1.Tag) bool {
|
|
||||||
return s.readASN1(out, outTag, false /* include header */)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PeekASN1Tag reports whether the next ASN.1 value on the string starts with
|
|
||||||
// the given tag.
|
|
||||||
func (s String) PeekASN1Tag(tag asn1.Tag) bool {
|
|
||||||
if len(s) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return asn1.Tag(s[0]) == tag
|
|
||||||
}
|
|
||||||
|
|
||||||
// SkipASN1 reads and discards an ASN.1 element with the given tag. It
|
|
||||||
// reports whether the operation was successful.
|
|
||||||
func (s *String) SkipASN1(tag asn1.Tag) bool {
|
|
||||||
var unused String
|
|
||||||
return s.ReadASN1(&unused, tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadOptionalASN1 attempts to read the contents of a DER-encoded ASN.1
|
|
||||||
// element (not including tag and length bytes) tagged with the given tag into
|
|
||||||
// out. It stores whether an element with the tag was found in outPresent,
|
|
||||||
// unless outPresent is nil. It reports whether the read was successful.
|
|
||||||
func (s *String) ReadOptionalASN1(out *String, outPresent *bool, tag asn1.Tag) bool {
|
|
||||||
present := s.PeekASN1Tag(tag)
|
|
||||||
if outPresent != nil {
|
|
||||||
*outPresent = present
|
|
||||||
}
|
|
||||||
if present && !s.ReadASN1(out, tag) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// SkipOptionalASN1 advances s over an ASN.1 element with the given tag, or
|
|
||||||
// else leaves s unchanged. It reports whether the operation was successful.
|
|
||||||
func (s *String) SkipOptionalASN1(tag asn1.Tag) bool {
|
|
||||||
if !s.PeekASN1Tag(tag) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
var unused String
|
|
||||||
return s.ReadASN1(&unused, tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadOptionalASN1Integer attempts to read an optional ASN.1 INTEGER explicitly
|
|
||||||
// tagged with tag into out and advances. If no element with a matching tag is
|
|
||||||
// present, it writes defaultValue into out instead. Otherwise, it behaves like
|
|
||||||
// ReadASN1Integer.
|
|
||||||
func (s *String) ReadOptionalASN1Integer(out interface{}, tag asn1.Tag, defaultValue interface{}) bool {
|
|
||||||
var present bool
|
|
||||||
var i String
|
|
||||||
if !s.ReadOptionalASN1(&i, &present, tag) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if !present {
|
|
||||||
switch out.(type) {
|
|
||||||
case *int, *int8, *int16, *int32, *int64,
|
|
||||||
*uint, *uint8, *uint16, *uint32, *uint64, *[]byte:
|
|
||||||
reflect.ValueOf(out).Elem().Set(reflect.ValueOf(defaultValue))
|
|
||||||
case *big.Int:
|
|
||||||
if defaultValue, ok := defaultValue.(*big.Int); ok {
|
|
||||||
out.(*big.Int).Set(defaultValue)
|
|
||||||
} else {
|
|
||||||
panic("out points to big.Int, but defaultValue does not")
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
panic("invalid integer type")
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if !i.ReadASN1Integer(out) || !i.Empty() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadOptionalASN1OctetString attempts to read an optional ASN.1 OCTET STRING
|
|
||||||
// explicitly tagged with tag into out and advances. If no element with a
|
|
||||||
// matching tag is present, it sets "out" to nil instead. It reports
|
|
||||||
// whether the read was successful.
|
|
||||||
func (s *String) ReadOptionalASN1OctetString(out *[]byte, outPresent *bool, tag asn1.Tag) bool {
|
|
||||||
var present bool
|
|
||||||
var child String
|
|
||||||
if !s.ReadOptionalASN1(&child, &present, tag) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if outPresent != nil {
|
|
||||||
*outPresent = present
|
|
||||||
}
|
|
||||||
if present {
|
|
||||||
var oct String
|
|
||||||
if !child.ReadASN1(&oct, asn1.OCTET_STRING) || !child.Empty() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*out = oct
|
|
||||||
} else {
|
|
||||||
*out = nil
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadOptionalASN1Boolean attempts to read an optional ASN.1 BOOLEAN
|
|
||||||
// explicitly tagged with tag into out and advances. If no element with a
|
|
||||||
// matching tag is present, it sets "out" to defaultValue instead. It reports
|
|
||||||
// whether the read was successful.
|
|
||||||
func (s *String) ReadOptionalASN1Boolean(out *bool, tag asn1.Tag, defaultValue bool) bool {
|
|
||||||
var present bool
|
|
||||||
var child String
|
|
||||||
if !s.ReadOptionalASN1(&child, &present, tag) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !present {
|
|
||||||
*out = defaultValue
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return child.ReadASN1Boolean(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *String) readASN1(out *String, outTag *asn1.Tag, skipHeader bool) bool {
|
|
||||||
if len(*s) < 2 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
tag, lenByte := (*s)[0], (*s)[1]
|
|
||||||
|
|
||||||
if tag&0x1f == 0x1f {
|
|
||||||
// ITU-T X.690 section 8.1.2
|
|
||||||
//
|
|
||||||
// An identifier octet with a tag part of 0x1f indicates a high-tag-number
|
|
||||||
// form identifier with two or more octets. We only support tags less than
|
|
||||||
// 31 (i.e. low-tag-number form, single octet identifier).
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if outTag != nil {
|
|
||||||
*outTag = asn1.Tag(tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ITU-T X.690 section 8.1.3
|
|
||||||
//
|
|
||||||
// Bit 8 of the first length byte indicates whether the length is short- or
|
|
||||||
// long-form.
|
|
||||||
var length, headerLen uint32 // length includes headerLen
|
|
||||||
if lenByte&0x80 == 0 {
|
|
||||||
// Short-form length (section 8.1.3.4), encoded in bits 1-7.
|
|
||||||
length = uint32(lenByte) + 2
|
|
||||||
headerLen = 2
|
|
||||||
} else {
|
|
||||||
// Long-form length (section 8.1.3.5). Bits 1-7 encode the number of octets
|
|
||||||
// used to encode the length.
|
|
||||||
lenLen := lenByte & 0x7f
|
|
||||||
var len32 uint32
|
|
||||||
|
|
||||||
if lenLen == 0 || lenLen > 4 || len(*s) < int(2+lenLen) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
lenBytes := String((*s)[2 : 2+lenLen])
|
|
||||||
if !lenBytes.readUnsigned(&len32, int(lenLen)) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// ITU-T X.690 section 10.1 (DER length forms) requires encoding the length
|
|
||||||
// with the minimum number of octets.
|
|
||||||
if len32 < 128 {
|
|
||||||
// Length should have used short-form encoding.
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if len32>>((lenLen-1)*8) == 0 {
|
|
||||||
// Leading octet is 0. Length should have been at least one byte shorter.
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
headerLen = 2 + uint32(lenLen)
|
|
||||||
if headerLen+len32 < len32 {
|
|
||||||
// Overflow.
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
length = headerLen + len32
|
|
||||||
}
|
|
||||||
|
|
||||||
if int(length) < 0 || !s.ReadBytes((*[]byte)(out), int(length)) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if skipHeader && !out.Skip(int(headerLen)) {
|
|
||||||
panic("cryptobyte: internal error")
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
|
@ -1,46 +0,0 @@
|
||||||
// Copyright 2017 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// Package asn1 contains supporting types for parsing and building ASN.1
|
|
||||||
// messages with the cryptobyte package.
|
|
||||||
package asn1 // import "golang.org/x/crypto/cryptobyte/asn1"
|
|
||||||
|
|
||||||
// Tag represents an ASN.1 identifier octet, consisting of a tag number
|
|
||||||
// (indicating a type) and class (such as context-specific or constructed).
|
|
||||||
//
|
|
||||||
// Methods in the cryptobyte package only support the low-tag-number form, i.e.
|
|
||||||
// a single identifier octet with bits 7-8 encoding the class and bits 1-6
|
|
||||||
// encoding the tag number.
|
|
||||||
type Tag uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
classConstructed = 0x20
|
|
||||||
classContextSpecific = 0x80
|
|
||||||
)
|
|
||||||
|
|
||||||
// Constructed returns t with the constructed class bit set.
|
|
||||||
func (t Tag) Constructed() Tag { return t | classConstructed }
|
|
||||||
|
|
||||||
// ContextSpecific returns t with the context-specific class bit set.
|
|
||||||
func (t Tag) ContextSpecific() Tag { return t | classContextSpecific }
|
|
||||||
|
|
||||||
// The following is a list of standard tag and class combinations.
|
|
||||||
const (
|
|
||||||
BOOLEAN = Tag(1)
|
|
||||||
INTEGER = Tag(2)
|
|
||||||
BIT_STRING = Tag(3)
|
|
||||||
OCTET_STRING = Tag(4)
|
|
||||||
NULL = Tag(5)
|
|
||||||
OBJECT_IDENTIFIER = Tag(6)
|
|
||||||
ENUM = Tag(10)
|
|
||||||
UTF8String = Tag(12)
|
|
||||||
SEQUENCE = Tag(16 | classConstructed)
|
|
||||||
SET = Tag(17 | classConstructed)
|
|
||||||
PrintableString = Tag(19)
|
|
||||||
T61String = Tag(20)
|
|
||||||
IA5String = Tag(22)
|
|
||||||
UTCTime = Tag(23)
|
|
||||||
GeneralizedTime = Tag(24)
|
|
||||||
GeneralString = Tag(27)
|
|
||||||
)
|
|
|
@ -1,350 +0,0 @@
|
||||||
// Copyright 2017 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package cryptobyte
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A Builder builds byte strings from fixed-length and length-prefixed values.
|
|
||||||
// Builders either allocate space as needed, or are ‘fixed’, which means that
|
|
||||||
// they write into a given buffer and produce an error if it's exhausted.
|
|
||||||
//
|
|
||||||
// The zero value is a usable Builder that allocates space as needed.
|
|
||||||
//
|
|
||||||
// Simple values are marshaled and appended to a Builder using methods on the
|
|
||||||
// Builder. Length-prefixed values are marshaled by providing a
|
|
||||||
// BuilderContinuation, which is a function that writes the inner contents of
|
|
||||||
// the value to a given Builder. See the documentation for BuilderContinuation
|
|
||||||
// for details.
|
|
||||||
type Builder struct {
|
|
||||||
err error
|
|
||||||
result []byte
|
|
||||||
fixedSize bool
|
|
||||||
child *Builder
|
|
||||||
offset int
|
|
||||||
pendingLenLen int
|
|
||||||
pendingIsASN1 bool
|
|
||||||
inContinuation *bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBuilder creates a Builder that appends its output to the given buffer.
|
|
||||||
// Like append(), the slice will be reallocated if its capacity is exceeded.
|
|
||||||
// Use Bytes to get the final buffer.
|
|
||||||
func NewBuilder(buffer []byte) *Builder {
|
|
||||||
return &Builder{
|
|
||||||
result: buffer,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFixedBuilder creates a Builder that appends its output into the given
|
|
||||||
// buffer. This builder does not reallocate the output buffer. Writes that
|
|
||||||
// would exceed the buffer's capacity are treated as an error.
|
|
||||||
func NewFixedBuilder(buffer []byte) *Builder {
|
|
||||||
return &Builder{
|
|
||||||
result: buffer,
|
|
||||||
fixedSize: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetError sets the value to be returned as the error from Bytes. Writes
|
|
||||||
// performed after calling SetError are ignored.
|
|
||||||
func (b *Builder) SetError(err error) {
|
|
||||||
b.err = err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bytes returns the bytes written by the builder or an error if one has
|
|
||||||
// occurred during building.
|
|
||||||
func (b *Builder) Bytes() ([]byte, error) {
|
|
||||||
if b.err != nil {
|
|
||||||
return nil, b.err
|
|
||||||
}
|
|
||||||
return b.result[b.offset:], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BytesOrPanic returns the bytes written by the builder or panics if an error
|
|
||||||
// has occurred during building.
|
|
||||||
func (b *Builder) BytesOrPanic() []byte {
|
|
||||||
if b.err != nil {
|
|
||||||
panic(b.err)
|
|
||||||
}
|
|
||||||
return b.result[b.offset:]
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUint8 appends an 8-bit value to the byte string.
|
|
||||||
func (b *Builder) AddUint8(v uint8) {
|
|
||||||
b.add(byte(v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUint16 appends a big-endian, 16-bit value to the byte string.
|
|
||||||
func (b *Builder) AddUint16(v uint16) {
|
|
||||||
b.add(byte(v>>8), byte(v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUint24 appends a big-endian, 24-bit value to the byte string. The highest
|
|
||||||
// byte of the 32-bit input value is silently truncated.
|
|
||||||
func (b *Builder) AddUint24(v uint32) {
|
|
||||||
b.add(byte(v>>16), byte(v>>8), byte(v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUint32 appends a big-endian, 32-bit value to the byte string.
|
|
||||||
func (b *Builder) AddUint32(v uint32) {
|
|
||||||
b.add(byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUint48 appends a big-endian, 48-bit value to the byte string.
|
|
||||||
func (b *Builder) AddUint48(v uint64) {
|
|
||||||
b.add(byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUint64 appends a big-endian, 64-bit value to the byte string.
|
|
||||||
func (b *Builder) AddUint64(v uint64) {
|
|
||||||
b.add(byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddBytes appends a sequence of bytes to the byte string.
|
|
||||||
func (b *Builder) AddBytes(v []byte) {
|
|
||||||
b.add(v...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuilderContinuation is a continuation-passing interface for building
|
|
||||||
// length-prefixed byte sequences. Builder methods for length-prefixed
|
|
||||||
// sequences (AddUint8LengthPrefixed etc) will invoke the BuilderContinuation
|
|
||||||
// supplied to them. The child builder passed to the continuation can be used
|
|
||||||
// to build the content of the length-prefixed sequence. For example:
|
|
||||||
//
|
|
||||||
// parent := cryptobyte.NewBuilder()
|
|
||||||
// parent.AddUint8LengthPrefixed(func (child *Builder) {
|
|
||||||
// child.AddUint8(42)
|
|
||||||
// child.AddUint8LengthPrefixed(func (grandchild *Builder) {
|
|
||||||
// grandchild.AddUint8(5)
|
|
||||||
// })
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// It is an error to write more bytes to the child than allowed by the reserved
|
|
||||||
// length prefix. After the continuation returns, the child must be considered
|
|
||||||
// invalid, i.e. users must not store any copies or references of the child
|
|
||||||
// that outlive the continuation.
|
|
||||||
//
|
|
||||||
// If the continuation panics with a value of type BuildError then the inner
|
|
||||||
// error will be returned as the error from Bytes. If the child panics
|
|
||||||
// otherwise then Bytes will repanic with the same value.
|
|
||||||
type BuilderContinuation func(child *Builder)
|
|
||||||
|
|
||||||
// BuildError wraps an error. If a BuilderContinuation panics with this value,
|
|
||||||
// the panic will be recovered and the inner error will be returned from
|
|
||||||
// Builder.Bytes.
|
|
||||||
type BuildError struct {
|
|
||||||
Err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUint8LengthPrefixed adds a 8-bit length-prefixed byte sequence.
|
|
||||||
func (b *Builder) AddUint8LengthPrefixed(f BuilderContinuation) {
|
|
||||||
b.addLengthPrefixed(1, false, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUint16LengthPrefixed adds a big-endian, 16-bit length-prefixed byte sequence.
|
|
||||||
func (b *Builder) AddUint16LengthPrefixed(f BuilderContinuation) {
|
|
||||||
b.addLengthPrefixed(2, false, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUint24LengthPrefixed adds a big-endian, 24-bit length-prefixed byte sequence.
|
|
||||||
func (b *Builder) AddUint24LengthPrefixed(f BuilderContinuation) {
|
|
||||||
b.addLengthPrefixed(3, false, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUint32LengthPrefixed adds a big-endian, 32-bit length-prefixed byte sequence.
|
|
||||||
func (b *Builder) AddUint32LengthPrefixed(f BuilderContinuation) {
|
|
||||||
b.addLengthPrefixed(4, false, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Builder) callContinuation(f BuilderContinuation, arg *Builder) {
|
|
||||||
if !*b.inContinuation {
|
|
||||||
*b.inContinuation = true
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
*b.inContinuation = false
|
|
||||||
|
|
||||||
r := recover()
|
|
||||||
if r == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if buildError, ok := r.(BuildError); ok {
|
|
||||||
b.err = buildError.Err
|
|
||||||
} else {
|
|
||||||
panic(r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
f(arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Builder) addLengthPrefixed(lenLen int, isASN1 bool, f BuilderContinuation) {
|
|
||||||
// Subsequent writes can be ignored if the builder has encountered an error.
|
|
||||||
if b.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
offset := len(b.result)
|
|
||||||
b.add(make([]byte, lenLen)...)
|
|
||||||
|
|
||||||
if b.inContinuation == nil {
|
|
||||||
b.inContinuation = new(bool)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.child = &Builder{
|
|
||||||
result: b.result,
|
|
||||||
fixedSize: b.fixedSize,
|
|
||||||
offset: offset,
|
|
||||||
pendingLenLen: lenLen,
|
|
||||||
pendingIsASN1: isASN1,
|
|
||||||
inContinuation: b.inContinuation,
|
|
||||||
}
|
|
||||||
|
|
||||||
b.callContinuation(f, b.child)
|
|
||||||
b.flushChild()
|
|
||||||
if b.child != nil {
|
|
||||||
panic("cryptobyte: internal error")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Builder) flushChild() {
|
|
||||||
if b.child == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.child.flushChild()
|
|
||||||
child := b.child
|
|
||||||
b.child = nil
|
|
||||||
|
|
||||||
if child.err != nil {
|
|
||||||
b.err = child.err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
length := len(child.result) - child.pendingLenLen - child.offset
|
|
||||||
|
|
||||||
if length < 0 {
|
|
||||||
panic("cryptobyte: internal error") // result unexpectedly shrunk
|
|
||||||
}
|
|
||||||
|
|
||||||
if child.pendingIsASN1 {
|
|
||||||
// For ASN.1, we reserved a single byte for the length. If that turned out
|
|
||||||
// to be incorrect, we have to move the contents along in order to make
|
|
||||||
// space.
|
|
||||||
if child.pendingLenLen != 1 {
|
|
||||||
panic("cryptobyte: internal error")
|
|
||||||
}
|
|
||||||
var lenLen, lenByte uint8
|
|
||||||
if int64(length) > 0xfffffffe {
|
|
||||||
b.err = errors.New("pending ASN.1 child too long")
|
|
||||||
return
|
|
||||||
} else if length > 0xffffff {
|
|
||||||
lenLen = 5
|
|
||||||
lenByte = 0x80 | 4
|
|
||||||
} else if length > 0xffff {
|
|
||||||
lenLen = 4
|
|
||||||
lenByte = 0x80 | 3
|
|
||||||
} else if length > 0xff {
|
|
||||||
lenLen = 3
|
|
||||||
lenByte = 0x80 | 2
|
|
||||||
} else if length > 0x7f {
|
|
||||||
lenLen = 2
|
|
||||||
lenByte = 0x80 | 1
|
|
||||||
} else {
|
|
||||||
lenLen = 1
|
|
||||||
lenByte = uint8(length)
|
|
||||||
length = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert the initial length byte, make space for successive length bytes,
|
|
||||||
// and adjust the offset.
|
|
||||||
child.result[child.offset] = lenByte
|
|
||||||
extraBytes := int(lenLen - 1)
|
|
||||||
if extraBytes != 0 {
|
|
||||||
child.add(make([]byte, extraBytes)...)
|
|
||||||
childStart := child.offset + child.pendingLenLen
|
|
||||||
copy(child.result[childStart+extraBytes:], child.result[childStart:])
|
|
||||||
}
|
|
||||||
child.offset++
|
|
||||||
child.pendingLenLen = extraBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
l := length
|
|
||||||
for i := child.pendingLenLen - 1; i >= 0; i-- {
|
|
||||||
child.result[child.offset+i] = uint8(l)
|
|
||||||
l >>= 8
|
|
||||||
}
|
|
||||||
if l != 0 {
|
|
||||||
b.err = fmt.Errorf("cryptobyte: pending child length %d exceeds %d-byte length prefix", length, child.pendingLenLen)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if b.fixedSize && &b.result[0] != &child.result[0] {
|
|
||||||
panic("cryptobyte: BuilderContinuation reallocated a fixed-size buffer")
|
|
||||||
}
|
|
||||||
|
|
||||||
b.result = child.result
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Builder) add(bytes ...byte) {
|
|
||||||
if b.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if b.child != nil {
|
|
||||||
panic("cryptobyte: attempted write while child is pending")
|
|
||||||
}
|
|
||||||
if len(b.result)+len(bytes) < len(bytes) {
|
|
||||||
b.err = errors.New("cryptobyte: length overflow")
|
|
||||||
}
|
|
||||||
if b.fixedSize && len(b.result)+len(bytes) > cap(b.result) {
|
|
||||||
b.err = errors.New("cryptobyte: Builder is exceeding its fixed-size buffer")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.result = append(b.result, bytes...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unwrite rolls back non-negative n bytes written directly to the Builder.
|
|
||||||
// An attempt by a child builder passed to a continuation to unwrite bytes
|
|
||||||
// from its parent will panic.
|
|
||||||
func (b *Builder) Unwrite(n int) {
|
|
||||||
if b.err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if b.child != nil {
|
|
||||||
panic("cryptobyte: attempted unwrite while child is pending")
|
|
||||||
}
|
|
||||||
length := len(b.result) - b.pendingLenLen - b.offset
|
|
||||||
if length < 0 {
|
|
||||||
panic("cryptobyte: internal error")
|
|
||||||
}
|
|
||||||
if n < 0 {
|
|
||||||
panic("cryptobyte: attempted to unwrite negative number of bytes")
|
|
||||||
}
|
|
||||||
if n > length {
|
|
||||||
panic("cryptobyte: attempted to unwrite more than was written")
|
|
||||||
}
|
|
||||||
b.result = b.result[:len(b.result)-n]
|
|
||||||
}
|
|
||||||
|
|
||||||
// A MarshalingValue marshals itself into a Builder.
|
|
||||||
type MarshalingValue interface {
|
|
||||||
// Marshal is called by Builder.AddValue. It receives a pointer to a builder
|
|
||||||
// to marshal itself into. It may return an error that occurred during
|
|
||||||
// marshaling, such as unset or invalid values.
|
|
||||||
Marshal(b *Builder) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddValue calls Marshal on v, passing a pointer to the builder to append to.
|
|
||||||
// If Marshal returns an error, it is set on the Builder so that subsequent
|
|
||||||
// appends don't have an effect.
|
|
||||||
func (b *Builder) AddValue(v MarshalingValue) {
|
|
||||||
err := v.Marshal(b)
|
|
||||||
if err != nil {
|
|
||||||
b.err = err
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,183 +0,0 @@
|
||||||
// Copyright 2017 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// Package cryptobyte contains types that help with parsing and constructing
|
|
||||||
// length-prefixed, binary messages, including ASN.1 DER. (The asn1 subpackage
|
|
||||||
// contains useful ASN.1 constants.)
|
|
||||||
//
|
|
||||||
// The String type is for parsing. It wraps a []byte slice and provides helper
|
|
||||||
// functions for consuming structures, value by value.
|
|
||||||
//
|
|
||||||
// The Builder type is for constructing messages. It providers helper functions
|
|
||||||
// for appending values and also for appending length-prefixed submessages –
|
|
||||||
// without having to worry about calculating the length prefix ahead of time.
|
|
||||||
//
|
|
||||||
// See the documentation and examples for the Builder and String types to get
|
|
||||||
// started.
|
|
||||||
package cryptobyte // import "golang.org/x/crypto/cryptobyte"
|
|
||||||
|
|
||||||
// String represents a string of bytes. It provides methods for parsing
|
|
||||||
// fixed-length and length-prefixed values from it.
|
|
||||||
type String []byte
|
|
||||||
|
|
||||||
// read advances a String by n bytes and returns them. If less than n bytes
|
|
||||||
// remain, it returns nil.
|
|
||||||
func (s *String) read(n int) []byte {
|
|
||||||
if len(*s) < n || n < 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
v := (*s)[:n]
|
|
||||||
*s = (*s)[n:]
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip advances the String by n byte and reports whether it was successful.
|
|
||||||
func (s *String) Skip(n int) bool {
|
|
||||||
return s.read(n) != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint8 decodes an 8-bit value into out and advances over it.
|
|
||||||
// It reports whether the read was successful.
|
|
||||||
func (s *String) ReadUint8(out *uint8) bool {
|
|
||||||
v := s.read(1)
|
|
||||||
if v == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*out = uint8(v[0])
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint16 decodes a big-endian, 16-bit value into out and advances over it.
|
|
||||||
// It reports whether the read was successful.
|
|
||||||
func (s *String) ReadUint16(out *uint16) bool {
|
|
||||||
v := s.read(2)
|
|
||||||
if v == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*out = uint16(v[0])<<8 | uint16(v[1])
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint24 decodes a big-endian, 24-bit value into out and advances over it.
|
|
||||||
// It reports whether the read was successful.
|
|
||||||
func (s *String) ReadUint24(out *uint32) bool {
|
|
||||||
v := s.read(3)
|
|
||||||
if v == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*out = uint32(v[0])<<16 | uint32(v[1])<<8 | uint32(v[2])
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint32 decodes a big-endian, 32-bit value into out and advances over it.
|
|
||||||
// It reports whether the read was successful.
|
|
||||||
func (s *String) ReadUint32(out *uint32) bool {
|
|
||||||
v := s.read(4)
|
|
||||||
if v == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*out = uint32(v[0])<<24 | uint32(v[1])<<16 | uint32(v[2])<<8 | uint32(v[3])
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint48 decodes a big-endian, 48-bit value into out and advances over it.
|
|
||||||
// It reports whether the read was successful.
|
|
||||||
func (s *String) ReadUint48(out *uint64) bool {
|
|
||||||
v := s.read(6)
|
|
||||||
if v == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*out = uint64(v[0])<<40 | uint64(v[1])<<32 | uint64(v[2])<<24 | uint64(v[3])<<16 | uint64(v[4])<<8 | uint64(v[5])
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint64 decodes a big-endian, 64-bit value into out and advances over it.
|
|
||||||
// It reports whether the read was successful.
|
|
||||||
func (s *String) ReadUint64(out *uint64) bool {
|
|
||||||
v := s.read(8)
|
|
||||||
if v == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*out = uint64(v[0])<<56 | uint64(v[1])<<48 | uint64(v[2])<<40 | uint64(v[3])<<32 | uint64(v[4])<<24 | uint64(v[5])<<16 | uint64(v[6])<<8 | uint64(v[7])
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *String) readUnsigned(out *uint32, length int) bool {
|
|
||||||
v := s.read(length)
|
|
||||||
if v == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
var result uint32
|
|
||||||
for i := 0; i < length; i++ {
|
|
||||||
result <<= 8
|
|
||||||
result |= uint32(v[i])
|
|
||||||
}
|
|
||||||
*out = result
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *String) readLengthPrefixed(lenLen int, outChild *String) bool {
|
|
||||||
lenBytes := s.read(lenLen)
|
|
||||||
if lenBytes == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
var length uint32
|
|
||||||
for _, b := range lenBytes {
|
|
||||||
length = length << 8
|
|
||||||
length = length | uint32(b)
|
|
||||||
}
|
|
||||||
v := s.read(int(length))
|
|
||||||
if v == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*outChild = v
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint8LengthPrefixed reads the content of an 8-bit length-prefixed value
|
|
||||||
// into out and advances over it. It reports whether the read was successful.
|
|
||||||
func (s *String) ReadUint8LengthPrefixed(out *String) bool {
|
|
||||||
return s.readLengthPrefixed(1, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint16LengthPrefixed reads the content of a big-endian, 16-bit
|
|
||||||
// length-prefixed value into out and advances over it. It reports whether the
|
|
||||||
// read was successful.
|
|
||||||
func (s *String) ReadUint16LengthPrefixed(out *String) bool {
|
|
||||||
return s.readLengthPrefixed(2, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadUint24LengthPrefixed reads the content of a big-endian, 24-bit
|
|
||||||
// length-prefixed value into out and advances over it. It reports whether
|
|
||||||
// the read was successful.
|
|
||||||
func (s *String) ReadUint24LengthPrefixed(out *String) bool {
|
|
||||||
return s.readLengthPrefixed(3, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadBytes reads n bytes into out and advances over them. It reports
|
|
||||||
// whether the read was successful.
|
|
||||||
func (s *String) ReadBytes(out *[]byte, n int) bool {
|
|
||||||
v := s.read(n)
|
|
||||||
if v == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
*out = v
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// CopyBytes copies len(out) bytes into out and advances over them. It reports
|
|
||||||
// whether the copy operation was successful
|
|
||||||
func (s *String) CopyBytes(out []byte) bool {
|
|
||||||
n := len(out)
|
|
||||||
v := s.read(n)
|
|
||||||
if v == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return copy(out, v) == n
|
|
||||||
}
|
|
||||||
|
|
||||||
// Empty reports whether the string does not contain any bytes.
|
|
||||||
func (s String) Empty() bool {
|
|
||||||
return len(s) == 0
|
|
||||||
}
|
|
|
@ -1,50 +0,0 @@
|
||||||
// Copyright 2021 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// Package constraints defines a set of useful constraints to be used
|
|
||||||
// with type parameters.
|
|
||||||
package constraints
|
|
||||||
|
|
||||||
// Signed is a constraint that permits any signed integer type.
|
|
||||||
// If future releases of Go add new predeclared signed integer types,
|
|
||||||
// this constraint will be modified to include them.
|
|
||||||
type Signed interface {
|
|
||||||
~int | ~int8 | ~int16 | ~int32 | ~int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unsigned is a constraint that permits any unsigned integer type.
|
|
||||||
// If future releases of Go add new predeclared unsigned integer types,
|
|
||||||
// this constraint will be modified to include them.
|
|
||||||
type Unsigned interface {
|
|
||||||
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Integer is a constraint that permits any integer type.
|
|
||||||
// If future releases of Go add new predeclared integer types,
|
|
||||||
// this constraint will be modified to include them.
|
|
||||||
type Integer interface {
|
|
||||||
Signed | Unsigned
|
|
||||||
}
|
|
||||||
|
|
||||||
// Float is a constraint that permits any floating-point type.
|
|
||||||
// If future releases of Go add new predeclared floating-point types,
|
|
||||||
// this constraint will be modified to include them.
|
|
||||||
type Float interface {
|
|
||||||
~float32 | ~float64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Complex is a constraint that permits any complex numeric type.
|
|
||||||
// If future releases of Go add new predeclared complex numeric types,
|
|
||||||
// this constraint will be modified to include them.
|
|
||||||
type Complex interface {
|
|
||||||
~complex64 | ~complex128
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ordered is a constraint that permits any ordered type: any type
|
|
||||||
// that supports the operators < <= >= >.
|
|
||||||
// If future releases of Go add new ordered types,
|
|
||||||
// this constraint will be modified to include them.
|
|
||||||
type Ordered interface {
|
|
||||||
Integer | Float | ~string
|
|
||||||
}
|
|
|
@ -229,9 +229,8 @@ github.com/prometheus/procfs/internal/fs
|
||||||
github.com/prometheus/procfs/internal/util
|
github.com/prometheus/procfs/internal/util
|
||||||
# github.com/quic-go/qtls-go1-20 v0.4.1
|
# github.com/quic-go/qtls-go1-20 v0.4.1
|
||||||
## explicit; go 1.20
|
## explicit; go 1.20
|
||||||
github.com/quic-go/qtls-go1-20
|
# github.com/quic-go/quic-go v0.40.1-0.20240101045026-22b7f7744eb6
|
||||||
# github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55
|
## explicit; go 1.21
|
||||||
## explicit; go 1.20
|
|
||||||
github.com/quic-go/quic-go
|
github.com/quic-go/quic-go
|
||||||
github.com/quic-go/quic-go/internal/ackhandler
|
github.com/quic-go/quic-go/internal/ackhandler
|
||||||
github.com/quic-go/quic-go/internal/congestion
|
github.com/quic-go/quic-go/internal/congestion
|
||||||
|
@ -323,8 +322,6 @@ golang.org/x/crypto/blake2b
|
||||||
golang.org/x/crypto/blowfish
|
golang.org/x/crypto/blowfish
|
||||||
golang.org/x/crypto/chacha20
|
golang.org/x/crypto/chacha20
|
||||||
golang.org/x/crypto/chacha20poly1305
|
golang.org/x/crypto/chacha20poly1305
|
||||||
golang.org/x/crypto/cryptobyte
|
|
||||||
golang.org/x/crypto/cryptobyte/asn1
|
|
||||||
golang.org/x/crypto/curve25519
|
golang.org/x/crypto/curve25519
|
||||||
golang.org/x/crypto/curve25519/internal/field
|
golang.org/x/crypto/curve25519/internal/field
|
||||||
golang.org/x/crypto/hkdf
|
golang.org/x/crypto/hkdf
|
||||||
|
@ -338,7 +335,6 @@ golang.org/x/crypto/ssh
|
||||||
golang.org/x/crypto/ssh/internal/bcrypt_pbkdf
|
golang.org/x/crypto/ssh/internal/bcrypt_pbkdf
|
||||||
# golang.org/x/exp v0.0.0-20221205204356-47842c84f3db
|
# golang.org/x/exp v0.0.0-20221205204356-47842c84f3db
|
||||||
## explicit; go 1.18
|
## explicit; go 1.18
|
||||||
golang.org/x/exp/constraints
|
|
||||||
golang.org/x/exp/rand
|
golang.org/x/exp/rand
|
||||||
# golang.org/x/mod v0.11.0
|
# golang.org/x/mod v0.11.0
|
||||||
## explicit; go 1.17
|
## explicit; go 1.17
|
||||||
|
|
Loading…
Reference in New Issue