TUN-8006: Update quic-go to latest upstream
This commit is contained in:
parent
45236a1f7d
commit
8068cdebb6
|
@ -9,6 +9,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -623,9 +624,19 @@ func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, logger *zerolog.
|
|||
localIP = net.IPv4zero
|
||||
}
|
||||
|
||||
listenNetwork := "udp"
|
||||
// https://github.com/quic-go/quic-go/issues/3793 DF bit cannot be set for dual stack listener on OSX
|
||||
if runtime.GOOS == "darwin" {
|
||||
if localIP.To4() != nil {
|
||||
listenNetwork = "udp4"
|
||||
} else {
|
||||
listenNetwork = "udp6"
|
||||
}
|
||||
}
|
||||
|
||||
// if port was not set yet, it will be zero, so bind will randomly allocate one.
|
||||
if port, ok := portForConnIndex[connIndex]; ok {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: localIP, Port: port})
|
||||
udpConn, err := net.ListenUDP(listenNetwork, &net.UDPAddr{IP: localIP, Port: port})
|
||||
// if there wasn't an error, or if port was 0 (independently of error or not, just return)
|
||||
if err == nil {
|
||||
return udpConn, nil
|
||||
|
@ -635,7 +646,7 @@ func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, logger *zerolog.
|
|||
}
|
||||
|
||||
// if we reached here, then there was an error or port as not been allocated it.
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: localIP, Port: 0})
|
||||
udpConn, err := net.ListenUDP(listenNetwork, &net.UDPAddr{IP: localIP, Port: 0})
|
||||
if err == nil {
|
||||
udpAddr, ok := (udpConn.LocalAddr()).(*net.UDPAddr)
|
||||
if !ok {
|
||||
|
|
|
@ -621,7 +621,7 @@ func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.
|
|||
muxedPayload, err = quicpogs.SuffixType(muxedPayload, quicpogs.DatagramTypeUDP)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = edgeQUICSession.SendMessage(muxedPayload)
|
||||
err = edgeQUICSession.SendDatagram(muxedPayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
readBuffer := make([]byte, len(payload)+1)
|
||||
|
|
27
go.mod
27
go.mod
|
@ -3,7 +3,6 @@ module github.com/cloudflare/cloudflared
|
|||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/cloudflare/golibs v0.0.0-20170913112048-333127dbecfc
|
||||
github.com/coredns/coredns v1.10.0
|
||||
github.com/coreos/go-oidc/v3 v3.6.0
|
||||
github.com/coreos/go-systemd/v22 v22.5.0
|
||||
|
@ -25,7 +24,8 @@ require (
|
|||
github.com/pkg/errors v0.9.1
|
||||
github.com/prometheus/client_golang v1.13.0
|
||||
github.com/prometheus/client_model v0.2.0
|
||||
github.com/quic-go/quic-go v0.0.0-00010101000000-000000000000
|
||||
github.com/quic-go/qtls-go1-20 v0.4.1
|
||||
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55
|
||||
github.com/rs/zerolog v1.20.0
|
||||
github.com/stretchr/testify v1.8.1
|
||||
github.com/urfave/cli/v2 v2.3.0
|
||||
|
@ -38,7 +38,7 @@ require (
|
|||
go.uber.org/automaxprocs v1.4.0
|
||||
golang.org/x/crypto v0.11.0
|
||||
golang.org/x/net v0.12.0
|
||||
golang.org/x/sync v0.1.0
|
||||
golang.org/x/sync v0.2.0
|
||||
golang.org/x/sys v0.10.0
|
||||
golang.org/x/term v0.10.0
|
||||
google.golang.org/protobuf v1.28.1
|
||||
|
@ -62,13 +62,12 @@ require (
|
|||
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect
|
||||
github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870 // indirect
|
||||
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
|
||||
github.com/go-logr/logr v1.2.3 // indirect
|
||||
github.com/go-logr/logr v1.2.4 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 // indirect
|
||||
github.com/gobwas/pool v0.2.1 // indirect
|
||||
github.com/golang/mock v1.6.0 // indirect
|
||||
github.com/golang/protobuf v1.5.2 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect
|
||||
|
@ -79,20 +78,18 @@ require (
|
|||
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.4.0 // indirect
|
||||
github.com/onsi/gomega v1.23.0 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/opentracing/opentracing-go v1.2.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/common v0.37.0 // indirect
|
||||
github.com/prometheus/procfs v0.8.0 // indirect
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
|
||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||
go.uber.org/mock v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
|
||||
golang.org/x/mod v0.8.0 // indirect
|
||||
golang.org/x/mod v0.11.0 // indirect
|
||||
golang.org/x/oauth2 v0.6.0 // indirect
|
||||
golang.org/x/text v0.11.0 // indirect
|
||||
golang.org/x/tools v0.6.0 // indirect
|
||||
golang.org/x/tools v0.9.1 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto v0.0.0-20221202195650-67e5cbc046fd // indirect
|
||||
google.golang.org/grpc v1.51.0 // indirect
|
||||
|
@ -107,8 +104,6 @@ replace github.com/prometheus/golang_client => github.com/prometheus/golang_clie
|
|||
|
||||
replace gopkg.in/yaml.v3 => gopkg.in/yaml.v3 v3.0.1
|
||||
|
||||
replace github.com/quic-go/quic-go => github.com/devincarr/quic-go v0.0.0-20230502200822-d1f4edacbee7
|
||||
|
||||
// Post-quantum tunnel RTG-1339
|
||||
// Branches go1.20 on github.com/cloudflare/qtls-pq
|
||||
replace github.com/quic-go/qtls-go1-20 => github.com/cloudflare/qtls-pq v0.0.0-20230320122459-4ed280d0d633
|
||||
replace github.com/quic-go/qtls-go1-20 => github.com/cloudflare/qtls-pq v0.0.0-20231024102457-5b458bcaf6d4
|
||||
|
|
49
go.sum
49
go.sum
|
@ -61,10 +61,8 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn
|
|||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cloudflare/circl v1.2.1-0.20220809205628-0a9554f37a47 h1:YzpECHxZ9TzO7LpnKmPxItSd79lLgrR5heIlnqU4dTU=
|
||||
github.com/cloudflare/circl v1.2.1-0.20220809205628-0a9554f37a47/go.mod h1:qhx8gBILsYlbam7h09SvHDSkjpe3TfLA7b/z4rxJvkE=
|
||||
github.com/cloudflare/golibs v0.0.0-20170913112048-333127dbecfc h1:Dvk3ySBsOm5EviLx6VCyILnafPcQinXGP5jbTdHUJgE=
|
||||
github.com/cloudflare/golibs v0.0.0-20170913112048-333127dbecfc/go.mod h1:HlgKKR8V5a1wroIDDIz3/A+T+9Janfq+7n1P5sEFdi0=
|
||||
github.com/cloudflare/qtls-pq v0.0.0-20230320122459-4ed280d0d633 h1:ZTub2XMOBpxyBiJf6Q+UKqAi07yt1rZmFitriHvFd8M=
|
||||
github.com/cloudflare/qtls-pq v0.0.0-20230320122459-4ed280d0d633/go.mod h1:j/igSUc4PgBMayIsBGjAFu2i7g663rm6kZrKy4htb7E=
|
||||
github.com/cloudflare/qtls-pq v0.0.0-20231024102457-5b458bcaf6d4 h1:gkjG0LZZTHDehVlLbY8pGcaCgeNHawJqV2IHgOjlGDM=
|
||||
github.com/cloudflare/qtls-pq v0.0.0-20231024102457-5b458bcaf6d4/go.mod h1:bhkEYs+1JsfHM6xqs1h4eprhpmUk/UTjdmOZK3kLIpM=
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI=
|
||||
|
@ -88,14 +86,13 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
|
|||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/devincarr/quic-go v0.0.0-20230502200822-d1f4edacbee7 h1:qxyoKKPXmPsbvT7SZTcvhEgUaZhEttk4f6u8rIawKj0=
|
||||
github.com/devincarr/quic-go v0.0.0-20230502200822-d1f4edacbee7/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
|
||||
github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0=
|
||||
github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51 h1:0JZ+dUmQeA8IIVUMzysrX4/AKuQwWhV2dYQuPZdvdSQ=
|
||||
github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51/go.mod h1:Yg+htXGokKKdzcwhuNDwVvN+uBxDGXJ7G/VN1d8fa64=
|
||||
|
@ -137,8 +134,9 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V
|
|||
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
||||
github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
|
||||
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
|
@ -149,8 +147,8 @@ github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/j
|
|||
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
|
||||
github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I=
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 h1:YyrUZvJaU8Q0QsoVo+xLFBgWDTam29PKea6GYmwvSiQ=
|
||||
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||
|
@ -178,8 +176,6 @@ github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt
|
|||
github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
|
||||
github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
|
||||
github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
|
||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
|
@ -195,8 +191,9 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD
|
|||
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
|
||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
|
@ -295,10 +292,9 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
|
|||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/onsi/ginkgo/v2 v2.4.0 h1:+Ig9nvqgS5OBSACXNk15PLdp0U9XPYROt9CFzVdFGIs=
|
||||
github.com/onsi/ginkgo/v2 v2.4.0/go.mod h1:iHkDK1fKGcBoEHT5W7YBq4RFWaQulw+caOMkAt4OrFo=
|
||||
github.com/onsi/gomega v1.23.0 h1:/oxKu9c2HVap+F3PfKort2Hw5DEU+HGlW8n+tguWsys=
|
||||
github.com/onsi/gomega v1.23.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
|
||||
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
|
||||
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
|
||||
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
|
||||
github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg=
|
||||
|
@ -335,8 +331,10 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
|
|||
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
|
||||
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
|
||||
github.com/quic-go/quic-go v0.40.0 h1:GYd1iznlKm7dpHD7pOVpUvItgMPo/jrMgDWZhMCecqw=
|
||||
github.com/quic-go/quic-go v0.40.0/go.mod h1:PeN7kuVJ4xZbxSv/4OX6S1USOX8MJvydwpTx31vx60c=
|
||||
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55 h1:I4N3ZRnkZPbDN935Tg8QDf8fRpHp3bZ0U0/L42jBgNE=
|
||||
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55/go.mod h1:PeN7kuVJ4xZbxSv/4OX6S1USOX8MJvydwpTx31vx60c=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
||||
|
@ -396,6 +394,8 @@ go.opentelemetry.io/proto/otlp v0.15.0 h1:h0bKrvdrT/9sBwEJ6iWUqT/N/xPcS66bL4u3is
|
|||
go.opentelemetry.io/proto/otlp v0.15.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U=
|
||||
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
|
||||
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
|
||||
go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo=
|
||||
go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
|
@ -439,8 +439,8 @@ golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzB
|
|||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
|
||||
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
|
@ -497,8 +497,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
|
|||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
|
||||
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
|
@ -605,10 +605,9 @@ golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roY
|
|||
golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
|
||||
golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
|
||||
golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
|
||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
|
||||
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
|
|
@ -53,7 +53,7 @@ func (dm *DatagramMuxer) SendToSession(session *packet.Session) error {
|
|||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
|
||||
}
|
||||
if err := dm.session.SendMessage(payloadWithMetadata); err != nil {
|
||||
if err := dm.session.SendDatagram(payloadWithMetadata); err != nil {
|
||||
return errors.Wrap(err, "Failed to send datagram back to edge")
|
||||
}
|
||||
return nil
|
||||
|
@ -64,7 +64,7 @@ func (dm *DatagramMuxer) ServeReceive(ctx context.Context) error {
|
|||
// Extracts datagram session ID, then sends the session ID and payload to receiver
|
||||
// which determines how to proxy to the origin. It assumes the datagram session has already been
|
||||
// registered with receiver through other side channel
|
||||
msg, err := dm.session.ReceiveMessage()
|
||||
msg, err := dm.session.ReceiveDatagram(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -116,7 +117,6 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi
|
|||
quicConfig := &quic.Config{
|
||||
KeepAlivePeriod: 5 * time.Millisecond,
|
||||
EnableDatagrams: true,
|
||||
MaxDatagramFrameSize: MaxDatagramFrameSize,
|
||||
}
|
||||
quicListener := newQUICListener(t, quicConfig)
|
||||
defer quicListener.Close()
|
||||
|
@ -182,8 +182,12 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi
|
|||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// https://github.com/quic-go/quic-go/issues/3793 MTU discovery is disabled on OSX for dual stack listeners
|
||||
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
require.NoError(t, err)
|
||||
// Establish quic connection
|
||||
quicSession, err := quic.DialAddrEarly(ctx, quicListener.Addr().String(), tlsClientConfig, quicConfig)
|
||||
quicSession, err := quic.DialEarly(ctx, udpConn, quicListener.Addr(), tlsClientConfig, quicConfig)
|
||||
require.NoError(t, err)
|
||||
defer quicSession.CloseWithError(0, "")
|
||||
|
||||
|
|
|
@ -86,7 +86,7 @@ func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error {
|
|||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
|
||||
}
|
||||
if err := dm.session.SendMessage(msgWithIDAndType); err != nil {
|
||||
if err := dm.session.SendDatagram(msgWithIDAndType); err != nil {
|
||||
return errors.Wrap(err, "Failed to send datagram back to edge")
|
||||
}
|
||||
return nil
|
||||
|
@ -104,7 +104,7 @@ func (dm *DatagramMuxerV2) SendPacket(pk Packet) error {
|
|||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
|
||||
}
|
||||
if err := dm.session.SendMessage(payloadWithMetadataAndType); err != nil {
|
||||
if err := dm.session.SendDatagram(payloadWithMetadataAndType); err != nil {
|
||||
return errors.Wrap(err, "Failed to send datagram back to edge")
|
||||
}
|
||||
return nil
|
||||
|
@ -113,7 +113,7 @@ func (dm *DatagramMuxerV2) SendPacket(pk Packet) error {
|
|||
// Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet
|
||||
func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
|
||||
for {
|
||||
msg, err := dm.session.ReceiveMessage()
|
||||
msg, err := dm.session.ReceiveDatagram(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ type tracerConfig struct {
|
|||
index uint8
|
||||
}
|
||||
|
||||
func NewClientTracer(logger *zerolog.Logger, index uint8) func(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
|
||||
func NewClientTracer(logger *zerolog.Logger, index uint8) func(context.Context, logging.Perspective, logging.ConnectionID) *logging.ConnectionTracer {
|
||||
t := &tracer{
|
||||
logger: logger,
|
||||
config: &tracerConfig{
|
||||
|
@ -32,42 +32,61 @@ func NewClientTracer(logger *zerolog.Logger, index uint8) func(context.Context,
|
|||
return t.TracerForConnection
|
||||
}
|
||||
|
||||
func NewServerTracer(logger *zerolog.Logger) logging.Tracer {
|
||||
return &tracer{
|
||||
logger: logger,
|
||||
config: &tracerConfig{
|
||||
isClient: false,
|
||||
},
|
||||
func NewServerTracer(logger *zerolog.Logger) *logging.Tracer {
|
||||
return &logging.Tracer{
|
||||
SentPacket: func(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {},
|
||||
SentVersionNegotiationPacket: func(_ net.Addr, dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {},
|
||||
DroppedPacket: func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) {},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tracer) TracerForConnection(_ctx context.Context, _p logging.Perspective, _odcid logging.ConnectionID) logging.ConnectionTracer {
|
||||
func (t *tracer) TracerForConnection(_ctx context.Context, _p logging.Perspective, _odcid logging.ConnectionID) *logging.ConnectionTracer {
|
||||
if t.config.isClient {
|
||||
return newConnTracer(newClientCollector(t.config.index))
|
||||
}
|
||||
return newConnTracer(newServiceCollector())
|
||||
}
|
||||
|
||||
func (*tracer) SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {
|
||||
}
|
||||
func (*tracer) SentVersionNegotiationPacket(_ net.Addr, dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
|
||||
}
|
||||
func (*tracer) DroppedPacket(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) {
|
||||
}
|
||||
|
||||
var _ logging.Tracer = (*tracer)(nil)
|
||||
|
||||
// connTracer collects connection level metrics
|
||||
type connTracer struct {
|
||||
metricsCollector MetricsCollector
|
||||
}
|
||||
|
||||
var _ logging.ConnectionTracer = (*connTracer)(nil)
|
||||
|
||||
func newConnTracer(metricsCollector MetricsCollector) logging.ConnectionTracer {
|
||||
return &connTracer{
|
||||
func newConnTracer(metricsCollector MetricsCollector) *logging.ConnectionTracer {
|
||||
tracer := connTracer{
|
||||
metricsCollector: metricsCollector,
|
||||
}
|
||||
return &logging.ConnectionTracer{
|
||||
StartedConnection: tracer.StartedConnection,
|
||||
NegotiatedVersion: tracer.NegotiatedVersion,
|
||||
ClosedConnection: tracer.ClosedConnection,
|
||||
SentTransportParameters: tracer.SentTransportParameters,
|
||||
ReceivedTransportParameters: tracer.ReceivedTransportParameters,
|
||||
RestoredTransportParameters: tracer.RestoredTransportParameters,
|
||||
SentLongHeaderPacket: tracer.SentLongHeaderPacket,
|
||||
SentShortHeaderPacket: tracer.SentShortHeaderPacket,
|
||||
ReceivedVersionNegotiationPacket: tracer.ReceivedVersionNegotiationPacket,
|
||||
ReceivedRetry: tracer.ReceivedRetry,
|
||||
ReceivedLongHeaderPacket: tracer.ReceivedLongHeaderPacket,
|
||||
ReceivedShortHeaderPacket: tracer.ReceivedShortHeaderPacket,
|
||||
BufferedPacket: tracer.BufferedPacket,
|
||||
DroppedPacket: tracer.DroppedPacket,
|
||||
UpdatedMetrics: tracer.UpdatedMetrics,
|
||||
AcknowledgedPacket: tracer.AcknowledgedPacket,
|
||||
LostPacket: tracer.LostPacket,
|
||||
UpdatedCongestionState: tracer.UpdatedCongestionState,
|
||||
UpdatedPTOCount: tracer.UpdatedPTOCount,
|
||||
UpdatedKeyFromTLS: tracer.UpdatedKeyFromTLS,
|
||||
UpdatedKey: tracer.UpdatedKey,
|
||||
DroppedEncryptionLevel: tracer.DroppedEncryptionLevel,
|
||||
DroppedKey: tracer.DroppedKey,
|
||||
SetLossTimer: tracer.SetLossTimer,
|
||||
LossTimerExpired: tracer.LossTimerExpired,
|
||||
LossTimerCanceled: tracer.LossTimerCanceled,
|
||||
ECNStateUpdated: tracer.ECNStateUpdated,
|
||||
Close: tracer.Close,
|
||||
Debug: tracer.Debug,
|
||||
}
|
||||
}
|
||||
|
||||
func (ct *connTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) {
|
||||
|
@ -90,7 +109,7 @@ func (ct *connTracer) BufferedPacket(pt logging.PacketType, size logging.ByteCou
|
|||
ct.metricsCollector.bufferedPackets(pt)
|
||||
}
|
||||
|
||||
func (ct *connTracer) DroppedPacket(pt logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) {
|
||||
func (ct *connTracer) DroppedPacket(pt logging.PacketType, number logging.PacketNumber, size logging.ByteCount, reason logging.PacketDropReason) {
|
||||
ct.metricsCollector.droppedPackets(pt, size, reason)
|
||||
}
|
||||
|
||||
|
@ -114,10 +133,10 @@ func (ct *connTracer) ReceivedTransportParameters(parameters *logging.TransportP
|
|||
func (ct *connTracer) RestoredTransportParameters(parameters *logging.TransportParameters) {
|
||||
}
|
||||
|
||||
func (ct *connTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
|
||||
func (ct *connTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) {
|
||||
}
|
||||
|
||||
func (ct *connTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
|
||||
func (ct *connTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) {
|
||||
}
|
||||
|
||||
func (ct *connTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
|
||||
|
@ -126,10 +145,10 @@ func (ct *connTracer) ReceivedVersionNegotiationPacket(dest, src logging.Arbitra
|
|||
func (ct *connTracer) ReceivedRetry(header *logging.Header) {
|
||||
}
|
||||
|
||||
func (ct *connTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, frames []logging.Frame) {
|
||||
func (ct *connTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) {
|
||||
}
|
||||
|
||||
func (ct *connTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, frames []logging.Frame) {
|
||||
func (ct *connTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) {
|
||||
}
|
||||
|
||||
func (ct *connTracer) AcknowledgedPacket(level logging.EncryptionLevel, number logging.PacketNumber) {
|
||||
|
@ -162,6 +181,10 @@ func (ct *connTracer) LossTimerExpired(timerType logging.TimerType, level loggin
|
|||
func (ct *connTracer) LossTimerCanceled() {
|
||||
}
|
||||
|
||||
func (ct *connTracer) ECNStateUpdated(state logging.ECNState, trigger logging.ECNStateTrigger) {
|
||||
|
||||
}
|
||||
|
||||
func (ct *connTracer) Close() {
|
||||
}
|
||||
|
||||
|
|
|
@ -597,7 +597,6 @@ func (e *EdgeTunnelServer) serveQUIC(
|
|||
MaxIncomingStreams: quicpogs.MaxIncomingStreams,
|
||||
MaxIncomingUniStreams: quicpogs.MaxIncomingStreams,
|
||||
EnableDatagrams: true,
|
||||
MaxDatagramFrameSize: quicpogs.MaxDatagramFrameSize,
|
||||
Tracer: quicpogs.NewClientTracer(connLogger.Logger(), connIndex),
|
||||
DisablePathMTUDiscovery: e.config.DisableQUICPathMTUDiscovery,
|
||||
}
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
Copyright (c) 2013 CloudFlare, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of the CloudFlare, Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -1,3 +0,0 @@
|
|||
cover.out~
|
||||
benchmark/benchmark
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
# Copyright (c) 2013 CloudFlare, Inc.
|
||||
|
||||
RACE+=--race
|
||||
|
||||
PKGNAME=github.com/cloudflare/golibs/lrucache
|
||||
SKIPCOVER=list.go|list_extension.go|priorityqueue.go
|
||||
|
||||
.PHONY: all test bench cover clean
|
||||
|
||||
all:
|
||||
@echo "Targets:"
|
||||
@echo " test: run tests with race detector"
|
||||
@echo " cover: print test coverage"
|
||||
@echo " bench: run basic benchmarks"
|
||||
|
||||
test:
|
||||
@go test $(RACE) -bench=. -v $(PKGNAME)
|
||||
|
||||
COVEROUT=cover.out
|
||||
cover:
|
||||
@go test -coverprofile=$(COVEROUT) -v $(PKGNAME)
|
||||
@cat $(COVEROUT) | egrep -v '$(SKIPCOVER)' > $(COVEROUT)~
|
||||
@go tool cover -func=$(COVEROUT)~|sed 's|^.*/\([^/]*/[^/]*/[^/]*\)$$|\1|g'
|
||||
|
||||
bench:
|
||||
@echo "[*] Scalability of cache/lrucache"
|
||||
@echo "[ ] Operations in shared cache using one core"
|
||||
@GOMAXPROCS=1 go test -run=- -bench='.*LRUCache.*' $(PKGNAME) \
|
||||
| egrep -v "^PASS|^ok"
|
||||
|
||||
@echo "[*] Scalability of cache/multilru"
|
||||
@echo "[ ] Operations in four caches using four cores "
|
||||
@GOMAXPROCS=4 go test -run=- -bench='.*MultiLRU.*' $(PKGNAME) \
|
||||
| egrep -v "^PASS|^ok"
|
||||
|
||||
|
||||
@(cd benchmark; go build $(PKGNAME)/benchmark)
|
||||
@./benchmark/benchmark
|
||||
|
||||
clean:
|
||||
rm -rf $(COVEROUT) $(COVEROUT)~ benchmark/benchmark
|
|
@ -1,40 +0,0 @@
|
|||
LRU Cache
|
||||
---------
|
||||
|
||||
A `golang` implementation of last recently used cache data structure.
|
||||
|
||||
To install:
|
||||
|
||||
go get github.com/cloudflare/golibs/lrucache
|
||||
|
||||
To test:
|
||||
|
||||
cd $GOPATH/src/github.com/cloudflare/golibs/lrucache
|
||||
make test
|
||||
|
||||
For coverage:
|
||||
|
||||
make cover
|
||||
|
||||
Basic benchmarks:
|
||||
|
||||
$ make bench # As tested on my two core i5
|
||||
[*] Scalability of cache/lrucache
|
||||
[ ] Operations in shared cache using one core
|
||||
BenchmarkConcurrentGetLRUCache 5000000 450 ns/op
|
||||
BenchmarkConcurrentSetLRUCache 2000000 821 ns/op
|
||||
BenchmarkConcurrentSetNXLRUCache 5000000 664 ns/op
|
||||
|
||||
[*] Scalability of cache/multilru
|
||||
[ ] Operations in four caches using four cores
|
||||
BenchmarkConcurrentGetMultiLRU-4 5000000 475 ns/op
|
||||
BenchmarkConcurrentSetMultiLRU-4 2000000 809 ns/op
|
||||
BenchmarkConcurrentSetNXMultiLRU-4 5000000 643 ns/op
|
||||
|
||||
[*] Capacity=4096 Keys=30000 KeySpace=15625
|
||||
vitess LRUCache MultiLRUCache-4
|
||||
create 1.709us 1.626374ms 343.54us
|
||||
Get (miss) 144.266083ms 132.470397ms 177.277193ms
|
||||
SetNX #1 338.637977ms 380.733302ms 411.709204ms
|
||||
Get (hit) 195.896066ms 173.252112ms 234.109494ms
|
||||
SetNX #2 349.785951ms 367.255624ms 419.129127ms
|
|
@ -1,69 +0,0 @@
|
|||
// Copyright (c) 2013 CloudFlare, Inc.
|
||||
|
||||
// Package lrucache implements a last recently used cache data structure.
|
||||
//
|
||||
// This code tries to avoid dynamic memory allocations - all required
|
||||
// memory is allocated on creation. Access to the data structure is
|
||||
// O(1). Modification O(log(n)) if expiry is used, O(1)
|
||||
// otherwise.
|
||||
//
|
||||
// This package exports three things:
|
||||
// LRUCache: is the main implementation. It supports multithreading by
|
||||
// using guarding mutex lock.
|
||||
//
|
||||
// MultiLRUCache: is a sharded implementation. It supports the same
|
||||
// API as LRUCache and uses it internally, but is not limited to
|
||||
// a single CPU as every shard is separately locked. Use this
|
||||
// data structure instead of LRUCache if you have have lock
|
||||
// contention issues.
|
||||
//
|
||||
// Cache interface: Both implementations fulfill it.
|
||||
package lrucache
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cache interface is fulfilled by the LRUCache and MultiLRUCache
|
||||
// implementations.
|
||||
type Cache interface {
|
||||
// Methods not needing to know current time.
|
||||
//
|
||||
// Get a key from the cache, possibly stale. Update its LRU
|
||||
// score.
|
||||
Get(key string) (value interface{}, ok bool)
|
||||
// Get a key from the cache, possibly stale. Don't modify its LRU score. O(1)
|
||||
GetQuiet(key string) (value interface{}, ok bool)
|
||||
// Get and remove a key from the cache.
|
||||
Del(key string) (value interface{}, ok bool)
|
||||
// Evict all items from the cache.
|
||||
Clear() int
|
||||
// Number of entries used in the LRU
|
||||
Len() int
|
||||
// Get the total capacity of the LRU
|
||||
Capacity() int
|
||||
|
||||
// Methods use time.Now() when neccessary to determine expiry.
|
||||
//
|
||||
// Add an item to the cache overwriting existing one if it
|
||||
// exists.
|
||||
Set(key string, value interface{}, expire time.Time)
|
||||
// Get a key from the cache, make sure it's not stale. Update
|
||||
// its LRU score.
|
||||
GetNotStale(key string) (value interface{}, ok bool)
|
||||
// Evict all the expired items.
|
||||
Expire() int
|
||||
|
||||
// Methods allowing to explicitly specify time used to
|
||||
// determine if items are expired.
|
||||
//
|
||||
// Add an item to the cache overwriting existing one if it
|
||||
// exists. Allows specifing current time required to expire an
|
||||
// item when no more slots are used.
|
||||
SetNow(key string, value interface{}, expire time.Time, now time.Time)
|
||||
// Get a key from the cache, make sure it's not stale. Update
|
||||
// its LRU score.
|
||||
GetNotStaleNow(key string, now time.Time) (value interface{}, ok bool)
|
||||
// Evict items that expire before Now.
|
||||
ExpireNow(now time.Time) int
|
||||
}
|
|
@ -1,238 +0,0 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/* This file is a slightly modified file from the go package sources
|
||||
and is released on the following license:
|
||||
|
||||
Copyright (c) 2012 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
// Package list implements a doubly linked list.
|
||||
//
|
||||
// To iterate over a list (where l is a *List):
|
||||
// for e := l.Front(); e != nil; e = e.Next() {
|
||||
// // do something with e.Value
|
||||
// }
|
||||
//
|
||||
|
||||
package lrucache
|
||||
|
||||
// Element is an element of a linked list.
|
||||
type element struct {
|
||||
// Next and previous pointers in the doubly-linked list of elements.
|
||||
// To simplify the implementation, internally a list l is implemented
|
||||
// as a ring, such that &l.root is both the next element of the last
|
||||
// list element (l.Back()) and the previous element of the first list
|
||||
// element (l.Front()).
|
||||
next, prev *element
|
||||
|
||||
// The list to which this element belongs.
|
||||
list *list
|
||||
|
||||
// The value stored with this element.
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// Next returns the next list element or nil.
|
||||
func (e *element) Next() *element {
|
||||
if p := e.next; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prev returns the previous list element or nil.
|
||||
func (e *element) Prev() *element {
|
||||
if p := e.prev; e.list != nil && p != &e.list.root {
|
||||
return p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List represents a doubly linked list.
|
||||
// The zero value for List is an empty list ready to use.
|
||||
type list struct {
|
||||
root element // sentinel list element, only &root, root.prev, and root.next are used
|
||||
len int // current list length excluding (this) sentinel element
|
||||
}
|
||||
|
||||
// Init initializes or clears list l.
|
||||
func (l *list) Init() *list {
|
||||
l.root.next = &l.root
|
||||
l.root.prev = &l.root
|
||||
l.len = 0
|
||||
return l
|
||||
}
|
||||
|
||||
// New returns an initialized list.
|
||||
// func New() *list { return new(list).Init() }
|
||||
|
||||
// Len returns the number of elements of list l.
|
||||
// The complexity is O(1).
|
||||
func (l *list) Len() int { return l.len }
|
||||
|
||||
// Front returns the first element of list l or nil
|
||||
func (l *list) Front() *element {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.next
|
||||
}
|
||||
|
||||
// Back returns the last element of list l or nil.
|
||||
func (l *list) Back() *element {
|
||||
if l.len == 0 {
|
||||
return nil
|
||||
}
|
||||
return l.root.prev
|
||||
}
|
||||
|
||||
// insert inserts e after at, increments l.len, and returns e.
|
||||
func (l *list) insert(e, at *element) *element {
|
||||
n := at.next
|
||||
at.next = e
|
||||
e.prev = at
|
||||
e.next = n
|
||||
n.prev = e
|
||||
e.list = l
|
||||
l.len++
|
||||
return e
|
||||
}
|
||||
|
||||
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
|
||||
func (l *list) insertValue(v interface{}, at *element) *element {
|
||||
return l.insert(&element{Value: v}, at)
|
||||
}
|
||||
|
||||
// remove removes e from its list, decrements l.len, and returns e.
|
||||
func (l *list) remove(e *element) *element {
|
||||
e.prev.next = e.next
|
||||
e.next.prev = e.prev
|
||||
e.next = nil // avoid memory leaks
|
||||
e.prev = nil // avoid memory leaks
|
||||
e.list = nil
|
||||
l.len--
|
||||
return e
|
||||
}
|
||||
|
||||
// Remove removes e from l if e is an element of list l.
|
||||
// It returns the element value e.Value.
|
||||
func (l *list) Remove(e *element) interface{} {
|
||||
if e.list == l {
|
||||
// if e.list == l, l must have been initialized when e was inserted
|
||||
// in l or l == nil (e is a zero Element) and l.remove will crash
|
||||
l.remove(e)
|
||||
}
|
||||
return e.Value
|
||||
}
|
||||
|
||||
// PushFront inserts a new element e with value v at the front of list l and returns e.
|
||||
func (l *list) PushFront(v interface{}) *element {
|
||||
return l.insertValue(v, &l.root)
|
||||
}
|
||||
|
||||
// PushBack inserts a new element e with value v at the back of list l and returns e.
|
||||
func (l *list) PushBack(v interface{}) *element {
|
||||
return l.insertValue(v, l.root.prev)
|
||||
}
|
||||
|
||||
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
func (l *list) InsertBefore(v interface{}, mark *element) *element {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark.prev)
|
||||
}
|
||||
|
||||
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
|
||||
// If mark is not an element of l, the list is not modified.
|
||||
func (l *list) InsertAfter(v interface{}, mark *element) *element {
|
||||
if mark.list != l {
|
||||
return nil
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
return l.insertValue(v, mark)
|
||||
}
|
||||
|
||||
// MoveToFront moves element e to the front of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
func (l *list) MoveToFront(e *element) {
|
||||
if e.list != l || l.root.next == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.insert(l.remove(e), &l.root)
|
||||
}
|
||||
|
||||
// MoveToBack moves element e to the back of list l.
|
||||
// If e is not an element of l, the list is not modified.
|
||||
func (l *list) MoveToBack(e *element) {
|
||||
if e.list != l || l.root.prev == e {
|
||||
return
|
||||
}
|
||||
// see comment in List.Remove about initialization of l
|
||||
l.insert(l.remove(e), l.root.prev)
|
||||
}
|
||||
|
||||
// MoveBefore moves element e to its new position before mark.
|
||||
// If e is not an element of l, or e == mark, the list is not modified.
|
||||
func (l *list) MoveBefore(e, mark *element) {
|
||||
if e.list != l || e == mark {
|
||||
return
|
||||
}
|
||||
l.insert(l.remove(e), mark.prev)
|
||||
}
|
||||
|
||||
// MoveAfter moves element e to its new position after mark.
|
||||
// If e is not an element of l, or e == mark, the list is not modified.
|
||||
func (l *list) MoveAfter(e, mark *element) {
|
||||
if e.list != l || e == mark {
|
||||
return
|
||||
}
|
||||
l.insert(l.remove(e), mark)
|
||||
}
|
||||
|
||||
// PushBackList inserts a copy of an other list at the back of list l.
|
||||
// The lists l and other may be the same.
|
||||
func (l *list) PushBackList(other *list) {
|
||||
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
|
||||
l.insertValue(e.Value, l.root.prev)
|
||||
}
|
||||
}
|
||||
|
||||
// PushFrontList inserts a copy of an other list at the front of list l.
|
||||
// The lists l and other may be the same.
|
||||
func (l *list) PushFrontList(other *list) {
|
||||
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
|
||||
l.insertValue(e.Value, &l.root)
|
||||
}
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
// Copyright (c) 2013 CloudFlare, Inc.
|
||||
|
||||
// Extensions to "container/list" that allowing reuse of Elements.
|
||||
|
||||
package lrucache
|
||||
|
||||
func (l *list) PushElementFront(e *element) *element {
|
||||
return l.insert(e, &l.root)
|
||||
}
|
||||
|
||||
func (l *list) PushElementBack(e *element) *element {
|
||||
return l.insert(e, l.root.prev)
|
||||
}
|
||||
|
||||
func (l *list) PopElementFront() *element {
|
||||
el := l.Front()
|
||||
l.Remove(el)
|
||||
return el
|
||||
}
|
||||
|
||||
func (l *list) PopFront() interface{} {
|
||||
el := l.Front()
|
||||
l.Remove(el)
|
||||
return el.Value
|
||||
}
|
|
@ -1,316 +0,0 @@
|
|||
// Copyright (c) 2013 CloudFlare, Inc.
|
||||
|
||||
package lrucache
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Every element in the cache is linked to three data structures:
|
||||
// Table map, PriorityQueue heap ordered by expiry and a LruList list
|
||||
// ordered by decreasing popularity.
|
||||
type entry struct {
|
||||
element element // list element. value is a pointer to this entry
|
||||
key string // key is a key!
|
||||
value interface{} //
|
||||
expire time.Time // time when the item is expired. it's okay to be stale.
|
||||
index int // index for priority queue needs. -1 if entry is free
|
||||
}
|
||||
|
||||
// LRUCache data structure. Never dereference it or copy it by
|
||||
// value. Always use it through a pointer.
|
||||
type LRUCache struct {
|
||||
lock sync.Mutex
|
||||
table map[string]*entry // all entries in table must be in lruList
|
||||
priorityQueue priorityQueue // some elements from table may be in priorityQueue
|
||||
lruList list // every entry is either used and resides in lruList
|
||||
freeList list // or free and is linked to freeList
|
||||
|
||||
ExpireGracePeriod time.Duration // time after an expired entry is purged from cache (unless pushed out of LRU)
|
||||
}
|
||||
|
||||
// Initialize the LRU cache instance. O(capacity)
|
||||
func (b *LRUCache) Init(capacity uint) {
|
||||
b.table = make(map[string]*entry, capacity)
|
||||
b.priorityQueue = make([]*entry, 0, capacity)
|
||||
b.lruList.Init()
|
||||
b.freeList.Init()
|
||||
heap.Init(&b.priorityQueue)
|
||||
|
||||
// Reserve all the entries in one giant continous block of memory
|
||||
arrayOfEntries := make([]entry, capacity)
|
||||
for i := uint(0); i < capacity; i++ {
|
||||
e := &arrayOfEntries[i]
|
||||
e.element.Value = e
|
||||
e.index = -1
|
||||
b.freeList.PushElementBack(&e.element)
|
||||
}
|
||||
}
|
||||
|
||||
// Create new LRU cache instance. Allocate all the needed memory. O(capacity)
|
||||
func NewLRUCache(capacity uint) *LRUCache {
|
||||
b := &LRUCache{}
|
||||
b.Init(capacity)
|
||||
return b
|
||||
}
|
||||
|
||||
// Give me the entry with lowest expiry field if it's before now.
|
||||
func (b *LRUCache) expiredEntry(now time.Time) *entry {
|
||||
if len(b.priorityQueue) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if now.IsZero() {
|
||||
// Fill it only when actually used.
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
if e := b.priorityQueue[0]; e.expire.Before(now) {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Give me the least used entry.
|
||||
func (b *LRUCache) leastUsedEntry() *entry {
|
||||
return b.lruList.Back().Value.(*entry)
|
||||
}
|
||||
|
||||
func (b *LRUCache) freeSomeEntry(now time.Time) (e *entry, used bool) {
|
||||
if b.freeList.Len() > 0 {
|
||||
return b.freeList.Front().Value.(*entry), false
|
||||
}
|
||||
|
||||
e = b.expiredEntry(now)
|
||||
if e != nil {
|
||||
return e, true
|
||||
}
|
||||
|
||||
if b.lruList.Len() == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return b.leastUsedEntry(), true
|
||||
}
|
||||
|
||||
// Move entry from used/lru list to a free list. Clear the entry as well.
|
||||
func (b *LRUCache) removeEntry(e *entry) {
|
||||
if e.element.list != &b.lruList {
|
||||
panic("list lruList")
|
||||
}
|
||||
|
||||
if e.index != -1 {
|
||||
heap.Remove(&b.priorityQueue, e.index)
|
||||
}
|
||||
b.lruList.Remove(&e.element)
|
||||
b.freeList.PushElementFront(&e.element)
|
||||
delete(b.table, e.key)
|
||||
e.key = ""
|
||||
e.value = nil
|
||||
}
|
||||
|
||||
func (b *LRUCache) insertEntry(e *entry) {
|
||||
if e.element.list != &b.freeList {
|
||||
panic("list freeList")
|
||||
}
|
||||
|
||||
if !e.expire.IsZero() {
|
||||
heap.Push(&b.priorityQueue, e)
|
||||
}
|
||||
b.freeList.Remove(&e.element)
|
||||
b.lruList.PushElementFront(&e.element)
|
||||
b.table[e.key] = e
|
||||
}
|
||||
|
||||
func (b *LRUCache) touchEntry(e *entry) {
|
||||
b.lruList.MoveToFront(&e.element)
|
||||
}
|
||||
|
||||
// Add an item to the cache overwriting existing one if it
|
||||
// exists. Allows specifing current time required to expire an item
|
||||
// when no more slots are used. O(log(n)) if expiry is set, O(1) when
|
||||
// clear.
|
||||
func (b *LRUCache) SetNow(key string, value interface{}, expire time.Time, now time.Time) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
var used bool
|
||||
|
||||
e := b.table[key]
|
||||
if e != nil {
|
||||
used = true
|
||||
} else {
|
||||
e, used = b.freeSomeEntry(now)
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if used {
|
||||
b.removeEntry(e)
|
||||
}
|
||||
|
||||
e.key = key
|
||||
e.value = value
|
||||
e.expire = expire
|
||||
b.insertEntry(e)
|
||||
}
|
||||
|
||||
// Add an item to the cache overwriting existing one if it
|
||||
// exists. O(log(n)) if expiry is set, O(1) when clear.
|
||||
func (b *LRUCache) Set(key string, value interface{}, expire time.Time) {
|
||||
b.SetNow(key, value, expire, time.Time{})
|
||||
}
|
||||
|
||||
// Get a key from the cache, possibly stale. Update its LRU score. O(1)
|
||||
func (b *LRUCache) Get(key string) (v interface{}, ok bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
e := b.table[key]
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
b.touchEntry(e)
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
// Get a key from the cache, possibly stale. Don't modify its LRU score. O(1)
|
||||
func (b *LRUCache) GetQuiet(key string) (v interface{}, ok bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
e := b.table[key]
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
// Get a key from the cache, make sure it's not stale. Update its
|
||||
// LRU score. O(log(n)) if the item is expired.
|
||||
func (b *LRUCache) GetNotStale(key string) (value interface{}, ok bool) {
|
||||
return b.GetNotStaleNow(key, time.Now())
|
||||
}
|
||||
|
||||
// Get a key from the cache, make sure it's not stale. Update its
|
||||
// LRU score. O(log(n)) if the item is expired.
|
||||
func (b *LRUCache) GetNotStaleNow(key string, now time.Time) (value interface{}, ok bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
e := b.table[key]
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if e.expire.Before(now) {
|
||||
// Remove entries expired for more than a graceful period
|
||||
if b.ExpireGracePeriod == 0 || e.expire.Sub(now) > b.ExpireGracePeriod {
|
||||
b.removeEntry(e)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
b.touchEntry(e)
|
||||
return e.value, true
|
||||
}
|
||||
|
||||
// Get a key from the cache, possibly stale. Update its LRU
|
||||
// score. O(1) always.
|
||||
func (b *LRUCache) GetStale(key string) (value interface{}, ok, expired bool) {
|
||||
return b.GetStaleNow(key, time.Now())
|
||||
}
|
||||
|
||||
// Get a key from the cache, possibly stale. Update its LRU
|
||||
// score. O(1) always.
|
||||
func (b *LRUCache) GetStaleNow(key string, now time.Time) (value interface{}, ok, expired bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
e := b.table[key]
|
||||
if e == nil {
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
b.touchEntry(e)
|
||||
return e.value, true, e.expire.Before(now)
|
||||
}
|
||||
|
||||
// Get and remove a key from the cache. O(log(n)) if the item is using expiry, O(1) otherwise.
|
||||
func (b *LRUCache) Del(key string) (v interface{}, ok bool) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
e := b.table[key]
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
value := e.value
|
||||
b.removeEntry(e)
|
||||
return value, true
|
||||
}
|
||||
|
||||
// Evict all items from the cache. O(n*log(n))
|
||||
func (b *LRUCache) Clear() int {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
// First, remove entries that have expiry set
|
||||
l := len(b.priorityQueue)
|
||||
for i := 0; i < l; i++ {
|
||||
// This could be reduced to O(n).
|
||||
b.removeEntry(b.priorityQueue[0])
|
||||
}
|
||||
|
||||
// Second, remove all remaining entries
|
||||
r := b.lruList.Len()
|
||||
for i := 0; i < r; i++ {
|
||||
b.removeEntry(b.leastUsedEntry())
|
||||
}
|
||||
return l + r
|
||||
}
|
||||
|
||||
// Evict all the expired items. O(n*log(n))
|
||||
func (b *LRUCache) Expire() int {
|
||||
return b.ExpireNow(time.Now())
|
||||
}
|
||||
|
||||
// Evict items that expire before `now`. O(n*log(n))
|
||||
func (b *LRUCache) ExpireNow(now time.Time) int {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
i := 0
|
||||
for {
|
||||
e := b.expiredEntry(now)
|
||||
if e == nil {
|
||||
break
|
||||
}
|
||||
b.removeEntry(e)
|
||||
i += 1
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// Number of entries used in the LRU
|
||||
func (b *LRUCache) Len() int {
|
||||
// yes. this stupid thing requires locking
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
return b.lruList.Len()
|
||||
}
|
||||
|
||||
// Get the total capacity of the LRU
|
||||
func (b *LRUCache) Capacity() int {
|
||||
// yes. this stupid thing requires locking
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
return b.lruList.Len() + b.freeList.Len()
|
||||
}
|
|
@ -1,118 +0,0 @@
|
|||
// Copyright (c) 2013 CloudFlare, Inc.
|
||||
|
||||
package lrucache
|
||||
|
||||
import (
|
||||
"hash/crc32"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MultiLRUCache data structure. Never dereference it or copy it by
|
||||
// value. Always use it through a pointer.
|
||||
type MultiLRUCache struct {
|
||||
buckets uint
|
||||
cache []*LRUCache
|
||||
}
|
||||
|
||||
// Using this constructor is almost always wrong. Use NewMultiLRUCache instead.
|
||||
func (m *MultiLRUCache) Init(buckets, bucket_capacity uint) {
|
||||
m.buckets = buckets
|
||||
m.cache = make([]*LRUCache, buckets)
|
||||
for i := uint(0); i < buckets; i++ {
|
||||
m.cache[i] = NewLRUCache(bucket_capacity)
|
||||
}
|
||||
}
|
||||
|
||||
// Set the stale expiry grace period for each cache in the multicache instance.
|
||||
func (m *MultiLRUCache) SetExpireGracePeriod(p time.Duration) {
|
||||
for _, c := range m.cache {
|
||||
c.ExpireGracePeriod = p
|
||||
}
|
||||
}
|
||||
|
||||
func NewMultiLRUCache(buckets, bucket_capacity uint) *MultiLRUCache {
|
||||
m := &MultiLRUCache{}
|
||||
m.Init(buckets, bucket_capacity)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) bucketNo(key string) uint {
|
||||
// Arbitrary choice. Any fast hash will do.
|
||||
return uint(crc32.ChecksumIEEE([]byte(key))) % m.buckets
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Set(key string, value interface{}, expire time.Time) {
|
||||
m.cache[m.bucketNo(key)].Set(key, value, expire)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) SetNow(key string, value interface{}, expire time.Time, now time.Time) {
|
||||
m.cache[m.bucketNo(key)].SetNow(key, value, expire, now)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Get(key string) (value interface{}, ok bool) {
|
||||
return m.cache[m.bucketNo(key)].Get(key)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) GetQuiet(key string) (value interface{}, ok bool) {
|
||||
return m.cache[m.bucketNo(key)].Get(key)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) GetNotStale(key string) (value interface{}, ok bool) {
|
||||
return m.cache[m.bucketNo(key)].GetNotStale(key)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) GetNotStaleNow(key string, now time.Time) (value interface{}, ok bool) {
|
||||
return m.cache[m.bucketNo(key)].GetNotStaleNow(key, now)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) GetStale(key string) (value interface{}, ok, expired bool) {
|
||||
return m.cache[m.bucketNo(key)].GetStale(key)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) GetStaleNow(key string, now time.Time) (value interface{}, ok, expired bool) {
|
||||
return m.cache[m.bucketNo(key)].GetStaleNow(key, now)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Del(key string) (value interface{}, ok bool) {
|
||||
return m.cache[m.bucketNo(key)].Del(key)
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Clear() int {
|
||||
var s int
|
||||
for _, c := range m.cache {
|
||||
s += c.Clear()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Len() int {
|
||||
var s int
|
||||
for _, c := range m.cache {
|
||||
s += c.Len()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Capacity() int {
|
||||
var s int
|
||||
for _, c := range m.cache {
|
||||
s += c.Capacity()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) Expire() int {
|
||||
var s int
|
||||
for _, c := range m.cache {
|
||||
s += c.Expire()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *MultiLRUCache) ExpireNow(now time.Time) int {
|
||||
var s int
|
||||
for _, c := range m.cache {
|
||||
s += c.ExpireNow(now)
|
||||
}
|
||||
return s
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
// Copyright (c) 2013 CloudFlare, Inc.
|
||||
|
||||
// This code is based on golang example from "container/heap" package.
|
||||
|
||||
package lrucache
|
||||
|
||||
type priorityQueue []*entry
|
||||
|
||||
func (pq priorityQueue) Len() int {
|
||||
return len(pq)
|
||||
}
|
||||
|
||||
func (pq priorityQueue) Less(i, j int) bool {
|
||||
return pq[i].expire.Before(pq[j].expire)
|
||||
}
|
||||
|
||||
func (pq priorityQueue) Swap(i, j int) {
|
||||
pq[i], pq[j] = pq[j], pq[i]
|
||||
pq[i].index = i
|
||||
pq[j].index = j
|
||||
}
|
||||
|
||||
func (pq *priorityQueue) Push(e interface{}) {
|
||||
n := len(*pq)
|
||||
item := e.(*entry)
|
||||
item.index = n
|
||||
*pq = append(*pq, item)
|
||||
}
|
||||
|
||||
func (pq *priorityQueue) Pop() interface{} {
|
||||
old := *pq
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
item.index = -1
|
||||
*pq = old[0 : n-1]
|
||||
return item
|
||||
}
|
|
@ -6,7 +6,6 @@ linters:
|
|||
disable-all: true
|
||||
enable:
|
||||
- asciicheck
|
||||
- deadcode
|
||||
- errcheck
|
||||
- forcetypeassert
|
||||
- gocritic
|
||||
|
@ -18,10 +17,8 @@ linters:
|
|||
- misspell
|
||||
- revive
|
||||
- staticcheck
|
||||
- structcheck
|
||||
- typecheck
|
||||
- unused
|
||||
- varcheck
|
||||
|
||||
issues:
|
||||
exclude-use-default: false
|
||||
|
|
|
@ -20,35 +20,5 @@ package logr
|
|||
// used whenever the caller is not interested in the logs. Logger instances
|
||||
// produced by this function always compare as equal.
|
||||
func Discard() Logger {
|
||||
return Logger{
|
||||
level: 0,
|
||||
sink: discardLogSink{},
|
||||
}
|
||||
}
|
||||
|
||||
// discardLogSink is a LogSink that discards all messages.
|
||||
type discardLogSink struct{}
|
||||
|
||||
// Verify that it actually implements the interface
|
||||
var _ LogSink = discardLogSink{}
|
||||
|
||||
func (l discardLogSink) Init(RuntimeInfo) {
|
||||
}
|
||||
|
||||
func (l discardLogSink) Enabled(int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (l discardLogSink) Info(int, string, ...interface{}) {
|
||||
}
|
||||
|
||||
func (l discardLogSink) Error(error, string, ...interface{}) {
|
||||
}
|
||||
|
||||
func (l discardLogSink) WithValues(...interface{}) LogSink {
|
||||
return l
|
||||
}
|
||||
|
||||
func (l discardLogSink) WithName(string) LogSink {
|
||||
return l
|
||||
return New(nil)
|
||||
}
|
||||
|
|
|
@ -21,13 +21,13 @@ limitations under the License.
|
|||
// github.com/go-logr/logr.LogSink with output through an arbitrary
|
||||
// "write" function. See New and NewJSON for details.
|
||||
//
|
||||
// Custom LogSinks
|
||||
// # Custom LogSinks
|
||||
//
|
||||
// For users who need more control, a funcr.Formatter can be embedded inside
|
||||
// your own custom LogSink implementation. This is useful when the LogSink
|
||||
// needs to implement additional methods, for example.
|
||||
//
|
||||
// Formatting
|
||||
// # Formatting
|
||||
//
|
||||
// This will respect logr.Marshaler, fmt.Stringer, and error interfaces for
|
||||
// values which are being logged. When rendering a struct, funcr will use Go's
|
||||
|
@ -37,6 +37,7 @@ package funcr
|
|||
import (
|
||||
"bytes"
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
|
@ -217,7 +218,7 @@ func newFormatter(opts Options, outfmt outputFormat) Formatter {
|
|||
prefix: "",
|
||||
values: nil,
|
||||
depth: 0,
|
||||
opts: opts,
|
||||
opts: &opts,
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
@ -231,7 +232,7 @@ type Formatter struct {
|
|||
values []interface{}
|
||||
valuesStr string
|
||||
depth int
|
||||
opts Options
|
||||
opts *Options
|
||||
}
|
||||
|
||||
// outputFormat indicates which outputFormat to use.
|
||||
|
@ -447,6 +448,7 @@ func (f Formatter) prettyWithFlags(value interface{}, flags uint32, depth int) s
|
|||
if flags&flagRawStruct == 0 {
|
||||
buf.WriteByte('{')
|
||||
}
|
||||
printComma := false // testing i>0 is not enough because of JSON omitted fields
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
fld := t.Field(i)
|
||||
if fld.PkgPath != "" {
|
||||
|
@ -478,9 +480,10 @@ func (f Formatter) prettyWithFlags(value interface{}, flags uint32, depth int) s
|
|||
if omitempty && isEmpty(v.Field(i)) {
|
||||
continue
|
||||
}
|
||||
if i > 0 {
|
||||
if printComma {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
printComma = true // if we got here, we are rendering a field
|
||||
if fld.Anonymous && fld.Type.Kind() == reflect.Struct && name == "" {
|
||||
buf.WriteString(f.prettyWithFlags(v.Field(i).Interface(), flags|flagRawStruct, depth+1))
|
||||
continue
|
||||
|
@ -500,6 +503,20 @@ func (f Formatter) prettyWithFlags(value interface{}, flags uint32, depth int) s
|
|||
}
|
||||
return buf.String()
|
||||
case reflect.Slice, reflect.Array:
|
||||
// If this is outputing as JSON make sure this isn't really a json.RawMessage.
|
||||
// If so just emit "as-is" and don't pretty it as that will just print
|
||||
// it as [X,Y,Z,...] which isn't terribly useful vs the string form you really want.
|
||||
if f.outputFormat == outputJSON {
|
||||
if rm, ok := value.(json.RawMessage); ok {
|
||||
// If it's empty make sure we emit an empty value as the array style would below.
|
||||
if len(rm) > 0 {
|
||||
buf.Write(rm)
|
||||
} else {
|
||||
buf.WriteString("null")
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
}
|
||||
buf.WriteByte('[')
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if i > 0 {
|
||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||
// to back that API. Packages in the Go ecosystem can depend on this package,
|
||||
// while callers can implement logging with whatever backend is appropriate.
|
||||
//
|
||||
// Usage
|
||||
// # Usage
|
||||
//
|
||||
// Logging is done using a Logger instance. Logger is a concrete type with
|
||||
// methods, which defers the actual logging to a LogSink interface. The main
|
||||
|
@ -30,15 +30,19 @@ limitations under the License.
|
|||
// "structured logging".
|
||||
//
|
||||
// With Go's standard log package, we might write:
|
||||
//
|
||||
// log.Printf("setting target value %s", targetValue)
|
||||
//
|
||||
// With logr's structured logging, we'd write:
|
||||
//
|
||||
// logger.Info("setting target", "value", targetValue)
|
||||
//
|
||||
// Errors are much the same. Instead of:
|
||||
//
|
||||
// log.Printf("failed to open the pod bay door for user %s: %v", user, err)
|
||||
//
|
||||
// We'd write:
|
||||
//
|
||||
// logger.Error(err, "failed to open the pod bay door", "user", user)
|
||||
//
|
||||
// Info() and Error() are very similar, but they are separate methods so that
|
||||
|
@ -47,7 +51,7 @@ limitations under the License.
|
|||
// always logged, regardless of the current verbosity. If there is no error
|
||||
// instance available, passing nil is valid.
|
||||
//
|
||||
// Verbosity
|
||||
// # Verbosity
|
||||
//
|
||||
// Often we want to log information only when the application in "verbose
|
||||
// mode". To write log lines that are more verbose, Logger has a V() method.
|
||||
|
@ -58,14 +62,16 @@ limitations under the License.
|
|||
// Error messages do not have a verbosity level and are always logged.
|
||||
//
|
||||
// Where we might have written:
|
||||
//
|
||||
// if flVerbose >= 2 {
|
||||
// log.Printf("an unusual thing happened")
|
||||
// }
|
||||
//
|
||||
// We can write:
|
||||
//
|
||||
// logger.V(2).Info("an unusual thing happened")
|
||||
//
|
||||
// Logger Names
|
||||
// # Logger Names
|
||||
//
|
||||
// Logger instances can have name strings so that all messages logged through
|
||||
// that instance have additional context. For example, you might want to add
|
||||
|
@ -82,17 +88,19 @@ limitations under the License.
|
|||
// joining operation (e.g. whitespace, commas, periods, slashes, brackets,
|
||||
// quotes, etc).
|
||||
//
|
||||
// Saved Values
|
||||
// # Saved Values
|
||||
//
|
||||
// Logger instances can store any number of key/value pairs, which will be
|
||||
// logged alongside all messages logged through that instance. For example,
|
||||
// you might want to create a Logger instance per managed object:
|
||||
//
|
||||
// With the standard log package, we might write:
|
||||
//
|
||||
// log.Printf("decided to set field foo to value %q for object %s/%s",
|
||||
// targetValue, object.Namespace, object.Name)
|
||||
//
|
||||
// With logr we'd write:
|
||||
//
|
||||
// // Elsewhere: set up the logger to log the object name.
|
||||
// obj.logger = mainLogger.WithValues(
|
||||
// "name", obj.name, "namespace", obj.namespace)
|
||||
|
@ -100,7 +108,7 @@ limitations under the License.
|
|||
// // later on...
|
||||
// obj.logger.Info("setting foo", "value", targetValue)
|
||||
//
|
||||
// Best Practices
|
||||
// # Best Practices
|
||||
//
|
||||
// Logger has very few hard rules, with the goal that LogSink implementations
|
||||
// might have a lot of freedom to differentiate. There are, however, some
|
||||
|
@ -124,15 +132,15 @@ limitations under the License.
|
|||
// around. For cases where passing a logger is optional, a pointer to Logger
|
||||
// should be used.
|
||||
//
|
||||
// Key Naming Conventions
|
||||
// # Key Naming Conventions
|
||||
//
|
||||
// Keys are not strictly required to conform to any specification or regex, but
|
||||
// it is recommended that they:
|
||||
// * be human-readable and meaningful (not auto-generated or simple ordinals)
|
||||
// * be constant (not dependent on input data)
|
||||
// * contain only printable characters
|
||||
// * not contain whitespace or punctuation
|
||||
// * use lower case for simple keys and lowerCamelCase for more complex ones
|
||||
// - be human-readable and meaningful (not auto-generated or simple ordinals)
|
||||
// - be constant (not dependent on input data)
|
||||
// - contain only printable characters
|
||||
// - not contain whitespace or punctuation
|
||||
// - use lower case for simple keys and lowerCamelCase for more complex ones
|
||||
//
|
||||
// These guidelines help ensure that log data is processed properly regardless
|
||||
// of the log implementation. For example, log implementations will try to
|
||||
|
@ -141,24 +149,25 @@ limitations under the License.
|
|||
// While users are generally free to use key names of their choice, it's
|
||||
// generally best to avoid using the following keys, as they're frequently used
|
||||
// by implementations:
|
||||
// * "caller": the calling information (file/line) of a particular log line
|
||||
// * "error": the underlying error value in the `Error` method
|
||||
// * "level": the log level
|
||||
// * "logger": the name of the associated logger
|
||||
// * "msg": the log message
|
||||
// * "stacktrace": the stack trace associated with a particular log line or
|
||||
// - "caller": the calling information (file/line) of a particular log line
|
||||
// - "error": the underlying error value in the `Error` method
|
||||
// - "level": the log level
|
||||
// - "logger": the name of the associated logger
|
||||
// - "msg": the log message
|
||||
// - "stacktrace": the stack trace associated with a particular log line or
|
||||
// error (often from the `Error` message)
|
||||
// * "ts": the timestamp for a log line
|
||||
// - "ts": the timestamp for a log line
|
||||
//
|
||||
// Implementations are encouraged to make use of these keys to represent the
|
||||
// above concepts, when necessary (for example, in a pure-JSON output form, it
|
||||
// would be necessary to represent at least message and timestamp as ordinary
|
||||
// named values).
|
||||
//
|
||||
// Break Glass
|
||||
// # Break Glass
|
||||
//
|
||||
// Implementations may choose to give callers access to the underlying
|
||||
// logging implementation. The recommended pattern for this is:
|
||||
//
|
||||
// // Underlier exposes access to the underlying logging implementation.
|
||||
// // Since callers only have a logr.Logger, they have to know which
|
||||
// // implementation is in use, so this interface is less of an abstraction
|
||||
|
@ -168,8 +177,9 @@ limitations under the License.
|
|||
// }
|
||||
//
|
||||
// Logger grants access to the sink to enable type assertions like this:
|
||||
//
|
||||
// func DoSomethingWithImpl(log logr.Logger) {
|
||||
// if underlier, ok := log.GetSink()(impl.Underlier) {
|
||||
// if underlier, ok := log.GetSink().(impl.Underlier); ok {
|
||||
// implLogger := underlier.GetUnderlying()
|
||||
// ...
|
||||
// }
|
||||
|
@ -177,11 +187,12 @@ limitations under the License.
|
|||
//
|
||||
// Custom `With*` functions can be implemented by copying the complete
|
||||
// Logger struct and replacing the sink in the copy:
|
||||
//
|
||||
// // WithFooBar changes the foobar parameter in the log sink and returns a
|
||||
// // new logger with that modified sink. It does nothing for loggers where
|
||||
// // the sink doesn't support that parameter.
|
||||
// func WithFoobar(log logr.Logger, foobar int) logr.Logger {
|
||||
// if foobarLogSink, ok := log.GetSink()(FoobarSink); ok {
|
||||
// if foobarLogSink, ok := log.GetSink().(FoobarSink); ok {
|
||||
// log = log.WithSink(foobarLogSink.WithFooBar(foobar))
|
||||
// }
|
||||
// return log
|
||||
|
@ -201,11 +212,14 @@ import (
|
|||
)
|
||||
|
||||
// New returns a new Logger instance. This is primarily used by libraries
|
||||
// implementing LogSink, rather than end users.
|
||||
// implementing LogSink, rather than end users. Passing a nil sink will create
|
||||
// a Logger which discards all log lines.
|
||||
func New(sink LogSink) Logger {
|
||||
logger := Logger{}
|
||||
logger.setSink(sink)
|
||||
if sink != nil {
|
||||
sink.Init(runtimeInfo)
|
||||
}
|
||||
return logger
|
||||
}
|
||||
|
||||
|
@ -244,7 +258,7 @@ type Logger struct {
|
|||
// Enabled tests whether this Logger is enabled. For example, commandline
|
||||
// flags might be used to set the logging verbosity and disable some info logs.
|
||||
func (l Logger) Enabled() bool {
|
||||
return l.sink.Enabled(l.level)
|
||||
return l.sink != nil && l.sink.Enabled(l.level)
|
||||
}
|
||||
|
||||
// Info logs a non-error message with the given key/value pairs as context.
|
||||
|
@ -254,6 +268,9 @@ func (l Logger) Enabled() bool {
|
|||
// information. The key/value pairs must alternate string keys and arbitrary
|
||||
// values.
|
||||
func (l Logger) Info(msg string, keysAndValues ...interface{}) {
|
||||
if l.sink == nil {
|
||||
return
|
||||
}
|
||||
if l.Enabled() {
|
||||
if withHelper, ok := l.sink.(CallStackHelperLogSink); ok {
|
||||
withHelper.GetCallStackHelper()()
|
||||
|
@ -273,6 +290,9 @@ func (l Logger) Info(msg string, keysAndValues ...interface{}) {
|
|||
// triggered this log line, if present. The err parameter is optional
|
||||
// and nil may be passed instead of an error instance.
|
||||
func (l Logger) Error(err error, msg string, keysAndValues ...interface{}) {
|
||||
if l.sink == nil {
|
||||
return
|
||||
}
|
||||
if withHelper, ok := l.sink.(CallStackHelperLogSink); ok {
|
||||
withHelper.GetCallStackHelper()()
|
||||
}
|
||||
|
@ -284,6 +304,9 @@ func (l Logger) Error(err error, msg string, keysAndValues ...interface{}) {
|
|||
// level means a log message is less important. Negative V-levels are treated
|
||||
// as 0.
|
||||
func (l Logger) V(level int) Logger {
|
||||
if l.sink == nil {
|
||||
return l
|
||||
}
|
||||
if level < 0 {
|
||||
level = 0
|
||||
}
|
||||
|
@ -294,6 +317,9 @@ func (l Logger) V(level int) Logger {
|
|||
// WithValues returns a new Logger instance with additional key/value pairs.
|
||||
// See Info for documentation on how key/value pairs work.
|
||||
func (l Logger) WithValues(keysAndValues ...interface{}) Logger {
|
||||
if l.sink == nil {
|
||||
return l
|
||||
}
|
||||
l.setSink(l.sink.WithValues(keysAndValues...))
|
||||
return l
|
||||
}
|
||||
|
@ -304,6 +330,9 @@ func (l Logger) WithValues(keysAndValues ...interface{}) Logger {
|
|||
// contain only letters, digits, and hyphens (see the package documentation for
|
||||
// more information).
|
||||
func (l Logger) WithName(name string) Logger {
|
||||
if l.sink == nil {
|
||||
return l
|
||||
}
|
||||
l.setSink(l.sink.WithName(name))
|
||||
return l
|
||||
}
|
||||
|
@ -324,6 +353,9 @@ func (l Logger) WithName(name string) Logger {
|
|||
// WithCallDepth(1) because it works with implementions that support the
|
||||
// CallDepthLogSink and/or CallStackHelperLogSink interfaces.
|
||||
func (l Logger) WithCallDepth(depth int) Logger {
|
||||
if l.sink == nil {
|
||||
return l
|
||||
}
|
||||
if withCallDepth, ok := l.sink.(CallDepthLogSink); ok {
|
||||
l.setSink(withCallDepth.WithCallDepth(depth))
|
||||
}
|
||||
|
@ -345,6 +377,9 @@ func (l Logger) WithCallDepth(depth int) Logger {
|
|||
// implementation does not support either of these, the original Logger will be
|
||||
// returned.
|
||||
func (l Logger) WithCallStackHelper() (func(), Logger) {
|
||||
if l.sink == nil {
|
||||
return func() {}, l
|
||||
}
|
||||
var helper func()
|
||||
if withCallDepth, ok := l.sink.(CallDepthLogSink); ok {
|
||||
l.setSink(withCallDepth.WithCallDepth(1))
|
||||
|
@ -357,6 +392,11 @@ func (l Logger) WithCallStackHelper() (func(), Logger) {
|
|||
return helper, l
|
||||
}
|
||||
|
||||
// IsZero returns true if this logger is an uninitialized zero value
|
||||
func (l Logger) IsZero() bool {
|
||||
return l.sink == nil
|
||||
}
|
||||
|
||||
// contextKey is how we find Loggers in a context.Context.
|
||||
type contextKey struct{}
|
||||
|
||||
|
@ -442,7 +482,7 @@ type LogSink interface {
|
|||
WithName(name string) LogSink
|
||||
}
|
||||
|
||||
// CallDepthLogSink represents a Logger that knows how to climb the call stack
|
||||
// CallDepthLogSink represents a LogSink that knows how to climb the call stack
|
||||
// to identify the original call site and can offset the depth by a specified
|
||||
// number of frames. This is useful for users who have helper functions
|
||||
// between the "real" call site and the actual calls to Logger methods.
|
||||
|
@ -467,7 +507,7 @@ type CallDepthLogSink interface {
|
|||
WithCallDepth(depth int) LogSink
|
||||
}
|
||||
|
||||
// CallStackHelperLogSink represents a Logger that knows how to climb
|
||||
// CallStackHelperLogSink represents a LogSink that knows how to climb
|
||||
// the call stack to identify the original call site and can skip
|
||||
// intermediate helper functions if they mark themselves as
|
||||
// helper. Go's testing package uses that approach.
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
// Copyright 2019 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// +build !go1.12
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
)
|
||||
|
||||
func printModuleVersion() {
|
||||
log.Printf("No version information is available for Mockgen compiled with " +
|
||||
"version 1.11")
|
||||
}
|
|
@ -386,8 +386,14 @@ func (u *Unmarshaler) unmarshalMessage(m protoreflect.Message, in []byte) error
|
|||
}
|
||||
|
||||
func isSingularWellKnownValue(fd protoreflect.FieldDescriptor) bool {
|
||||
if fd.Cardinality() == protoreflect.Repeated {
|
||||
return false
|
||||
}
|
||||
if md := fd.Message(); md != nil {
|
||||
return md.FullName() == "google.protobuf.Value" && fd.Cardinality() != protoreflect.Repeated
|
||||
return md.FullName() == "google.protobuf.Value"
|
||||
}
|
||||
if ed := fd.Enum(); ed != nil {
|
||||
return ed.FullName() == "google.protobuf.NullValue"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
@ -50,6 +51,37 @@ func NewWithNoColorBool(noColor bool) Formatter {
|
|||
}
|
||||
|
||||
func New(colorMode ColorMode) Formatter {
|
||||
colorAliases := map[string]int{
|
||||
"black": 0,
|
||||
"red": 1,
|
||||
"green": 2,
|
||||
"yellow": 3,
|
||||
"blue": 4,
|
||||
"magenta": 5,
|
||||
"cyan": 6,
|
||||
"white": 7,
|
||||
}
|
||||
for colorAlias, n := range colorAliases {
|
||||
colorAliases[fmt.Sprintf("bright-%s", colorAlias)] = n + 8
|
||||
}
|
||||
|
||||
getColor := func(color, defaultEscapeCode string) string {
|
||||
color = strings.ToUpper(strings.ReplaceAll(color, "-", "_"))
|
||||
envVar := fmt.Sprintf("GINKGO_CLI_COLOR_%s", color)
|
||||
envVarColor := os.Getenv(envVar)
|
||||
if envVarColor == "" {
|
||||
return defaultEscapeCode
|
||||
}
|
||||
if colorCode, ok := colorAliases[envVarColor]; ok {
|
||||
return fmt.Sprintf("\x1b[38;5;%dm", colorCode)
|
||||
}
|
||||
colorCode, err := strconv.Atoi(envVarColor)
|
||||
if err != nil || colorCode < 0 || colorCode > 255 {
|
||||
return defaultEscapeCode
|
||||
}
|
||||
return fmt.Sprintf("\x1b[38;5;%dm", colorCode)
|
||||
}
|
||||
|
||||
f := Formatter{
|
||||
ColorMode: colorMode,
|
||||
colors: map[string]string{
|
||||
|
@ -57,18 +89,18 @@ func New(colorMode ColorMode) Formatter {
|
|||
"bold": "\x1b[1m",
|
||||
"underline": "\x1b[4m",
|
||||
|
||||
"red": "\x1b[38;5;9m",
|
||||
"orange": "\x1b[38;5;214m",
|
||||
"coral": "\x1b[38;5;204m",
|
||||
"magenta": "\x1b[38;5;13m",
|
||||
"green": "\x1b[38;5;10m",
|
||||
"dark-green": "\x1b[38;5;28m",
|
||||
"yellow": "\x1b[38;5;11m",
|
||||
"light-yellow": "\x1b[38;5;228m",
|
||||
"cyan": "\x1b[38;5;14m",
|
||||
"gray": "\x1b[38;5;243m",
|
||||
"light-gray": "\x1b[38;5;246m",
|
||||
"blue": "\x1b[38;5;12m",
|
||||
"red": getColor("red", "\x1b[38;5;9m"),
|
||||
"orange": getColor("orange", "\x1b[38;5;214m"),
|
||||
"coral": getColor("coral", "\x1b[38;5;204m"),
|
||||
"magenta": getColor("magenta", "\x1b[38;5;13m"),
|
||||
"green": getColor("green", "\x1b[38;5;10m"),
|
||||
"dark-green": getColor("dark-green", "\x1b[38;5;28m"),
|
||||
"yellow": getColor("yellow", "\x1b[38;5;11m"),
|
||||
"light-yellow": getColor("light-yellow", "\x1b[38;5;228m"),
|
||||
"cyan": getColor("cyan", "\x1b[38;5;14m"),
|
||||
"gray": getColor("gray", "\x1b[38;5;243m"),
|
||||
"light-gray": getColor("light-gray", "\x1b[38;5;246m"),
|
||||
"blue": getColor("blue", "\x1b[38;5;12m"),
|
||||
},
|
||||
}
|
||||
colors := []string{}
|
||||
|
@ -88,7 +120,10 @@ func (f Formatter) Fi(indentation uint, format string, args ...interface{}) stri
|
|||
}
|
||||
|
||||
func (f Formatter) Fiw(indentation uint, maxWidth uint, format string, args ...interface{}) string {
|
||||
out := fmt.Sprintf(f.style(format), args...)
|
||||
out := f.style(format)
|
||||
if len(args) > 0 {
|
||||
out = fmt.Sprintf(out, args...)
|
||||
}
|
||||
|
||||
if indentation == 0 && maxWidth == 0 {
|
||||
return out
|
||||
|
|
|
@ -2,6 +2,7 @@ package generators
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"text/template"
|
||||
|
@ -25,6 +26,9 @@ func BuildBootstrapCommand() command.Command {
|
|||
{Name: "template", KeyPath: "CustomTemplate",
|
||||
UsageArgument: "template-file",
|
||||
Usage: "If specified, generate will use the contents of the file passed as the bootstrap template"},
|
||||
{Name: "template-data", KeyPath: "CustomTemplateData",
|
||||
UsageArgument: "template-data-file",
|
||||
Usage: "If specified, generate will use the contents of the file passed as data to be rendered in the bootstrap template"},
|
||||
},
|
||||
&conf,
|
||||
types.GinkgoFlagSections{},
|
||||
|
@ -57,6 +61,7 @@ type bootstrapData struct {
|
|||
GomegaImport string
|
||||
GinkgoPackage string
|
||||
GomegaPackage string
|
||||
CustomData map[string]any
|
||||
}
|
||||
|
||||
func generateBootstrap(conf GeneratorsConfig) {
|
||||
|
@ -95,17 +100,32 @@ func generateBootstrap(conf GeneratorsConfig) {
|
|||
tpl, err := os.ReadFile(conf.CustomTemplate)
|
||||
command.AbortIfError("Failed to read custom bootstrap file:", err)
|
||||
templateText = string(tpl)
|
||||
if conf.CustomTemplateData != "" {
|
||||
var tplCustomDataMap map[string]any
|
||||
tplCustomData, err := os.ReadFile(conf.CustomTemplateData)
|
||||
command.AbortIfError("Failed to read custom boostrap data file:", err)
|
||||
if !json.Valid([]byte(tplCustomData)) {
|
||||
command.AbortWith("Invalid JSON object in custom data file.")
|
||||
}
|
||||
//create map from the custom template data
|
||||
json.Unmarshal(tplCustomData, &tplCustomDataMap)
|
||||
data.CustomData = tplCustomDataMap
|
||||
}
|
||||
} else if conf.Agouti {
|
||||
templateText = agoutiBootstrapText
|
||||
} else {
|
||||
templateText = bootstrapText
|
||||
}
|
||||
|
||||
bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Parse(templateText)
|
||||
//Setting the option to explicitly fail if template is rendered trying to access missing key
|
||||
bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Option("missingkey=error").Parse(templateText)
|
||||
command.AbortIfError("Failed to parse bootstrap template:", err)
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
bootstrapTemplate.Execute(buf, data)
|
||||
//Being explicit about failing sooner during template rendering
|
||||
//when accessing custom data rather than during the go fmt command
|
||||
err = bootstrapTemplate.Execute(buf, data)
|
||||
command.AbortIfError("Failed to render bootstrap template:", err)
|
||||
|
||||
buf.WriteTo(f)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ package generators
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -28,6 +29,9 @@ func BuildGenerateCommand() command.Command {
|
|||
{Name: "template", KeyPath: "CustomTemplate",
|
||||
UsageArgument: "template-file",
|
||||
Usage: "If specified, generate will use the contents of the file passed as the test file template"},
|
||||
{Name: "template-data", KeyPath: "CustomTemplateData",
|
||||
UsageArgument: "template-data-file",
|
||||
Usage: "If specified, generate will use the contents of the file passed as data to be rendered in the test file template"},
|
||||
},
|
||||
&conf,
|
||||
types.GinkgoFlagSections{},
|
||||
|
@ -64,6 +68,7 @@ type specData struct {
|
|||
GomegaImport string
|
||||
GinkgoPackage string
|
||||
GomegaPackage string
|
||||
CustomData map[string]any
|
||||
}
|
||||
|
||||
func generateTestFiles(conf GeneratorsConfig, args []string) {
|
||||
|
@ -122,16 +127,31 @@ func generateTestFileForSubject(subject string, conf GeneratorsConfig) {
|
|||
tpl, err := os.ReadFile(conf.CustomTemplate)
|
||||
command.AbortIfError("Failed to read custom template file:", err)
|
||||
templateText = string(tpl)
|
||||
if conf.CustomTemplateData != "" {
|
||||
var tplCustomDataMap map[string]any
|
||||
tplCustomData, err := os.ReadFile(conf.CustomTemplateData)
|
||||
command.AbortIfError("Failed to read custom template data file:", err)
|
||||
if !json.Valid([]byte(tplCustomData)) {
|
||||
command.AbortWith("Invalid JSON object in custom data file.")
|
||||
}
|
||||
//create map from the custom template data
|
||||
json.Unmarshal(tplCustomData, &tplCustomDataMap)
|
||||
data.CustomData = tplCustomDataMap
|
||||
}
|
||||
} else if conf.Agouti {
|
||||
templateText = agoutiSpecText
|
||||
} else {
|
||||
templateText = specText
|
||||
}
|
||||
|
||||
specTemplate, err := template.New("spec").Funcs(sprig.TxtFuncMap()).Parse(templateText)
|
||||
//Setting the option to explicitly fail if template is rendered trying to access missing key
|
||||
specTemplate, err := template.New("spec").Funcs(sprig.TxtFuncMap()).Option("missingkey=error").Parse(templateText)
|
||||
command.AbortIfError("Failed to read parse test template:", err)
|
||||
|
||||
specTemplate.Execute(f, data)
|
||||
//Being explicit about failing sooner during template rendering
|
||||
//when accessing custom data rather than during the go fmt command
|
||||
err = specTemplate.Execute(f, data)
|
||||
command.AbortIfError("Failed to render bootstrap template:", err)
|
||||
internal.GoFmt(targetFile)
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
type GeneratorsConfig struct {
|
||||
Agouti, NoDot, Internal bool
|
||||
CustomTemplate string
|
||||
CustomTemplateData string
|
||||
}
|
||||
|
||||
func getPackageAndFormattedName() (string, string, string) {
|
||||
|
|
|
@ -25,7 +25,16 @@ func CompileSuite(suite TestSuite, goFlagsConfig types.GoFlagsConfig) TestSuite
|
|||
return suite
|
||||
}
|
||||
|
||||
args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, path, "./")
|
||||
ginkgoInvocationPath, _ := os.Getwd()
|
||||
ginkgoInvocationPath, _ = filepath.Abs(ginkgoInvocationPath)
|
||||
packagePath := suite.AbsPath()
|
||||
pathToInvocationPath, err := filepath.Rel(packagePath, ginkgoInvocationPath)
|
||||
if err != nil {
|
||||
suite.State = TestSuiteStateFailedToCompile
|
||||
suite.CompilationError = fmt.Errorf("Failed to get relative path from package to the current working directory:\n%s", err.Error())
|
||||
return suite
|
||||
}
|
||||
args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, path, "./", pathToInvocationPath)
|
||||
if err != nil {
|
||||
suite.State = TestSuiteStateFailedToCompile
|
||||
suite.CompilationError = fmt.Errorf("Failed to generate go test compile flags:\n%s", err.Error())
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
@ -63,6 +64,12 @@ func checkForNoTestsWarning(buf *bytes.Buffer) bool {
|
|||
}
|
||||
|
||||
func runGoTest(suite TestSuite, cliConfig types.CLIConfig, goFlagsConfig types.GoFlagsConfig) TestSuite {
|
||||
// As we run the go test from the suite directory, make sure the cover profile is absolute
|
||||
// and placed into the expected output directory when one is configured.
|
||||
if goFlagsConfig.Cover && !filepath.IsAbs(goFlagsConfig.CoverProfile) {
|
||||
goFlagsConfig.CoverProfile = AbsPathForGeneratedAsset(goFlagsConfig.CoverProfile, suite, cliConfig, 0)
|
||||
}
|
||||
|
||||
args, err := types.GenerateGoTestRunArgs(goFlagsConfig)
|
||||
command.AbortIfError("Failed to generate test run arguments", err)
|
||||
cmd, buf := buildAndStartCommand(suite, args, true)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package outline
|
||||
|
||||
import (
|
||||
"github.com/onsi/ginkgo/v2/types"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"strconv"
|
||||
|
@ -28,6 +29,7 @@ type ginkgoMetadata struct {
|
|||
Spec bool `json:"spec"`
|
||||
Focused bool `json:"focused"`
|
||||
Pending bool `json:"pending"`
|
||||
Labels []string `json:"labels"`
|
||||
}
|
||||
|
||||
// ginkgoNode is used to construct the outline as a tree
|
||||
|
@ -145,27 +147,35 @@ func ginkgoNodeFromCallExpr(fset *token.FileSet, ce *ast.CallExpr, ginkgoPackage
|
|||
case "It", "Specify", "Entry":
|
||||
n.Spec = true
|
||||
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
|
||||
n.Labels = labelFromCallExpr(ce)
|
||||
n.Pending = pendingFromCallExpr(ce)
|
||||
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
|
||||
case "FIt", "FSpecify", "FEntry":
|
||||
n.Spec = true
|
||||
n.Focused = true
|
||||
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
|
||||
n.Labels = labelFromCallExpr(ce)
|
||||
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
|
||||
case "PIt", "PSpecify", "XIt", "XSpecify", "PEntry", "XEntry":
|
||||
n.Spec = true
|
||||
n.Pending = true
|
||||
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
|
||||
n.Labels = labelFromCallExpr(ce)
|
||||
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
|
||||
case "Context", "Describe", "When", "DescribeTable":
|
||||
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
|
||||
n.Labels = labelFromCallExpr(ce)
|
||||
n.Pending = pendingFromCallExpr(ce)
|
||||
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
|
||||
case "FContext", "FDescribe", "FWhen", "FDescribeTable":
|
||||
n.Focused = true
|
||||
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
|
||||
n.Labels = labelFromCallExpr(ce)
|
||||
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
|
||||
case "PContext", "PDescribe", "PWhen", "XContext", "XDescribe", "XWhen", "PDescribeTable", "XDescribeTable":
|
||||
n.Pending = true
|
||||
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
|
||||
n.Labels = labelFromCallExpr(ce)
|
||||
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
|
||||
case "By":
|
||||
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
|
||||
|
@ -216,3 +226,77 @@ func textFromCallExpr(ce *ast.CallExpr) (string, bool) {
|
|||
return text.Value, true
|
||||
}
|
||||
}
|
||||
|
||||
func labelFromCallExpr(ce *ast.CallExpr) []string {
|
||||
|
||||
labels := []string{}
|
||||
if len(ce.Args) < 2 {
|
||||
return labels
|
||||
}
|
||||
|
||||
for _, arg := range ce.Args[1:] {
|
||||
switch expr := arg.(type) {
|
||||
case *ast.CallExpr:
|
||||
id, ok := expr.Fun.(*ast.Ident)
|
||||
if !ok {
|
||||
// to skip over cases where the expr.Fun. is actually *ast.SelectorExpr
|
||||
continue
|
||||
}
|
||||
if id.Name == "Label" {
|
||||
ls := extractLabels(expr)
|
||||
for _, label := range ls {
|
||||
labels = append(labels, label)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return labels
|
||||
}
|
||||
|
||||
func extractLabels(expr *ast.CallExpr) []string {
|
||||
out := []string{}
|
||||
for _, arg := range expr.Args {
|
||||
switch expr := arg.(type) {
|
||||
case *ast.BasicLit:
|
||||
if expr.Kind == token.STRING {
|
||||
unquoted, err := strconv.Unquote(expr.Value)
|
||||
if err != nil {
|
||||
unquoted = expr.Value
|
||||
}
|
||||
validated, err := types.ValidateAndCleanupLabel(unquoted, types.CodeLocation{})
|
||||
if err == nil {
|
||||
out = append(out, validated)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func pendingFromCallExpr(ce *ast.CallExpr) bool {
|
||||
|
||||
pending := false
|
||||
if len(ce.Args) < 2 {
|
||||
return pending
|
||||
}
|
||||
|
||||
for _, arg := range ce.Args[1:] {
|
||||
switch expr := arg.(type) {
|
||||
case *ast.CallExpr:
|
||||
id, ok := expr.Fun.(*ast.Ident)
|
||||
if !ok {
|
||||
// to skip over cases where the expr.Fun. is actually *ast.SelectorExpr
|
||||
continue
|
||||
}
|
||||
if id.Name == "Pending" {
|
||||
pending = true
|
||||
}
|
||||
case *ast.Ident:
|
||||
if expr.Name == "Pending" {
|
||||
pending = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return pending
|
||||
}
|
||||
|
|
|
@ -85,12 +85,19 @@ func (o *outline) String() string {
|
|||
// one 'width' of spaces for every level of nesting.
|
||||
func (o *outline) StringIndent(width int) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("Name,Text,Start,End,Spec,Focused,Pending\n")
|
||||
b.WriteString("Name,Text,Start,End,Spec,Focused,Pending,Labels\n")
|
||||
|
||||
currentIndent := 0
|
||||
pre := func(n *ginkgoNode) {
|
||||
b.WriteString(fmt.Sprintf("%*s", currentIndent, ""))
|
||||
b.WriteString(fmt.Sprintf("%s,%s,%d,%d,%t,%t,%t\n", n.Name, n.Text, n.Start, n.End, n.Spec, n.Focused, n.Pending))
|
||||
var labels string
|
||||
if len(n.Labels) == 1 {
|
||||
labels = n.Labels[0]
|
||||
} else {
|
||||
labels = strings.Join(n.Labels, ", ")
|
||||
}
|
||||
//enclosing labels in a double quoted comma separate listed so that when inmported into a CSV app the Labels column has comma separate strings
|
||||
b.WriteString(fmt.Sprintf("%s,%s,%d,%d,%t,%t,%t,\"%s\"\n", n.Name, n.Text, n.Start, n.End, n.Spec, n.Focused, n.Pending, labels))
|
||||
currentIndent += width
|
||||
}
|
||||
post := func(n *ginkgoNode) {
|
||||
|
|
23
vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/interrupt_handler.go
generated
vendored
23
vendor/github.com/onsi/ginkgo/v2/internal/interrupt_handler/interrupt_handler.go
generated
vendored
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/onsi/ginkgo/v2/internal/parallel_support"
|
||||
)
|
||||
|
||||
const ABORT_POLLING_INTERVAL = 500 * time.Millisecond
|
||||
var ABORT_POLLING_INTERVAL = 500 * time.Millisecond
|
||||
|
||||
type InterruptCause uint
|
||||
|
||||
|
@ -69,6 +69,7 @@ type InterruptHandler struct {
|
|||
client parallel_support.Client
|
||||
stop chan interface{}
|
||||
signals []os.Signal
|
||||
requestAbortCheck chan interface{}
|
||||
}
|
||||
|
||||
func NewInterruptHandler(client parallel_support.Client, signals ...os.Signal) *InterruptHandler {
|
||||
|
@ -79,6 +80,7 @@ func NewInterruptHandler(client parallel_support.Client, signals ...os.Signal) *
|
|||
c: make(chan interface{}),
|
||||
lock: &sync.Mutex{},
|
||||
stop: make(chan interface{}),
|
||||
requestAbortCheck: make(chan interface{}),
|
||||
client: client,
|
||||
signals: signals,
|
||||
}
|
||||
|
@ -109,6 +111,12 @@ func (handler *InterruptHandler) registerForInterrupts() {
|
|||
pollTicker.Stop()
|
||||
return
|
||||
}
|
||||
case <-handler.requestAbortCheck:
|
||||
if handler.client.ShouldAbort() {
|
||||
close(abortChannel)
|
||||
pollTicker.Stop()
|
||||
return
|
||||
}
|
||||
case <-handler.stop:
|
||||
pollTicker.Stop()
|
||||
return
|
||||
|
@ -152,11 +160,18 @@ func (handler *InterruptHandler) registerForInterrupts() {
|
|||
|
||||
func (handler *InterruptHandler) Status() InterruptStatus {
|
||||
handler.lock.Lock()
|
||||
defer handler.lock.Unlock()
|
||||
|
||||
return InterruptStatus{
|
||||
status := InterruptStatus{
|
||||
Level: handler.level,
|
||||
Channel: handler.c,
|
||||
Cause: handler.cause,
|
||||
}
|
||||
handler.lock.Unlock()
|
||||
|
||||
if handler.client != nil && handler.client.ShouldAbort() && !status.Interrupted() {
|
||||
close(handler.requestAbortCheck)
|
||||
<-status.Channel
|
||||
return handler.Status()
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
|
|
@ -42,6 +42,8 @@ type Client interface {
|
|||
PostSuiteWillBegin(report types.Report) error
|
||||
PostDidRun(report types.SpecReport) error
|
||||
PostSuiteDidEnd(report types.Report) error
|
||||
PostReportBeforeSuiteCompleted(state types.SpecState) error
|
||||
BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error)
|
||||
PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error
|
||||
BlockUntilSynchronizedBeforeSuiteData() (types.SpecState, []byte, error)
|
||||
BlockUntilNonprimaryProcsHaveFinished() error
|
||||
|
|
|
@ -98,6 +98,19 @@ func (client *httpClient) PostEmitProgressReport(report types.ProgressReport) er
|
|||
return client.post("/progress-report", report)
|
||||
}
|
||||
|
||||
func (client *httpClient) PostReportBeforeSuiteCompleted(state types.SpecState) error {
|
||||
return client.post("/report-before-suite-completed", state)
|
||||
}
|
||||
|
||||
func (client *httpClient) BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error) {
|
||||
var state types.SpecState
|
||||
err := client.poll("/report-before-suite-state", &state)
|
||||
if err == ErrorGone {
|
||||
return types.SpecStateFailed, nil
|
||||
}
|
||||
return state, err
|
||||
}
|
||||
|
||||
func (client *httpClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error {
|
||||
beforeSuiteState := BeforeSuiteState{
|
||||
State: state,
|
||||
|
|
|
@ -26,7 +26,7 @@ type httpServer struct {
|
|||
handler *ServerHandler
|
||||
}
|
||||
|
||||
//Create a new server, automatically selecting a port
|
||||
// Create a new server, automatically selecting a port
|
||||
func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer, error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
|
@ -38,7 +38,7 @@ func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer,
|
|||
}, nil
|
||||
}
|
||||
|
||||
//Start the server. You don't need to `go s.Start()`, just `s.Start()`
|
||||
// Start the server. You don't need to `go s.Start()`, just `s.Start()`
|
||||
func (server *httpServer) Start() {
|
||||
httpServer := &http.Server{}
|
||||
mux := http.NewServeMux()
|
||||
|
@ -52,6 +52,8 @@ func (server *httpServer) Start() {
|
|||
mux.HandleFunc("/progress-report", server.emitProgressReport)
|
||||
|
||||
//synchronization endpoints
|
||||
mux.HandleFunc("/report-before-suite-completed", server.handleReportBeforeSuiteCompleted)
|
||||
mux.HandleFunc("/report-before-suite-state", server.handleReportBeforeSuiteState)
|
||||
mux.HandleFunc("/before-suite-completed", server.handleBeforeSuiteCompleted)
|
||||
mux.HandleFunc("/before-suite-state", server.handleBeforeSuiteState)
|
||||
mux.HandleFunc("/have-nonprimary-procs-finished", server.handleHaveNonprimaryProcsFinished)
|
||||
|
@ -63,12 +65,12 @@ func (server *httpServer) Start() {
|
|||
go httpServer.Serve(server.listener)
|
||||
}
|
||||
|
||||
//Stop the server
|
||||
// Stop the server
|
||||
func (server *httpServer) Close() {
|
||||
server.listener.Close()
|
||||
}
|
||||
|
||||
//The address the server can be reached it. Pass this into the `ForwardingReporter`.
|
||||
// The address the server can be reached it. Pass this into the `ForwardingReporter`.
|
||||
func (server *httpServer) Address() string {
|
||||
return "http://" + server.listener.Addr().String()
|
||||
}
|
||||
|
@ -93,7 +95,7 @@ func (server *httpServer) RegisterAlive(node int, alive func() bool) {
|
|||
// Streaming Endpoints
|
||||
//
|
||||
|
||||
//The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters`
|
||||
// The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters`
|
||||
func (server *httpServer) decode(writer http.ResponseWriter, request *http.Request, object interface{}) bool {
|
||||
defer request.Body.Close()
|
||||
if json.NewDecoder(request.Body).Decode(object) != nil {
|
||||
|
@ -164,6 +166,23 @@ func (server *httpServer) emitProgressReport(writer http.ResponseWriter, request
|
|||
server.handleError(server.handler.EmitProgressReport(report, voidReceiver), writer)
|
||||
}
|
||||
|
||||
func (server *httpServer) handleReportBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) {
|
||||
var state types.SpecState
|
||||
if !server.decode(writer, request, &state) {
|
||||
return
|
||||
}
|
||||
|
||||
server.handleError(server.handler.ReportBeforeSuiteCompleted(state, voidReceiver), writer)
|
||||
}
|
||||
|
||||
func (server *httpServer) handleReportBeforeSuiteState(writer http.ResponseWriter, request *http.Request) {
|
||||
var state types.SpecState
|
||||
if server.handleError(server.handler.ReportBeforeSuiteState(voidSender, &state), writer) {
|
||||
return
|
||||
}
|
||||
json.NewEncoder(writer).Encode(state)
|
||||
}
|
||||
|
||||
func (server *httpServer) handleBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) {
|
||||
var beforeSuiteState BeforeSuiteState
|
||||
if !server.decode(writer, request, &beforeSuiteState) {
|
||||
|
|
|
@ -76,6 +76,19 @@ func (client *rpcClient) PostEmitProgressReport(report types.ProgressReport) err
|
|||
return client.client.Call("Server.EmitProgressReport", report, voidReceiver)
|
||||
}
|
||||
|
||||
func (client *rpcClient) PostReportBeforeSuiteCompleted(state types.SpecState) error {
|
||||
return client.client.Call("Server.ReportBeforeSuiteCompleted", state, voidReceiver)
|
||||
}
|
||||
|
||||
func (client *rpcClient) BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error) {
|
||||
var state types.SpecState
|
||||
err := client.poll("Server.ReportBeforeSuiteState", &state)
|
||||
if err == ErrorGone {
|
||||
return types.SpecStateFailed, nil
|
||||
}
|
||||
return state, err
|
||||
}
|
||||
|
||||
func (client *rpcClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error {
|
||||
beforeSuiteState := BeforeSuiteState{
|
||||
State: state,
|
||||
|
|
|
@ -24,6 +24,7 @@ type ServerHandler struct {
|
|||
alives []func() bool
|
||||
lock *sync.Mutex
|
||||
beforeSuiteState BeforeSuiteState
|
||||
reportBeforeSuiteState types.SpecState
|
||||
parallelTotal int
|
||||
counter int
|
||||
counterLock *sync.Mutex
|
||||
|
@ -42,6 +43,7 @@ func newServerHandler(parallelTotal int, reporter reporters.Reporter) *ServerHan
|
|||
counterLock: &sync.Mutex{},
|
||||
alives: make([]func() bool, parallelTotal),
|
||||
beforeSuiteState: BeforeSuiteState{Data: nil, State: types.SpecStateInvalid},
|
||||
|
||||
parallelTotal: parallelTotal,
|
||||
outputDestination: os.Stdout,
|
||||
done: make(chan interface{}),
|
||||
|
@ -140,6 +142,29 @@ func (handler *ServerHandler) haveNonprimaryProcsFinished() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (handler *ServerHandler) ReportBeforeSuiteCompleted(reportBeforeSuiteState types.SpecState, _ *Void) error {
|
||||
handler.lock.Lock()
|
||||
defer handler.lock.Unlock()
|
||||
handler.reportBeforeSuiteState = reportBeforeSuiteState
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler *ServerHandler) ReportBeforeSuiteState(_ Void, reportBeforeSuiteState *types.SpecState) error {
|
||||
proc1IsAlive := handler.procIsAlive(1)
|
||||
handler.lock.Lock()
|
||||
defer handler.lock.Unlock()
|
||||
if handler.reportBeforeSuiteState == types.SpecStateInvalid {
|
||||
if proc1IsAlive {
|
||||
return ErrorEarly
|
||||
} else {
|
||||
return ErrorGone
|
||||
}
|
||||
}
|
||||
*reportBeforeSuiteState = handler.reportBeforeSuiteState
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler *ServerHandler) BeforeSuiteCompleted(beforeSuiteState BeforeSuiteState, _ *Void) error {
|
||||
handler.lock.Lock()
|
||||
defer handler.lock.Unlock()
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"io"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/onsi/ginkgo/v2/formatter"
|
||||
|
@ -23,13 +24,16 @@ type DefaultReporter struct {
|
|||
writer io.Writer
|
||||
|
||||
// managing the emission stream
|
||||
lastChar string
|
||||
lastCharWasNewline bool
|
||||
lastEmissionWasDelimiter bool
|
||||
|
||||
// rendering
|
||||
specDenoter string
|
||||
retryDenoter string
|
||||
formatter formatter.Formatter
|
||||
|
||||
runningInParallel bool
|
||||
lock *sync.Mutex
|
||||
}
|
||||
|
||||
func NewDefaultReporterUnderTest(conf types.ReporterConfig, writer io.Writer) *DefaultReporter {
|
||||
|
@ -44,12 +48,13 @@ func NewDefaultReporter(conf types.ReporterConfig, writer io.Writer) *DefaultRep
|
|||
conf: conf,
|
||||
writer: writer,
|
||||
|
||||
lastChar: "\n",
|
||||
lastCharWasNewline: true,
|
||||
lastEmissionWasDelimiter: false,
|
||||
|
||||
specDenoter: "•",
|
||||
retryDenoter: "↺",
|
||||
formatter: formatter.NewWithNoColorBool(conf.NoColor),
|
||||
lock: &sync.Mutex{},
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
reporter.specDenoter = "+"
|
||||
|
@ -97,230 +102,10 @@ func (r *DefaultReporter) SuiteWillBegin(report types.Report) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) WillRun(report types.SpecReport) {
|
||||
if r.conf.Verbosity().LT(types.VerbosityLevelVerbose) || report.State.Is(types.SpecStatePending|types.SpecStateSkipped) {
|
||||
return
|
||||
}
|
||||
|
||||
r.emitDelimiter()
|
||||
indentation := uint(0)
|
||||
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
|
||||
r.emitBlock(r.f("{{bold}}[%s] %s{{/}}", report.LeafNodeType.String(), report.LeafNodeText))
|
||||
} else {
|
||||
if len(report.ContainerHierarchyTexts) > 0 {
|
||||
r.emitBlock(r.cycleJoin(report.ContainerHierarchyTexts, " "))
|
||||
indentation = 1
|
||||
}
|
||||
line := r.fi(indentation, "{{bold}}%s{{/}}", report.LeafNodeText)
|
||||
labels := report.Labels()
|
||||
if len(labels) > 0 {
|
||||
line += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels, ", "))
|
||||
}
|
||||
r.emitBlock(line)
|
||||
}
|
||||
r.emitBlock(r.fi(indentation, "{{gray}}%s{{/}}", report.LeafNodeLocation))
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) DidRun(report types.SpecReport) {
|
||||
v := r.conf.Verbosity()
|
||||
var header, highlightColor string
|
||||
includeRuntime, emitGinkgoWriterOutput, stream, denoter := true, true, false, r.specDenoter
|
||||
succinctLocationBlock := v.Is(types.VerbosityLevelSuccinct)
|
||||
|
||||
hasGW := report.CapturedGinkgoWriterOutput != ""
|
||||
hasStd := report.CapturedStdOutErr != ""
|
||||
hasEmittableReports := report.ReportEntries.HasVisibility(types.ReportEntryVisibilityAlways) || (report.ReportEntries.HasVisibility(types.ReportEntryVisibilityFailureOrVerbose) && (!report.Failure.IsZero() || v.GTE(types.VerbosityLevelVerbose)))
|
||||
|
||||
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
|
||||
denoter = fmt.Sprintf("[%s]", report.LeafNodeType)
|
||||
}
|
||||
|
||||
highlightColor = r.highlightColorForState(report.State)
|
||||
|
||||
switch report.State {
|
||||
case types.SpecStatePassed:
|
||||
succinctLocationBlock = v.LT(types.VerbosityLevelVerbose)
|
||||
emitGinkgoWriterOutput = (r.conf.AlwaysEmitGinkgoWriter || v.GTE(types.VerbosityLevelVerbose)) && hasGW
|
||||
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
|
||||
if v.GTE(types.VerbosityLevelVerbose) || hasStd || hasEmittableReports {
|
||||
header = fmt.Sprintf("%s PASSED", denoter)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
header, stream = denoter, true
|
||||
if report.NumAttempts > 1 && report.MaxFlakeAttempts > 1 {
|
||||
header, stream = fmt.Sprintf("%s [FLAKEY TEST - TOOK %d ATTEMPTS TO PASS]", r.retryDenoter, report.NumAttempts), false
|
||||
}
|
||||
if report.RunTime > r.conf.SlowSpecThreshold {
|
||||
header, stream = fmt.Sprintf("%s [SLOW TEST]", header), false
|
||||
}
|
||||
}
|
||||
if hasStd || emitGinkgoWriterOutput || hasEmittableReports {
|
||||
stream = false
|
||||
}
|
||||
case types.SpecStatePending:
|
||||
includeRuntime, emitGinkgoWriterOutput = false, false
|
||||
if v.Is(types.VerbosityLevelSuccinct) {
|
||||
header, stream = "P", true
|
||||
} else {
|
||||
header, succinctLocationBlock = "P [PENDING]", v.LT(types.VerbosityLevelVeryVerbose)
|
||||
}
|
||||
case types.SpecStateSkipped:
|
||||
if report.Failure.Message != "" || v.Is(types.VerbosityLevelVeryVerbose) {
|
||||
header = "S [SKIPPED]"
|
||||
} else {
|
||||
header, stream = "S", true
|
||||
}
|
||||
case types.SpecStateFailed:
|
||||
header = fmt.Sprintf("%s [FAILED]", denoter)
|
||||
case types.SpecStateTimedout:
|
||||
header = fmt.Sprintf("%s [TIMEDOUT]", denoter)
|
||||
case types.SpecStatePanicked:
|
||||
header = fmt.Sprintf("%s! [PANICKED]", denoter)
|
||||
case types.SpecStateInterrupted:
|
||||
header = fmt.Sprintf("%s! [INTERRUPTED]", denoter)
|
||||
case types.SpecStateAborted:
|
||||
header = fmt.Sprintf("%s! [ABORTED]", denoter)
|
||||
}
|
||||
|
||||
if report.State.Is(types.SpecStateFailureStates) && report.MaxMustPassRepeatedly > 1 {
|
||||
header, stream = fmt.Sprintf("%s DURING REPETITION #%d", header, report.NumAttempts), false
|
||||
}
|
||||
// Emit stream and return
|
||||
if stream {
|
||||
r.emit(r.f(highlightColor + header + "{{/}}"))
|
||||
return
|
||||
}
|
||||
|
||||
// Emit header
|
||||
r.emitDelimiter()
|
||||
if includeRuntime {
|
||||
header = r.f("%s [%.3f seconds]", header, report.RunTime.Seconds())
|
||||
}
|
||||
r.emitBlock(r.f(highlightColor + header + "{{/}}"))
|
||||
|
||||
// Emit Code Location Block
|
||||
r.emitBlock(r.codeLocationBlock(report, highlightColor, succinctLocationBlock, false))
|
||||
|
||||
//Emit Stdout/Stderr Output
|
||||
if hasStd {
|
||||
r.emitBlock("\n")
|
||||
r.emitBlock(r.fi(1, "{{gray}}Begin Captured StdOut/StdErr Output >>{{/}}"))
|
||||
r.emitBlock(r.fi(2, "%s", report.CapturedStdOutErr))
|
||||
r.emitBlock(r.fi(1, "{{gray}}<< End Captured StdOut/StdErr Output{{/}}"))
|
||||
}
|
||||
|
||||
//Emit Captured GinkgoWriter Output
|
||||
if emitGinkgoWriterOutput && hasGW {
|
||||
r.emitBlock("\n")
|
||||
r.emitGinkgoWriterOutput(1, report.CapturedGinkgoWriterOutput, 0)
|
||||
}
|
||||
|
||||
if hasEmittableReports {
|
||||
r.emitBlock("\n")
|
||||
r.emitBlock(r.fi(1, "{{gray}}Begin Report Entries >>{{/}}"))
|
||||
reportEntries := report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways)
|
||||
if !report.Failure.IsZero() || v.GTE(types.VerbosityLevelVerbose) {
|
||||
reportEntries = report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways, types.ReportEntryVisibilityFailureOrVerbose)
|
||||
}
|
||||
for _, entry := range reportEntries {
|
||||
r.emitBlock(r.fi(2, "{{bold}}"+entry.Name+"{{gray}} - %s @ %s{{/}}", entry.Location, entry.Time.Format(types.GINKGO_TIME_FORMAT)))
|
||||
if representation := entry.StringRepresentation(); representation != "" {
|
||||
r.emitBlock(r.fi(3, representation))
|
||||
}
|
||||
}
|
||||
r.emitBlock(r.fi(1, "{{gray}}<< End Report Entries{{/}}"))
|
||||
}
|
||||
|
||||
// Emit Failure Message
|
||||
if !report.Failure.IsZero() {
|
||||
r.emitBlock("\n")
|
||||
r.EmitFailure(1, report.State, report.Failure, false)
|
||||
}
|
||||
|
||||
if len(report.AdditionalFailures) > 0 {
|
||||
if v.GTE(types.VerbosityLevelVerbose) {
|
||||
r.emitBlock("\n")
|
||||
r.emitBlock(r.fi(1, "{{bold}}There were additional failures detected after the initial failure:{{/}}"))
|
||||
for i, additionalFailure := range report.AdditionalFailures {
|
||||
r.EmitFailure(2, additionalFailure.State, additionalFailure.Failure, true)
|
||||
if i < len(report.AdditionalFailures)-1 {
|
||||
r.emitBlock(r.fi(2, "{{gray}}%s{{/}}", strings.Repeat("-", 10)))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
r.emitBlock("\n")
|
||||
r.emitBlock(r.fi(1, "{{bold}}There were additional failures detected after the initial failure. Here's a summary - for full details run Ginkgo in verbose mode:{{/}}"))
|
||||
for _, additionalFailure := range report.AdditionalFailures {
|
||||
r.emitBlock(r.fi(2, r.highlightColorForState(additionalFailure.State)+"[%s]{{/}} in [%s] at %s",
|
||||
r.humanReadableState(additionalFailure.State),
|
||||
additionalFailure.Failure.FailureNodeType,
|
||||
additionalFailure.Failure.Location,
|
||||
))
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
r.emitDelimiter()
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) highlightColorForState(state types.SpecState) string {
|
||||
switch state {
|
||||
case types.SpecStatePassed:
|
||||
return "{{green}}"
|
||||
case types.SpecStatePending:
|
||||
return "{{yellow}}"
|
||||
case types.SpecStateSkipped:
|
||||
return "{{cyan}}"
|
||||
case types.SpecStateFailed:
|
||||
return "{{red}}"
|
||||
case types.SpecStateTimedout:
|
||||
return "{{orange}}"
|
||||
case types.SpecStatePanicked:
|
||||
return "{{magenta}}"
|
||||
case types.SpecStateInterrupted:
|
||||
return "{{orange}}"
|
||||
case types.SpecStateAborted:
|
||||
return "{{coral}}"
|
||||
default:
|
||||
return "{{gray}}"
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) humanReadableState(state types.SpecState) string {
|
||||
return strings.ToUpper(state.String())
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) EmitFailure(indent uint, state types.SpecState, failure types.Failure, includeState bool) {
|
||||
highlightColor := r.highlightColorForState(state)
|
||||
if includeState {
|
||||
r.emitBlock(r.fi(indent, highlightColor+"[%s]{{/}}", r.humanReadableState(state)))
|
||||
}
|
||||
r.emitBlock(r.fi(indent, highlightColor+"%s{{/}}", failure.Message))
|
||||
r.emitBlock(r.fi(indent, highlightColor+"In {{bold}}[%s]{{/}}"+highlightColor+" at: {{bold}}%s{{/}}\n", failure.FailureNodeType, failure.Location))
|
||||
if failure.ForwardedPanic != "" {
|
||||
r.emitBlock("\n")
|
||||
r.emitBlock(r.fi(indent, highlightColor+"%s{{/}}", failure.ForwardedPanic))
|
||||
}
|
||||
|
||||
if r.conf.FullTrace || failure.ForwardedPanic != "" {
|
||||
r.emitBlock("\n")
|
||||
r.emitBlock(r.fi(indent, highlightColor+"Full Stack Trace{{/}}"))
|
||||
r.emitBlock(r.fi(indent+1, "%s", failure.Location.FullStackTrace))
|
||||
}
|
||||
|
||||
if !failure.ProgressReport.IsZero() {
|
||||
r.emitBlock("\n")
|
||||
r.emitProgressReport(indent, false, failure.ProgressReport)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) SuiteDidEnd(report types.Report) {
|
||||
failures := report.SpecReports.WithState(types.SpecStateFailureStates)
|
||||
if len(failures) > 0 {
|
||||
r.emitBlock("\n\n")
|
||||
r.emitBlock("\n")
|
||||
if len(failures) > 1 {
|
||||
r.emitBlock(r.f("{{red}}{{bold}}Summarizing %d Failures:{{/}}", len(failures)))
|
||||
} else {
|
||||
|
@ -338,7 +123,7 @@ func (r *DefaultReporter) SuiteDidEnd(report types.Report) {
|
|||
case types.SpecStateInterrupted:
|
||||
highlightColor, heading = "{{orange}}", "[INTERRUPTED]"
|
||||
}
|
||||
locationBlock := r.codeLocationBlock(specReport, highlightColor, true, true)
|
||||
locationBlock := r.codeLocationBlock(specReport, highlightColor, false, true)
|
||||
r.emitBlock(r.fi(1, highlightColor+"%s{{/}} %s", heading, locationBlock))
|
||||
}
|
||||
}
|
||||
|
@ -387,14 +172,271 @@ func (r *DefaultReporter) SuiteDidEnd(report types.Report) {
|
|||
}
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) WillRun(report types.SpecReport) {
|
||||
v := r.conf.Verbosity()
|
||||
if v.LT(types.VerbosityLevelVerbose) || report.State.Is(types.SpecStatePending|types.SpecStateSkipped) || report.RunningInParallel {
|
||||
return
|
||||
}
|
||||
|
||||
r.emitDelimiter(0)
|
||||
r.emitBlock(r.f(r.codeLocationBlock(report, "{{/}}", v.Is(types.VerbosityLevelVeryVerbose), false)))
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) DidRun(report types.SpecReport) {
|
||||
v := r.conf.Verbosity()
|
||||
inParallel := report.RunningInParallel
|
||||
|
||||
header := r.specDenoter
|
||||
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
|
||||
header = fmt.Sprintf("[%s]", report.LeafNodeType)
|
||||
}
|
||||
highlightColor := r.highlightColorForState(report.State)
|
||||
|
||||
// have we already been streaming the timeline?
|
||||
timelineHasBeenStreaming := v.GTE(types.VerbosityLevelVerbose) && !inParallel
|
||||
|
||||
// should we show the timeline?
|
||||
var timeline types.Timeline
|
||||
showTimeline := !timelineHasBeenStreaming && (v.GTE(types.VerbosityLevelVerbose) || report.Failed())
|
||||
if showTimeline {
|
||||
timeline = report.Timeline().WithoutHiddenReportEntries()
|
||||
keepVeryVerboseSpecEvents := v.Is(types.VerbosityLevelVeryVerbose) ||
|
||||
(v.Is(types.VerbosityLevelVerbose) && r.conf.ShowNodeEvents) ||
|
||||
(report.Failed() && r.conf.ShowNodeEvents)
|
||||
if !keepVeryVerboseSpecEvents {
|
||||
timeline = timeline.WithoutVeryVerboseSpecEvents()
|
||||
}
|
||||
if len(timeline) == 0 && report.CapturedGinkgoWriterOutput == "" {
|
||||
// the timeline is completely empty - don't show it
|
||||
showTimeline = false
|
||||
}
|
||||
if v.LT(types.VerbosityLevelVeryVerbose) && report.CapturedGinkgoWriterOutput == "" && len(timeline) > 0 {
|
||||
//if we aren't -vv and the timeline only has a single failure, don't show it as it will appear at the end of the report
|
||||
failure, isFailure := timeline[0].(types.Failure)
|
||||
if isFailure && (len(timeline) == 1 || (len(timeline) == 2 && failure.AdditionalFailure != nil)) {
|
||||
showTimeline = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// should we have a separate section for always-visible reports?
|
||||
showSeparateVisibilityAlwaysReportsSection := !timelineHasBeenStreaming && !showTimeline && report.ReportEntries.HasVisibility(types.ReportEntryVisibilityAlways)
|
||||
|
||||
// should we have a separate section for captured stdout/stderr
|
||||
showSeparateStdSection := inParallel && (report.CapturedStdOutErr != "")
|
||||
|
||||
// given all that - do we have any actual content to show? or are we a single denoter in a stream?
|
||||
reportHasContent := v.Is(types.VerbosityLevelVeryVerbose) || showTimeline || showSeparateVisibilityAlwaysReportsSection || showSeparateStdSection || report.Failed() || (v.Is(types.VerbosityLevelVerbose) && !report.State.Is(types.SpecStateSkipped))
|
||||
|
||||
// should we show a runtime?
|
||||
includeRuntime := !report.State.Is(types.SpecStateSkipped|types.SpecStatePending) || (report.State.Is(types.SpecStateSkipped) && report.Failure.Message != "")
|
||||
|
||||
// should we show the codelocation block?
|
||||
showCodeLocation := !timelineHasBeenStreaming || !report.State.Is(types.SpecStatePassed)
|
||||
|
||||
switch report.State {
|
||||
case types.SpecStatePassed:
|
||||
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) && !reportHasContent {
|
||||
return
|
||||
}
|
||||
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
|
||||
header = fmt.Sprintf("%s PASSED", header)
|
||||
}
|
||||
if report.NumAttempts > 1 && report.MaxFlakeAttempts > 1 {
|
||||
header, reportHasContent = fmt.Sprintf("%s [FLAKEY TEST - TOOK %d ATTEMPTS TO PASS]", r.retryDenoter, report.NumAttempts), true
|
||||
}
|
||||
case types.SpecStatePending:
|
||||
header = "P"
|
||||
if v.GT(types.VerbosityLevelSuccinct) {
|
||||
header, reportHasContent = "P [PENDING]", true
|
||||
}
|
||||
case types.SpecStateSkipped:
|
||||
header = "S"
|
||||
if v.Is(types.VerbosityLevelVeryVerbose) || (v.Is(types.VerbosityLevelVerbose) && report.Failure.Message != "") {
|
||||
header, reportHasContent = "S [SKIPPED]", true
|
||||
}
|
||||
default:
|
||||
header = fmt.Sprintf("%s [%s]", header, r.humanReadableState(report.State))
|
||||
if report.MaxMustPassRepeatedly > 1 {
|
||||
header = fmt.Sprintf("%s DURING REPETITION #%d", header, report.NumAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
// If we have no content to show, jsut emit the header and return
|
||||
if !reportHasContent {
|
||||
r.emit(r.f(highlightColor + header + "{{/}}"))
|
||||
return
|
||||
}
|
||||
|
||||
if includeRuntime {
|
||||
header = r.f("%s [%.3f seconds]", header, report.RunTime.Seconds())
|
||||
}
|
||||
|
||||
// Emit header
|
||||
if !timelineHasBeenStreaming {
|
||||
r.emitDelimiter(0)
|
||||
}
|
||||
r.emitBlock(r.f(highlightColor + header + "{{/}}"))
|
||||
if showCodeLocation {
|
||||
r.emitBlock(r.codeLocationBlock(report, highlightColor, v.Is(types.VerbosityLevelVeryVerbose), false))
|
||||
}
|
||||
|
||||
//Emit Stdout/Stderr Output
|
||||
if showSeparateStdSection {
|
||||
r.emitBlock("\n")
|
||||
r.emitBlock(r.fi(1, "{{gray}}Captured StdOut/StdErr Output >>{{/}}"))
|
||||
r.emitBlock(r.fi(1, "%s", report.CapturedStdOutErr))
|
||||
r.emitBlock(r.fi(1, "{{gray}}<< Captured StdOut/StdErr Output{{/}}"))
|
||||
}
|
||||
|
||||
if showSeparateVisibilityAlwaysReportsSection {
|
||||
r.emitBlock("\n")
|
||||
r.emitBlock(r.fi(1, "{{gray}}Report Entries >>{{/}}"))
|
||||
for _, entry := range report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways) {
|
||||
r.emitReportEntry(1, entry)
|
||||
}
|
||||
r.emitBlock(r.fi(1, "{{gray}}<< Report Entries{{/}}"))
|
||||
}
|
||||
|
||||
if showTimeline {
|
||||
r.emitBlock("\n")
|
||||
r.emitBlock(r.fi(1, "{{gray}}Timeline >>{{/}}"))
|
||||
r.emitTimeline(1, report, timeline)
|
||||
r.emitBlock(r.fi(1, "{{gray}}<< Timeline{{/}}"))
|
||||
}
|
||||
|
||||
// Emit Failure Message
|
||||
if !report.Failure.IsZero() && !v.Is(types.VerbosityLevelVeryVerbose) {
|
||||
r.emitBlock("\n")
|
||||
r.emitFailure(1, report.State, report.Failure, true)
|
||||
if len(report.AdditionalFailures) > 0 {
|
||||
r.emitBlock(r.fi(1, "\nThere were {{bold}}{{red}}additional failures{{/}} detected. To view them in detail run {{bold}}ginkgo -vv{{/}}"))
|
||||
}
|
||||
}
|
||||
|
||||
r.emitDelimiter(0)
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) highlightColorForState(state types.SpecState) string {
|
||||
switch state {
|
||||
case types.SpecStatePassed:
|
||||
return "{{green}}"
|
||||
case types.SpecStatePending:
|
||||
return "{{yellow}}"
|
||||
case types.SpecStateSkipped:
|
||||
return "{{cyan}}"
|
||||
case types.SpecStateFailed:
|
||||
return "{{red}}"
|
||||
case types.SpecStateTimedout:
|
||||
return "{{orange}}"
|
||||
case types.SpecStatePanicked:
|
||||
return "{{magenta}}"
|
||||
case types.SpecStateInterrupted:
|
||||
return "{{orange}}"
|
||||
case types.SpecStateAborted:
|
||||
return "{{coral}}"
|
||||
default:
|
||||
return "{{gray}}"
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) humanReadableState(state types.SpecState) string {
|
||||
return strings.ToUpper(state.String())
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) emitTimeline(indent uint, report types.SpecReport, timeline types.Timeline) {
|
||||
isVeryVerbose := r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose)
|
||||
gw := report.CapturedGinkgoWriterOutput
|
||||
cursor := 0
|
||||
for _, entry := range timeline {
|
||||
tl := entry.GetTimelineLocation()
|
||||
if tl.Offset < len(gw) {
|
||||
r.emit(r.fi(indent, "%s", gw[cursor:tl.Offset]))
|
||||
cursor = tl.Offset
|
||||
} else if cursor < len(gw) {
|
||||
r.emit(r.fi(indent, "%s", gw[cursor:]))
|
||||
cursor = len(gw)
|
||||
}
|
||||
switch x := entry.(type) {
|
||||
case types.Failure:
|
||||
if isVeryVerbose {
|
||||
r.emitFailure(indent, report.State, x, false)
|
||||
} else {
|
||||
r.emitShortFailure(indent, report.State, x)
|
||||
}
|
||||
case types.AdditionalFailure:
|
||||
if isVeryVerbose {
|
||||
r.emitFailure(indent, x.State, x.Failure, true)
|
||||
} else {
|
||||
r.emitShortFailure(indent, x.State, x.Failure)
|
||||
}
|
||||
case types.ReportEntry:
|
||||
r.emitReportEntry(indent, x)
|
||||
case types.ProgressReport:
|
||||
r.emitProgressReport(indent, false, x)
|
||||
case types.SpecEvent:
|
||||
if isVeryVerbose || !x.IsOnlyVisibleAtVeryVerbose() || r.conf.ShowNodeEvents {
|
||||
r.emitSpecEvent(indent, x, isVeryVerbose)
|
||||
}
|
||||
}
|
||||
}
|
||||
if cursor < len(gw) {
|
||||
r.emit(r.fi(indent, "%s", gw[cursor:]))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) EmitFailure(state types.SpecState, failure types.Failure) {
|
||||
if r.conf.Verbosity().Is(types.VerbosityLevelVerbose) {
|
||||
r.emitShortFailure(1, state, failure)
|
||||
} else if r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose) {
|
||||
r.emitFailure(1, state, failure, true)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) emitShortFailure(indent uint, state types.SpecState, failure types.Failure) {
|
||||
r.emitBlock(r.fi(indent, r.highlightColorForState(state)+"[%s]{{/}} in [%s] - %s {{gray}}@ %s{{/}}",
|
||||
r.humanReadableState(state),
|
||||
failure.FailureNodeType,
|
||||
failure.Location,
|
||||
failure.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT),
|
||||
))
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) emitFailure(indent uint, state types.SpecState, failure types.Failure, includeAdditionalFailure bool) {
|
||||
highlightColor := r.highlightColorForState(state)
|
||||
r.emitBlock(r.fi(indent, highlightColor+"[%s] %s{{/}}", r.humanReadableState(state), failure.Message))
|
||||
r.emitBlock(r.fi(indent, highlightColor+"In {{bold}}[%s]{{/}}"+highlightColor+" at: {{bold}}%s{{/}} {{gray}}@ %s{{/}}\n", failure.FailureNodeType, failure.Location, failure.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
|
||||
if failure.ForwardedPanic != "" {
|
||||
r.emitBlock("\n")
|
||||
r.emitBlock(r.fi(indent, highlightColor+"%s{{/}}", failure.ForwardedPanic))
|
||||
}
|
||||
|
||||
if r.conf.FullTrace || failure.ForwardedPanic != "" {
|
||||
r.emitBlock("\n")
|
||||
r.emitBlock(r.fi(indent, highlightColor+"Full Stack Trace{{/}}"))
|
||||
r.emitBlock(r.fi(indent+1, "%s", failure.Location.FullStackTrace))
|
||||
}
|
||||
|
||||
if !failure.ProgressReport.IsZero() {
|
||||
r.emitBlock("\n")
|
||||
r.emitProgressReport(indent, false, failure.ProgressReport)
|
||||
}
|
||||
|
||||
if failure.AdditionalFailure != nil && includeAdditionalFailure {
|
||||
r.emitBlock("\n")
|
||||
r.emitFailure(indent, failure.AdditionalFailure.State, failure.AdditionalFailure.Failure, true)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) EmitProgressReport(report types.ProgressReport) {
|
||||
r.emitDelimiter()
|
||||
r.emitDelimiter(1)
|
||||
|
||||
if report.RunningInParallel {
|
||||
r.emit(r.f("{{coral}}Progress Report for Ginkgo Process #{{bold}}%d{{/}}\n", report.ParallelProcess))
|
||||
r.emit(r.fi(1, "{{coral}}Progress Report for Ginkgo Process #{{bold}}%d{{/}}\n", report.ParallelProcess))
|
||||
}
|
||||
r.emitProgressReport(0, true, report)
|
||||
r.emitDelimiter()
|
||||
shouldEmitGW := report.RunningInParallel || r.conf.Verbosity().LT(types.VerbosityLevelVerbose)
|
||||
r.emitProgressReport(1, shouldEmitGW, report)
|
||||
r.emitDelimiter(1)
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput bool, report types.ProgressReport) {
|
||||
|
@ -409,7 +451,7 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
|
|||
r.emit(" ")
|
||||
subjectIndent = 0
|
||||
}
|
||||
r.emit(r.fi(subjectIndent, "{{bold}}{{orange}}%s{{/}} (Spec Runtime: %s)\n", report.LeafNodeText, report.Time.Sub(report.SpecStartTime).Round(time.Millisecond)))
|
||||
r.emit(r.fi(subjectIndent, "{{bold}}{{orange}}%s{{/}} (Spec Runtime: %s)\n", report.LeafNodeText, report.Time().Sub(report.SpecStartTime).Round(time.Millisecond)))
|
||||
r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.LeafNodeLocation))
|
||||
indent += 1
|
||||
}
|
||||
|
@ -419,12 +461,12 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
|
|||
r.emit(r.f(" {{bold}}{{orange}}%s{{/}}", report.CurrentNodeText))
|
||||
}
|
||||
|
||||
r.emit(r.f(" (Node Runtime: %s)\n", report.Time.Sub(report.CurrentNodeStartTime).Round(time.Millisecond)))
|
||||
r.emit(r.f(" (Node Runtime: %s)\n", report.Time().Sub(report.CurrentNodeStartTime).Round(time.Millisecond)))
|
||||
r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentNodeLocation))
|
||||
indent += 1
|
||||
}
|
||||
if report.CurrentStepText != "" {
|
||||
r.emit(r.fi(indent, "At {{bold}}{{orange}}[By Step] %s{{/}} (Step Runtime: %s)\n", report.CurrentStepText, report.Time.Sub(report.CurrentStepStartTime).Round(time.Millisecond)))
|
||||
r.emit(r.fi(indent, "At {{bold}}{{orange}}[By Step] %s{{/}} (Step Runtime: %s)\n", report.CurrentStepText, report.Time().Sub(report.CurrentStepStartTime).Round(time.Millisecond)))
|
||||
r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentStepLocation))
|
||||
indent += 1
|
||||
}
|
||||
|
@ -433,9 +475,19 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
|
|||
indent -= 1
|
||||
}
|
||||
|
||||
if emitGinkgoWriterOutput && report.CapturedGinkgoWriterOutput != "" && (report.RunningInParallel || r.conf.Verbosity().LT(types.VerbosityLevelVerbose)) {
|
||||
if emitGinkgoWriterOutput && report.CapturedGinkgoWriterOutput != "" {
|
||||
r.emit("\n")
|
||||
r.emitGinkgoWriterOutput(indent, report.CapturedGinkgoWriterOutput, 10)
|
||||
r.emitBlock(r.fi(indent, "{{gray}}Begin Captured GinkgoWriter Output >>{{/}}"))
|
||||
limit, lines := 10, strings.Split(report.CapturedGinkgoWriterOutput, "\n")
|
||||
if len(lines) <= limit {
|
||||
r.emitBlock(r.fi(indent+1, "%s", report.CapturedGinkgoWriterOutput))
|
||||
} else {
|
||||
r.emitBlock(r.fi(indent+1, "{{gray}}...{{/}}"))
|
||||
for _, line := range lines[len(lines)-limit-1:] {
|
||||
r.emitBlock(r.fi(indent+1, "%s", line))
|
||||
}
|
||||
}
|
||||
r.emitBlock(r.fi(indent, "{{gray}}<< End Captured GinkgoWriter Output{{/}}"))
|
||||
}
|
||||
|
||||
if !report.SpecGoroutine().IsZero() {
|
||||
|
@ -471,22 +523,48 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
|
|||
}
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) emitGinkgoWriterOutput(indent uint, output string, limit int) {
|
||||
r.emitBlock(r.fi(indent, "{{gray}}Begin Captured GinkgoWriter Output >>{{/}}"))
|
||||
if limit == 0 {
|
||||
r.emitBlock(r.fi(indent+1, "%s", output))
|
||||
} else {
|
||||
lines := strings.Split(output, "\n")
|
||||
if len(lines) <= limit {
|
||||
r.emitBlock(r.fi(indent+1, "%s", output))
|
||||
} else {
|
||||
r.emitBlock(r.fi(indent+1, "{{gray}}...{{/}}"))
|
||||
for _, line := range lines[len(lines)-limit-1:] {
|
||||
r.emitBlock(r.fi(indent+1, "%s", line))
|
||||
func (r *DefaultReporter) EmitReportEntry(entry types.ReportEntry) {
|
||||
if r.conf.Verbosity().LT(types.VerbosityLevelVerbose) || entry.Visibility == types.ReportEntryVisibilityNever {
|
||||
return
|
||||
}
|
||||
r.emitReportEntry(1, entry)
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) emitReportEntry(indent uint, entry types.ReportEntry) {
|
||||
r.emitBlock(r.fi(indent, "{{bold}}"+entry.Name+"{{gray}} "+fmt.Sprintf("- %s @ %s{{/}}", entry.Location, entry.Time.Format(types.GINKGO_TIME_FORMAT))))
|
||||
if representation := entry.StringRepresentation(); representation != "" {
|
||||
r.emitBlock(r.fi(indent+1, representation))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) EmitSpecEvent(event types.SpecEvent) {
|
||||
v := r.conf.Verbosity()
|
||||
if v.Is(types.VerbosityLevelVeryVerbose) || (v.Is(types.VerbosityLevelVerbose) && (r.conf.ShowNodeEvents || !event.IsOnlyVisibleAtVeryVerbose())) {
|
||||
r.emitSpecEvent(1, event, r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) emitSpecEvent(indent uint, event types.SpecEvent, includeLocation bool) {
|
||||
location := ""
|
||||
if includeLocation {
|
||||
location = fmt.Sprintf("- %s ", event.CodeLocation.String())
|
||||
}
|
||||
switch event.SpecEventType {
|
||||
case types.SpecEventInvalid:
|
||||
return
|
||||
case types.SpecEventByStart:
|
||||
r.emitBlock(r.fi(indent, "{{bold}}STEP:{{/}} %s {{gray}}%s@ %s{{/}}", event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
|
||||
case types.SpecEventByEnd:
|
||||
r.emitBlock(r.fi(indent, "{{bold}}END STEP:{{/}} %s {{gray}}%s@ %s (%s){{/}}", event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT), event.Duration.Round(time.Millisecond)))
|
||||
case types.SpecEventNodeStart:
|
||||
r.emitBlock(r.fi(indent, "> Enter {{bold}}[%s]{{/}} %s {{gray}}%s@ %s{{/}}", event.NodeType.String(), event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
|
||||
case types.SpecEventNodeEnd:
|
||||
r.emitBlock(r.fi(indent, "< Exit {{bold}}[%s]{{/}} %s {{gray}}%s@ %s (%s){{/}}", event.NodeType.String(), event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT), event.Duration.Round(time.Millisecond)))
|
||||
case types.SpecEventSpecRepeat:
|
||||
r.emitBlock(r.fi(indent, "\n{{bold}}Attempt #%d {{green}}Passed{{/}}{{bold}}. Repeating %s{{/}} {{gray}}@ %s{{/}}\n\n", event.Attempt, r.retryDenoter, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
|
||||
case types.SpecEventSpecRetry:
|
||||
r.emitBlock(r.fi(indent, "\n{{bold}}Attempt #%d {{red}}Failed{{/}}{{bold}}. Retrying %s{{/}} {{gray}}@ %s{{/}}\n\n", event.Attempt, r.retryDenoter, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
|
||||
}
|
||||
r.emitBlock(r.fi(indent, "{{gray}}<< End Captured GinkgoWriter Output{{/}}"))
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) emitGoroutines(indent uint, goroutines ...types.Goroutine) {
|
||||
|
@ -544,31 +622,37 @@ func (r *DefaultReporter) emitSource(indent uint, fc types.FunctionCall) {
|
|||
|
||||
/* Emitting to the writer */
|
||||
func (r *DefaultReporter) emit(s string) {
|
||||
if len(s) > 0 {
|
||||
r.lastChar = s[len(s)-1:]
|
||||
r.lastEmissionWasDelimiter = false
|
||||
r.writer.Write([]byte(s))
|
||||
}
|
||||
r._emit(s, false, false)
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) emitBlock(s string) {
|
||||
if len(s) > 0 {
|
||||
if r.lastChar != "\n" {
|
||||
r.emit("\n")
|
||||
}
|
||||
r.emit(s)
|
||||
if r.lastChar != "\n" {
|
||||
r.emit("\n")
|
||||
}
|
||||
}
|
||||
r._emit(s, true, false)
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) emitDelimiter() {
|
||||
if r.lastEmissionWasDelimiter {
|
||||
func (r *DefaultReporter) emitDelimiter(indent uint) {
|
||||
r._emit(r.fi(indent, "{{gray}}%s{{/}}", strings.Repeat("-", 30)), true, true)
|
||||
}
|
||||
|
||||
// a bit ugly - but we're trying to minimize locking on this hot codepath
|
||||
func (r *DefaultReporter) _emit(s string, block bool, isDelimiter bool) {
|
||||
if len(s) == 0 {
|
||||
return
|
||||
}
|
||||
r.emitBlock(r.f("{{gray}}%s{{/}}", strings.Repeat("-", 30)))
|
||||
r.lastEmissionWasDelimiter = true
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
if isDelimiter && r.lastEmissionWasDelimiter {
|
||||
return
|
||||
}
|
||||
if block && !r.lastCharWasNewline {
|
||||
r.writer.Write([]byte("\n"))
|
||||
}
|
||||
r.lastCharWasNewline = (s[len(s)-1:] == "\n")
|
||||
r.writer.Write([]byte(s))
|
||||
if block && !r.lastCharWasNewline {
|
||||
r.writer.Write([]byte("\n"))
|
||||
r.lastCharWasNewline = true
|
||||
}
|
||||
r.lastEmissionWasDelimiter = isDelimiter
|
||||
}
|
||||
|
||||
/* Rendering text */
|
||||
|
@ -584,13 +668,14 @@ func (r *DefaultReporter) cycleJoin(elements []string, joiner string) string {
|
|||
return r.formatter.CycleJoin(elements, joiner, []string{"{{/}}", "{{gray}}"})
|
||||
}
|
||||
|
||||
func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightColor string, succinct bool, usePreciseFailureLocation bool) string {
|
||||
func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightColor string, veryVerbose bool, usePreciseFailureLocation bool) string {
|
||||
texts, locations, labels := []string{}, []types.CodeLocation{}, [][]string{}
|
||||
texts, locations, labels = append(texts, report.ContainerHierarchyTexts...), append(locations, report.ContainerHierarchyLocations...), append(labels, report.ContainerHierarchyLabels...)
|
||||
|
||||
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
|
||||
texts = append(texts, r.f("[%s] %s", report.LeafNodeType, report.LeafNodeText))
|
||||
} else {
|
||||
texts = append(texts, report.LeafNodeText)
|
||||
texts = append(texts, r.f(report.LeafNodeText))
|
||||
}
|
||||
labels = append(labels, report.LeafNodeLabels)
|
||||
locations = append(locations, report.LeafNodeLocation)
|
||||
|
@ -600,24 +685,58 @@ func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightCo
|
|||
failureLocation = report.Failure.Location
|
||||
}
|
||||
|
||||
highlightIndex := -1
|
||||
switch report.Failure.FailureNodeContext {
|
||||
case types.FailureNodeAtTopLevel:
|
||||
texts = append([]string{r.f(highlightColor+"{{bold}}TOP-LEVEL [%s]{{/}}", report.Failure.FailureNodeType)}, texts...)
|
||||
texts = append([]string{fmt.Sprintf("TOP-LEVEL [%s]", report.Failure.FailureNodeType)}, texts...)
|
||||
locations = append([]types.CodeLocation{failureLocation}, locations...)
|
||||
labels = append([][]string{{}}, labels...)
|
||||
highlightIndex = 0
|
||||
case types.FailureNodeInContainer:
|
||||
i := report.Failure.FailureNodeContainerIndex
|
||||
texts[i] = r.f(highlightColor+"{{bold}}%s [%s]{{/}}", texts[i], report.Failure.FailureNodeType)
|
||||
texts[i] = fmt.Sprintf("%s [%s]", texts[i], report.Failure.FailureNodeType)
|
||||
locations[i] = failureLocation
|
||||
highlightIndex = i
|
||||
case types.FailureNodeIsLeafNode:
|
||||
i := len(texts) - 1
|
||||
texts[i] = r.f(highlightColor+"{{bold}}[%s] %s{{/}}", report.LeafNodeType, report.LeafNodeText)
|
||||
texts[i] = fmt.Sprintf("[%s] %s", report.LeafNodeType, report.LeafNodeText)
|
||||
locations[i] = failureLocation
|
||||
highlightIndex = i
|
||||
default:
|
||||
//there is no failure, so we highlight the leaf ndoe
|
||||
highlightIndex = len(texts) - 1
|
||||
}
|
||||
|
||||
out := ""
|
||||
if succinct {
|
||||
out += r.f("%s", r.cycleJoin(texts, " "))
|
||||
if veryVerbose {
|
||||
for i := range texts {
|
||||
if i == highlightIndex {
|
||||
out += r.fi(uint(i), highlightColor+"{{bold}}%s{{/}}", texts[i])
|
||||
} else {
|
||||
out += r.fi(uint(i), "%s", texts[i])
|
||||
}
|
||||
if len(labels[i]) > 0 {
|
||||
out += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels[i], ", "))
|
||||
}
|
||||
out += "\n"
|
||||
out += r.fi(uint(i), "{{gray}}%s{{/}}\n", locations[i])
|
||||
}
|
||||
} else {
|
||||
for i := range texts {
|
||||
style := "{{/}}"
|
||||
if i%2 == 1 {
|
||||
style = "{{gray}}"
|
||||
}
|
||||
if i == highlightIndex {
|
||||
style = highlightColor + "{{bold}}"
|
||||
}
|
||||
out += r.f(style+"%s", texts[i])
|
||||
if i < len(texts)-1 {
|
||||
out += " "
|
||||
} else {
|
||||
out += r.f("{{/}}")
|
||||
}
|
||||
}
|
||||
flattenedLabels := report.Labels()
|
||||
if len(flattenedLabels) > 0 {
|
||||
out += r.f(" {{coral}}[%s]{{/}}", strings.Join(flattenedLabels, ", "))
|
||||
|
@ -626,17 +745,15 @@ func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightCo
|
|||
if usePreciseFailureLocation {
|
||||
out += r.f("{{gray}}%s{{/}}", failureLocation)
|
||||
} else {
|
||||
out += r.f("{{gray}}%s{{/}}", locations[len(locations)-1])
|
||||
}
|
||||
leafLocation := locations[len(locations)-1]
|
||||
if (report.Failure.FailureNodeLocation != types.CodeLocation{}) && (report.Failure.FailureNodeLocation != leafLocation) {
|
||||
out += r.fi(1, highlightColor+"[%s]{{/}} {{gray}}%s{{/}}\n", report.Failure.FailureNodeType, report.Failure.FailureNodeLocation)
|
||||
out += r.fi(1, "{{gray}}[%s] %s{{/}}", report.LeafNodeType, leafLocation)
|
||||
} else {
|
||||
for i := range texts {
|
||||
out += r.fi(uint(i), "%s", texts[i])
|
||||
if len(labels[i]) > 0 {
|
||||
out += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels[i], ", "))
|
||||
out += r.f("{{gray}}%s{{/}}", leafLocation)
|
||||
}
|
||||
out += "\n"
|
||||
out += r.fi(uint(i), "{{gray}}%s{{/}}\n", locations[i])
|
||||
}
|
||||
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ func ReportViaDeprecatedReporter(reporter DeprecatedReporter, report types.Repor
|
|||
FailOnPending: report.SuiteConfig.FailOnPending,
|
||||
FailFast: report.SuiteConfig.FailFast,
|
||||
FlakeAttempts: report.SuiteConfig.FlakeAttempts,
|
||||
EmitSpecProgress: report.SuiteConfig.EmitSpecProgress,
|
||||
EmitSpecProgress: false,
|
||||
DryRun: report.SuiteConfig.DryRun,
|
||||
ParallelNode: report.SuiteConfig.ParallelProcess,
|
||||
ParallelTotal: report.SuiteConfig.ParallelTotal,
|
||||
|
|
|
@ -15,12 +15,32 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/onsi/ginkgo/v2/config"
|
||||
"github.com/onsi/ginkgo/v2/types"
|
||||
)
|
||||
|
||||
type JunitReportConfig struct {
|
||||
// Spec States for which no timeline should be emitted for system-err
|
||||
// set this to types.SpecStatePassed|types.SpecStateSkipped|types.SpecStatePending to only match failing specs
|
||||
OmitTimelinesForSpecState types.SpecState
|
||||
|
||||
// Enable OmitFailureMessageAttr to prevent failure messages appearing in the "message" attribute of the Failure and Error tags
|
||||
OmitFailureMessageAttr bool
|
||||
|
||||
//Enable OmitCapturedStdOutErr to prevent captured stdout/stderr appearing in system-out
|
||||
OmitCapturedStdOutErr bool
|
||||
|
||||
// Enable OmitSpecLabels to prevent labels from appearing in the spec name
|
||||
OmitSpecLabels bool
|
||||
|
||||
// Enable OmitLeafNodeType to prevent the spec leaf node type from appearing in the spec name
|
||||
OmitLeafNodeType bool
|
||||
|
||||
// Enable OmitSuiteSetupNodes to prevent the creation of testcase entries for setup nodes
|
||||
OmitSuiteSetupNodes bool
|
||||
}
|
||||
|
||||
type JUnitTestSuites struct {
|
||||
XMLName xml.Name `xml:"testsuites"`
|
||||
// Tests maps onto the total number of specs in all test suites (this includes any suite nodes such as BeforeSuite)
|
||||
|
@ -128,6 +148,10 @@ type JUnitFailure struct {
|
|||
}
|
||||
|
||||
func GenerateJUnitReport(report types.Report, dst string) error {
|
||||
return GenerateJUnitReportWithConfig(report, dst, JunitReportConfig{})
|
||||
}
|
||||
|
||||
func GenerateJUnitReportWithConfig(report types.Report, dst string, config JunitReportConfig) error {
|
||||
suite := JUnitTestSuite{
|
||||
Name: report.SuiteDescription,
|
||||
Package: report.SuitePath,
|
||||
|
@ -149,7 +173,6 @@ func GenerateJUnitReport(report types.Report, dst string) error {
|
|||
{"FailOnPending", fmt.Sprintf("%t", report.SuiteConfig.FailOnPending)},
|
||||
{"FailFast", fmt.Sprintf("%t", report.SuiteConfig.FailFast)},
|
||||
{"FlakeAttempts", fmt.Sprintf("%d", report.SuiteConfig.FlakeAttempts)},
|
||||
{"EmitSpecProgress", fmt.Sprintf("%t", report.SuiteConfig.EmitSpecProgress)},
|
||||
{"DryRun", fmt.Sprintf("%t", report.SuiteConfig.DryRun)},
|
||||
{"ParallelTotal", fmt.Sprintf("%d", report.SuiteConfig.ParallelTotal)},
|
||||
{"OutputInterceptorMode", report.SuiteConfig.OutputInterceptorMode},
|
||||
|
@ -157,22 +180,33 @@ func GenerateJUnitReport(report types.Report, dst string) error {
|
|||
},
|
||||
}
|
||||
for _, spec := range report.SpecReports {
|
||||
if config.OmitSuiteSetupNodes && spec.LeafNodeType != types.NodeTypeIt {
|
||||
continue
|
||||
}
|
||||
name := fmt.Sprintf("[%s]", spec.LeafNodeType)
|
||||
if config.OmitLeafNodeType {
|
||||
name = ""
|
||||
}
|
||||
if spec.FullText() != "" {
|
||||
name = name + " " + spec.FullText()
|
||||
}
|
||||
labels := spec.Labels()
|
||||
if len(labels) > 0 {
|
||||
if len(labels) > 0 && !config.OmitSpecLabels {
|
||||
name = name + " [" + strings.Join(labels, ", ") + "]"
|
||||
}
|
||||
name = strings.TrimSpace(name)
|
||||
|
||||
test := JUnitTestCase{
|
||||
Name: name,
|
||||
Classname: report.SuiteDescription,
|
||||
Status: spec.State.String(),
|
||||
Time: spec.RunTime.Seconds(),
|
||||
SystemOut: systemOutForUnstructuredReporters(spec),
|
||||
SystemErr: systemErrForUnstructuredReporters(spec),
|
||||
}
|
||||
if !spec.State.Is(config.OmitTimelinesForSpecState) {
|
||||
test.SystemErr = systemErrForUnstructuredReporters(spec)
|
||||
}
|
||||
if !config.OmitCapturedStdOutErr {
|
||||
test.SystemOut = systemOutForUnstructuredReporters(spec)
|
||||
}
|
||||
suite.Tests += 1
|
||||
|
||||
|
@ -193,6 +227,9 @@ func GenerateJUnitReport(report types.Report, dst string) error {
|
|||
Type: "failed",
|
||||
Description: failureDescriptionForUnstructuredReporters(spec),
|
||||
}
|
||||
if config.OmitFailureMessageAttr {
|
||||
test.Failure.Message = ""
|
||||
}
|
||||
suite.Failures += 1
|
||||
case types.SpecStateTimedout:
|
||||
test.Failure = &JUnitFailure{
|
||||
|
@ -200,6 +237,9 @@ func GenerateJUnitReport(report types.Report, dst string) error {
|
|||
Type: "timedout",
|
||||
Description: failureDescriptionForUnstructuredReporters(spec),
|
||||
}
|
||||
if config.OmitFailureMessageAttr {
|
||||
test.Failure.Message = ""
|
||||
}
|
||||
suite.Failures += 1
|
||||
case types.SpecStateInterrupted:
|
||||
test.Error = &JUnitError{
|
||||
|
@ -207,6 +247,9 @@ func GenerateJUnitReport(report types.Report, dst string) error {
|
|||
Type: "interrupted",
|
||||
Description: failureDescriptionForUnstructuredReporters(spec),
|
||||
}
|
||||
if config.OmitFailureMessageAttr {
|
||||
test.Error.Message = ""
|
||||
}
|
||||
suite.Errors += 1
|
||||
case types.SpecStateAborted:
|
||||
test.Failure = &JUnitFailure{
|
||||
|
@ -214,6 +257,9 @@ func GenerateJUnitReport(report types.Report, dst string) error {
|
|||
Type: "aborted",
|
||||
Description: failureDescriptionForUnstructuredReporters(spec),
|
||||
}
|
||||
if config.OmitFailureMessageAttr {
|
||||
test.Failure.Message = ""
|
||||
}
|
||||
suite.Errors += 1
|
||||
case types.SpecStatePanicked:
|
||||
test.Error = &JUnitError{
|
||||
|
@ -221,6 +267,9 @@ func GenerateJUnitReport(report types.Report, dst string) error {
|
|||
Type: "panicked",
|
||||
Description: failureDescriptionForUnstructuredReporters(spec),
|
||||
}
|
||||
if config.OmitFailureMessageAttr {
|
||||
test.Error.Message = ""
|
||||
}
|
||||
suite.Errors += 1
|
||||
}
|
||||
|
||||
|
@ -287,63 +336,25 @@ func MergeAndCleanupJUnitReports(sources []string, dst string) ([]string, error)
|
|||
|
||||
func failureDescriptionForUnstructuredReporters(spec types.SpecReport) string {
|
||||
out := &strings.Builder{}
|
||||
out.WriteString(spec.Failure.Location.String() + "\n")
|
||||
out.WriteString(spec.Failure.Location.FullStackTrace)
|
||||
if !spec.Failure.ProgressReport.IsZero() {
|
||||
out.WriteString("\n")
|
||||
NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitProgressReport(spec.Failure.ProgressReport)
|
||||
}
|
||||
NewDefaultReporter(types.ReporterConfig{NoColor: true, VeryVerbose: true}, out).emitFailure(0, spec.State, spec.Failure, true)
|
||||
if len(spec.AdditionalFailures) > 0 {
|
||||
out.WriteString("\nThere were additional failures detected after the initial failure:\n")
|
||||
for i, additionalFailure := range spec.AdditionalFailures {
|
||||
NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitFailure(0, additionalFailure.State, additionalFailure.Failure, true)
|
||||
if i < len(spec.AdditionalFailures)-1 {
|
||||
out.WriteString("----------\n")
|
||||
}
|
||||
}
|
||||
out.WriteString("\nThere were additional failures detected after the initial failure. These are visible in the timeline\n")
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func systemErrForUnstructuredReporters(spec types.SpecReport) string {
|
||||
return RenderTimeline(spec, true)
|
||||
}
|
||||
|
||||
func RenderTimeline(spec types.SpecReport, noColor bool) string {
|
||||
out := &strings.Builder{}
|
||||
gw := spec.CapturedGinkgoWriterOutput
|
||||
cursor := 0
|
||||
for _, pr := range spec.ProgressReports {
|
||||
if cursor < pr.GinkgoWriterOffset {
|
||||
if pr.GinkgoWriterOffset < len(gw) {
|
||||
out.WriteString(gw[cursor:pr.GinkgoWriterOffset])
|
||||
cursor = pr.GinkgoWriterOffset
|
||||
} else if cursor < len(gw) {
|
||||
out.WriteString(gw[cursor:])
|
||||
cursor = len(gw)
|
||||
}
|
||||
}
|
||||
NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitProgressReport(pr)
|
||||
}
|
||||
|
||||
if cursor < len(gw) {
|
||||
out.WriteString(gw[cursor:])
|
||||
}
|
||||
|
||||
NewDefaultReporter(types.ReporterConfig{NoColor: noColor, VeryVerbose: true}, out).emitTimeline(0, spec, spec.Timeline())
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func systemOutForUnstructuredReporters(spec types.SpecReport) string {
|
||||
systemOut := spec.CapturedStdOutErr
|
||||
if len(spec.ReportEntries) > 0 {
|
||||
systemOut += "\nReport Entries:\n"
|
||||
for i, entry := range spec.ReportEntries {
|
||||
systemOut += fmt.Sprintf("%s\n%s\n%s\n", entry.Name, entry.Location, entry.Time.Format(time.RFC3339Nano))
|
||||
if representation := entry.StringRepresentation(); representation != "" {
|
||||
systemOut += representation + "\n"
|
||||
}
|
||||
if i+1 < len(spec.ReportEntries) {
|
||||
systemOut += "--\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
return systemOut
|
||||
return spec.CapturedStdOutErr
|
||||
}
|
||||
|
||||
// Deprecated JUnitReporter (so folks can still compile their suites)
|
||||
|
|
|
@ -9,7 +9,12 @@ type Reporter interface {
|
|||
WillRun(report types.SpecReport)
|
||||
DidRun(report types.SpecReport)
|
||||
SuiteDidEnd(report types.Report)
|
||||
|
||||
//Timeline emission
|
||||
EmitFailure(state types.SpecState, failure types.Failure)
|
||||
EmitProgressReport(progressReport types.ProgressReport)
|
||||
EmitReportEntry(entry types.ReportEntry)
|
||||
EmitSpecEvent(event types.SpecEvent)
|
||||
}
|
||||
|
||||
type NoopReporter struct{}
|
||||
|
@ -18,4 +23,7 @@ func (n NoopReporter) SuiteWillBegin(report types.Report) {}
|
|||
func (n NoopReporter) WillRun(report types.SpecReport) {}
|
||||
func (n NoopReporter) DidRun(report types.SpecReport) {}
|
||||
func (n NoopReporter) SuiteDidEnd(report types.Report) {}
|
||||
func (n NoopReporter) EmitFailure(state types.SpecState, failure types.Failure) {}
|
||||
func (n NoopReporter) EmitProgressReport(progressReport types.ProgressReport) {}
|
||||
func (n NoopReporter) EmitReportEntry(entry types.ReportEntry) {}
|
||||
func (n NoopReporter) EmitSpecEvent(event types.SpecEvent) {}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -6,6 +7,7 @@ import (
|
|||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type CodeLocation struct {
|
||||
|
@ -37,6 +39,73 @@ func (codeLocation CodeLocation) ContentsOfLine() string {
|
|||
return lines[codeLocation.LineNumber-1]
|
||||
}
|
||||
|
||||
type codeLocationLocator struct {
|
||||
pcs map[uintptr]bool
|
||||
helpers map[string]bool
|
||||
lock *sync.Mutex
|
||||
}
|
||||
|
||||
func (c *codeLocationLocator) addHelper(pc uintptr) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if c.pcs[pc] {
|
||||
return
|
||||
}
|
||||
c.lock.Unlock()
|
||||
f := runtime.FuncForPC(pc)
|
||||
c.lock.Lock()
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
c.helpers[f.Name()] = true
|
||||
c.pcs[pc] = true
|
||||
}
|
||||
|
||||
func (c *codeLocationLocator) hasHelper(name string) bool {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
return c.helpers[name]
|
||||
}
|
||||
|
||||
func (c *codeLocationLocator) getCodeLocation(skip int) CodeLocation {
|
||||
pc := make([]uintptr, 40)
|
||||
n := runtime.Callers(skip+2, pc)
|
||||
if n == 0 {
|
||||
return CodeLocation{}
|
||||
}
|
||||
pc = pc[:n]
|
||||
frames := runtime.CallersFrames(pc)
|
||||
for {
|
||||
frame, more := frames.Next()
|
||||
if !c.hasHelper(frame.Function) {
|
||||
return CodeLocation{FileName: frame.File, LineNumber: frame.Line}
|
||||
}
|
||||
if !more {
|
||||
break
|
||||
}
|
||||
}
|
||||
return CodeLocation{}
|
||||
}
|
||||
|
||||
var clLocator = &codeLocationLocator{
|
||||
pcs: map[uintptr]bool{},
|
||||
helpers: map[string]bool{},
|
||||
lock: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// MarkAsHelper is used by GinkgoHelper to mark the caller (appropriately offset by skip)as a helper. You can use this directly if you need to provide an optional `skip` to mark functions further up the call stack as helpers.
|
||||
func MarkAsHelper(optionalSkip ...int) {
|
||||
skip := 1
|
||||
if len(optionalSkip) > 0 {
|
||||
skip += optionalSkip[0]
|
||||
}
|
||||
pc, _, _, ok := runtime.Caller(skip)
|
||||
if ok {
|
||||
clLocator.addHelper(pc)
|
||||
}
|
||||
}
|
||||
|
||||
func NewCustomCodeLocation(message string) CodeLocation {
|
||||
return CodeLocation{
|
||||
CustomMessage: message,
|
||||
|
@ -44,14 +113,13 @@ func NewCustomCodeLocation(message string) CodeLocation {
|
|||
}
|
||||
|
||||
func NewCodeLocation(skip int) CodeLocation {
|
||||
_, file, line, _ := runtime.Caller(skip + 1)
|
||||
return CodeLocation{FileName: file, LineNumber: line}
|
||||
return clLocator.getCodeLocation(skip + 1)
|
||||
}
|
||||
|
||||
func NewCodeLocationWithStackTrace(skip int) CodeLocation {
|
||||
_, file, line, _ := runtime.Caller(skip + 1)
|
||||
stackTrace := PruneStack(string(debug.Stack()), skip+1)
|
||||
return CodeLocation{FileName: file, LineNumber: line, FullStackTrace: stackTrace}
|
||||
cl := clLocator.getCodeLocation(skip + 1)
|
||||
cl.FullStackTrace = PruneStack(string(debug.Stack()), skip+1)
|
||||
return cl
|
||||
}
|
||||
|
||||
// PruneStack removes references to functions that are internal to Ginkgo
|
||||
|
|
|
@ -8,6 +8,7 @@ package types
|
|||
import (
|
||||
"flag"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -26,11 +27,11 @@ type SuiteConfig struct {
|
|||
FailOnPending bool
|
||||
FailFast bool
|
||||
FlakeAttempts int
|
||||
EmitSpecProgress bool
|
||||
DryRun bool
|
||||
PollProgressAfter time.Duration
|
||||
PollProgressInterval time.Duration
|
||||
Timeout time.Duration
|
||||
EmitSpecProgress bool // this is deprecated but its removal is causing compile issue for some users that were setting it manually
|
||||
OutputInterceptorMode string
|
||||
SourceRoots []string
|
||||
GracePeriod time.Duration
|
||||
|
@ -82,12 +83,11 @@ func (vl VerbosityLevel) LT(comp VerbosityLevel) bool {
|
|||
// Configuration for Ginkgo's reporter
|
||||
type ReporterConfig struct {
|
||||
NoColor bool
|
||||
SlowSpecThreshold time.Duration
|
||||
Succinct bool
|
||||
Verbose bool
|
||||
VeryVerbose bool
|
||||
FullTrace bool
|
||||
AlwaysEmitGinkgoWriter bool
|
||||
ShowNodeEvents bool
|
||||
|
||||
JSONReport string
|
||||
JUnitReport string
|
||||
|
@ -110,9 +110,7 @@ func (rc ReporterConfig) WillGenerateReport() bool {
|
|||
}
|
||||
|
||||
func NewDefaultReporterConfig() ReporterConfig {
|
||||
return ReporterConfig{
|
||||
SlowSpecThreshold: 5 * time.Second,
|
||||
}
|
||||
return ReporterConfig{}
|
||||
}
|
||||
|
||||
// Configuration for the Ginkgo CLI
|
||||
|
@ -235,6 +233,9 @@ type deprecatedConfig struct {
|
|||
SlowSpecThresholdWithFLoatUnits float64
|
||||
Stream bool
|
||||
Notify bool
|
||||
EmitSpecProgress bool
|
||||
SlowSpecThreshold time.Duration
|
||||
AlwaysEmitGinkgoWriter bool
|
||||
}
|
||||
|
||||
// Flags
|
||||
|
@ -275,8 +276,6 @@ var SuiteConfigFlags = GinkgoFlags{
|
|||
|
||||
{KeyPath: "S.DryRun", Name: "dry-run", SectionKey: "debug", DeprecatedName: "dryRun", DeprecatedDocLink: "changed-command-line-flags",
|
||||
Usage: "If set, ginkgo will walk the test hierarchy without actually running anything. Best paired with -v."},
|
||||
{KeyPath: "S.EmitSpecProgress", Name: "progress", SectionKey: "debug",
|
||||
Usage: "If set, ginkgo will emit progress information as each spec runs to the GinkgoWriter."},
|
||||
{KeyPath: "S.PollProgressAfter", Name: "poll-progress-after", SectionKey: "debug", UsageDefaultValue: "0",
|
||||
Usage: "Emit node progress reports periodically if node hasn't completed after this duration."},
|
||||
{KeyPath: "S.PollProgressInterval", Name: "poll-progress-interval", SectionKey: "debug", UsageDefaultValue: "10s",
|
||||
|
@ -303,6 +302,8 @@ var SuiteConfigFlags = GinkgoFlags{
|
|||
|
||||
{KeyPath: "D.RegexScansFilePath", DeprecatedName: "regexScansFilePath", DeprecatedDocLink: "removed--regexscansfilepath", DeprecatedVersion: "2.0.0"},
|
||||
{KeyPath: "D.DebugParallel", DeprecatedName: "debug", DeprecatedDocLink: "removed--debug", DeprecatedVersion: "2.0.0"},
|
||||
{KeyPath: "D.EmitSpecProgress", DeprecatedName: "progress", SectionKey: "debug",
|
||||
DeprecatedVersion: "2.5.0", Usage: ". The functionality provided by --progress was confusing and is no longer needed. Use --show-node-events instead to see node entry and exit events included in the timeline of failed and verbose specs. Or you can run with -vv to always see all node events. Lastly, --poll-progress-after and the PollProgressAfter decorator now provide a better mechanism for debugging specs that tend to get stuck."},
|
||||
}
|
||||
|
||||
// ParallelConfigFlags provides flags for the Ginkgo test process (not the CLI)
|
||||
|
@ -319,8 +320,6 @@ var ParallelConfigFlags = GinkgoFlags{
|
|||
var ReporterConfigFlags = GinkgoFlags{
|
||||
{KeyPath: "R.NoColor", Name: "no-color", SectionKey: "output", DeprecatedName: "noColor", DeprecatedDocLink: "changed-command-line-flags",
|
||||
Usage: "If set, suppress color output in default reporter."},
|
||||
{KeyPath: "R.SlowSpecThreshold", Name: "slow-spec-threshold", SectionKey: "output", UsageArgument: "duration", UsageDefaultValue: "5s",
|
||||
Usage: "Specs that take longer to run than this threshold are flagged as slow by the default reporter."},
|
||||
{KeyPath: "R.Verbose", Name: "v", SectionKey: "output",
|
||||
Usage: "If set, emits more output including GinkgoWriter contents."},
|
||||
{KeyPath: "R.VeryVerbose", Name: "vv", SectionKey: "output",
|
||||
|
@ -329,8 +328,8 @@ var ReporterConfigFlags = GinkgoFlags{
|
|||
Usage: "If set, default reporter prints out a very succinct report"},
|
||||
{KeyPath: "R.FullTrace", Name: "trace", SectionKey: "output",
|
||||
Usage: "If set, default reporter prints out the full stack trace when a failure occurs"},
|
||||
{KeyPath: "R.AlwaysEmitGinkgoWriter", Name: "always-emit-ginkgo-writer", SectionKey: "output", DeprecatedName: "reportPassed", DeprecatedDocLink: "renamed--reportpassed",
|
||||
Usage: "If set, default reporter prints out captured output of passed tests."},
|
||||
{KeyPath: "R.ShowNodeEvents", Name: "show-node-events", SectionKey: "output",
|
||||
Usage: "If set, default reporter prints node > Enter and < Exit events when specs fail"},
|
||||
|
||||
{KeyPath: "R.JSONReport", Name: "json-report", UsageArgument: "filename.json", SectionKey: "output",
|
||||
Usage: "If set, Ginkgo will generate a JSON-formatted test report at the specified location."},
|
||||
|
@ -343,6 +342,8 @@ var ReporterConfigFlags = GinkgoFlags{
|
|||
Usage: "use --slow-spec-threshold instead and pass in a duration string (e.g. '5s', not '5.0')"},
|
||||
{KeyPath: "D.NoisyPendings", DeprecatedName: "noisyPendings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"},
|
||||
{KeyPath: "D.NoisySkippings", DeprecatedName: "noisySkippings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"},
|
||||
{KeyPath: "D.SlowSpecThreshold", DeprecatedName: "slow-spec-threshold", SectionKey: "output", Usage: "--slow-spec-threshold has been deprecated and will be removed in a future version of Ginkgo. This feature has proved to be more noisy than useful. You can use --poll-progress-after, instead, to get more actionable feedback about potentially slow specs and understand where they might be getting stuck.", DeprecatedVersion: "2.5.0"},
|
||||
{KeyPath: "D.AlwaysEmitGinkgoWriter", DeprecatedName: "always-emit-ginkgo-writer", SectionKey: "output", Usage: " - use -v instead, or one of Ginkgo's machine-readable report formats to get GinkgoWriter output for passing specs."},
|
||||
}
|
||||
|
||||
// BuildTestSuiteFlagSet attaches to the CommandLine flagset and provides flags for the Ginkgo test process
|
||||
|
@ -600,13 +601,29 @@ func VetAndInitializeCLIAndGoConfig(cliConfig CLIConfig, goFlagsConfig GoFlagsCo
|
|||
}
|
||||
|
||||
// GenerateGoTestCompileArgs is used by the Ginkgo CLI to generate command line arguments to pass to the go test -c command when compiling the test
|
||||
func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, destination string, packageToBuild string) ([]string, error) {
|
||||
func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, destination string, packageToBuild string, pathToInvocationPath string) ([]string, error) {
|
||||
// if the user has set the CoverProfile run-time flag make sure to set the build-time cover flag to make sure
|
||||
// the built test binary can generate a coverprofile
|
||||
if goFlagsConfig.CoverProfile != "" {
|
||||
goFlagsConfig.Cover = true
|
||||
}
|
||||
|
||||
if goFlagsConfig.CoverPkg != "" {
|
||||
coverPkgs := strings.Split(goFlagsConfig.CoverPkg, ",")
|
||||
adjustedCoverPkgs := make([]string, len(coverPkgs))
|
||||
for i, coverPkg := range coverPkgs {
|
||||
coverPkg = strings.Trim(coverPkg, " ")
|
||||
if strings.HasPrefix(coverPkg, "./") {
|
||||
// this is a relative coverPkg - we need to reroot it
|
||||
adjustedCoverPkgs[i] = "./" + filepath.Join(pathToInvocationPath, strings.TrimPrefix(coverPkg, "./"))
|
||||
} else {
|
||||
// this is a package name - don't touch it
|
||||
adjustedCoverPkgs[i] = coverPkg
|
||||
}
|
||||
}
|
||||
goFlagsConfig.CoverPkg = strings.Join(adjustedCoverPkgs, ",")
|
||||
}
|
||||
|
||||
args := []string{"test", "-c", "-o", destination, packageToBuild}
|
||||
goArgs, err := GenerateFlagArgs(
|
||||
GoBuildFlags,
|
||||
|
|
|
@ -38,7 +38,7 @@ func (d deprecations) Async() Deprecation {
|
|||
|
||||
func (d deprecations) Measure() Deprecation {
|
||||
return Deprecation{
|
||||
Message: "Measure is deprecated and will be removed in Ginkgo V2. Please migrate to gomega/gmeasure.",
|
||||
Message: "Measure is deprecated and has been removed from Ginkgo V2. Any Measure tests in your spec will not run. Please migrate to gomega/gmeasure.",
|
||||
DocLink: "removed-measure",
|
||||
Version: "1.16.3",
|
||||
}
|
||||
|
@ -83,6 +83,13 @@ func (d deprecations) Nodot() Deprecation {
|
|||
}
|
||||
}
|
||||
|
||||
func (d deprecations) SuppressProgressReporting() Deprecation {
|
||||
return Deprecation{
|
||||
Message: "Improvements to how reporters emit timeline information means that SuppressProgressReporting is no longer necessary and has been deprecated.",
|
||||
Version: "2.5.0",
|
||||
}
|
||||
}
|
||||
|
||||
type DeprecationTracker struct {
|
||||
deprecations map[Deprecation][]CodeLocation
|
||||
lock *sync.Mutex
|
||||
|
|
|
@ -108,8 +108,8 @@ Please ensure all assertions are inside leaf nodes such as {{bold}}BeforeEach{{/
|
|||
|
||||
func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocation) error {
|
||||
docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite"
|
||||
if nodeType.Is(NodeTypeReportAfterSuite) {
|
||||
docLink = "reporting-nodes---reportaftersuite"
|
||||
if nodeType.Is(NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite) {
|
||||
docLink = "reporting-nodes---reportbeforesuite-and-reportaftersuite"
|
||||
}
|
||||
|
||||
return GinkgoError{
|
||||
|
@ -125,8 +125,8 @@ func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocatio
|
|||
|
||||
func (g ginkgoErrors) SuiteNodeDuringRunPhase(nodeType NodeType, cl CodeLocation) error {
|
||||
docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite"
|
||||
if nodeType.Is(NodeTypeReportAfterSuite) {
|
||||
docLink = "reporting-nodes---reportaftersuite"
|
||||
if nodeType.Is(NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite) {
|
||||
docLink = "reporting-nodes---reportbeforesuite-and-reportaftersuite"
|
||||
}
|
||||
|
||||
return GinkgoError{
|
||||
|
@ -298,6 +298,15 @@ func (g ginkgoErrors) SetupNodeNotInOrderedContainer(cl CodeLocation, nodeType N
|
|||
}
|
||||
}
|
||||
|
||||
func (g ginkgoErrors) InvalidContinueOnFailureDecoration(cl CodeLocation) error {
|
||||
return GinkgoError{
|
||||
Heading: "ContinueOnFailure not decorating an outermost Ordered Container",
|
||||
Message: "ContinueOnFailure can only decorate an Ordered container, and this Ordered container must be the outermost Ordered container.",
|
||||
CodeLocation: cl,
|
||||
DocLink: "ordered-containers",
|
||||
}
|
||||
}
|
||||
|
||||
/* DeferCleanup errors */
|
||||
func (g ginkgoErrors) DeferCleanupInvalidFunction(cl CodeLocation) error {
|
||||
return GinkgoError{
|
||||
|
@ -320,7 +329,7 @@ func (g ginkgoErrors) PushingCleanupNodeDuringTreeConstruction(cl CodeLocation)
|
|||
func (g ginkgoErrors) PushingCleanupInReportingNode(cl CodeLocation, nodeType NodeType) error {
|
||||
return GinkgoError{
|
||||
Heading: fmt.Sprintf("DeferCleanup cannot be called in %s", nodeType),
|
||||
Message: "Please inline your cleanup code - Ginkgo won't run cleanup code after a ReportAfterEach or ReportAfterSuite.",
|
||||
Message: "Please inline your cleanup code - Ginkgo won't run cleanup code after a Reporting node.",
|
||||
CodeLocation: cl,
|
||||
DocLink: "cleaning-up-our-cleanup-code-defercleanup",
|
||||
}
|
||||
|
|
|
@ -272,12 +272,23 @@ func tokenize(input string) func() (*treeNode, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func MustParseLabelFilter(input string) LabelFilter {
|
||||
filter, err := ParseLabelFilter(input)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return filter
|
||||
}
|
||||
|
||||
func ParseLabelFilter(input string) (LabelFilter, error) {
|
||||
if DEBUG_LABEL_FILTER_PARSING {
|
||||
fmt.Println("\n==============")
|
||||
fmt.Println("Input: ", input)
|
||||
fmt.Print("Tokens: ")
|
||||
}
|
||||
if input == "" {
|
||||
return func(_ []string) bool { return true }, nil
|
||||
}
|
||||
nextToken := tokenize(input)
|
||||
|
||||
root := &treeNode{token: lfTokenRoot}
|
||||
|
|
|
@ -6,8 +6,8 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
//ReportEntryValue wraps a report entry's value ensuring it can be encoded and decoded safely into reports
|
||||
//and across the network connection when running in parallel
|
||||
// ReportEntryValue wraps a report entry's value ensuring it can be encoded and decoded safely into reports
|
||||
// and across the network connection when running in parallel
|
||||
type ReportEntryValue struct {
|
||||
raw interface{} //unexported to prevent gob from freaking out about unregistered structs
|
||||
AsJSON string
|
||||
|
@ -85,10 +85,12 @@ func (rev *ReportEntryValue) GobDecode(data []byte) error {
|
|||
type ReportEntry struct {
|
||||
// Visibility captures the visibility policy for this ReportEntry
|
||||
Visibility ReportEntryVisibility
|
||||
// Time captures the time the AddReportEntry was called
|
||||
Time time.Time
|
||||
// Location captures the location of the AddReportEntry call
|
||||
Location CodeLocation
|
||||
|
||||
Time time.Time //need this for backwards compatibility
|
||||
TimelineLocation TimelineLocation
|
||||
|
||||
// Name captures the name of this report
|
||||
Name string
|
||||
// Value captures the (optional) object passed into AddReportEntry - this can be
|
||||
|
@ -120,7 +122,9 @@ func (entry ReportEntry) GetRawValue() interface{} {
|
|||
return entry.Value.GetRawValue()
|
||||
}
|
||||
|
||||
|
||||
func (entry ReportEntry) GetTimelineLocation() TimelineLocation {
|
||||
return entry.TimelineLocation
|
||||
}
|
||||
|
||||
type ReportEntries []ReportEntry
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ package types
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
@ -56,19 +58,20 @@ type Report struct {
|
|||
SuiteConfig SuiteConfig
|
||||
|
||||
//SpecReports is a list of all SpecReports generated by this test run
|
||||
//It is empty when the SuiteReport is provided to ReportBeforeSuite
|
||||
SpecReports SpecReports
|
||||
}
|
||||
|
||||
//PreRunStats contains a set of stats captured before the test run begins. This is primarily used
|
||||
//by Ginkgo's reporter to tell the user how many specs are in the current suite (PreRunStats.TotalSpecs)
|
||||
//and how many it intends to run (PreRunStats.SpecsThatWillRun) after applying any relevant focus or skip filters.
|
||||
// PreRunStats contains a set of stats captured before the test run begins. This is primarily used
|
||||
// by Ginkgo's reporter to tell the user how many specs are in the current suite (PreRunStats.TotalSpecs)
|
||||
// and how many it intends to run (PreRunStats.SpecsThatWillRun) after applying any relevant focus or skip filters.
|
||||
type PreRunStats struct {
|
||||
TotalSpecs int
|
||||
SpecsThatWillRun int
|
||||
}
|
||||
|
||||
//Add is used by Ginkgo's parallel aggregation mechanisms to combine test run reports form individual parallel processes
|
||||
//to form a complete final report.
|
||||
// Add is used by Ginkgo's parallel aggregation mechanisms to combine test run reports form individual parallel processes
|
||||
// to form a complete final report.
|
||||
func (report Report) Add(other Report) Report {
|
||||
report.SuiteSucceeded = report.SuiteSucceeded && other.SuiteSucceeded
|
||||
|
||||
|
@ -147,6 +150,9 @@ type SpecReport struct {
|
|||
// ParallelProcess captures the parallel process that this spec ran on
|
||||
ParallelProcess int
|
||||
|
||||
// RunningInParallel captures whether this spec is part of a suite that ran in parallel
|
||||
RunningInParallel bool
|
||||
|
||||
//Failure is populated if a spec has failed, panicked, been interrupted, or skipped by the user (e.g. calling Skip())
|
||||
//It includes detailed information about the Failure
|
||||
Failure Failure
|
||||
|
@ -178,6 +184,9 @@ type SpecReport struct {
|
|||
|
||||
// AdditionalFailures contains any failures that occurred after the initial spec failure. These typically occur in cleanup nodes after the initial failure and are only emitted when running in verbose mode.
|
||||
AdditionalFailures []AdditionalFailure
|
||||
|
||||
// SpecEvents capture additional events that occur during the spec run
|
||||
SpecEvents SpecEvents
|
||||
}
|
||||
|
||||
func (report SpecReport) MarshalJSON() ([]byte, error) {
|
||||
|
@ -204,6 +213,7 @@ func (report SpecReport) MarshalJSON() ([]byte, error) {
|
|||
ReportEntries ReportEntries `json:",omitempty"`
|
||||
ProgressReports []ProgressReport `json:",omitempty"`
|
||||
AdditionalFailures []AdditionalFailure `json:",omitempty"`
|
||||
SpecEvents SpecEvents `json:",omitempty"`
|
||||
}{
|
||||
ContainerHierarchyTexts: report.ContainerHierarchyTexts,
|
||||
ContainerHierarchyLocations: report.ContainerHierarchyLocations,
|
||||
|
@ -238,6 +248,9 @@ func (report SpecReport) MarshalJSON() ([]byte, error) {
|
|||
if len(report.AdditionalFailures) > 0 {
|
||||
out.AdditionalFailures = report.AdditionalFailures
|
||||
}
|
||||
if len(report.SpecEvents) > 0 {
|
||||
out.SpecEvents = report.SpecEvents
|
||||
}
|
||||
|
||||
return json.Marshal(out)
|
||||
}
|
||||
|
@ -255,13 +268,13 @@ func (report SpecReport) CombinedOutput() string {
|
|||
return report.CapturedStdOutErr + "\n" + report.CapturedGinkgoWriterOutput
|
||||
}
|
||||
|
||||
//Failed returns true if report.State is one of the SpecStateFailureStates
|
||||
// Failed returns true if report.State is one of the SpecStateFailureStates
|
||||
// (SpecStateFailed, SpecStatePanicked, SpecStateinterrupted, SpecStateAborted)
|
||||
func (report SpecReport) Failed() bool {
|
||||
return report.State.Is(SpecStateFailureStates)
|
||||
}
|
||||
|
||||
//FullText returns a concatenation of all the report.ContainerHierarchyTexts and report.LeafNodeText
|
||||
// FullText returns a concatenation of all the report.ContainerHierarchyTexts and report.LeafNodeText
|
||||
func (report SpecReport) FullText() string {
|
||||
texts := []string{}
|
||||
texts = append(texts, report.ContainerHierarchyTexts...)
|
||||
|
@ -271,7 +284,7 @@ func (report SpecReport) FullText() string {
|
|||
return strings.Join(texts, " ")
|
||||
}
|
||||
|
||||
//Labels returns a deduped set of all the spec's Labels.
|
||||
// Labels returns a deduped set of all the spec's Labels.
|
||||
func (report SpecReport) Labels() []string {
|
||||
out := []string{}
|
||||
seen := map[string]bool{}
|
||||
|
@ -293,7 +306,7 @@ func (report SpecReport) Labels() []string {
|
|||
return out
|
||||
}
|
||||
|
||||
//MatchesLabelFilter returns true if the spec satisfies the passed in label filter query
|
||||
// MatchesLabelFilter returns true if the spec satisfies the passed in label filter query
|
||||
func (report SpecReport) MatchesLabelFilter(query string) (bool, error) {
|
||||
filter, err := ParseLabelFilter(query)
|
||||
if err != nil {
|
||||
|
@ -302,29 +315,54 @@ func (report SpecReport) MatchesLabelFilter(query string) (bool, error) {
|
|||
return filter(report.Labels()), nil
|
||||
}
|
||||
|
||||
//FileName() returns the name of the file containing the spec
|
||||
// FileName() returns the name of the file containing the spec
|
||||
func (report SpecReport) FileName() string {
|
||||
return report.LeafNodeLocation.FileName
|
||||
}
|
||||
|
||||
//LineNumber() returns the line number of the leaf node
|
||||
// LineNumber() returns the line number of the leaf node
|
||||
func (report SpecReport) LineNumber() int {
|
||||
return report.LeafNodeLocation.LineNumber
|
||||
}
|
||||
|
||||
//FailureMessage() returns the failure message (or empty string if the test hasn't failed)
|
||||
// FailureMessage() returns the failure message (or empty string if the test hasn't failed)
|
||||
func (report SpecReport) FailureMessage() string {
|
||||
return report.Failure.Message
|
||||
}
|
||||
|
||||
//FailureLocation() returns the location of the failure (or an empty CodeLocation if the test hasn't failed)
|
||||
// FailureLocation() returns the location of the failure (or an empty CodeLocation if the test hasn't failed)
|
||||
func (report SpecReport) FailureLocation() CodeLocation {
|
||||
return report.Failure.Location
|
||||
}
|
||||
|
||||
// Timeline() returns a timeline view of the report
|
||||
func (report SpecReport) Timeline() Timeline {
|
||||
timeline := Timeline{}
|
||||
if !report.Failure.IsZero() {
|
||||
timeline = append(timeline, report.Failure)
|
||||
if report.Failure.AdditionalFailure != nil {
|
||||
timeline = append(timeline, *(report.Failure.AdditionalFailure))
|
||||
}
|
||||
}
|
||||
for _, additionalFailure := range report.AdditionalFailures {
|
||||
timeline = append(timeline, additionalFailure)
|
||||
}
|
||||
for _, reportEntry := range report.ReportEntries {
|
||||
timeline = append(timeline, reportEntry)
|
||||
}
|
||||
for _, progressReport := range report.ProgressReports {
|
||||
timeline = append(timeline, progressReport)
|
||||
}
|
||||
for _, specEvent := range report.SpecEvents {
|
||||
timeline = append(timeline, specEvent)
|
||||
}
|
||||
sort.Sort(timeline)
|
||||
return timeline
|
||||
}
|
||||
|
||||
type SpecReports []SpecReport
|
||||
|
||||
//WithLeafNodeType returns the subset of SpecReports with LeafNodeType matching one of the requested NodeTypes
|
||||
// WithLeafNodeType returns the subset of SpecReports with LeafNodeType matching one of the requested NodeTypes
|
||||
func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports {
|
||||
count := 0
|
||||
for i := range reports {
|
||||
|
@ -344,7 +382,7 @@ func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports {
|
|||
return out
|
||||
}
|
||||
|
||||
//WithState returns the subset of SpecReports with State matching one of the requested SpecStates
|
||||
// WithState returns the subset of SpecReports with State matching one of the requested SpecStates
|
||||
func (reports SpecReports) WithState(states SpecState) SpecReports {
|
||||
count := 0
|
||||
for i := range reports {
|
||||
|
@ -363,7 +401,7 @@ func (reports SpecReports) WithState(states SpecState) SpecReports {
|
|||
return out
|
||||
}
|
||||
|
||||
//CountWithState returns the number of SpecReports with State matching one of the requested SpecStates
|
||||
// CountWithState returns the number of SpecReports with State matching one of the requested SpecStates
|
||||
func (reports SpecReports) CountWithState(states SpecState) int {
|
||||
n := 0
|
||||
for i := range reports {
|
||||
|
@ -374,7 +412,7 @@ func (reports SpecReports) CountWithState(states SpecState) int {
|
|||
return n
|
||||
}
|
||||
|
||||
//If the Spec passes, CountOfFlakedSpecs returns the number of SpecReports that failed after multiple attempts.
|
||||
// If the Spec passes, CountOfFlakedSpecs returns the number of SpecReports that failed after multiple attempts.
|
||||
func (reports SpecReports) CountOfFlakedSpecs() int {
|
||||
n := 0
|
||||
for i := range reports {
|
||||
|
@ -385,7 +423,7 @@ func (reports SpecReports) CountOfFlakedSpecs() int {
|
|||
return n
|
||||
}
|
||||
|
||||
//If the Spec fails, CountOfRepeatedSpecs returns the number of SpecReports that passed after multiple attempts
|
||||
// If the Spec fails, CountOfRepeatedSpecs returns the number of SpecReports that passed after multiple attempts
|
||||
func (reports SpecReports) CountOfRepeatedSpecs() int {
|
||||
n := 0
|
||||
for i := range reports {
|
||||
|
@ -396,6 +434,53 @@ func (reports SpecReports) CountOfRepeatedSpecs() int {
|
|||
return n
|
||||
}
|
||||
|
||||
// TimelineLocation captures the location of an event in the spec's timeline
|
||||
type TimelineLocation struct {
|
||||
//Offset is the offset (in bytes) of the event relative to the GinkgoWriter stream
|
||||
Offset int `json:",omitempty"`
|
||||
|
||||
//Order is the order of the event with respect to other events. The absolute value of Order
|
||||
//is irrelevant. All that matters is that an event with a lower Order occurs before ane vent with a higher Order
|
||||
Order int `json:",omitempty"`
|
||||
|
||||
Time time.Time
|
||||
}
|
||||
|
||||
// TimelineEvent represent an event on the timeline
|
||||
// consumers of Timeline will need to check the concrete type of each entry to determine how to handle it
|
||||
type TimelineEvent interface {
|
||||
GetTimelineLocation() TimelineLocation
|
||||
}
|
||||
|
||||
type Timeline []TimelineEvent
|
||||
|
||||
func (t Timeline) Len() int { return len(t) }
|
||||
func (t Timeline) Less(i, j int) bool {
|
||||
return t[i].GetTimelineLocation().Order < t[j].GetTimelineLocation().Order
|
||||
}
|
||||
func (t Timeline) Swap(i, j int) { t[i], t[j] = t[j], t[i] }
|
||||
func (t Timeline) WithoutHiddenReportEntries() Timeline {
|
||||
out := Timeline{}
|
||||
for _, event := range t {
|
||||
if reportEntry, isReportEntry := event.(ReportEntry); isReportEntry && reportEntry.Visibility == ReportEntryVisibilityNever {
|
||||
continue
|
||||
}
|
||||
out = append(out, event)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (t Timeline) WithoutVeryVerboseSpecEvents() Timeline {
|
||||
out := Timeline{}
|
||||
for _, event := range t {
|
||||
if specEvent, isSpecEvent := event.(SpecEvent); isSpecEvent && specEvent.IsOnlyVisibleAtVeryVerbose() {
|
||||
continue
|
||||
}
|
||||
out = append(out, event)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Failure captures failure information for an individual test
|
||||
type Failure struct {
|
||||
// Message - the failure message passed into Fail(...). When using a matcher library
|
||||
|
@ -408,6 +493,8 @@ type Failure struct {
|
|||
// This CodeLocation will include a fully-populated StackTrace
|
||||
Location CodeLocation
|
||||
|
||||
TimelineLocation TimelineLocation
|
||||
|
||||
// ForwardedPanic - if the failure represents a captured panic (i.e. Summary.State == SpecStatePanicked)
|
||||
// then ForwardedPanic will be populated with a string representation of the captured panic.
|
||||
ForwardedPanic string `json:",omitempty"`
|
||||
|
@ -420,19 +507,29 @@ type Failure struct {
|
|||
// FailureNodeType will contain the NodeType of the node in which the failure occurred.
|
||||
// FailureNodeLocation will contain the CodeLocation of the node in which the failure occurred.
|
||||
// If populated, FailureNodeContainerIndex will be the index into SpecReport.ContainerHierarchyTexts and SpecReport.ContainerHierarchyLocations that represents the parent container of the node in which the failure occurred.
|
||||
FailureNodeContext FailureNodeContext
|
||||
FailureNodeType NodeType
|
||||
FailureNodeLocation CodeLocation
|
||||
FailureNodeContainerIndex int
|
||||
FailureNodeContext FailureNodeContext `json:",omitempty"`
|
||||
|
||||
FailureNodeType NodeType `json:",omitempty"`
|
||||
|
||||
FailureNodeLocation CodeLocation `json:",omitempty"`
|
||||
|
||||
FailureNodeContainerIndex int `json:",omitempty"`
|
||||
|
||||
//ProgressReport is populated if the spec was interrupted or timed out
|
||||
ProgressReport ProgressReport
|
||||
ProgressReport ProgressReport `json:",omitempty"`
|
||||
|
||||
//AdditionalFailure is non-nil if a follow-on failure occurred within the same node after the primary failure. This only happens when a node has timed out or been interrupted. In such cases the AdditionalFailure can include information about where/why the spec was stuck.
|
||||
AdditionalFailure *AdditionalFailure `json:",omitempty"`
|
||||
}
|
||||
|
||||
func (f Failure) IsZero() bool {
|
||||
return f.Message == "" && (f.Location == CodeLocation{})
|
||||
}
|
||||
|
||||
func (f Failure) GetTimelineLocation() TimelineLocation {
|
||||
return f.TimelineLocation
|
||||
}
|
||||
|
||||
// FailureNodeContext captures the location context for the node containing the failing line of code
|
||||
type FailureNodeContext uint
|
||||
|
||||
|
@ -471,6 +568,10 @@ type AdditionalFailure struct {
|
|||
Failure Failure
|
||||
}
|
||||
|
||||
func (f AdditionalFailure) GetTimelineLocation() TimelineLocation {
|
||||
return f.Failure.TimelineLocation
|
||||
}
|
||||
|
||||
// SpecState captures the state of a spec
|
||||
// To determine if a given `state` represents a failure state, use `state.Is(SpecStateFailureStates)`
|
||||
type SpecState uint
|
||||
|
@ -503,6 +604,9 @@ var ssEnumSupport = NewEnumSupport(map[uint]string{
|
|||
func (ss SpecState) String() string {
|
||||
return ssEnumSupport.String(uint(ss))
|
||||
}
|
||||
func (ss SpecState) GomegaString() string {
|
||||
return ssEnumSupport.String(uint(ss))
|
||||
}
|
||||
func (ss *SpecState) UnmarshalJSON(b []byte) error {
|
||||
out, err := ssEnumSupport.UnmarshJSON(b)
|
||||
*ss = SpecState(out)
|
||||
|
@ -520,38 +624,40 @@ func (ss SpecState) Is(states SpecState) bool {
|
|||
|
||||
// ProgressReport captures the progress of the current spec. It is, effectively, a structured Ginkgo-aware stack trace
|
||||
type ProgressReport struct {
|
||||
Message string
|
||||
ParallelProcess int
|
||||
RunningInParallel bool
|
||||
Message string `json:",omitempty"`
|
||||
ParallelProcess int `json:",omitempty"`
|
||||
RunningInParallel bool `json:",omitempty"`
|
||||
|
||||
Time time.Time
|
||||
ContainerHierarchyTexts []string `json:",omitempty"`
|
||||
LeafNodeText string `json:",omitempty"`
|
||||
LeafNodeLocation CodeLocation `json:",omitempty"`
|
||||
SpecStartTime time.Time `json:",omitempty"`
|
||||
|
||||
ContainerHierarchyTexts []string
|
||||
LeafNodeText string
|
||||
LeafNodeLocation CodeLocation
|
||||
SpecStartTime time.Time
|
||||
CurrentNodeType NodeType `json:",omitempty"`
|
||||
CurrentNodeText string `json:",omitempty"`
|
||||
CurrentNodeLocation CodeLocation `json:",omitempty"`
|
||||
CurrentNodeStartTime time.Time `json:",omitempty"`
|
||||
|
||||
CurrentNodeType NodeType
|
||||
CurrentNodeText string
|
||||
CurrentNodeLocation CodeLocation
|
||||
CurrentNodeStartTime time.Time
|
||||
CurrentStepText string `json:",omitempty"`
|
||||
CurrentStepLocation CodeLocation `json:",omitempty"`
|
||||
CurrentStepStartTime time.Time `json:",omitempty"`
|
||||
|
||||
CurrentStepText string
|
||||
CurrentStepLocation CodeLocation
|
||||
CurrentStepStartTime time.Time
|
||||
|
||||
AdditionalReports []string
|
||||
AdditionalReports []string `json:",omitempty"`
|
||||
|
||||
CapturedGinkgoWriterOutput string `json:",omitempty"`
|
||||
GinkgoWriterOffset int
|
||||
TimelineLocation TimelineLocation `json:",omitempty"`
|
||||
|
||||
Goroutines []Goroutine
|
||||
Goroutines []Goroutine `json:",omitempty"`
|
||||
}
|
||||
|
||||
func (pr ProgressReport) IsZero() bool {
|
||||
return pr.CurrentNodeType == NodeTypeInvalid
|
||||
}
|
||||
|
||||
func (pr ProgressReport) Time() time.Time {
|
||||
return pr.TimelineLocation.Time
|
||||
}
|
||||
|
||||
func (pr ProgressReport) SpecGoroutine() Goroutine {
|
||||
for _, goroutine := range pr.Goroutines {
|
||||
if goroutine.IsSpecGoroutine {
|
||||
|
@ -589,6 +695,22 @@ func (pr ProgressReport) WithoutCapturedGinkgoWriterOutput() ProgressReport {
|
|||
return out
|
||||
}
|
||||
|
||||
func (pr ProgressReport) WithoutOtherGoroutines() ProgressReport {
|
||||
out := pr
|
||||
filteredGoroutines := []Goroutine{}
|
||||
for _, goroutine := range pr.Goroutines {
|
||||
if goroutine.IsSpecGoroutine || goroutine.HasHighlights() {
|
||||
filteredGoroutines = append(filteredGoroutines, goroutine)
|
||||
}
|
||||
}
|
||||
out.Goroutines = filteredGoroutines
|
||||
return out
|
||||
}
|
||||
|
||||
func (pr ProgressReport) GetTimelineLocation() TimelineLocation {
|
||||
return pr.TimelineLocation
|
||||
}
|
||||
|
||||
type Goroutine struct {
|
||||
ID uint64
|
||||
State string
|
||||
|
@ -643,6 +765,7 @@ const (
|
|||
|
||||
NodeTypeReportBeforeEach
|
||||
NodeTypeReportAfterEach
|
||||
NodeTypeReportBeforeSuite
|
||||
NodeTypeReportAfterSuite
|
||||
|
||||
NodeTypeCleanupInvalid
|
||||
|
@ -652,9 +775,9 @@ const (
|
|||
)
|
||||
|
||||
var NodeTypesForContainerAndIt = NodeTypeContainer | NodeTypeIt
|
||||
var NodeTypesForSuiteLevelNodes = NodeTypeBeforeSuite | NodeTypeSynchronizedBeforeSuite | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeReportAfterSuite | NodeTypeCleanupAfterSuite
|
||||
var NodeTypesForSuiteLevelNodes = NodeTypeBeforeSuite | NodeTypeSynchronizedBeforeSuite | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite | NodeTypeCleanupAfterSuite
|
||||
var NodeTypesAllowedDuringCleanupInterrupt = NodeTypeAfterEach | NodeTypeJustAfterEach | NodeTypeAfterAll | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeCleanupAfterEach | NodeTypeCleanupAfterAll | NodeTypeCleanupAfterSuite
|
||||
var NodeTypesAllowedDuringReportInterrupt = NodeTypeReportBeforeEach | NodeTypeReportAfterEach | NodeTypeReportAfterSuite
|
||||
var NodeTypesAllowedDuringReportInterrupt = NodeTypeReportBeforeEach | NodeTypeReportAfterEach | NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite
|
||||
|
||||
var ntEnumSupport = NewEnumSupport(map[uint]string{
|
||||
uint(NodeTypeInvalid): "INVALID NODE TYPE",
|
||||
|
@ -672,6 +795,7 @@ var ntEnumSupport = NewEnumSupport(map[uint]string{
|
|||
uint(NodeTypeSynchronizedAfterSuite): "SynchronizedAfterSuite",
|
||||
uint(NodeTypeReportBeforeEach): "ReportBeforeEach",
|
||||
uint(NodeTypeReportAfterEach): "ReportAfterEach",
|
||||
uint(NodeTypeReportBeforeSuite): "ReportBeforeSuite",
|
||||
uint(NodeTypeReportAfterSuite): "ReportAfterSuite",
|
||||
uint(NodeTypeCleanupInvalid): "DeferCleanup",
|
||||
uint(NodeTypeCleanupAfterEach): "DeferCleanup (Each)",
|
||||
|
@ -694,3 +818,99 @@ func (nt NodeType) MarshalJSON() ([]byte, error) {
|
|||
func (nt NodeType) Is(nodeTypes NodeType) bool {
|
||||
return nt&nodeTypes != 0
|
||||
}
|
||||
|
||||
/*
|
||||
SpecEvent captures a vareity of events that can occur when specs run. See SpecEventType for the list of available events.
|
||||
*/
|
||||
type SpecEvent struct {
|
||||
SpecEventType SpecEventType
|
||||
|
||||
CodeLocation CodeLocation
|
||||
TimelineLocation TimelineLocation
|
||||
|
||||
Message string `json:",omitempty"`
|
||||
Duration time.Duration `json:",omitempty"`
|
||||
NodeType NodeType `json:",omitempty"`
|
||||
Attempt int `json:",omitempty"`
|
||||
}
|
||||
|
||||
func (se SpecEvent) GetTimelineLocation() TimelineLocation {
|
||||
return se.TimelineLocation
|
||||
}
|
||||
|
||||
func (se SpecEvent) IsOnlyVisibleAtVeryVerbose() bool {
|
||||
return se.SpecEventType.Is(SpecEventByEnd | SpecEventNodeStart | SpecEventNodeEnd)
|
||||
}
|
||||
|
||||
func (se SpecEvent) GomegaString() string {
|
||||
out := &strings.Builder{}
|
||||
out.WriteString("[" + se.SpecEventType.String() + " SpecEvent] ")
|
||||
if se.Message != "" {
|
||||
out.WriteString("Message=")
|
||||
out.WriteString(`"` + se.Message + `",`)
|
||||
}
|
||||
if se.Duration != 0 {
|
||||
out.WriteString("Duration=" + se.Duration.String() + ",")
|
||||
}
|
||||
if se.NodeType != NodeTypeInvalid {
|
||||
out.WriteString("NodeType=" + se.NodeType.String() + ",")
|
||||
}
|
||||
if se.Attempt != 0 {
|
||||
out.WriteString(fmt.Sprintf("Attempt=%d", se.Attempt) + ",")
|
||||
}
|
||||
out.WriteString("CL=" + se.CodeLocation.String() + ",")
|
||||
out.WriteString(fmt.Sprintf("TL.Offset=%d", se.TimelineLocation.Offset))
|
||||
|
||||
return out.String()
|
||||
}
|
||||
|
||||
type SpecEvents []SpecEvent
|
||||
|
||||
func (se SpecEvents) WithType(seType SpecEventType) SpecEvents {
|
||||
out := SpecEvents{}
|
||||
for _, event := range se {
|
||||
if event.SpecEventType.Is(seType) {
|
||||
out = append(out, event)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type SpecEventType uint
|
||||
|
||||
const (
|
||||
SpecEventInvalid SpecEventType = 0
|
||||
|
||||
SpecEventByStart SpecEventType = 1 << iota
|
||||
SpecEventByEnd
|
||||
SpecEventNodeStart
|
||||
SpecEventNodeEnd
|
||||
SpecEventSpecRepeat
|
||||
SpecEventSpecRetry
|
||||
)
|
||||
|
||||
var seEnumSupport = NewEnumSupport(map[uint]string{
|
||||
uint(SpecEventInvalid): "INVALID SPEC EVENT",
|
||||
uint(SpecEventByStart): "By",
|
||||
uint(SpecEventByEnd): "By (End)",
|
||||
uint(SpecEventNodeStart): "Node",
|
||||
uint(SpecEventNodeEnd): "Node (End)",
|
||||
uint(SpecEventSpecRepeat): "Repeat",
|
||||
uint(SpecEventSpecRetry): "Retry",
|
||||
})
|
||||
|
||||
func (se SpecEventType) String() string {
|
||||
return seEnumSupport.String(uint(se))
|
||||
}
|
||||
func (se *SpecEventType) UnmarshalJSON(b []byte) error {
|
||||
out, err := seEnumSupport.UnmarshJSON(b)
|
||||
*se = SpecEventType(out)
|
||||
return err
|
||||
}
|
||||
func (se SpecEventType) MarshalJSON() ([]byte, error) {
|
||||
return seEnumSupport.MarshJSON(uint(se))
|
||||
}
|
||||
|
||||
func (se SpecEventType) Is(specEventTypes SpecEventType) bool {
|
||||
return se&specEventTypes != 0
|
||||
}
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
package types
|
||||
|
||||
const VERSION = "2.4.0"
|
||||
const VERSION = "2.9.5"
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -1,6 +0,0 @@
|
|||
# qtls
|
||||
|
||||
[![Go Reference](https://pkg.go.dev/badge/github.com/quic-go/qtls-go1-19.svg)](https://pkg.go.dev/github.com/quic-go/qtls-go1-19)
|
||||
[![.github/workflows/go-test.yml](https://github.com/quic-go/qtls-go1-19/actions/workflows/go-test.yml/badge.svg)](https://github.com/quic-go/qtls-go1-19/actions/workflows/go-test.yml)
|
||||
|
||||
This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/lucas-clemente/quic-go).
|
|
@ -1,102 +0,0 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package qtls
|
||||
|
||||
import "strconv"
|
||||
|
||||
type alert uint8
|
||||
|
||||
// Alert is a TLS alert
|
||||
type Alert = alert
|
||||
|
||||
const (
|
||||
// alert level
|
||||
alertLevelWarning = 1
|
||||
alertLevelError = 2
|
||||
)
|
||||
|
||||
const (
|
||||
alertCloseNotify alert = 0
|
||||
alertUnexpectedMessage alert = 10
|
||||
alertBadRecordMAC alert = 20
|
||||
alertDecryptionFailed alert = 21
|
||||
alertRecordOverflow alert = 22
|
||||
alertDecompressionFailure alert = 30
|
||||
alertHandshakeFailure alert = 40
|
||||
alertBadCertificate alert = 42
|
||||
alertUnsupportedCertificate alert = 43
|
||||
alertCertificateRevoked alert = 44
|
||||
alertCertificateExpired alert = 45
|
||||
alertCertificateUnknown alert = 46
|
||||
alertIllegalParameter alert = 47
|
||||
alertUnknownCA alert = 48
|
||||
alertAccessDenied alert = 49
|
||||
alertDecodeError alert = 50
|
||||
alertDecryptError alert = 51
|
||||
alertExportRestriction alert = 60
|
||||
alertProtocolVersion alert = 70
|
||||
alertInsufficientSecurity alert = 71
|
||||
alertInternalError alert = 80
|
||||
alertInappropriateFallback alert = 86
|
||||
alertUserCanceled alert = 90
|
||||
alertNoRenegotiation alert = 100
|
||||
alertMissingExtension alert = 109
|
||||
alertUnsupportedExtension alert = 110
|
||||
alertCertificateUnobtainable alert = 111
|
||||
alertUnrecognizedName alert = 112
|
||||
alertBadCertificateStatusResponse alert = 113
|
||||
alertBadCertificateHashValue alert = 114
|
||||
alertUnknownPSKIdentity alert = 115
|
||||
alertCertificateRequired alert = 116
|
||||
alertNoApplicationProtocol alert = 120
|
||||
)
|
||||
|
||||
var alertText = map[alert]string{
|
||||
alertCloseNotify: "close notify",
|
||||
alertUnexpectedMessage: "unexpected message",
|
||||
alertBadRecordMAC: "bad record MAC",
|
||||
alertDecryptionFailed: "decryption failed",
|
||||
alertRecordOverflow: "record overflow",
|
||||
alertDecompressionFailure: "decompression failure",
|
||||
alertHandshakeFailure: "handshake failure",
|
||||
alertBadCertificate: "bad certificate",
|
||||
alertUnsupportedCertificate: "unsupported certificate",
|
||||
alertCertificateRevoked: "revoked certificate",
|
||||
alertCertificateExpired: "expired certificate",
|
||||
alertCertificateUnknown: "unknown certificate",
|
||||
alertIllegalParameter: "illegal parameter",
|
||||
alertUnknownCA: "unknown certificate authority",
|
||||
alertAccessDenied: "access denied",
|
||||
alertDecodeError: "error decoding message",
|
||||
alertDecryptError: "error decrypting message",
|
||||
alertExportRestriction: "export restriction",
|
||||
alertProtocolVersion: "protocol version not supported",
|
||||
alertInsufficientSecurity: "insufficient security level",
|
||||
alertInternalError: "internal error",
|
||||
alertInappropriateFallback: "inappropriate fallback",
|
||||
alertUserCanceled: "user canceled",
|
||||
alertNoRenegotiation: "no renegotiation",
|
||||
alertMissingExtension: "missing extension",
|
||||
alertUnsupportedExtension: "unsupported extension",
|
||||
alertCertificateUnobtainable: "certificate unobtainable",
|
||||
alertUnrecognizedName: "unrecognized name",
|
||||
alertBadCertificateStatusResponse: "bad certificate status response",
|
||||
alertBadCertificateHashValue: "bad certificate hash value",
|
||||
alertUnknownPSKIdentity: "unknown PSK identity",
|
||||
alertCertificateRequired: "certificate required",
|
||||
alertNoApplicationProtocol: "no application protocol",
|
||||
}
|
||||
|
||||
func (e alert) String() string {
|
||||
s, ok := alertText[e]
|
||||
if ok {
|
||||
return "tls: " + s
|
||||
}
|
||||
return "tls: alert(" + strconv.Itoa(int(e)) + ")"
|
||||
}
|
||||
|
||||
func (e alert) Error() string {
|
||||
return e.String()
|
||||
}
|
|
@ -1,293 +0,0 @@
|
|||
// Copyright 2017 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
)
|
||||
|
||||
// verifyHandshakeSignature verifies a signature against pre-hashed
|
||||
// (if required) handshake contents.
|
||||
func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error {
|
||||
switch sigType {
|
||||
case signatureECDSA:
|
||||
pubKey, ok := pubkey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected an ECDSA public key, got %T", pubkey)
|
||||
}
|
||||
if !ecdsa.VerifyASN1(pubKey, signed, sig) {
|
||||
return errors.New("ECDSA verification failure")
|
||||
}
|
||||
case signatureEd25519:
|
||||
pubKey, ok := pubkey.(ed25519.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey)
|
||||
}
|
||||
if !ed25519.Verify(pubKey, signed, sig) {
|
||||
return errors.New("Ed25519 verification failure")
|
||||
}
|
||||
case signaturePKCS1v15:
|
||||
pubKey, ok := pubkey.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
|
||||
}
|
||||
if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil {
|
||||
return err
|
||||
}
|
||||
case signatureRSAPSS:
|
||||
pubKey, ok := pubkey.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
|
||||
}
|
||||
signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}
|
||||
if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.New("internal error: unknown signature type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
serverSignatureContext = "TLS 1.3, server CertificateVerify\x00"
|
||||
clientSignatureContext = "TLS 1.3, client CertificateVerify\x00"
|
||||
)
|
||||
|
||||
var signaturePadding = []byte{
|
||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
||||
}
|
||||
|
||||
// signedMessage returns the pre-hashed (if necessary) message to be signed by
|
||||
// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3.
|
||||
func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte {
|
||||
if sigHash == directSigning {
|
||||
b := &bytes.Buffer{}
|
||||
b.Write(signaturePadding)
|
||||
io.WriteString(b, context)
|
||||
b.Write(transcript.Sum(nil))
|
||||
return b.Bytes()
|
||||
}
|
||||
h := sigHash.New()
|
||||
h.Write(signaturePadding)
|
||||
io.WriteString(h, context)
|
||||
h.Write(transcript.Sum(nil))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// typeAndHashFromSignatureScheme returns the corresponding signature type and
|
||||
// crypto.Hash for a given TLS SignatureScheme.
|
||||
func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) {
|
||||
switch signatureAlgorithm {
|
||||
case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512:
|
||||
sigType = signaturePKCS1v15
|
||||
case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512:
|
||||
sigType = signatureRSAPSS
|
||||
case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512:
|
||||
sigType = signatureECDSA
|
||||
case Ed25519:
|
||||
sigType = signatureEd25519
|
||||
default:
|
||||
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
|
||||
}
|
||||
switch signatureAlgorithm {
|
||||
case PKCS1WithSHA1, ECDSAWithSHA1:
|
||||
hash = crypto.SHA1
|
||||
case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256:
|
||||
hash = crypto.SHA256
|
||||
case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384:
|
||||
hash = crypto.SHA384
|
||||
case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512:
|
||||
hash = crypto.SHA512
|
||||
case Ed25519:
|
||||
hash = directSigning
|
||||
default:
|
||||
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
|
||||
}
|
||||
return sigType, hash, nil
|
||||
}
|
||||
|
||||
// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for
|
||||
// a given public key used with TLS 1.0 and 1.1, before the introduction of
|
||||
// signature algorithm negotiation.
|
||||
func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) {
|
||||
switch pub.(type) {
|
||||
case *rsa.PublicKey:
|
||||
return signaturePKCS1v15, crypto.MD5SHA1, nil
|
||||
case *ecdsa.PublicKey:
|
||||
return signatureECDSA, crypto.SHA1, nil
|
||||
case ed25519.PublicKey:
|
||||
// RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1,
|
||||
// but it requires holding on to a handshake transcript to do a
|
||||
// full signature, and not even OpenSSL bothers with the
|
||||
// complexity, so we can't even test it properly.
|
||||
return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2")
|
||||
default:
|
||||
return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub)
|
||||
}
|
||||
}
|
||||
|
||||
var rsaSignatureSchemes = []struct {
|
||||
scheme SignatureScheme
|
||||
minModulusBytes int
|
||||
maxVersion uint16
|
||||
}{
|
||||
// RSA-PSS is used with PSSSaltLengthEqualsHash, and requires
|
||||
// emLen >= hLen + sLen + 2
|
||||
{PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13},
|
||||
{PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13},
|
||||
{PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13},
|
||||
// PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires
|
||||
// emLen >= len(prefix) + hLen + 11
|
||||
// TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS.
|
||||
{PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12},
|
||||
{PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12},
|
||||
{PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12},
|
||||
{PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12},
|
||||
}
|
||||
|
||||
// signatureSchemesForCertificate returns the list of supported SignatureSchemes
|
||||
// for a given certificate, based on the public key and the protocol version,
|
||||
// and optionally filtered by its explicit SupportedSignatureAlgorithms.
|
||||
//
|
||||
// This function must be kept in sync with supportedSignatureAlgorithms.
|
||||
// FIPS filtering is applied in the caller, selectSignatureScheme.
|
||||
func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme {
|
||||
priv, ok := cert.PrivateKey.(crypto.Signer)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var sigAlgs []SignatureScheme
|
||||
switch pub := priv.Public().(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
if version != VersionTLS13 {
|
||||
// In TLS 1.2 and earlier, ECDSA algorithms are not
|
||||
// constrained to a single curve.
|
||||
sigAlgs = []SignatureScheme{
|
||||
ECDSAWithP256AndSHA256,
|
||||
ECDSAWithP384AndSHA384,
|
||||
ECDSAWithP521AndSHA512,
|
||||
ECDSAWithSHA1,
|
||||
}
|
||||
break
|
||||
}
|
||||
switch pub.Curve {
|
||||
case elliptic.P256():
|
||||
sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256}
|
||||
case elliptic.P384():
|
||||
sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384}
|
||||
case elliptic.P521():
|
||||
sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
case *rsa.PublicKey:
|
||||
size := pub.Size()
|
||||
sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes))
|
||||
for _, candidate := range rsaSignatureSchemes {
|
||||
if size >= candidate.minModulusBytes && version <= candidate.maxVersion {
|
||||
sigAlgs = append(sigAlgs, candidate.scheme)
|
||||
}
|
||||
}
|
||||
case ed25519.PublicKey:
|
||||
sigAlgs = []SignatureScheme{Ed25519}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
if cert.SupportedSignatureAlgorithms != nil {
|
||||
var filteredSigAlgs []SignatureScheme
|
||||
for _, sigAlg := range sigAlgs {
|
||||
if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) {
|
||||
filteredSigAlgs = append(filteredSigAlgs, sigAlg)
|
||||
}
|
||||
}
|
||||
return filteredSigAlgs
|
||||
}
|
||||
return sigAlgs
|
||||
}
|
||||
|
||||
// selectSignatureScheme picks a SignatureScheme from the peer's preference list
|
||||
// that works with the selected certificate. It's only called for protocol
|
||||
// versions that support signature algorithms, so TLS 1.2 and 1.3.
|
||||
func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) {
|
||||
supportedAlgs := signatureSchemesForCertificate(vers, c)
|
||||
if len(supportedAlgs) == 0 {
|
||||
return 0, unsupportedCertificateError(c)
|
||||
}
|
||||
if len(peerAlgs) == 0 && vers == VersionTLS12 {
|
||||
// For TLS 1.2, if the client didn't send signature_algorithms then we
|
||||
// can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1.
|
||||
peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1}
|
||||
}
|
||||
// Pick signature scheme in the peer's preference order, as our
|
||||
// preference order is not configurable.
|
||||
for _, preferredAlg := range peerAlgs {
|
||||
if needFIPS() && !isSupportedSignatureAlgorithm(preferredAlg, fipsSupportedSignatureAlgorithms) {
|
||||
continue
|
||||
}
|
||||
if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) {
|
||||
return preferredAlg, nil
|
||||
}
|
||||
}
|
||||
return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms")
|
||||
}
|
||||
|
||||
// unsupportedCertificateError returns a helpful error for certificates with
|
||||
// an unsupported private key.
|
||||
func unsupportedCertificateError(cert *Certificate) error {
|
||||
switch cert.PrivateKey.(type) {
|
||||
case rsa.PrivateKey, ecdsa.PrivateKey:
|
||||
return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T",
|
||||
cert.PrivateKey, cert.PrivateKey)
|
||||
case *ed25519.PrivateKey:
|
||||
return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey")
|
||||
}
|
||||
|
||||
signer, ok := cert.PrivateKey.(crypto.Signer)
|
||||
if !ok {
|
||||
return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer",
|
||||
cert.PrivateKey)
|
||||
}
|
||||
|
||||
switch pub := signer.Public().(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
switch pub.Curve {
|
||||
case elliptic.P256():
|
||||
case elliptic.P384():
|
||||
case elliptic.P521():
|
||||
default:
|
||||
return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name)
|
||||
}
|
||||
case *rsa.PublicKey:
|
||||
return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms")
|
||||
case ed25519.PublicKey:
|
||||
default:
|
||||
return fmt.Errorf("tls: unsupported certificate key (%T)", pub)
|
||||
}
|
||||
|
||||
if cert.SupportedSignatureAlgorithms != nil {
|
||||
return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms")
|
||||
}
|
||||
|
||||
return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey)
|
||||
}
|
|
@ -1,693 +0,0 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/des"
|
||||
"crypto/hmac"
|
||||
"crypto/rc4"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"hash"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// CipherSuite is a TLS cipher suite. Note that most functions in this package
|
||||
// accept and expose cipher suite IDs instead of this type.
|
||||
type CipherSuite struct {
|
||||
ID uint16
|
||||
Name string
|
||||
|
||||
// Supported versions is the list of TLS protocol versions that can
|
||||
// negotiate this cipher suite.
|
||||
SupportedVersions []uint16
|
||||
|
||||
// Insecure is true if the cipher suite has known security issues
|
||||
// due to its primitives, design, or implementation.
|
||||
Insecure bool
|
||||
}
|
||||
|
||||
var (
|
||||
supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12}
|
||||
supportedOnlyTLS12 = []uint16{VersionTLS12}
|
||||
supportedOnlyTLS13 = []uint16{VersionTLS13}
|
||||
)
|
||||
|
||||
// CipherSuites returns a list of cipher suites currently implemented by this
|
||||
// package, excluding those with security issues, which are returned by
|
||||
// InsecureCipherSuites.
|
||||
//
|
||||
// The list is sorted by ID. Note that the default cipher suites selected by
|
||||
// this package might depend on logic that can't be captured by a static list,
|
||||
// and might not match those returned by this function.
|
||||
func CipherSuites() []*CipherSuite {
|
||||
return []*CipherSuite{
|
||||
{TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
|
||||
{TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
|
||||
{TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
|
||||
{TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
|
||||
|
||||
{TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false},
|
||||
{TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false},
|
||||
{TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false},
|
||||
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
|
||||
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
|
||||
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
|
||||
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
|
||||
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
|
||||
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
|
||||
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
|
||||
}
|
||||
}
|
||||
|
||||
// InsecureCipherSuites returns a list of cipher suites currently implemented by
|
||||
// this package and which have security issues.
|
||||
//
|
||||
// Most applications should not use the cipher suites in this list, and should
|
||||
// only use those returned by CipherSuites.
|
||||
func InsecureCipherSuites() []*CipherSuite {
|
||||
// This list includes RC4, CBC_SHA256, and 3DES cipher suites. See
|
||||
// cipherSuitesPreferenceOrder for details.
|
||||
return []*CipherSuite{
|
||||
{TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
|
||||
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
|
||||
{TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
|
||||
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
|
||||
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
|
||||
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
|
||||
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
|
||||
}
|
||||
}
|
||||
|
||||
// CipherSuiteName returns the standard name for the passed cipher suite ID
|
||||
// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation
|
||||
// of the ID value if the cipher suite is not implemented by this package.
|
||||
func CipherSuiteName(id uint16) string {
|
||||
for _, c := range CipherSuites() {
|
||||
if c.ID == id {
|
||||
return c.Name
|
||||
}
|
||||
}
|
||||
for _, c := range InsecureCipherSuites() {
|
||||
if c.ID == id {
|
||||
return c.Name
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("0x%04X", id)
|
||||
}
|
||||
|
||||
const (
|
||||
// suiteECDHE indicates that the cipher suite involves elliptic curve
|
||||
// Diffie-Hellman. This means that it should only be selected when the
|
||||
// client indicates that it supports ECC with a curve and point format
|
||||
// that we're happy with.
|
||||
suiteECDHE = 1 << iota
|
||||
// suiteECSign indicates that the cipher suite involves an ECDSA or
|
||||
// EdDSA signature and therefore may only be selected when the server's
|
||||
// certificate is ECDSA or EdDSA. If this is not set then the cipher suite
|
||||
// is RSA based.
|
||||
suiteECSign
|
||||
// suiteTLS12 indicates that the cipher suite should only be advertised
|
||||
// and accepted when using TLS 1.2.
|
||||
suiteTLS12
|
||||
// suiteSHA384 indicates that the cipher suite uses SHA384 as the
|
||||
// handshake hash.
|
||||
suiteSHA384
|
||||
)
|
||||
|
||||
// A cipherSuite is a TLS 1.0–1.2 cipher suite, and defines the key exchange
|
||||
// mechanism, as well as the cipher+MAC pair or the AEAD.
|
||||
type cipherSuite struct {
|
||||
id uint16
|
||||
// the lengths, in bytes, of the key material needed for each component.
|
||||
keyLen int
|
||||
macLen int
|
||||
ivLen int
|
||||
ka func(version uint16) keyAgreement
|
||||
// flags is a bitmask of the suite* values, above.
|
||||
flags int
|
||||
cipher func(key, iv []byte, isRead bool) any
|
||||
mac func(key []byte) hash.Hash
|
||||
aead func(key, fixedNonce []byte) aead
|
||||
}
|
||||
|
||||
var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter.
|
||||
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
|
||||
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
|
||||
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM},
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM},
|
||||
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
|
||||
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil},
|
||||
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil},
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
|
||||
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
|
||||
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
|
||||
{TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM},
|
||||
{TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
|
||||
{TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil},
|
||||
{TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
|
||||
{TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
|
||||
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil},
|
||||
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil},
|
||||
{TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil},
|
||||
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil},
|
||||
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil},
|
||||
}
|
||||
|
||||
// selectCipherSuite returns the first TLS 1.0–1.2 cipher suite from ids which
|
||||
// is also in supportedIDs and passes the ok filter.
|
||||
func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite {
|
||||
for _, id := range ids {
|
||||
candidate := cipherSuiteByID(id)
|
||||
if candidate == nil || !ok(candidate) {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, suppID := range supportedIDs {
|
||||
if id == suppID {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash
|
||||
// algorithm to be used with HKDF. See RFC 8446, Appendix B.4.
|
||||
type cipherSuiteTLS13 struct {
|
||||
id uint16
|
||||
keyLen int
|
||||
aead func(key, fixedNonce []byte) aead
|
||||
hash crypto.Hash
|
||||
}
|
||||
|
||||
type CipherSuiteTLS13 struct {
|
||||
ID uint16
|
||||
KeyLen int
|
||||
Hash crypto.Hash
|
||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
||||
}
|
||||
|
||||
func (c *CipherSuiteTLS13) IVLen() int {
|
||||
return aeadNonceLength
|
||||
}
|
||||
|
||||
var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map.
|
||||
{TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256},
|
||||
{TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256},
|
||||
{TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384},
|
||||
}
|
||||
|
||||
// cipherSuitesPreferenceOrder is the order in which we'll select (on the
|
||||
// server) or advertise (on the client) TLS 1.0–1.2 cipher suites.
|
||||
//
|
||||
// Cipher suites are filtered but not reordered based on the application and
|
||||
// peer's preferences, meaning we'll never select a suite lower in this list if
|
||||
// any higher one is available. This makes it more defensible to keep weaker
|
||||
// cipher suites enabled, especially on the server side where we get the last
|
||||
// word, since there are no known downgrade attacks on cipher suites selection.
|
||||
//
|
||||
// The list is sorted by applying the following priority rules, stopping at the
|
||||
// first (most important) applicable one:
|
||||
//
|
||||
// - Anything else comes before RC4
|
||||
//
|
||||
// RC4 has practically exploitable biases. See https://www.rc4nomore.com.
|
||||
//
|
||||
// - Anything else comes before CBC_SHA256
|
||||
//
|
||||
// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13
|
||||
// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and
|
||||
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
|
||||
//
|
||||
// - Anything else comes before 3DES
|
||||
//
|
||||
// 3DES has 64-bit blocks, which makes it fundamentally susceptible to
|
||||
// birthday attacks. See https://sweet32.info.
|
||||
//
|
||||
// - ECDHE comes before anything else
|
||||
//
|
||||
// Once we got the broken stuff out of the way, the most important
|
||||
// property a cipher suite can have is forward secrecy. We don't
|
||||
// implement FFDHE, so that means ECDHE.
|
||||
//
|
||||
// - AEADs come before CBC ciphers
|
||||
//
|
||||
// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites
|
||||
// are fundamentally fragile, and suffered from an endless sequence of
|
||||
// padding oracle attacks. See https://eprint.iacr.org/2015/1129,
|
||||
// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and
|
||||
// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/.
|
||||
//
|
||||
// - AES comes before ChaCha20
|
||||
//
|
||||
// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster
|
||||
// than ChaCha20Poly1305.
|
||||
//
|
||||
// When AES hardware is not available, AES-128-GCM is one or more of: much
|
||||
// slower, way more complex, and less safe (because not constant time)
|
||||
// than ChaCha20Poly1305.
|
||||
//
|
||||
// We use this list if we think both peers have AES hardware, and
|
||||
// cipherSuitesPreferenceOrderNoAES otherwise.
|
||||
//
|
||||
// - AES-128 comes before AES-256
|
||||
//
|
||||
// The only potential advantages of AES-256 are better multi-target
|
||||
// margins, and hypothetical post-quantum properties. Neither apply to
|
||||
// TLS, and AES-256 is slower due to its four extra rounds (which don't
|
||||
// contribute to the advantages above).
|
||||
//
|
||||
// - ECDSA comes before RSA
|
||||
//
|
||||
// The relative order of ECDSA and RSA cipher suites doesn't matter,
|
||||
// as they depend on the certificate. Pick one to get a stable order.
|
||||
var cipherSuitesPreferenceOrder = []uint16{
|
||||
// AEADs w/ ECDHE
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
|
||||
// CBC w/ ECDHE
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
|
||||
// AEADs w/o ECDHE
|
||||
TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
|
||||
// CBC w/o ECDHE
|
||||
TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
|
||||
// 3DES
|
||||
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
|
||||
// CBC_SHA256
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||
TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||
|
||||
// RC4
|
||||
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
||||
TLS_RSA_WITH_RC4_128_SHA,
|
||||
}
|
||||
|
||||
var cipherSuitesPreferenceOrderNoAES = []uint16{
|
||||
// ChaCha20Poly1305
|
||||
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
|
||||
// AES-GCM w/ ECDHE
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
|
||||
// The rest of cipherSuitesPreferenceOrder.
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||
TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
||||
TLS_RSA_WITH_RC4_128_SHA,
|
||||
}
|
||||
|
||||
// disabledCipherSuites are not used unless explicitly listed in
|
||||
// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder.
|
||||
var disabledCipherSuites = []uint16{
|
||||
// CBC_SHA256
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
|
||||
TLS_RSA_WITH_AES_128_CBC_SHA256,
|
||||
|
||||
// RC4
|
||||
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
|
||||
TLS_RSA_WITH_RC4_128_SHA,
|
||||
}
|
||||
|
||||
var (
|
||||
defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites)
|
||||
defaultCipherSuites = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen]
|
||||
)
|
||||
|
||||
// defaultCipherSuitesTLS13 is also the preference order, since there are no
|
||||
// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as
|
||||
// cipherSuitesPreferenceOrder applies.
|
||||
var defaultCipherSuitesTLS13 = []uint16{
|
||||
TLS_AES_128_GCM_SHA256,
|
||||
TLS_AES_256_GCM_SHA384,
|
||||
TLS_CHACHA20_POLY1305_SHA256,
|
||||
}
|
||||
|
||||
var defaultCipherSuitesTLS13NoAES = []uint16{
|
||||
TLS_CHACHA20_POLY1305_SHA256,
|
||||
TLS_AES_128_GCM_SHA256,
|
||||
TLS_AES_256_GCM_SHA384,
|
||||
}
|
||||
|
||||
var aesgcmCiphers = map[uint16]bool{
|
||||
// TLS 1.2
|
||||
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true,
|
||||
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true,
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true,
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true,
|
||||
// TLS 1.3
|
||||
TLS_AES_128_GCM_SHA256: true,
|
||||
TLS_AES_256_GCM_SHA384: true,
|
||||
}
|
||||
|
||||
var nonAESGCMAEADCiphers = map[uint16]bool{
|
||||
// TLS 1.2
|
||||
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true,
|
||||
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true,
|
||||
// TLS 1.3
|
||||
TLS_CHACHA20_POLY1305_SHA256: true,
|
||||
}
|
||||
|
||||
// aesgcmPreferred returns whether the first known cipher in the preference list
|
||||
// is an AES-GCM cipher, implying the peer has hardware support for it.
|
||||
func aesgcmPreferred(ciphers []uint16) bool {
|
||||
for _, cID := range ciphers {
|
||||
if c := cipherSuiteByID(cID); c != nil {
|
||||
return aesgcmCiphers[cID]
|
||||
}
|
||||
if c := cipherSuiteTLS13ByID(cID); c != nil {
|
||||
return aesgcmCiphers[cID]
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func cipherRC4(key, iv []byte, isRead bool) any {
|
||||
cipher, _ := rc4.NewCipher(key)
|
||||
return cipher
|
||||
}
|
||||
|
||||
func cipher3DES(key, iv []byte, isRead bool) any {
|
||||
block, _ := des.NewTripleDESCipher(key)
|
||||
if isRead {
|
||||
return cipher.NewCBCDecrypter(block, iv)
|
||||
}
|
||||
return cipher.NewCBCEncrypter(block, iv)
|
||||
}
|
||||
|
||||
func cipherAES(key, iv []byte, isRead bool) any {
|
||||
block, _ := aes.NewCipher(key)
|
||||
if isRead {
|
||||
return cipher.NewCBCDecrypter(block, iv)
|
||||
}
|
||||
return cipher.NewCBCEncrypter(block, iv)
|
||||
}
|
||||
|
||||
// macSHA1 returns a SHA-1 based constant time MAC.
|
||||
func macSHA1(key []byte) hash.Hash {
|
||||
h := sha1.New
|
||||
h = newConstantTimeHash(h)
|
||||
return hmac.New(h, key)
|
||||
}
|
||||
|
||||
// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and
|
||||
// is currently only used in disabled-by-default cipher suites.
|
||||
func macSHA256(key []byte) hash.Hash {
|
||||
return hmac.New(sha256.New, key)
|
||||
}
|
||||
|
||||
type aead interface {
|
||||
cipher.AEAD
|
||||
|
||||
// explicitNonceLen returns the number of bytes of explicit nonce
|
||||
// included in each record. This is eight for older AEADs and
|
||||
// zero for modern ones.
|
||||
explicitNonceLen() int
|
||||
}
|
||||
|
||||
const (
|
||||
aeadNonceLength = 12
|
||||
noncePrefixLength = 4
|
||||
)
|
||||
|
||||
// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to
|
||||
// each call.
|
||||
type prefixNonceAEAD struct {
|
||||
// nonce contains the fixed part of the nonce in the first four bytes.
|
||||
nonce [aeadNonceLength]byte
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength }
|
||||
func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() }
|
||||
func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() }
|
||||
|
||||
func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
|
||||
copy(f.nonce[4:], nonce)
|
||||
return f.aead.Seal(out, f.nonce[:], plaintext, additionalData)
|
||||
}
|
||||
|
||||
func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
|
||||
copy(f.nonce[4:], nonce)
|
||||
return f.aead.Open(out, f.nonce[:], ciphertext, additionalData)
|
||||
}
|
||||
|
||||
// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
|
||||
// before each call.
|
||||
type xorNonceAEAD struct {
|
||||
nonceMask [aeadNonceLength]byte
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
|
||||
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
|
||||
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
|
||||
|
||||
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
|
||||
for i, b := range nonce {
|
||||
f.nonceMask[4+i] ^= b
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func aeadAESGCM(key, noncePrefix []byte) aead {
|
||||
if len(noncePrefix) != noncePrefixLength {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
}
|
||||
aes, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var aead cipher.AEAD
|
||||
aead, err = cipher.NewGCM(aes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ret := &prefixNonceAEAD{aead: aead}
|
||||
copy(ret.nonce[:], noncePrefix)
|
||||
return ret
|
||||
}
|
||||
|
||||
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
|
||||
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
|
||||
return aeadAESGCMTLS13(key, fixedNonce)
|
||||
}
|
||||
|
||||
func aeadAESGCMTLS13(key, nonceMask []byte) aead {
|
||||
if len(nonceMask) != aeadNonceLength {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
}
|
||||
aes, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(aes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ret := &xorNonceAEAD{aead: aead}
|
||||
copy(ret.nonceMask[:], nonceMask)
|
||||
return ret
|
||||
}
|
||||
|
||||
func aeadChaCha20Poly1305(key, nonceMask []byte) aead {
|
||||
if len(nonceMask) != aeadNonceLength {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
}
|
||||
aead, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ret := &xorNonceAEAD{aead: aead}
|
||||
copy(ret.nonceMask[:], nonceMask)
|
||||
return ret
|
||||
}
|
||||
|
||||
type constantTimeHash interface {
|
||||
hash.Hash
|
||||
ConstantTimeSum(b []byte) []byte
|
||||
}
|
||||
|
||||
// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces
|
||||
// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC.
|
||||
type cthWrapper struct {
|
||||
h constantTimeHash
|
||||
}
|
||||
|
||||
func (c *cthWrapper) Size() int { return c.h.Size() }
|
||||
func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() }
|
||||
func (c *cthWrapper) Reset() { c.h.Reset() }
|
||||
func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) }
|
||||
func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) }
|
||||
|
||||
func newConstantTimeHash(h func() hash.Hash) func() hash.Hash {
|
||||
return func() hash.Hash {
|
||||
return &cthWrapper{h().(constantTimeHash)}
|
||||
}
|
||||
}
|
||||
|
||||
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3.
|
||||
func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte {
|
||||
h.Reset()
|
||||
h.Write(seq)
|
||||
h.Write(header)
|
||||
h.Write(data)
|
||||
res := h.Sum(out)
|
||||
if extra != nil {
|
||||
h.Write(extra)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func rsaKA(version uint16) keyAgreement {
|
||||
return rsaKeyAgreement{}
|
||||
}
|
||||
|
||||
func ecdheECDSAKA(version uint16) keyAgreement {
|
||||
return &ecdheKeyAgreement{
|
||||
isRSA: false,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
func ecdheRSAKA(version uint16) keyAgreement {
|
||||
return &ecdheKeyAgreement{
|
||||
isRSA: true,
|
||||
version: version,
|
||||
}
|
||||
}
|
||||
|
||||
// mutualCipherSuite returns a cipherSuite given a list of supported
|
||||
// ciphersuites and the id requested by the peer.
|
||||
func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
|
||||
for _, id := range have {
|
||||
if id == want {
|
||||
return cipherSuiteByID(id)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cipherSuiteByID(id uint16) *cipherSuite {
|
||||
for _, cipherSuite := range cipherSuites {
|
||||
if cipherSuite.id == id {
|
||||
return cipherSuite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 {
|
||||
for _, id := range have {
|
||||
if id == want {
|
||||
return cipherSuiteTLS13ByID(id)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 {
|
||||
for _, cipherSuite := range cipherSuitesTLS13 {
|
||||
if cipherSuite.id == id {
|
||||
return cipherSuite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// A list of cipher suite IDs that are, or have been, implemented by this
|
||||
// package.
|
||||
//
|
||||
// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml
|
||||
const (
|
||||
// TLS 1.0 - 1.2 cipher suites.
|
||||
TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005
|
||||
TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a
|
||||
TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f
|
||||
TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035
|
||||
TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c
|
||||
TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c
|
||||
TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d
|
||||
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a
|
||||
TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011
|
||||
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012
|
||||
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013
|
||||
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023
|
||||
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027
|
||||
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f
|
||||
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b
|
||||
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030
|
||||
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c
|
||||
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8
|
||||
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9
|
||||
|
||||
// TLS 1.3 cipher suites.
|
||||
TLS_AES_128_GCM_SHA256 uint16 = 0x1301
|
||||
TLS_AES_256_GCM_SHA384 uint16 = 0x1302
|
||||
TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303
|
||||
|
||||
// TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator
|
||||
// that the client is doing version fallback. See RFC 7507.
|
||||
TLS_FALLBACK_SCSV uint16 = 0x5600
|
||||
|
||||
// Legacy names for the corresponding cipher suites with the correct _SHA256
|
||||
// suffix, retained for backward compatibility.
|
||||
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
|
||||
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
|
||||
)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,22 +0,0 @@
|
|||
//go:build !js
|
||||
// +build !js
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
var (
|
||||
hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
|
||||
hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
|
||||
// Keep in sync with crypto/aes/cipher_s390x.go.
|
||||
hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR &&
|
||||
(cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
|
||||
|
||||
hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 ||
|
||||
runtime.GOARCH == "arm64" && hasGCMAsmARM64 ||
|
||||
runtime.GOARCH == "s390x" && hasGCMAsmS390X
|
||||
)
|
|
@ -1,12 +0,0 @@
|
|||
//go:build js
|
||||
// +build js
|
||||
|
||||
package qtls
|
||||
|
||||
var (
|
||||
hasGCMAsmAMD64 = false
|
||||
hasGCMAsmARM64 = false
|
||||
hasGCMAsmS390X = false
|
||||
|
||||
hasAESGCMHardwareSupport = false
|
||||
)
|
File diff suppressed because it is too large
Load Diff
|
@ -1,755 +0,0 @@
|
|||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"crypto/rsa"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"hash"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
)
|
||||
|
||||
type clientHandshakeStateTLS13 struct {
|
||||
c *Conn
|
||||
ctx context.Context
|
||||
serverHello *serverHelloMsg
|
||||
hello *clientHelloMsg
|
||||
ecdheParams ecdheParameters
|
||||
|
||||
session *clientSessionState
|
||||
earlySecret []byte
|
||||
binderKey []byte
|
||||
|
||||
certReq *certificateRequestMsgTLS13
|
||||
usingPSK bool
|
||||
sentDummyCCS bool
|
||||
suite *cipherSuiteTLS13
|
||||
transcript hash.Hash
|
||||
masterSecret []byte
|
||||
trafficSecret []byte // client_application_traffic_secret_0
|
||||
}
|
||||
|
||||
// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and,
|
||||
// optionally, hs.session, hs.earlySecret and hs.binderKey to be set.
|
||||
func (hs *clientHandshakeStateTLS13) handshake() error {
|
||||
c := hs.c
|
||||
|
||||
if needFIPS() {
|
||||
return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode")
|
||||
}
|
||||
|
||||
// The server must not select TLS 1.3 in a renegotiation. See RFC 8446,
|
||||
// sections 4.1.2 and 4.1.3.
|
||||
if c.handshakes > 0 {
|
||||
c.sendAlert(alertProtocolVersion)
|
||||
return errors.New("tls: server selected TLS 1.3 in a renegotiation")
|
||||
}
|
||||
|
||||
// Consistency check on the presence of a keyShare and its parameters.
|
||||
if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 {
|
||||
return c.sendAlert(alertInternalError)
|
||||
}
|
||||
|
||||
if err := hs.checkServerHelloOrHRR(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hs.transcript = hs.suite.hash.New()
|
||||
|
||||
if err := transcriptMsg(hs.hello, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
|
||||
if err := hs.sendDummyChangeCipherSpec(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.processHelloRetryRequest(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.buffering = true
|
||||
if err := hs.processServerHello(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.updateConnectionState()
|
||||
if err := hs.sendDummyChangeCipherSpec(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.establishHandshakeKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readServerParameters(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readServerCertificate(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.updateConnectionState()
|
||||
if err := hs.readServerFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendClientCertificate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendClientFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&c.handshakeStatus, 1)
|
||||
c.updateConnectionState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkServerHelloOrHRR does validity checks that apply to both ServerHello and
|
||||
// HelloRetryRequest messages. It sets hs.suite.
|
||||
func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
|
||||
c := hs.c
|
||||
|
||||
if hs.serverHello.supportedVersion == 0 {
|
||||
c.sendAlert(alertMissingExtension)
|
||||
return errors.New("tls: server selected TLS 1.3 using the legacy version field")
|
||||
}
|
||||
|
||||
if hs.serverHello.supportedVersion != VersionTLS13 {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server selected an invalid version after a HelloRetryRequest")
|
||||
}
|
||||
|
||||
if hs.serverHello.vers != VersionTLS12 {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server sent an incorrect legacy version")
|
||||
}
|
||||
|
||||
if hs.serverHello.ocspStapling ||
|
||||
hs.serverHello.ticketSupported ||
|
||||
hs.serverHello.secureRenegotiationSupported ||
|
||||
len(hs.serverHello.secureRenegotiation) != 0 ||
|
||||
len(hs.serverHello.alpnProtocol) != 0 ||
|
||||
len(hs.serverHello.scts) != 0 {
|
||||
c.sendAlert(alertUnsupportedExtension)
|
||||
return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3")
|
||||
}
|
||||
|
||||
if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server did not echo the legacy session ID")
|
||||
}
|
||||
|
||||
if hs.serverHello.compressionMethod != compressionNone {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server selected unsupported compression format")
|
||||
}
|
||||
|
||||
selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite)
|
||||
if hs.suite != nil && selectedSuite != hs.suite {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server changed cipher suite after a HelloRetryRequest")
|
||||
}
|
||||
if selectedSuite == nil {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server chose an unconfigured cipher suite")
|
||||
}
|
||||
hs.suite = selectedSuite
|
||||
c.cipherSuite = hs.suite.id
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
|
||||
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
|
||||
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
|
||||
if hs.sentDummyCCS {
|
||||
return nil
|
||||
}
|
||||
hs.sentDummyCCS = true
|
||||
|
||||
return hs.c.writeChangeCipherRecord()
|
||||
}
|
||||
|
||||
// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and
|
||||
// resends hs.hello, and reads the new ServerHello into hs.serverHello.
|
||||
func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
|
||||
c := hs.c
|
||||
|
||||
// The first ClientHello gets double-hashed into the transcript upon a
|
||||
// HelloRetryRequest. (The idea is that the server might offload transcript
|
||||
// storage to the client in the cookie.) See RFC 8446, Section 4.4.1.
|
||||
chHash := hs.transcript.Sum(nil)
|
||||
hs.transcript.Reset()
|
||||
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
|
||||
hs.transcript.Write(chHash)
|
||||
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The only HelloRetryRequest extensions we support are key_share and
|
||||
// cookie, and clients must abort the handshake if the HRR would not result
|
||||
// in any change in the ClientHello.
|
||||
if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server sent an unnecessary HelloRetryRequest message")
|
||||
}
|
||||
|
||||
if hs.serverHello.cookie != nil {
|
||||
hs.hello.cookie = hs.serverHello.cookie
|
||||
}
|
||||
|
||||
if hs.serverHello.serverShare.group != 0 {
|
||||
c.sendAlert(alertDecodeError)
|
||||
return errors.New("tls: received malformed key_share extension")
|
||||
}
|
||||
|
||||
// If the server sent a key_share extension selecting a group, ensure it's
|
||||
// a group we advertised but did not send a key share for, and send a key
|
||||
// share for it this time.
|
||||
if curveID := hs.serverHello.selectedGroup; curveID != 0 {
|
||||
curveOK := false
|
||||
for _, id := range hs.hello.supportedCurves {
|
||||
if id == curveID {
|
||||
curveOK = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !curveOK {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server selected unsupported group")
|
||||
}
|
||||
if hs.ecdheParams.CurveID() == curveID {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
|
||||
}
|
||||
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
|
||||
c.sendAlert(alertInternalError)
|
||||
return errors.New("tls: CurvePreferences includes unsupported curve")
|
||||
}
|
||||
params, err := generateECDHEParameters(c.config.rand(), curveID)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
hs.ecdheParams = params
|
||||
hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}}
|
||||
}
|
||||
|
||||
hs.hello.raw = nil
|
||||
if len(hs.hello.pskIdentities) > 0 {
|
||||
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
|
||||
if pskSuite == nil {
|
||||
return c.sendAlert(alertInternalError)
|
||||
}
|
||||
if pskSuite.hash == hs.suite.hash {
|
||||
// Update binders and obfuscated_ticket_age.
|
||||
ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond)
|
||||
hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd
|
||||
|
||||
transcript := hs.suite.hash.New()
|
||||
transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
|
||||
transcript.Write(chHash)
|
||||
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
helloBytes, err := hs.hello.marshalWithoutBinders()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
transcript.Write(helloBytes)
|
||||
pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)}
|
||||
if err := hs.hello.updateBinders(pskBinders); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Server selected a cipher suite incompatible with the PSK.
|
||||
hs.hello.pskIdentities = nil
|
||||
hs.hello.pskBinders = nil
|
||||
}
|
||||
}
|
||||
|
||||
if hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil {
|
||||
c.extraConfig.Rejected0RTT()
|
||||
}
|
||||
hs.hello.earlyData = false // disable 0-RTT
|
||||
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// serverHelloMsg is not included in the transcript
|
||||
msg, err := c.readHandshake(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serverHello, ok := msg.(*serverHelloMsg)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(serverHello, msg)
|
||||
}
|
||||
hs.serverHello = serverHello
|
||||
|
||||
if err := hs.checkServerHelloOrHRR(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeStateTLS13) processServerHello() error {
|
||||
c := hs.c
|
||||
|
||||
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return errors.New("tls: server sent two HelloRetryRequest messages")
|
||||
}
|
||||
|
||||
if len(hs.serverHello.cookie) != 0 {
|
||||
c.sendAlert(alertUnsupportedExtension)
|
||||
return errors.New("tls: server sent a cookie in a normal ServerHello")
|
||||
}
|
||||
|
||||
if hs.serverHello.selectedGroup != 0 {
|
||||
c.sendAlert(alertDecodeError)
|
||||
return errors.New("tls: malformed key_share extension")
|
||||
}
|
||||
|
||||
if hs.serverHello.serverShare.group == 0 {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server did not send a key share")
|
||||
}
|
||||
if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server selected unsupported group")
|
||||
}
|
||||
|
||||
if !hs.serverHello.selectedIdentityPresent {
|
||||
return nil
|
||||
}
|
||||
|
||||
if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server selected an invalid PSK")
|
||||
}
|
||||
|
||||
if len(hs.hello.pskIdentities) != 1 || hs.session == nil {
|
||||
return c.sendAlert(alertInternalError)
|
||||
}
|
||||
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
|
||||
if pskSuite == nil {
|
||||
return c.sendAlert(alertInternalError)
|
||||
}
|
||||
if pskSuite.hash != hs.suite.hash {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: server selected an invalid PSK and cipher suite pair")
|
||||
}
|
||||
|
||||
hs.usingPSK = true
|
||||
c.didResume = true
|
||||
c.peerCertificates = hs.session.serverCertificates
|
||||
c.verifiedChains = hs.session.verifiedChains
|
||||
c.ocspResponse = hs.session.ocspResponse
|
||||
c.scts = hs.session.scts
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
|
||||
c := hs.c
|
||||
|
||||
sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data)
|
||||
if sharedKey == nil {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: invalid server key share")
|
||||
}
|
||||
|
||||
earlySecret := hs.earlySecret
|
||||
if !hs.usingPSK {
|
||||
earlySecret = hs.suite.extract(nil, nil)
|
||||
}
|
||||
|
||||
handshakeSecret := hs.suite.extract(sharedKey,
|
||||
hs.suite.deriveSecret(earlySecret, "derived", nil))
|
||||
|
||||
clientSecret := hs.suite.deriveSecret(handshakeSecret,
|
||||
clientHandshakeTrafficLabel, hs.transcript)
|
||||
c.out.exportKey(EncryptionHandshake, hs.suite, clientSecret)
|
||||
c.out.setTrafficSecret(hs.suite, clientSecret)
|
||||
serverSecret := hs.suite.deriveSecret(handshakeSecret,
|
||||
serverHandshakeTrafficLabel, hs.transcript)
|
||||
c.in.exportKey(EncryptionHandshake, hs.suite, serverSecret)
|
||||
c.in.setTrafficSecret(hs.suite, serverSecret)
|
||||
|
||||
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
|
||||
hs.masterSecret = hs.suite.extract(nil,
|
||||
hs.suite.deriveSecret(handshakeSecret, "derived", nil))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeStateTLS13) readServerParameters() error {
|
||||
c := hs.c
|
||||
|
||||
msg, err := c.readHandshake(hs.transcript)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(encryptedExtensions, msg)
|
||||
}
|
||||
// Notify the caller if 0-RTT was rejected.
|
||||
if !encryptedExtensions.earlyData && hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil {
|
||||
c.extraConfig.Rejected0RTT()
|
||||
}
|
||||
c.used0RTT = encryptedExtensions.earlyData
|
||||
if hs.c.extraConfig != nil && hs.c.extraConfig.ReceivedExtensions != nil {
|
||||
hs.c.extraConfig.ReceivedExtensions(typeEncryptedExtensions, encryptedExtensions.additionalExtensions)
|
||||
}
|
||||
|
||||
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil {
|
||||
c.sendAlert(alertUnsupportedExtension)
|
||||
return err
|
||||
}
|
||||
c.clientProtocol = encryptedExtensions.alpnProtocol
|
||||
|
||||
if c.extraConfig != nil && c.extraConfig.EnforceNextProtoSelection {
|
||||
if len(encryptedExtensions.alpnProtocol) == 0 {
|
||||
// the server didn't select an ALPN
|
||||
c.sendAlert(alertNoApplicationProtocol)
|
||||
return errors.New("ALPN negotiation failed. Server didn't offer any protocols")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
|
||||
c := hs.c
|
||||
|
||||
// Either a PSK or a certificate is always used, but not both.
|
||||
// See RFC 8446, Section 4.1.1.
|
||||
if hs.usingPSK {
|
||||
// Make sure the connection is still being verified whether or not this
|
||||
// is a resumption. Resumptions currently don't reverify certificates so
|
||||
// they don't call verifyServerCertificate. See Issue 31641.
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
msg, err := c.readHandshake(hs.transcript)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certReq, ok := msg.(*certificateRequestMsgTLS13)
|
||||
if ok {
|
||||
hs.certReq = certReq
|
||||
|
||||
msg, err = c.readHandshake(hs.transcript)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
certMsg, ok := msg.(*certificateMsgTLS13)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(certMsg, msg)
|
||||
}
|
||||
if len(certMsg.certificate.Certificate) == 0 {
|
||||
c.sendAlert(alertDecodeError)
|
||||
return errors.New("tls: received empty certificates message")
|
||||
}
|
||||
|
||||
c.scts = certMsg.certificate.SignedCertificateTimestamps
|
||||
c.ocspResponse = certMsg.certificate.OCSPStaple
|
||||
|
||||
if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// certificateVerifyMsg is included in the transcript, but not until
|
||||
// after we verify the handshake signature, since the state before
|
||||
// this message was sent is used.
|
||||
msg, err = c.readHandshake(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certVerify, ok := msg.(*certificateVerifyMsg)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(certVerify, msg)
|
||||
}
|
||||
|
||||
// See RFC 8446, Section 4.4.3.
|
||||
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: certificate used with invalid signature algorithm")
|
||||
}
|
||||
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
|
||||
if err != nil {
|
||||
return c.sendAlert(alertInternalError)
|
||||
}
|
||||
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: certificate used with invalid signature algorithm")
|
||||
}
|
||||
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
|
||||
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
|
||||
sigHash, signed, certVerify.signature); err != nil {
|
||||
c.sendAlert(alertDecryptError)
|
||||
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
|
||||
}
|
||||
|
||||
if err := transcriptMsg(certVerify, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeStateTLS13) readServerFinished() error {
|
||||
c := hs.c
|
||||
|
||||
// finishedMsg is included in the transcript, but not until after we
|
||||
// check the client version, since the state before this message was
|
||||
// sent is used during verification.
|
||||
msg, err := c.readHandshake(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
finished, ok := msg.(*finishedMsg)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(finished, msg)
|
||||
}
|
||||
|
||||
expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
|
||||
if !hmac.Equal(expectedMAC, finished.verifyData) {
|
||||
c.sendAlert(alertDecryptError)
|
||||
return errors.New("tls: invalid server finished hash")
|
||||
}
|
||||
|
||||
if err := transcriptMsg(finished, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive secrets that take context through the server Finished.
|
||||
|
||||
hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
|
||||
clientApplicationTrafficLabel, hs.transcript)
|
||||
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
|
||||
serverApplicationTrafficLabel, hs.transcript)
|
||||
c.in.exportKey(EncryptionApplication, hs.suite, serverSecret)
|
||||
c.in.setTrafficSecret(hs.suite, serverSecret)
|
||||
|
||||
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
|
||||
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
|
||||
c := hs.c
|
||||
|
||||
if hs.certReq == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cert, err := c.getClientCertificate(toCertificateRequestInfo(&certificateRequestInfo{
|
||||
AcceptableCAs: hs.certReq.certificateAuthorities,
|
||||
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
|
||||
Version: c.vers,
|
||||
ctx: hs.ctx,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certMsg := new(certificateMsgTLS13)
|
||||
|
||||
certMsg.certificate = *cert
|
||||
certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0
|
||||
certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0
|
||||
|
||||
if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we sent an empty certificate message, skip the CertificateVerify.
|
||||
if len(cert.Certificate) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
certVerifyMsg := new(certificateVerifyMsg)
|
||||
certVerifyMsg.hasSignatureAlgorithm = true
|
||||
|
||||
certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms)
|
||||
if err != nil {
|
||||
// getClientCertificate returned a certificate incompatible with the
|
||||
// CertificateRequestInfo supported signature algorithms.
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return err
|
||||
}
|
||||
|
||||
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm)
|
||||
if err != nil {
|
||||
return c.sendAlert(alertInternalError)
|
||||
}
|
||||
|
||||
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
|
||||
signOpts := crypto.SignerOpts(sigHash)
|
||||
if sigType == signatureRSAPSS {
|
||||
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
|
||||
}
|
||||
sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return errors.New("tls: failed to sign handshake: " + err.Error())
|
||||
}
|
||||
certVerifyMsg.signature = sig
|
||||
|
||||
if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
|
||||
c := hs.c
|
||||
|
||||
finished := &finishedMsg{
|
||||
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
|
||||
}
|
||||
|
||||
if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.out.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret)
|
||||
c.out.setTrafficSecret(hs.suite, hs.trafficSecret)
|
||||
|
||||
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
|
||||
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
|
||||
resumptionLabel, hs.transcript)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
|
||||
if !c.isClient {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return errors.New("tls: received new session ticket from a client")
|
||||
}
|
||||
|
||||
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// See RFC 8446, Section 4.6.1.
|
||||
if msg.lifetime == 0 {
|
||||
return nil
|
||||
}
|
||||
lifetime := time.Duration(msg.lifetime) * time.Second
|
||||
if lifetime > maxSessionTicketLifetime {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: received a session ticket with invalid lifetime")
|
||||
}
|
||||
|
||||
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
|
||||
if cipherSuite == nil || c.resumptionSecret == nil {
|
||||
return c.sendAlert(alertInternalError)
|
||||
}
|
||||
|
||||
// We need to save the max_early_data_size that the server sent us, in order
|
||||
// to decide if we're going to try 0-RTT with this ticket.
|
||||
// However, at the same time, the qtls.ClientSessionTicket needs to be equal to
|
||||
// the tls.ClientSessionTicket, so we can't just add a new field to the struct.
|
||||
// We therefore abuse the nonce field (which is a byte slice)
|
||||
nonceWithEarlyData := make([]byte, len(msg.nonce)+4)
|
||||
binary.BigEndian.PutUint32(nonceWithEarlyData, msg.maxEarlyData)
|
||||
copy(nonceWithEarlyData[4:], msg.nonce)
|
||||
|
||||
var appData []byte
|
||||
if c.extraConfig != nil && c.extraConfig.GetAppDataForSessionState != nil {
|
||||
appData = c.extraConfig.GetAppDataForSessionState()
|
||||
}
|
||||
var b cryptobyte.Builder
|
||||
b.AddUint16(clientSessionStateVersion) // revision
|
||||
b.AddUint32(msg.maxEarlyData)
|
||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(appData)
|
||||
})
|
||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(msg.nonce)
|
||||
})
|
||||
|
||||
// Save the resumption_master_secret and nonce instead of deriving the PSK
|
||||
// to do the least amount of work on NewSessionTicket messages before we
|
||||
// know if the ticket will be used. Forward secrecy of resumed connections
|
||||
// is guaranteed by the requirement for pskModeDHE.
|
||||
session := &clientSessionState{
|
||||
sessionTicket: msg.label,
|
||||
vers: c.vers,
|
||||
cipherSuite: c.cipherSuite,
|
||||
masterSecret: c.resumptionSecret,
|
||||
serverCertificates: c.peerCertificates,
|
||||
verifiedChains: c.verifiedChains,
|
||||
receivedAt: c.config.time(),
|
||||
nonce: b.BytesOrPanic(),
|
||||
useBy: c.config.time().Add(lifetime),
|
||||
ageAdd: msg.ageAdd,
|
||||
ocspResponse: c.ocspResponse,
|
||||
scts: c.scts,
|
||||
}
|
||||
|
||||
cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
|
||||
c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session))
|
||||
|
||||
return nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -1,922 +0,0 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// serverHandshakeState contains details of a server handshake in progress.
|
||||
// It's discarded once the handshake has completed.
|
||||
type serverHandshakeState struct {
|
||||
c *Conn
|
||||
ctx context.Context
|
||||
clientHello *clientHelloMsg
|
||||
hello *serverHelloMsg
|
||||
suite *cipherSuite
|
||||
ecdheOk bool
|
||||
ecSignOk bool
|
||||
rsaDecryptOk bool
|
||||
rsaSignOk bool
|
||||
sessionState *sessionState
|
||||
finishedHash finishedHash
|
||||
masterSecret []byte
|
||||
cert *Certificate
|
||||
}
|
||||
|
||||
// serverHandshake performs a TLS handshake as a server.
|
||||
func (c *Conn) serverHandshake(ctx context.Context) error {
|
||||
c.setAlternativeRecordLayer()
|
||||
|
||||
clientHello, err := c.readClientHello(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.vers == VersionTLS13 {
|
||||
hs := serverHandshakeStateTLS13{
|
||||
c: c,
|
||||
ctx: ctx,
|
||||
clientHello: clientHello,
|
||||
}
|
||||
return hs.handshake()
|
||||
} else if c.extraConfig.usesAlternativeRecordLayer() {
|
||||
// This should already have been caught by the check that the ClientHello doesn't
|
||||
// offer any (supported) versions older than TLS 1.3.
|
||||
// Check again to make sure we can't be tricked into using an older version.
|
||||
c.sendAlert(alertProtocolVersion)
|
||||
return errors.New("tls: negotiated TLS < 1.3 when using QUIC")
|
||||
}
|
||||
|
||||
hs := serverHandshakeState{
|
||||
c: c,
|
||||
ctx: ctx,
|
||||
clientHello: clientHello,
|
||||
}
|
||||
return hs.handshake()
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) handshake() error {
|
||||
c := hs.c
|
||||
|
||||
if err := hs.processClientHello(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For an overview of TLS handshaking, see RFC 5246, Section 7.3.
|
||||
c.buffering = true
|
||||
if hs.checkForResumption() {
|
||||
// The client has included a session ticket and so we do an abbreviated handshake.
|
||||
c.didResume = true
|
||||
if err := hs.doResumeHandshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.establishKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendSessionTicket(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendFinished(c.serverFinished[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.clientFinishedIsFirst = false
|
||||
if err := hs.readFinished(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// The client didn't include a session ticket, or it wasn't
|
||||
// valid so we do a full handshake.
|
||||
if err := hs.pickCipherSuite(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.doFullHandshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.establishKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readFinished(c.clientFinished[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
c.clientFinishedIsFirst = true
|
||||
c.buffering = true
|
||||
if err := hs.sendSessionTicket(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendFinished(nil); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
|
||||
atomic.StoreUint32(&c.handshakeStatus, 1)
|
||||
|
||||
c.updateConnectionState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// readClientHello reads a ClientHello message and selects the protocol version.
|
||||
func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
|
||||
// clientHelloMsg is included in the transcript, but we haven't initialized
|
||||
// it yet. The respective handshake functions will record it themselves.
|
||||
msg, err := c.readHandshake(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientHello, ok := msg.(*clientHelloMsg)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return nil, unexpectedMessageError(clientHello, msg)
|
||||
}
|
||||
|
||||
var configForClient *config
|
||||
originalConfig := c.config
|
||||
if c.config.GetConfigForClient != nil {
|
||||
chi := newClientHelloInfo(ctx, c, clientHello)
|
||||
if cfc, err := c.config.GetConfigForClient(chi); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return nil, err
|
||||
} else if cfc != nil {
|
||||
configForClient = fromConfig(cfc)
|
||||
c.config = configForClient
|
||||
}
|
||||
}
|
||||
c.ticketKeys = originalConfig.ticketKeys(configForClient)
|
||||
|
||||
clientVersions := clientHello.supportedVersions
|
||||
if len(clientHello.supportedVersions) == 0 {
|
||||
clientVersions = supportedVersionsFromMax(clientHello.vers)
|
||||
}
|
||||
if c.extraConfig.usesAlternativeRecordLayer() {
|
||||
// In QUIC, the client MUST NOT offer any old TLS versions.
|
||||
// Here, we can only check that none of the other supported versions of this library
|
||||
// (TLS 1.0 - TLS 1.2) is offered. We don't check for any SSL versions here.
|
||||
for _, ver := range clientVersions {
|
||||
if ver == VersionTLS13 {
|
||||
continue
|
||||
}
|
||||
for _, v := range supportedVersions {
|
||||
if ver == v {
|
||||
c.sendAlert(alertProtocolVersion)
|
||||
return nil, fmt.Errorf("tls: client offered old TLS version %#x", ver)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Make the config we're using allows us to use TLS 1.3.
|
||||
if c.config.maxSupportedVersion(roleServer) < VersionTLS13 {
|
||||
c.sendAlert(alertInternalError)
|
||||
return nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3")
|
||||
}
|
||||
}
|
||||
c.vers, ok = c.config.mutualVersion(roleServer, clientVersions)
|
||||
if !ok {
|
||||
c.sendAlert(alertProtocolVersion)
|
||||
return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions)
|
||||
}
|
||||
c.haveVers = true
|
||||
c.in.version = c.vers
|
||||
c.out.version = c.vers
|
||||
|
||||
return clientHello, nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) processClientHello() error {
|
||||
c := hs.c
|
||||
|
||||
hs.hello = new(serverHelloMsg)
|
||||
hs.hello.vers = c.vers
|
||||
|
||||
foundCompression := false
|
||||
// We only support null compression, so check that the client offered it.
|
||||
for _, compression := range hs.clientHello.compressionMethods {
|
||||
if compression == compressionNone {
|
||||
foundCompression = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundCompression {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return errors.New("tls: client does not support uncompressed connections")
|
||||
}
|
||||
|
||||
hs.hello.random = make([]byte, 32)
|
||||
serverRandom := hs.hello.random
|
||||
// Downgrade protection canaries. See RFC 8446, Section 4.1.3.
|
||||
maxVers := c.config.maxSupportedVersion(roleServer)
|
||||
if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary {
|
||||
if c.vers == VersionTLS12 {
|
||||
copy(serverRandom[24:], downgradeCanaryTLS12)
|
||||
} else {
|
||||
copy(serverRandom[24:], downgradeCanaryTLS11)
|
||||
}
|
||||
serverRandom = serverRandom[:24]
|
||||
}
|
||||
_, err := io.ReadFull(c.config.rand(), serverRandom)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(hs.clientHello.secureRenegotiation) != 0 {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return errors.New("tls: initial handshake had non-empty renegotiation extension")
|
||||
}
|
||||
|
||||
hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
|
||||
hs.hello.compressionMethod = compressionNone
|
||||
if len(hs.clientHello.serverName) > 0 {
|
||||
c.serverName = hs.clientHello.serverName
|
||||
}
|
||||
|
||||
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
|
||||
if err != nil {
|
||||
c.sendAlert(alertNoApplicationProtocol)
|
||||
return err
|
||||
}
|
||||
hs.hello.alpnProtocol = selectedProto
|
||||
c.clientProtocol = selectedProto
|
||||
|
||||
hs.cert, err = c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello))
|
||||
if err != nil {
|
||||
if err == errNoCertificates {
|
||||
c.sendAlert(alertUnrecognizedName)
|
||||
} else {
|
||||
c.sendAlert(alertInternalError)
|
||||
}
|
||||
return err
|
||||
}
|
||||
if hs.clientHello.scts {
|
||||
hs.hello.scts = hs.cert.SignedCertificateTimestamps
|
||||
}
|
||||
|
||||
hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints)
|
||||
|
||||
if hs.ecdheOk && len(hs.clientHello.supportedPoints) > 0 {
|
||||
// Although omitting the ec_point_formats extension is permitted, some
|
||||
// old OpenSSL version will refuse to handshake if not present.
|
||||
//
|
||||
// Per RFC 4492, section 5.1.2, implementations MUST support the
|
||||
// uncompressed point format. See golang.org/issue/31943.
|
||||
hs.hello.supportedPoints = []uint8{pointFormatUncompressed}
|
||||
}
|
||||
|
||||
if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok {
|
||||
switch priv.Public().(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
hs.ecSignOk = true
|
||||
case ed25519.PublicKey:
|
||||
hs.ecSignOk = true
|
||||
case *rsa.PublicKey:
|
||||
hs.rsaSignOk = true
|
||||
default:
|
||||
c.sendAlert(alertInternalError)
|
||||
return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public())
|
||||
}
|
||||
}
|
||||
if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok {
|
||||
switch priv.Public().(type) {
|
||||
case *rsa.PublicKey:
|
||||
hs.rsaDecryptOk = true
|
||||
default:
|
||||
c.sendAlert(alertInternalError)
|
||||
return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// negotiateALPN picks a shared ALPN protocol that both sides support in server
|
||||
// preference order. If ALPN is not configured or the peer doesn't support it,
|
||||
// it returns "" and no error.
|
||||
func negotiateALPN(serverProtos, clientProtos []string) (string, error) {
|
||||
if len(serverProtos) == 0 || len(clientProtos) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
var http11fallback bool
|
||||
for _, s := range serverProtos {
|
||||
for _, c := range clientProtos {
|
||||
if s == c {
|
||||
return s, nil
|
||||
}
|
||||
if s == "h2" && c == "http/1.1" {
|
||||
http11fallback = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// As a special case, let http/1.1 clients connect to h2 servers as if they
|
||||
// didn't support ALPN. We used not to enforce protocol overlap, so over
|
||||
// time a number of HTTP servers were configured with only "h2", but
|
||||
// expected to accept connections from "http/1.1" clients. See Issue 46310.
|
||||
if http11fallback {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos)
|
||||
}
|
||||
|
||||
// supportsECDHE returns whether ECDHE key exchanges can be used with this
|
||||
// pre-TLS 1.3 client.
|
||||
func supportsECDHE(c *config, supportedCurves []CurveID, supportedPoints []uint8) bool {
|
||||
supportsCurve := false
|
||||
for _, curve := range supportedCurves {
|
||||
if c.supportsCurve(curve) {
|
||||
supportsCurve = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
supportsPointFormat := false
|
||||
for _, pointFormat := range supportedPoints {
|
||||
if pointFormat == pointFormatUncompressed {
|
||||
supportsPointFormat = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// Per RFC 8422, Section 5.1.2, if the Supported Point Formats extension is
|
||||
// missing, uncompressed points are supported. If supportedPoints is empty,
|
||||
// the extension must be missing, as an empty extension body is rejected by
|
||||
// the parser. See https://go.dev/issue/49126.
|
||||
if len(supportedPoints) == 0 {
|
||||
supportsPointFormat = true
|
||||
}
|
||||
|
||||
return supportsCurve && supportsPointFormat
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) pickCipherSuite() error {
|
||||
c := hs.c
|
||||
|
||||
preferenceOrder := cipherSuitesPreferenceOrder
|
||||
if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) {
|
||||
preferenceOrder = cipherSuitesPreferenceOrderNoAES
|
||||
}
|
||||
|
||||
configCipherSuites := c.config.cipherSuites()
|
||||
preferenceList := make([]uint16, 0, len(configCipherSuites))
|
||||
for _, suiteID := range preferenceOrder {
|
||||
for _, id := range configCipherSuites {
|
||||
if id == suiteID {
|
||||
preferenceList = append(preferenceList, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk)
|
||||
if hs.suite == nil {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return errors.New("tls: no cipher suite supported by both client and server")
|
||||
}
|
||||
c.cipherSuite = hs.suite.id
|
||||
|
||||
for _, id := range hs.clientHello.cipherSuites {
|
||||
if id == TLS_FALLBACK_SCSV {
|
||||
// The client is doing a fallback connection. See RFC 7507.
|
||||
if hs.clientHello.vers < c.config.maxSupportedVersion(roleServer) {
|
||||
c.sendAlert(alertInappropriateFallback)
|
||||
return errors.New("tls: client using inappropriate protocol fallback")
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool {
|
||||
if c.flags&suiteECDHE != 0 {
|
||||
if !hs.ecdheOk {
|
||||
return false
|
||||
}
|
||||
if c.flags&suiteECSign != 0 {
|
||||
if !hs.ecSignOk {
|
||||
return false
|
||||
}
|
||||
} else if !hs.rsaSignOk {
|
||||
return false
|
||||
}
|
||||
} else if !hs.rsaDecryptOk {
|
||||
return false
|
||||
}
|
||||
if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkForResumption reports whether we should perform resumption on this connection.
|
||||
func (hs *serverHandshakeState) checkForResumption() bool {
|
||||
c := hs.c
|
||||
|
||||
if c.config.SessionTicketsDisabled {
|
||||
return false
|
||||
}
|
||||
|
||||
plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket)
|
||||
if plaintext == nil {
|
||||
return false
|
||||
}
|
||||
hs.sessionState = &sessionState{usedOldKey: usedOldKey}
|
||||
ok := hs.sessionState.unmarshal(plaintext)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
createdAt := time.Unix(int64(hs.sessionState.createdAt), 0)
|
||||
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
|
||||
return false
|
||||
}
|
||||
|
||||
// Never resume a session for a different TLS version.
|
||||
if c.vers != hs.sessionState.vers {
|
||||
return false
|
||||
}
|
||||
|
||||
cipherSuiteOk := false
|
||||
// Check that the client is still offering the ciphersuite in the session.
|
||||
for _, id := range hs.clientHello.cipherSuites {
|
||||
if id == hs.sessionState.cipherSuite {
|
||||
cipherSuiteOk = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !cipherSuiteOk {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check that we also support the ciphersuite from the session.
|
||||
hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite},
|
||||
c.config.cipherSuites(), hs.cipherSuiteOk)
|
||||
if hs.suite == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
sessionHasClientCerts := len(hs.sessionState.certificates) != 0
|
||||
needClientCerts := requiresClientCert(c.config.ClientAuth)
|
||||
if needClientCerts && !sessionHasClientCerts {
|
||||
return false
|
||||
}
|
||||
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) doResumeHandshake() error {
|
||||
c := hs.c
|
||||
|
||||
hs.hello.cipherSuite = hs.suite.id
|
||||
c.cipherSuite = hs.suite.id
|
||||
// We echo the client's session ID in the ServerHello to let it know
|
||||
// that we're doing a resumption.
|
||||
hs.hello.sessionId = hs.clientHello.sessionId
|
||||
hs.hello.ticketSupported = hs.sessionState.usedOldKey
|
||||
hs.finishedHash = newFinishedHash(c.vers, hs.suite)
|
||||
hs.finishedHash.discardHandshakeBuffer()
|
||||
if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.processCertsFromClient(Certificate{
|
||||
Certificate: hs.sessionState.certificates,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
hs.masterSecret = hs.sessionState.masterSecret
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) doFullHandshake() error {
|
||||
c := hs.c
|
||||
|
||||
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
|
||||
hs.hello.ocspStapling = true
|
||||
}
|
||||
|
||||
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
|
||||
hs.hello.cipherSuite = hs.suite.id
|
||||
|
||||
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
|
||||
if c.config.ClientAuth == NoClientCert {
|
||||
// No need to keep a full record of the handshake if client
|
||||
// certificates won't be used.
|
||||
hs.finishedHash.discardHandshakeBuffer()
|
||||
}
|
||||
if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certMsg := new(certificateMsg)
|
||||
certMsg.certificates = hs.cert.Certificate
|
||||
if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hs.hello.ocspStapling {
|
||||
certStatus := new(certificateStatusMsg)
|
||||
certStatus.response = hs.cert.OCSPStaple
|
||||
if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
keyAgreement := hs.suite.ka(c.vers)
|
||||
skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello)
|
||||
if err != nil {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return err
|
||||
}
|
||||
if skx != nil {
|
||||
if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var certReq *certificateRequestMsg
|
||||
if c.config.ClientAuth >= RequestClientCert {
|
||||
// Request a client certificate
|
||||
certReq = new(certificateRequestMsg)
|
||||
certReq.certificateTypes = []byte{
|
||||
byte(certTypeRSASign),
|
||||
byte(certTypeECDSASign),
|
||||
}
|
||||
if c.vers >= VersionTLS12 {
|
||||
certReq.hasSignatureAlgorithm = true
|
||||
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
|
||||
}
|
||||
|
||||
// An empty list of certificateAuthorities signals to
|
||||
// the client that it may send any certificate in response
|
||||
// to our request. When we know the CAs we trust, then
|
||||
// we can send them down, so that the client can choose
|
||||
// an appropriate certificate to give to us.
|
||||
if c.config.ClientCAs != nil {
|
||||
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
|
||||
}
|
||||
if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
helloDone := new(serverHelloDoneMsg)
|
||||
if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := c.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var pub crypto.PublicKey // public key for client auth, if any
|
||||
|
||||
msg, err := c.readHandshake(&hs.finishedHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If we requested a client certificate, then the client must send a
|
||||
// certificate message, even if it's empty.
|
||||
if c.config.ClientAuth >= RequestClientCert {
|
||||
certMsg, ok := msg.(*certificateMsg)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(certMsg, msg)
|
||||
}
|
||||
|
||||
if err := c.processCertsFromClient(Certificate{
|
||||
Certificate: certMsg.certificates,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(certMsg.certificates) != 0 {
|
||||
pub = c.peerCertificates[0].PublicKey
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake(&hs.finishedHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Get client key exchange
|
||||
ckx, ok := msg.(*clientKeyExchangeMsg)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(ckx, msg)
|
||||
}
|
||||
|
||||
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
|
||||
if err != nil {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return err
|
||||
}
|
||||
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
|
||||
if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
|
||||
// If we received a client cert in response to our certificate request message,
|
||||
// the client will send us a certificateVerifyMsg immediately after the
|
||||
// clientKeyExchangeMsg. This message is a digest of all preceding
|
||||
// handshake-layer messages that is signed using the private key corresponding
|
||||
// to the client's certificate. This allows us to verify that the client is in
|
||||
// possession of the private key of the certificate.
|
||||
if len(c.peerCertificates) > 0 {
|
||||
// certificateVerifyMsg is included in the transcript, but not until
|
||||
// after we verify the handshake signature, since the state before
|
||||
// this message was sent is used.
|
||||
msg, err = c.readHandshake(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
certVerify, ok := msg.(*certificateVerifyMsg)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(certVerify, msg)
|
||||
}
|
||||
|
||||
var sigType uint8
|
||||
var sigHash crypto.Hash
|
||||
if c.vers >= VersionTLS12 {
|
||||
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: client certificate used with invalid signature algorithm")
|
||||
}
|
||||
sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
|
||||
if err != nil {
|
||||
return c.sendAlert(alertInternalError)
|
||||
}
|
||||
} else {
|
||||
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub)
|
||||
if err != nil {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret)
|
||||
if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil {
|
||||
c.sendAlert(alertDecryptError)
|
||||
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
|
||||
}
|
||||
|
||||
if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
hs.finishedHash.discardHandshakeBuffer()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) establishKeys() error {
|
||||
c := hs.c
|
||||
|
||||
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
|
||||
keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
|
||||
|
||||
var clientCipher, serverCipher any
|
||||
var clientHash, serverHash hash.Hash
|
||||
|
||||
if hs.suite.aead == nil {
|
||||
clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */)
|
||||
clientHash = hs.suite.mac(clientMAC)
|
||||
serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */)
|
||||
serverHash = hs.suite.mac(serverMAC)
|
||||
} else {
|
||||
clientCipher = hs.suite.aead(clientKey, clientIV)
|
||||
serverCipher = hs.suite.aead(serverKey, serverIV)
|
||||
}
|
||||
|
||||
c.in.prepareCipherSpec(c.vers, clientCipher, clientHash)
|
||||
c.out.prepareCipherSpec(c.vers, serverCipher, serverHash)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) readFinished(out []byte) error {
|
||||
c := hs.c
|
||||
|
||||
if err := c.readChangeCipherSpec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// finishedMsg is included in the transcript, but not until after we
|
||||
// check the client version, since the state before this message was
|
||||
// sent is used during verification.
|
||||
msg, err := c.readHandshake(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientFinished, ok := msg.(*finishedMsg)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(clientFinished, msg)
|
||||
}
|
||||
|
||||
verify := hs.finishedHash.clientSum(hs.masterSecret)
|
||||
if len(verify) != len(clientFinished.verifyData) ||
|
||||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return errors.New("tls: client's Finished message is incorrect")
|
||||
}
|
||||
|
||||
if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
copy(out, verify)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) sendSessionTicket() error {
|
||||
// ticketSupported is set in a resumption handshake if the
|
||||
// ticket from the client was encrypted with an old session
|
||||
// ticket key and thus a refreshed ticket should be sent.
|
||||
if !hs.hello.ticketSupported {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := hs.c
|
||||
m := new(newSessionTicketMsg)
|
||||
|
||||
createdAt := uint64(c.config.time().Unix())
|
||||
if hs.sessionState != nil {
|
||||
// If this is re-wrapping an old key, then keep
|
||||
// the original time it was created.
|
||||
createdAt = hs.sessionState.createdAt
|
||||
}
|
||||
|
||||
var certsFromClient [][]byte
|
||||
for _, cert := range c.peerCertificates {
|
||||
certsFromClient = append(certsFromClient, cert.Raw)
|
||||
}
|
||||
state := sessionState{
|
||||
vers: c.vers,
|
||||
cipherSuite: hs.suite.id,
|
||||
createdAt: createdAt,
|
||||
masterSecret: hs.masterSecret,
|
||||
certificates: certsFromClient,
|
||||
}
|
||||
stateBytes, err := state.marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ticket, err = c.encryptTicket(stateBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) sendFinished(out []byte) error {
|
||||
c := hs.c
|
||||
|
||||
if err := c.writeChangeCipherRecord(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
finished := new(finishedMsg)
|
||||
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
|
||||
if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
copy(out, finished.verifyData)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processCertsFromClient takes a chain of client certificates either from a
|
||||
// Certificates message or from a sessionState and verifies them. It returns
|
||||
// the public key of the leaf certificate.
|
||||
func (c *Conn) processCertsFromClient(certificate Certificate) error {
|
||||
certificates := certificate.Certificate
|
||||
certs := make([]*x509.Certificate, len(certificates))
|
||||
var err error
|
||||
for i, asn1Data := range certificates {
|
||||
if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return errors.New("tls: failed to parse client certificate: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return errors.New("tls: client didn't provide a certificate")
|
||||
}
|
||||
|
||||
if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: c.config.ClientCAs,
|
||||
CurrentTime: c.config.time(),
|
||||
Intermediates: x509.NewCertPool(),
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
}
|
||||
|
||||
for _, cert := range certs[1:] {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
|
||||
chains, err := certs[0].Verify(opts)
|
||||
if err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return errors.New("tls: failed to verify client certificate: " + err.Error())
|
||||
}
|
||||
|
||||
c.verifiedChains = chains
|
||||
}
|
||||
|
||||
c.peerCertificates = certs
|
||||
c.ocspResponse = certificate.OCSPStaple
|
||||
c.scts = certificate.SignedCertificateTimestamps
|
||||
|
||||
if len(certs) > 0 {
|
||||
switch certs[0].PublicKey.(type) {
|
||||
case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
|
||||
default:
|
||||
c.sendAlert(alertUnsupportedCertificate)
|
||||
return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.VerifyPeerCertificate != nil {
|
||||
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newClientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
|
||||
supportedVersions := clientHello.supportedVersions
|
||||
if len(clientHello.supportedVersions) == 0 {
|
||||
supportedVersions = supportedVersionsFromMax(clientHello.vers)
|
||||
}
|
||||
|
||||
return toClientHelloInfo(&clientHelloInfo{
|
||||
CipherSuites: clientHello.cipherSuites,
|
||||
ServerName: clientHello.serverName,
|
||||
SupportedCurves: clientHello.supportedCurves,
|
||||
SupportedPoints: clientHello.supportedPoints,
|
||||
SignatureSchemes: clientHello.supportedSignatureAlgorithms,
|
||||
SupportedProtos: clientHello.alpnProtocols,
|
||||
SupportedVersions: supportedVersions,
|
||||
Conn: c.conn,
|
||||
config: toConfig(c.config),
|
||||
ctx: ctx,
|
||||
})
|
||||
}
|
|
@ -1,903 +0,0 @@
|
|||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"hash"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// maxClientPSKIdentities is the number of client PSK identities the server will
|
||||
// attempt to validate. It will ignore the rest not to let cheap ClientHello
|
||||
// messages cause too much work in session ticket decryption attempts.
|
||||
const maxClientPSKIdentities = 5
|
||||
|
||||
type serverHandshakeStateTLS13 struct {
|
||||
c *Conn
|
||||
ctx context.Context
|
||||
clientHello *clientHelloMsg
|
||||
hello *serverHelloMsg
|
||||
alpnNegotiationErr error
|
||||
encryptedExtensions *encryptedExtensionsMsg
|
||||
sentDummyCCS bool
|
||||
usingPSK bool
|
||||
suite *cipherSuiteTLS13
|
||||
cert *Certificate
|
||||
sigAlg SignatureScheme
|
||||
earlySecret []byte
|
||||
sharedKey []byte
|
||||
handshakeSecret []byte
|
||||
masterSecret []byte
|
||||
trafficSecret []byte // client_application_traffic_secret_0
|
||||
transcript hash.Hash
|
||||
clientFinished []byte
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) handshake() error {
|
||||
c := hs.c
|
||||
|
||||
if needFIPS() {
|
||||
return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode")
|
||||
}
|
||||
|
||||
// For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2.
|
||||
if err := hs.processClientHello(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.checkForResumption(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.updateConnectionState()
|
||||
if err := hs.pickCertificate(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.buffering = true
|
||||
if err := hs.sendServerParameters(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendServerCertificate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendServerFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Note that at this point we could start sending application data without
|
||||
// waiting for the client's second flight, but the application might not
|
||||
// expect the lack of replay protection of the ClientHello parameters.
|
||||
if _, err := c.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readClientCertificate(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.updateConnectionState()
|
||||
if err := hs.readClientFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&c.handshakeStatus, 1)
|
||||
c.updateConnectionState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) processClientHello() error {
|
||||
c := hs.c
|
||||
|
||||
hs.hello = new(serverHelloMsg)
|
||||
hs.encryptedExtensions = new(encryptedExtensionsMsg)
|
||||
|
||||
// TLS 1.3 froze the ServerHello.legacy_version field, and uses
|
||||
// supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1.
|
||||
hs.hello.vers = VersionTLS12
|
||||
hs.hello.supportedVersion = c.vers
|
||||
|
||||
if len(hs.clientHello.supportedVersions) == 0 {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: client used the legacy version field to negotiate TLS 1.3")
|
||||
}
|
||||
|
||||
// Abort if the client is doing a fallback and landing lower than what we
|
||||
// support. See RFC 7507, which however does not specify the interaction
|
||||
// with supported_versions. The only difference is that with
|
||||
// supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4]
|
||||
// handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case,
|
||||
// it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to
|
||||
// TLS 1.2, because a TLS 1.3 server would abort here. The situation before
|
||||
// supported_versions was not better because there was just no way to do a
|
||||
// TLS 1.4 handshake without risking the server selecting TLS 1.3.
|
||||
for _, id := range hs.clientHello.cipherSuites {
|
||||
if id == TLS_FALLBACK_SCSV {
|
||||
// Use c.vers instead of max(supported_versions) because an attacker
|
||||
// could defeat this by adding an arbitrary high version otherwise.
|
||||
if c.vers < c.config.maxSupportedVersion(roleServer) {
|
||||
c.sendAlert(alertInappropriateFallback)
|
||||
return errors.New("tls: client using inappropriate protocol fallback")
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(hs.clientHello.compressionMethods) != 1 ||
|
||||
hs.clientHello.compressionMethods[0] != compressionNone {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: TLS 1.3 client supports illegal compression methods")
|
||||
}
|
||||
|
||||
hs.hello.random = make([]byte, 32)
|
||||
if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(hs.clientHello.secureRenegotiation) != 0 {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return errors.New("tls: initial handshake had non-empty renegotiation extension")
|
||||
}
|
||||
|
||||
hs.hello.sessionId = hs.clientHello.sessionId
|
||||
hs.hello.compressionMethod = compressionNone
|
||||
|
||||
preferenceList := defaultCipherSuitesTLS13
|
||||
if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) {
|
||||
preferenceList = defaultCipherSuitesTLS13NoAES
|
||||
}
|
||||
for _, suiteID := range preferenceList {
|
||||
hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID)
|
||||
if hs.suite != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if hs.suite == nil {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return errors.New("tls: no cipher suite supported by both client and server")
|
||||
}
|
||||
c.cipherSuite = hs.suite.id
|
||||
hs.hello.cipherSuite = hs.suite.id
|
||||
hs.transcript = hs.suite.hash.New()
|
||||
|
||||
// Pick the ECDHE group in server preference order, but give priority to
|
||||
// groups with a key share, to avoid a HelloRetryRequest round-trip.
|
||||
var selectedGroup CurveID
|
||||
var clientKeyShare *keyShare
|
||||
GroupSelection:
|
||||
for _, preferredGroup := range c.config.curvePreferences() {
|
||||
for _, ks := range hs.clientHello.keyShares {
|
||||
if ks.group == preferredGroup {
|
||||
selectedGroup = ks.group
|
||||
clientKeyShare = &ks
|
||||
break GroupSelection
|
||||
}
|
||||
}
|
||||
if selectedGroup != 0 {
|
||||
continue
|
||||
}
|
||||
for _, group := range hs.clientHello.supportedCurves {
|
||||
if group == preferredGroup {
|
||||
selectedGroup = group
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if selectedGroup == 0 {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return errors.New("tls: no ECDHE curve supported by both client and server")
|
||||
}
|
||||
if clientKeyShare == nil {
|
||||
if err := hs.doHelloRetryRequest(selectedGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
clientKeyShare = &hs.clientHello.keyShares[0]
|
||||
}
|
||||
|
||||
if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok {
|
||||
c.sendAlert(alertInternalError)
|
||||
return errors.New("tls: CurvePreferences includes unsupported curve")
|
||||
}
|
||||
params, err := generateECDHEParameters(c.config.rand(), selectedGroup)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()}
|
||||
hs.sharedKey = params.SharedKey(clientKeyShare.data)
|
||||
if hs.sharedKey == nil {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: invalid client key share")
|
||||
}
|
||||
|
||||
c.serverName = hs.clientHello.serverName
|
||||
|
||||
if c.extraConfig != nil && c.extraConfig.ReceivedExtensions != nil {
|
||||
c.extraConfig.ReceivedExtensions(typeClientHello, hs.clientHello.additionalExtensions)
|
||||
}
|
||||
|
||||
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
|
||||
if err != nil {
|
||||
hs.alpnNegotiationErr = err
|
||||
}
|
||||
hs.encryptedExtensions.alpnProtocol = selectedProto
|
||||
c.clientProtocol = selectedProto
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) checkForResumption() error {
|
||||
c := hs.c
|
||||
|
||||
if c.config.SessionTicketsDisabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
modeOK := false
|
||||
for _, mode := range hs.clientHello.pskModes {
|
||||
if mode == pskModeDHE {
|
||||
modeOK = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !modeOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: invalid or missing PSK binders")
|
||||
}
|
||||
if len(hs.clientHello.pskIdentities) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i, identity := range hs.clientHello.pskIdentities {
|
||||
if i >= maxClientPSKIdentities {
|
||||
break
|
||||
}
|
||||
|
||||
plaintext, _ := c.decryptTicket(identity.label)
|
||||
if plaintext == nil {
|
||||
continue
|
||||
}
|
||||
sessionState := new(sessionStateTLS13)
|
||||
if ok := sessionState.unmarshal(plaintext); !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if hs.clientHello.earlyData {
|
||||
if sessionState.maxEarlyData == 0 {
|
||||
c.sendAlert(alertUnsupportedExtension)
|
||||
return errors.New("tls: client sent unexpected early data")
|
||||
}
|
||||
|
||||
if hs.alpnNegotiationErr == nil && sessionState.alpn == c.clientProtocol &&
|
||||
c.extraConfig != nil && c.extraConfig.MaxEarlyData > 0 &&
|
||||
c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) {
|
||||
hs.encryptedExtensions.earlyData = true
|
||||
c.used0RTT = true
|
||||
}
|
||||
}
|
||||
|
||||
createdAt := time.Unix(int64(sessionState.createdAt), 0)
|
||||
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
|
||||
continue
|
||||
}
|
||||
|
||||
// We don't check the obfuscated ticket age because it's affected by
|
||||
// clock skew and it's only a freshness signal useful for shrinking the
|
||||
// window for replay attacks, which don't affect us as we don't do 0-RTT.
|
||||
|
||||
pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite)
|
||||
if pskSuite == nil || pskSuite.hash != hs.suite.hash {
|
||||
continue
|
||||
}
|
||||
|
||||
// PSK connections don't re-establish client certificates, but carry
|
||||
// them over in the session ticket. Ensure the presence of client certs
|
||||
// in the ticket is consistent with the configured requirements.
|
||||
sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0
|
||||
needClientCerts := requiresClientCert(c.config.ClientAuth)
|
||||
if needClientCerts && !sessionHasClientCerts {
|
||||
continue
|
||||
}
|
||||
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
|
||||
continue
|
||||
}
|
||||
|
||||
psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption",
|
||||
nil, hs.suite.hash.Size())
|
||||
hs.earlySecret = hs.suite.extract(psk, nil)
|
||||
binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil)
|
||||
// Clone the transcript in case a HelloRetryRequest was recorded.
|
||||
transcript := cloneHash(hs.transcript, hs.suite.hash)
|
||||
if transcript == nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return errors.New("tls: internal error: failed to clone hash")
|
||||
}
|
||||
clientHelloBytes, err := hs.clientHello.marshalWithoutBinders()
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
transcript.Write(clientHelloBytes)
|
||||
pskBinder := hs.suite.finishedHash(binderKey, transcript)
|
||||
if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) {
|
||||
c.sendAlert(alertDecryptError)
|
||||
return errors.New("tls: invalid PSK binder")
|
||||
}
|
||||
|
||||
c.didResume = true
|
||||
if err := c.processCertsFromClient(sessionState.certificate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h := cloneHash(hs.transcript, hs.suite.hash)
|
||||
clientHelloWithBindersBytes, err := hs.clientHello.marshal()
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
h.Write(clientHelloWithBindersBytes)
|
||||
if hs.encryptedExtensions.earlyData {
|
||||
clientEarlySecret := hs.suite.deriveSecret(hs.earlySecret, "c e traffic", h)
|
||||
c.in.exportKey(Encryption0RTT, hs.suite, clientEarlySecret)
|
||||
if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hs.clientHello.random, clientEarlySecret); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
hs.hello.selectedIdentityPresent = true
|
||||
hs.hello.selectedIdentity = uint16(i)
|
||||
hs.usingPSK = true
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
|
||||
// interfaces implemented by standard library hashes to clone the state of in
|
||||
// to a new instance of h. It returns nil if the operation fails.
|
||||
func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash {
|
||||
// Recreate the interface to avoid importing encoding.
|
||||
type binaryMarshaler interface {
|
||||
MarshalBinary() (data []byte, err error)
|
||||
UnmarshalBinary(data []byte) error
|
||||
}
|
||||
marshaler, ok := in.(binaryMarshaler)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
state, err := marshaler.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
out := h.New()
|
||||
unmarshaler, ok := out.(binaryMarshaler)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err := unmarshaler.UnmarshalBinary(state); err != nil {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) pickCertificate() error {
|
||||
c := hs.c
|
||||
|
||||
// Only one of PSK and certificates are used at a time.
|
||||
if hs.usingPSK {
|
||||
return nil
|
||||
}
|
||||
|
||||
// signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3.
|
||||
if len(hs.clientHello.supportedSignatureAlgorithms) == 0 {
|
||||
return c.sendAlert(alertMissingExtension)
|
||||
}
|
||||
|
||||
certificate, err := c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello))
|
||||
if err != nil {
|
||||
if err == errNoCertificates {
|
||||
c.sendAlert(alertUnrecognizedName)
|
||||
} else {
|
||||
c.sendAlert(alertInternalError)
|
||||
}
|
||||
return err
|
||||
}
|
||||
hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms)
|
||||
if err != nil {
|
||||
// getCertificate returned a certificate that is unsupported or
|
||||
// incompatible with the client's signature algorithms.
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return err
|
||||
}
|
||||
hs.cert = certificate
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
|
||||
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
|
||||
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
|
||||
if hs.sentDummyCCS {
|
||||
return nil
|
||||
}
|
||||
hs.sentDummyCCS = true
|
||||
|
||||
return hs.c.writeChangeCipherRecord()
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error {
|
||||
c := hs.c
|
||||
|
||||
// The first ClientHello gets double-hashed into the transcript upon a
|
||||
// HelloRetryRequest. See RFC 8446, Section 4.4.1.
|
||||
if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
chHash := hs.transcript.Sum(nil)
|
||||
hs.transcript.Reset()
|
||||
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
|
||||
hs.transcript.Write(chHash)
|
||||
|
||||
helloRetryRequest := &serverHelloMsg{
|
||||
vers: hs.hello.vers,
|
||||
random: helloRetryRequestRandom,
|
||||
sessionId: hs.hello.sessionId,
|
||||
cipherSuite: hs.hello.cipherSuite,
|
||||
compressionMethod: hs.hello.compressionMethod,
|
||||
supportedVersion: hs.hello.supportedVersion,
|
||||
selectedGroup: selectedGroup,
|
||||
}
|
||||
|
||||
if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := hs.sendDummyChangeCipherSpec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// clientHelloMsg is not included in the transcript.
|
||||
msg, err := c.readHandshake(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientHello, ok := msg.(*clientHelloMsg)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(clientHello, msg)
|
||||
}
|
||||
|
||||
if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: client sent invalid key share in second ClientHello")
|
||||
}
|
||||
|
||||
if clientHello.earlyData {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: client indicated early data in second ClientHello")
|
||||
}
|
||||
|
||||
if illegalClientHelloChange(clientHello, hs.clientHello) {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: client illegally modified second ClientHello")
|
||||
}
|
||||
|
||||
if clientHello.earlyData {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: client offered 0-RTT data in second ClientHello")
|
||||
}
|
||||
|
||||
hs.clientHello = clientHello
|
||||
return nil
|
||||
}
|
||||
|
||||
// illegalClientHelloChange reports whether the two ClientHello messages are
|
||||
// different, with the exception of the changes allowed before and after a
|
||||
// HelloRetryRequest. See RFC 8446, Section 4.1.2.
|
||||
func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool {
|
||||
if len(ch.supportedVersions) != len(ch1.supportedVersions) ||
|
||||
len(ch.cipherSuites) != len(ch1.cipherSuites) ||
|
||||
len(ch.supportedCurves) != len(ch1.supportedCurves) ||
|
||||
len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) ||
|
||||
len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) ||
|
||||
len(ch.alpnProtocols) != len(ch1.alpnProtocols) {
|
||||
return true
|
||||
}
|
||||
for i := range ch.supportedVersions {
|
||||
if ch.supportedVersions[i] != ch1.supportedVersions[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for i := range ch.cipherSuites {
|
||||
if ch.cipherSuites[i] != ch1.cipherSuites[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for i := range ch.supportedCurves {
|
||||
if ch.supportedCurves[i] != ch1.supportedCurves[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for i := range ch.supportedSignatureAlgorithms {
|
||||
if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for i := range ch.supportedSignatureAlgorithmsCert {
|
||||
if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for i := range ch.alpnProtocols {
|
||||
if ch.alpnProtocols[i] != ch1.alpnProtocols[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return ch.vers != ch1.vers ||
|
||||
!bytes.Equal(ch.random, ch1.random) ||
|
||||
!bytes.Equal(ch.sessionId, ch1.sessionId) ||
|
||||
!bytes.Equal(ch.compressionMethods, ch1.compressionMethods) ||
|
||||
ch.serverName != ch1.serverName ||
|
||||
ch.ocspStapling != ch1.ocspStapling ||
|
||||
!bytes.Equal(ch.supportedPoints, ch1.supportedPoints) ||
|
||||
ch.ticketSupported != ch1.ticketSupported ||
|
||||
!bytes.Equal(ch.sessionTicket, ch1.sessionTicket) ||
|
||||
ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported ||
|
||||
!bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) ||
|
||||
ch.scts != ch1.scts ||
|
||||
!bytes.Equal(ch.cookie, ch1.cookie) ||
|
||||
!bytes.Equal(ch.pskModes, ch1.pskModes)
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
|
||||
c := hs.c
|
||||
|
||||
if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := hs.sendDummyChangeCipherSpec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
earlySecret := hs.earlySecret
|
||||
if earlySecret == nil {
|
||||
earlySecret = hs.suite.extract(nil, nil)
|
||||
}
|
||||
hs.handshakeSecret = hs.suite.extract(hs.sharedKey,
|
||||
hs.suite.deriveSecret(earlySecret, "derived", nil))
|
||||
|
||||
clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
|
||||
clientHandshakeTrafficLabel, hs.transcript)
|
||||
c.in.exportKey(EncryptionHandshake, hs.suite, clientSecret)
|
||||
c.in.setTrafficSecret(hs.suite, clientSecret)
|
||||
serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
|
||||
serverHandshakeTrafficLabel, hs.transcript)
|
||||
c.out.exportKey(EncryptionHandshake, hs.suite, serverSecret)
|
||||
c.out.setTrafficSecret(hs.suite, serverSecret)
|
||||
|
||||
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
|
||||
if hs.alpnNegotiationErr != nil {
|
||||
c.sendAlert(alertNoApplicationProtocol)
|
||||
return hs.alpnNegotiationErr
|
||||
}
|
||||
if hs.c.extraConfig != nil && hs.c.extraConfig.GetExtensions != nil {
|
||||
hs.encryptedExtensions.additionalExtensions = hs.c.extraConfig.GetExtensions(typeEncryptedExtensions)
|
||||
}
|
||||
|
||||
if _, err := hs.c.writeHandshakeRecord(hs.encryptedExtensions, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) requestClientCert() bool {
|
||||
return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) sendServerCertificate() error {
|
||||
c := hs.c
|
||||
|
||||
// Only one of PSK and certificates are used at a time.
|
||||
if hs.usingPSK {
|
||||
return nil
|
||||
}
|
||||
|
||||
if hs.requestClientCert() {
|
||||
// Request a client certificate
|
||||
certReq := new(certificateRequestMsgTLS13)
|
||||
certReq.ocspStapling = true
|
||||
certReq.scts = true
|
||||
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
|
||||
if c.config.ClientCAs != nil {
|
||||
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
|
||||
}
|
||||
|
||||
if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
certMsg := new(certificateMsgTLS13)
|
||||
|
||||
certMsg.certificate = *hs.cert
|
||||
certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0
|
||||
certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0
|
||||
|
||||
if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certVerifyMsg := new(certificateVerifyMsg)
|
||||
certVerifyMsg.hasSignatureAlgorithm = true
|
||||
certVerifyMsg.signatureAlgorithm = hs.sigAlg
|
||||
|
||||
sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg)
|
||||
if err != nil {
|
||||
return c.sendAlert(alertInternalError)
|
||||
}
|
||||
|
||||
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
|
||||
signOpts := crypto.SignerOpts(sigHash)
|
||||
if sigType == signatureRSAPSS {
|
||||
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
|
||||
}
|
||||
sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
|
||||
if err != nil {
|
||||
public := hs.cert.PrivateKey.(crypto.Signer).Public()
|
||||
if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS &&
|
||||
rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
} else {
|
||||
c.sendAlert(alertInternalError)
|
||||
}
|
||||
return errors.New("tls: failed to sign handshake: " + err.Error())
|
||||
}
|
||||
certVerifyMsg.signature = sig
|
||||
|
||||
if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
|
||||
c := hs.c
|
||||
|
||||
finished := &finishedMsg{
|
||||
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
|
||||
}
|
||||
|
||||
if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Derive secrets that take context through the server Finished.
|
||||
|
||||
hs.masterSecret = hs.suite.extract(nil,
|
||||
hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil))
|
||||
|
||||
hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
|
||||
clientApplicationTrafficLabel, hs.transcript)
|
||||
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
|
||||
serverApplicationTrafficLabel, hs.transcript)
|
||||
c.out.exportKey(EncryptionApplication, hs.suite, serverSecret)
|
||||
c.out.setTrafficSecret(hs.suite, serverSecret)
|
||||
|
||||
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
|
||||
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
|
||||
|
||||
// If we did not request client certificates, at this point we can
|
||||
// precompute the client finished and roll the transcript forward to send
|
||||
// session tickets in our first flight.
|
||||
if !hs.requestClientCert() {
|
||||
if err := hs.sendSessionTickets(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool {
|
||||
if hs.c.config.SessionTicketsDisabled {
|
||||
return false
|
||||
}
|
||||
|
||||
// Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9.
|
||||
for _, pskMode := range hs.clientHello.pskModes {
|
||||
if pskMode == pskModeDHE {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
|
||||
c := hs.c
|
||||
|
||||
hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
|
||||
finishedMsg := &finishedMsg{
|
||||
verifyData: hs.clientFinished,
|
||||
}
|
||||
if err := transcriptMsg(finishedMsg, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !hs.shouldSendSessionTickets() {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
|
||||
resumptionLabel, hs.transcript)
|
||||
|
||||
// Don't send session tickets when the alternative record layer is set.
|
||||
// Instead, save the resumption secret on the Conn.
|
||||
// Session tickets can then be generated by calling Conn.GetSessionTicket().
|
||||
if hs.c.extraConfig != nil && hs.c.extraConfig.AlternativeRecordLayer != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m, err := hs.c.getSessionTicketMsg(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := c.writeHandshakeRecord(m, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
|
||||
c := hs.c
|
||||
|
||||
if !hs.requestClientCert() {
|
||||
// Make sure the connection is still being verified whether or not
|
||||
// the server requested a client certificate.
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we requested a client certificate, then the client must send a
|
||||
// certificate message. If it's empty, no CertificateVerify is sent.
|
||||
|
||||
msg, err := c.readHandshake(hs.transcript)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certMsg, ok := msg.(*certificateMsgTLS13)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(certMsg, msg)
|
||||
}
|
||||
|
||||
if err := c.processCertsFromClient(certMsg.certificate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.config.VerifyConnection != nil {
|
||||
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(certMsg.certificate.Certificate) != 0 {
|
||||
// certificateVerifyMsg is included in the transcript, but not until
|
||||
// after we verify the handshake signature, since the state before
|
||||
// this message was sent is used.
|
||||
msg, err = c.readHandshake(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certVerify, ok := msg.(*certificateVerifyMsg)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(certVerify, msg)
|
||||
}
|
||||
|
||||
// See RFC 8446, Section 4.4.3.
|
||||
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: client certificate used with invalid signature algorithm")
|
||||
}
|
||||
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
|
||||
if err != nil {
|
||||
return c.sendAlert(alertInternalError)
|
||||
}
|
||||
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: client certificate used with invalid signature algorithm")
|
||||
}
|
||||
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
|
||||
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
|
||||
sigHash, signed, certVerify.signature); err != nil {
|
||||
c.sendAlert(alertDecryptError)
|
||||
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
|
||||
}
|
||||
|
||||
if err := transcriptMsg(certVerify, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If we waited until the client certificates to send session tickets, we
|
||||
// are ready to do it now.
|
||||
if err := hs.sendSessionTickets(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) readClientFinished() error {
|
||||
c := hs.c
|
||||
|
||||
// finishedMsg is not included in the transcript.
|
||||
msg, err := c.readHandshake(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
finished, ok := msg.(*finishedMsg)
|
||||
if !ok {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(finished, msg)
|
||||
}
|
||||
|
||||
if !hmac.Equal(hs.clientFinished, finished.verifyData) {
|
||||
c.sendAlert(alertDecryptError)
|
||||
return errors.New("tls: invalid client finished hash")
|
||||
}
|
||||
|
||||
c.in.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret)
|
||||
c.in.setTrafficSecret(hs.suite, hs.trafficSecret)
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,357 +0,0 @@
|
|||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/md5"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// a keyAgreement implements the client and server side of a TLS key agreement
|
||||
// protocol by generating and processing key exchange messages.
|
||||
type keyAgreement interface {
|
||||
// On the server side, the first two methods are called in order.
|
||||
|
||||
// In the case that the key agreement protocol doesn't use a
|
||||
// ServerKeyExchange message, generateServerKeyExchange can return nil,
|
||||
// nil.
|
||||
generateServerKeyExchange(*config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error)
|
||||
processClientKeyExchange(*config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error)
|
||||
|
||||
// On the client side, the next two methods are called in order.
|
||||
|
||||
// This method may not be called if the server doesn't send a
|
||||
// ServerKeyExchange message.
|
||||
processServerKeyExchange(*config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error
|
||||
generateClientKeyExchange(*config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error)
|
||||
}
|
||||
|
||||
var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
|
||||
var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
|
||||
|
||||
// rsaKeyAgreement implements the standard TLS key agreement where the client
|
||||
// encrypts the pre-master secret to the server's public key.
|
||||
type rsaKeyAgreement struct{}
|
||||
|
||||
func (ka rsaKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (ka rsaKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
|
||||
if len(ckx.ciphertext) < 2 {
|
||||
return nil, errClientKeyExchange
|
||||
}
|
||||
ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
|
||||
if ciphertextLen != len(ckx.ciphertext)-2 {
|
||||
return nil, errClientKeyExchange
|
||||
}
|
||||
ciphertext := ckx.ciphertext[2:]
|
||||
|
||||
priv, ok := cert.PrivateKey.(crypto.Decrypter)
|
||||
if !ok {
|
||||
return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
|
||||
}
|
||||
// Perform constant time RSA PKCS #1 v1.5 decryption
|
||||
preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// We don't check the version number in the premaster secret. For one,
|
||||
// by checking it, we would leak information about the validity of the
|
||||
// encrypted pre-master secret. Secondly, it provides only a small
|
||||
// benefit against a downgrade attack and some implementations send the
|
||||
// wrong version anyway. See the discussion at the end of section
|
||||
// 7.4.7.1 of RFC 4346.
|
||||
return preMasterSecret, nil
|
||||
}
|
||||
|
||||
func (ka rsaKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
|
||||
return errors.New("tls: unexpected ServerKeyExchange")
|
||||
}
|
||||
|
||||
func (ka rsaKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
|
||||
preMasterSecret := make([]byte, 48)
|
||||
preMasterSecret[0] = byte(clientHello.vers >> 8)
|
||||
preMasterSecret[1] = byte(clientHello.vers)
|
||||
_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rsaKey, ok := cert.PublicKey.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite")
|
||||
}
|
||||
encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
ckx := new(clientKeyExchangeMsg)
|
||||
ckx.ciphertext = make([]byte, len(encrypted)+2)
|
||||
ckx.ciphertext[0] = byte(len(encrypted) >> 8)
|
||||
ckx.ciphertext[1] = byte(len(encrypted))
|
||||
copy(ckx.ciphertext[2:], encrypted)
|
||||
return preMasterSecret, ckx, nil
|
||||
}
|
||||
|
||||
// sha1Hash calculates a SHA1 hash over the given byte slices.
|
||||
func sha1Hash(slices [][]byte) []byte {
|
||||
hsha1 := sha1.New()
|
||||
for _, slice := range slices {
|
||||
hsha1.Write(slice)
|
||||
}
|
||||
return hsha1.Sum(nil)
|
||||
}
|
||||
|
||||
// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
|
||||
// concatenation of an MD5 and SHA1 hash.
|
||||
func md5SHA1Hash(slices [][]byte) []byte {
|
||||
md5sha1 := make([]byte, md5.Size+sha1.Size)
|
||||
hmd5 := md5.New()
|
||||
for _, slice := range slices {
|
||||
hmd5.Write(slice)
|
||||
}
|
||||
copy(md5sha1, hmd5.Sum(nil))
|
||||
copy(md5sha1[md5.Size:], sha1Hash(slices))
|
||||
return md5sha1
|
||||
}
|
||||
|
||||
// hashForServerKeyExchange hashes the given slices and returns their digest
|
||||
// using the given hash function (for >= TLS 1.2) or using a default based on
|
||||
// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't
|
||||
// do pre-hashing, it returns the concatenation of the slices.
|
||||
func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte {
|
||||
if sigType == signatureEd25519 {
|
||||
var signed []byte
|
||||
for _, slice := range slices {
|
||||
signed = append(signed, slice...)
|
||||
}
|
||||
return signed
|
||||
}
|
||||
if version >= VersionTLS12 {
|
||||
h := hashFunc.New()
|
||||
for _, slice := range slices {
|
||||
h.Write(slice)
|
||||
}
|
||||
digest := h.Sum(nil)
|
||||
return digest
|
||||
}
|
||||
if sigType == signatureECDSA {
|
||||
return sha1Hash(slices)
|
||||
}
|
||||
return md5SHA1Hash(slices)
|
||||
}
|
||||
|
||||
// ecdheKeyAgreement implements a TLS key agreement where the server
|
||||
// generates an ephemeral EC public/private key pair and signs it. The
|
||||
// pre-master secret is then calculated using ECDH. The signature may
|
||||
// be ECDSA, Ed25519 or RSA.
|
||||
type ecdheKeyAgreement struct {
|
||||
version uint16
|
||||
isRSA bool
|
||||
params ecdheParameters
|
||||
|
||||
// ckx and preMasterSecret are generated in processServerKeyExchange
|
||||
// and returned in generateClientKeyExchange.
|
||||
ckx *clientKeyExchangeMsg
|
||||
preMasterSecret []byte
|
||||
}
|
||||
|
||||
func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
|
||||
var curveID CurveID
|
||||
for _, c := range clientHello.supportedCurves {
|
||||
if config.supportsCurve(c) {
|
||||
curveID = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if curveID == 0 {
|
||||
return nil, errors.New("tls: no supported elliptic curves offered")
|
||||
}
|
||||
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
|
||||
return nil, errors.New("tls: CurvePreferences includes unsupported curve")
|
||||
}
|
||||
|
||||
params, err := generateECDHEParameters(config.rand(), curveID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ka.params = params
|
||||
|
||||
// See RFC 4492, Section 5.4.
|
||||
ecdhePublic := params.PublicKey()
|
||||
serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic))
|
||||
serverECDHEParams[0] = 3 // named curve
|
||||
serverECDHEParams[1] = byte(curveID >> 8)
|
||||
serverECDHEParams[2] = byte(curveID)
|
||||
serverECDHEParams[3] = byte(len(ecdhePublic))
|
||||
copy(serverECDHEParams[4:], ecdhePublic)
|
||||
|
||||
priv, ok := cert.PrivateKey.(crypto.Signer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey)
|
||||
}
|
||||
|
||||
var signatureAlgorithm SignatureScheme
|
||||
var sigType uint8
|
||||
var sigHash crypto.Hash
|
||||
if ka.version >= VersionTLS12 {
|
||||
signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
|
||||
return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
|
||||
}
|
||||
|
||||
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams)
|
||||
|
||||
signOpts := crypto.SignerOpts(sigHash)
|
||||
if sigType == signatureRSAPSS {
|
||||
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
|
||||
}
|
||||
sig, err := priv.Sign(config.rand(), signed, signOpts)
|
||||
if err != nil {
|
||||
return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
|
||||
}
|
||||
|
||||
skx := new(serverKeyExchangeMsg)
|
||||
sigAndHashLen := 0
|
||||
if ka.version >= VersionTLS12 {
|
||||
sigAndHashLen = 2
|
||||
}
|
||||
skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig))
|
||||
copy(skx.key, serverECDHEParams)
|
||||
k := skx.key[len(serverECDHEParams):]
|
||||
if ka.version >= VersionTLS12 {
|
||||
k[0] = byte(signatureAlgorithm >> 8)
|
||||
k[1] = byte(signatureAlgorithm)
|
||||
k = k[2:]
|
||||
}
|
||||
k[0] = byte(len(sig) >> 8)
|
||||
k[1] = byte(len(sig))
|
||||
copy(k[2:], sig)
|
||||
|
||||
return skx, nil
|
||||
}
|
||||
|
||||
func (ka *ecdheKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
|
||||
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
|
||||
return nil, errClientKeyExchange
|
||||
}
|
||||
|
||||
preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:])
|
||||
if preMasterSecret == nil {
|
||||
return nil, errClientKeyExchange
|
||||
}
|
||||
|
||||
return preMasterSecret, nil
|
||||
}
|
||||
|
||||
func (ka *ecdheKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
|
||||
if len(skx.key) < 4 {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
if skx.key[0] != 3 { // named curve
|
||||
return errors.New("tls: server selected unsupported curve")
|
||||
}
|
||||
curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
|
||||
|
||||
publicLen := int(skx.key[3])
|
||||
if publicLen+4 > len(skx.key) {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
serverECDHEParams := skx.key[:4+publicLen]
|
||||
publicKey := serverECDHEParams[4:]
|
||||
|
||||
sig := skx.key[4+publicLen:]
|
||||
if len(sig) < 2 {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
|
||||
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
|
||||
return errors.New("tls: server selected unsupported curve")
|
||||
}
|
||||
|
||||
params, err := generateECDHEParameters(config.rand(), curveID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ka.params = params
|
||||
|
||||
ka.preMasterSecret = params.SharedKey(publicKey)
|
||||
if ka.preMasterSecret == nil {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
|
||||
ourPublicKey := params.PublicKey()
|
||||
ka.ckx = new(clientKeyExchangeMsg)
|
||||
ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey))
|
||||
ka.ckx.ciphertext[0] = byte(len(ourPublicKey))
|
||||
copy(ka.ckx.ciphertext[1:], ourPublicKey)
|
||||
|
||||
var sigType uint8
|
||||
var sigHash crypto.Hash
|
||||
if ka.version >= VersionTLS12 {
|
||||
signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
|
||||
sig = sig[2:]
|
||||
if len(sig) < 2 {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
|
||||
if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) {
|
||||
return errors.New("tls: certificate used with invalid signature algorithm")
|
||||
}
|
||||
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
|
||||
sigLen := int(sig[0])<<8 | int(sig[1])
|
||||
if sigLen+2 != len(sig) {
|
||||
return errServerKeyExchange
|
||||
}
|
||||
sig = sig[2:]
|
||||
|
||||
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams)
|
||||
if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil {
|
||||
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
|
||||
if ka.ckx == nil {
|
||||
return nil, nil, errors.New("tls: missing ServerKeyExchange message")
|
||||
}
|
||||
|
||||
return ka.preMasterSecret, ka.ckx, nil
|
||||
}
|
|
@ -1,216 +0,0 @@
|
|||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
"crypto/hmac"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"math/big"
|
||||
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// This file contains the functions necessary to compute the TLS 1.3 key
|
||||
// schedule. See RFC 8446, Section 7.
|
||||
|
||||
const (
|
||||
resumptionBinderLabel = "res binder"
|
||||
clientHandshakeTrafficLabel = "c hs traffic"
|
||||
serverHandshakeTrafficLabel = "s hs traffic"
|
||||
clientApplicationTrafficLabel = "c ap traffic"
|
||||
serverApplicationTrafficLabel = "s ap traffic"
|
||||
exporterLabel = "exp master"
|
||||
resumptionLabel = "res master"
|
||||
trafficUpdateLabel = "traffic upd"
|
||||
)
|
||||
|
||||
// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
|
||||
func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte {
|
||||
var hkdfLabel cryptobyte.Builder
|
||||
hkdfLabel.AddUint16(uint16(length))
|
||||
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes([]byte("tls13 "))
|
||||
b.AddBytes([]byte(label))
|
||||
})
|
||||
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(context)
|
||||
})
|
||||
hkdfLabelBytes, err := hkdfLabel.Bytes()
|
||||
if err != nil {
|
||||
// Rather than calling BytesOrPanic, we explicitly handle this error, in
|
||||
// order to provide a reasonable error message. It should be basically
|
||||
// impossible for this to panic, and routing errors back through the
|
||||
// tree rooted in this function is quite painful. The labels are fixed
|
||||
// size, and the context is either a fixed-length computed hash, or
|
||||
// parsed from a field which has the same length limitation. As such, an
|
||||
// error here is likely to only be caused during development.
|
||||
//
|
||||
// NOTE: another reasonable approach here might be to return a
|
||||
// randomized slice if we encounter an error, which would break the
|
||||
// connection, but avoid panicking. This would perhaps be safer but
|
||||
// significantly more confusing to users.
|
||||
panic(fmt.Errorf("failed to construct HKDF label: %s", err))
|
||||
}
|
||||
out := make([]byte, length)
|
||||
n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out)
|
||||
if err != nil || n != length {
|
||||
panic("tls: HKDF-Expand-Label invocation failed unexpectedly")
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1.
|
||||
func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte {
|
||||
if transcript == nil {
|
||||
transcript = c.hash.New()
|
||||
}
|
||||
return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size())
|
||||
}
|
||||
|
||||
// extract implements HKDF-Extract with the cipher suite hash.
|
||||
func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte {
|
||||
if newSecret == nil {
|
||||
newSecret = make([]byte, c.hash.Size())
|
||||
}
|
||||
return hkdf.Extract(c.hash.New, newSecret, currentSecret)
|
||||
}
|
||||
|
||||
// nextTrafficSecret generates the next traffic secret, given the current one,
|
||||
// according to RFC 8446, Section 7.2.
|
||||
func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte {
|
||||
return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size())
|
||||
}
|
||||
|
||||
// trafficKey generates traffic keys according to RFC 8446, Section 7.3.
|
||||
func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) {
|
||||
key = c.expandLabel(trafficSecret, "key", nil, c.keyLen)
|
||||
iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength)
|
||||
return
|
||||
}
|
||||
|
||||
// finishedHash generates the Finished verify_data or PskBinderEntry according
|
||||
// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey
|
||||
// selection.
|
||||
func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte {
|
||||
finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size())
|
||||
verifyData := hmac.New(c.hash.New, finishedKey)
|
||||
verifyData.Write(transcript.Sum(nil))
|
||||
return verifyData.Sum(nil)
|
||||
}
|
||||
|
||||
// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to
|
||||
// RFC 8446, Section 7.5.
|
||||
func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) {
|
||||
expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript)
|
||||
return func(label string, context []byte, length int) ([]byte, error) {
|
||||
secret := c.deriveSecret(expMasterSecret, label, nil)
|
||||
h := c.hash.New()
|
||||
h.Write(context)
|
||||
return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil
|
||||
}
|
||||
}
|
||||
|
||||
// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519,
|
||||
// according to RFC 8446, Section 4.2.8.2.
|
||||
type ecdheParameters interface {
|
||||
CurveID() CurveID
|
||||
PublicKey() []byte
|
||||
SharedKey(peerPublicKey []byte) []byte
|
||||
}
|
||||
|
||||
func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) {
|
||||
if curveID == X25519 {
|
||||
privateKey := make([]byte, curve25519.ScalarSize)
|
||||
if _, err := io.ReadFull(rand, privateKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil
|
||||
}
|
||||
|
||||
curve, ok := curveForCurveID(curveID)
|
||||
if !ok {
|
||||
return nil, errors.New("tls: internal error: unsupported curve")
|
||||
}
|
||||
|
||||
p := &nistParameters{curveID: curveID}
|
||||
var err error
|
||||
p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func curveForCurveID(id CurveID) (elliptic.Curve, bool) {
|
||||
switch id {
|
||||
case CurveP256:
|
||||
return elliptic.P256(), true
|
||||
case CurveP384:
|
||||
return elliptic.P384(), true
|
||||
case CurveP521:
|
||||
return elliptic.P521(), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
type nistParameters struct {
|
||||
privateKey []byte
|
||||
x, y *big.Int // public key
|
||||
curveID CurveID
|
||||
}
|
||||
|
||||
func (p *nistParameters) CurveID() CurveID {
|
||||
return p.curveID
|
||||
}
|
||||
|
||||
func (p *nistParameters) PublicKey() []byte {
|
||||
curve, _ := curveForCurveID(p.curveID)
|
||||
return elliptic.Marshal(curve, p.x, p.y)
|
||||
}
|
||||
|
||||
func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte {
|
||||
curve, _ := curveForCurveID(p.curveID)
|
||||
// Unmarshal also checks whether the given point is on the curve.
|
||||
x, y := elliptic.Unmarshal(curve, peerPublicKey)
|
||||
if x == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
xShared, _ := curve.ScalarMult(x, y, p.privateKey)
|
||||
sharedKey := make([]byte, (curve.Params().BitSize+7)/8)
|
||||
return xShared.FillBytes(sharedKey)
|
||||
}
|
||||
|
||||
type x25519Parameters struct {
|
||||
privateKey []byte
|
||||
publicKey []byte
|
||||
}
|
||||
|
||||
func (p *x25519Parameters) CurveID() CurveID {
|
||||
return X25519
|
||||
}
|
||||
|
||||
func (p *x25519Parameters) PublicKey() []byte {
|
||||
return p.publicKey[:]
|
||||
}
|
||||
|
||||
func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte {
|
||||
sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return sharedKey
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package qtls
|
||||
|
||||
func needFIPS() bool { return false }
|
||||
|
||||
func supportedSignatureAlgorithms() []SignatureScheme {
|
||||
return defaultSupportedSignatureAlgorithms
|
||||
}
|
||||
|
||||
func fipsMinVersion(c *config) uint16 { panic("fipsMinVersion") }
|
||||
func fipsMaxVersion(c *config) uint16 { panic("fipsMaxVersion") }
|
||||
func fipsCurvePreferences(c *config) []CurveID { panic("fipsCurvePreferences") }
|
||||
func fipsCipherSuites(c *config) []uint16 { panic("fipsCipherSuites") }
|
||||
|
||||
var fipsSupportedSignatureAlgorithms []SignatureScheme
|
|
@ -1,283 +0,0 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
)
|
||||
|
||||
// Split a premaster secret in two as specified in RFC 4346, Section 5.
|
||||
func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
|
||||
s1 = secret[0 : (len(secret)+1)/2]
|
||||
s2 = secret[len(secret)/2:]
|
||||
return
|
||||
}
|
||||
|
||||
// pHash implements the P_hash function, as defined in RFC 4346, Section 5.
|
||||
func pHash(result, secret, seed []byte, hash func() hash.Hash) {
|
||||
h := hmac.New(hash, secret)
|
||||
h.Write(seed)
|
||||
a := h.Sum(nil)
|
||||
|
||||
j := 0
|
||||
for j < len(result) {
|
||||
h.Reset()
|
||||
h.Write(a)
|
||||
h.Write(seed)
|
||||
b := h.Sum(nil)
|
||||
copy(result[j:], b)
|
||||
j += len(b)
|
||||
|
||||
h.Reset()
|
||||
h.Write(a)
|
||||
a = h.Sum(nil)
|
||||
}
|
||||
}
|
||||
|
||||
// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5.
|
||||
func prf10(result, secret, label, seed []byte) {
|
||||
hashSHA1 := sha1.New
|
||||
hashMD5 := md5.New
|
||||
|
||||
labelAndSeed := make([]byte, len(label)+len(seed))
|
||||
copy(labelAndSeed, label)
|
||||
copy(labelAndSeed[len(label):], seed)
|
||||
|
||||
s1, s2 := splitPreMasterSecret(secret)
|
||||
pHash(result, s1, labelAndSeed, hashMD5)
|
||||
result2 := make([]byte, len(result))
|
||||
pHash(result2, s2, labelAndSeed, hashSHA1)
|
||||
|
||||
for i, b := range result2 {
|
||||
result[i] ^= b
|
||||
}
|
||||
}
|
||||
|
||||
// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5.
|
||||
func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) {
|
||||
return func(result, secret, label, seed []byte) {
|
||||
labelAndSeed := make([]byte, len(label)+len(seed))
|
||||
copy(labelAndSeed, label)
|
||||
copy(labelAndSeed[len(label):], seed)
|
||||
|
||||
pHash(result, secret, labelAndSeed, hashFunc)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
masterSecretLength = 48 // Length of a master secret in TLS 1.1.
|
||||
finishedVerifyLength = 12 // Length of verify_data in a Finished message.
|
||||
)
|
||||
|
||||
var masterSecretLabel = []byte("master secret")
|
||||
var keyExpansionLabel = []byte("key expansion")
|
||||
var clientFinishedLabel = []byte("client finished")
|
||||
var serverFinishedLabel = []byte("server finished")
|
||||
|
||||
func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) {
|
||||
switch version {
|
||||
case VersionTLS10, VersionTLS11:
|
||||
return prf10, crypto.Hash(0)
|
||||
case VersionTLS12:
|
||||
if suite.flags&suiteSHA384 != 0 {
|
||||
return prf12(sha512.New384), crypto.SHA384
|
||||
}
|
||||
return prf12(sha256.New), crypto.SHA256
|
||||
default:
|
||||
panic("unknown version")
|
||||
}
|
||||
}
|
||||
|
||||
func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) {
|
||||
prf, _ := prfAndHashForVersion(version, suite)
|
||||
return prf
|
||||
}
|
||||
|
||||
// masterFromPreMasterSecret generates the master secret from the pre-master
|
||||
// secret. See RFC 5246, Section 8.1.
|
||||
func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte {
|
||||
seed := make([]byte, 0, len(clientRandom)+len(serverRandom))
|
||||
seed = append(seed, clientRandom...)
|
||||
seed = append(seed, serverRandom...)
|
||||
|
||||
masterSecret := make([]byte, masterSecretLength)
|
||||
prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed)
|
||||
return masterSecret
|
||||
}
|
||||
|
||||
// keysFromMasterSecret generates the connection keys from the master
|
||||
// secret, given the lengths of the MAC key, cipher key and IV, as defined in
|
||||
// RFC 2246, Section 6.3.
|
||||
func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
|
||||
seed := make([]byte, 0, len(serverRandom)+len(clientRandom))
|
||||
seed = append(seed, serverRandom...)
|
||||
seed = append(seed, clientRandom...)
|
||||
|
||||
n := 2*macLen + 2*keyLen + 2*ivLen
|
||||
keyMaterial := make([]byte, n)
|
||||
prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed)
|
||||
clientMAC = keyMaterial[:macLen]
|
||||
keyMaterial = keyMaterial[macLen:]
|
||||
serverMAC = keyMaterial[:macLen]
|
||||
keyMaterial = keyMaterial[macLen:]
|
||||
clientKey = keyMaterial[:keyLen]
|
||||
keyMaterial = keyMaterial[keyLen:]
|
||||
serverKey = keyMaterial[:keyLen]
|
||||
keyMaterial = keyMaterial[keyLen:]
|
||||
clientIV = keyMaterial[:ivLen]
|
||||
keyMaterial = keyMaterial[ivLen:]
|
||||
serverIV = keyMaterial[:ivLen]
|
||||
return
|
||||
}
|
||||
|
||||
func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash {
|
||||
var buffer []byte
|
||||
if version >= VersionTLS12 {
|
||||
buffer = []byte{}
|
||||
}
|
||||
|
||||
prf, hash := prfAndHashForVersion(version, cipherSuite)
|
||||
if hash != 0 {
|
||||
return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf}
|
||||
}
|
||||
|
||||
return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf}
|
||||
}
|
||||
|
||||
// A finishedHash calculates the hash of a set of handshake messages suitable
|
||||
// for including in a Finished message.
|
||||
type finishedHash struct {
|
||||
client hash.Hash
|
||||
server hash.Hash
|
||||
|
||||
// Prior to TLS 1.2, an additional MD5 hash is required.
|
||||
clientMD5 hash.Hash
|
||||
serverMD5 hash.Hash
|
||||
|
||||
// In TLS 1.2, a full buffer is sadly required.
|
||||
buffer []byte
|
||||
|
||||
version uint16
|
||||
prf func(result, secret, label, seed []byte)
|
||||
}
|
||||
|
||||
func (h *finishedHash) Write(msg []byte) (n int, err error) {
|
||||
h.client.Write(msg)
|
||||
h.server.Write(msg)
|
||||
|
||||
if h.version < VersionTLS12 {
|
||||
h.clientMD5.Write(msg)
|
||||
h.serverMD5.Write(msg)
|
||||
}
|
||||
|
||||
if h.buffer != nil {
|
||||
h.buffer = append(h.buffer, msg...)
|
||||
}
|
||||
|
||||
return len(msg), nil
|
||||
}
|
||||
|
||||
func (h finishedHash) Sum() []byte {
|
||||
if h.version >= VersionTLS12 {
|
||||
return h.client.Sum(nil)
|
||||
}
|
||||
|
||||
out := make([]byte, 0, md5.Size+sha1.Size)
|
||||
out = h.clientMD5.Sum(out)
|
||||
return h.client.Sum(out)
|
||||
}
|
||||
|
||||
// clientSum returns the contents of the verify_data member of a client's
|
||||
// Finished message.
|
||||
func (h finishedHash) clientSum(masterSecret []byte) []byte {
|
||||
out := make([]byte, finishedVerifyLength)
|
||||
h.prf(out, masterSecret, clientFinishedLabel, h.Sum())
|
||||
return out
|
||||
}
|
||||
|
||||
// serverSum returns the contents of the verify_data member of a server's
|
||||
// Finished message.
|
||||
func (h finishedHash) serverSum(masterSecret []byte) []byte {
|
||||
out := make([]byte, finishedVerifyLength)
|
||||
h.prf(out, masterSecret, serverFinishedLabel, h.Sum())
|
||||
return out
|
||||
}
|
||||
|
||||
// hashForClientCertificate returns the handshake messages so far, pre-hashed if
|
||||
// necessary, suitable for signing by a TLS client certificate.
|
||||
func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte {
|
||||
if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil {
|
||||
panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer")
|
||||
}
|
||||
|
||||
if sigType == signatureEd25519 {
|
||||
return h.buffer
|
||||
}
|
||||
|
||||
if h.version >= VersionTLS12 {
|
||||
hash := hashAlg.New()
|
||||
hash.Write(h.buffer)
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
if sigType == signatureECDSA {
|
||||
return h.server.Sum(nil)
|
||||
}
|
||||
|
||||
return h.Sum()
|
||||
}
|
||||
|
||||
// discardHandshakeBuffer is called when there is no more need to
|
||||
// buffer the entirety of the handshake messages.
|
||||
func (h *finishedHash) discardHandshakeBuffer() {
|
||||
h.buffer = nil
|
||||
}
|
||||
|
||||
// noExportedKeyingMaterial is used as a value of
|
||||
// ConnectionState.ekm when renegotiation is enabled and thus
|
||||
// we wish to fail all key-material export requests.
|
||||
func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
|
||||
return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled")
|
||||
}
|
||||
|
||||
// ekmFromMasterSecret generates exported keying material as defined in RFC 5705.
|
||||
func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) {
|
||||
return func(label string, context []byte, length int) ([]byte, error) {
|
||||
switch label {
|
||||
case "client finished", "server finished", "master secret", "key expansion":
|
||||
// These values are reserved and may not be used.
|
||||
return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label)
|
||||
}
|
||||
|
||||
seedLen := len(serverRandom) + len(clientRandom)
|
||||
if context != nil {
|
||||
seedLen += 2 + len(context)
|
||||
}
|
||||
seed := make([]byte, 0, seedLen)
|
||||
|
||||
seed = append(seed, clientRandom...)
|
||||
seed = append(seed, serverRandom...)
|
||||
|
||||
if context != nil {
|
||||
if len(context) >= 1<<16 {
|
||||
return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long")
|
||||
}
|
||||
seed = append(seed, byte(len(context)>>8), byte(len(context)))
|
||||
seed = append(seed, context...)
|
||||
}
|
||||
|
||||
keyMaterial := make([]byte, length)
|
||||
prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed)
|
||||
return keyMaterial, nil
|
||||
}
|
||||
}
|
|
@ -1,277 +0,0 @@
|
|||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
)
|
||||
|
||||
// sessionState contains the information that is serialized into a session
|
||||
// ticket in order to later resume a connection.
|
||||
type sessionState struct {
|
||||
vers uint16
|
||||
cipherSuite uint16
|
||||
createdAt uint64
|
||||
masterSecret []byte // opaque master_secret<1..2^16-1>;
|
||||
// struct { opaque certificate<1..2^24-1> } Certificate;
|
||||
certificates [][]byte // Certificate certificate_list<0..2^24-1>;
|
||||
|
||||
// usedOldKey is true if the ticket from which this session came from
|
||||
// was encrypted with an older key and thus should be refreshed.
|
||||
usedOldKey bool
|
||||
}
|
||||
|
||||
func (m *sessionState) marshal() ([]byte, error) {
|
||||
var b cryptobyte.Builder
|
||||
b.AddUint16(m.vers)
|
||||
b.AddUint16(m.cipherSuite)
|
||||
addUint64(&b, m.createdAt)
|
||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(m.masterSecret)
|
||||
})
|
||||
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
for _, cert := range m.certificates {
|
||||
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(cert)
|
||||
})
|
||||
}
|
||||
})
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func (m *sessionState) unmarshal(data []byte) bool {
|
||||
*m = sessionState{usedOldKey: m.usedOldKey}
|
||||
s := cryptobyte.String(data)
|
||||
if ok := s.ReadUint16(&m.vers) &&
|
||||
s.ReadUint16(&m.cipherSuite) &&
|
||||
readUint64(&s, &m.createdAt) &&
|
||||
readUint16LengthPrefixed(&s, &m.masterSecret) &&
|
||||
len(m.masterSecret) != 0; !ok {
|
||||
return false
|
||||
}
|
||||
var certList cryptobyte.String
|
||||
if !s.ReadUint24LengthPrefixed(&certList) {
|
||||
return false
|
||||
}
|
||||
for !certList.Empty() {
|
||||
var cert []byte
|
||||
if !readUint24LengthPrefixed(&certList, &cert) {
|
||||
return false
|
||||
}
|
||||
m.certificates = append(m.certificates, cert)
|
||||
}
|
||||
return s.Empty()
|
||||
}
|
||||
|
||||
// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first
|
||||
// version (revision = 0) doesn't carry any of the information needed for 0-RTT
|
||||
// validation and the nonce is always empty.
|
||||
// version (revision = 1) carries the max_early_data_size sent in the ticket.
|
||||
// version (revision = 2) carries the ALPN sent in the ticket.
|
||||
type sessionStateTLS13 struct {
|
||||
// uint8 version = 0x0304;
|
||||
// uint8 revision = 2;
|
||||
cipherSuite uint16
|
||||
createdAt uint64
|
||||
resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>;
|
||||
certificate Certificate // CertificateEntry certificate_list<0..2^24-1>;
|
||||
maxEarlyData uint32
|
||||
alpn string
|
||||
|
||||
appData []byte
|
||||
}
|
||||
|
||||
func (m *sessionStateTLS13) marshal() ([]byte, error) {
|
||||
var b cryptobyte.Builder
|
||||
b.AddUint16(VersionTLS13)
|
||||
b.AddUint8(2) // revision
|
||||
b.AddUint16(m.cipherSuite)
|
||||
addUint64(&b, m.createdAt)
|
||||
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(m.resumptionSecret)
|
||||
})
|
||||
marshalCertificate(&b, m.certificate)
|
||||
b.AddUint32(m.maxEarlyData)
|
||||
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes([]byte(m.alpn))
|
||||
})
|
||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(m.appData)
|
||||
})
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func (m *sessionStateTLS13) unmarshal(data []byte) bool {
|
||||
*m = sessionStateTLS13{}
|
||||
s := cryptobyte.String(data)
|
||||
var version uint16
|
||||
var revision uint8
|
||||
var alpn []byte
|
||||
ret := s.ReadUint16(&version) &&
|
||||
version == VersionTLS13 &&
|
||||
s.ReadUint8(&revision) &&
|
||||
revision == 2 &&
|
||||
s.ReadUint16(&m.cipherSuite) &&
|
||||
readUint64(&s, &m.createdAt) &&
|
||||
readUint8LengthPrefixed(&s, &m.resumptionSecret) &&
|
||||
len(m.resumptionSecret) != 0 &&
|
||||
unmarshalCertificate(&s, &m.certificate) &&
|
||||
s.ReadUint32(&m.maxEarlyData) &&
|
||||
readUint8LengthPrefixed(&s, &alpn) &&
|
||||
readUint16LengthPrefixed(&s, &m.appData) &&
|
||||
s.Empty()
|
||||
m.alpn = string(alpn)
|
||||
return ret
|
||||
}
|
||||
|
||||
func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
|
||||
if len(c.ticketKeys) == 0 {
|
||||
return nil, errors.New("tls: internal error: session ticket keys unavailable")
|
||||
}
|
||||
|
||||
encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size)
|
||||
keyName := encrypted[:ticketKeyNameLen]
|
||||
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
|
||||
macBytes := encrypted[len(encrypted)-sha256.Size:]
|
||||
|
||||
if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key := c.ticketKeys[0]
|
||||
copy(keyName, key.keyName[:])
|
||||
block, err := aes.NewCipher(key.aesKey[:])
|
||||
if err != nil {
|
||||
return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
|
||||
}
|
||||
cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state)
|
||||
|
||||
mac := hmac.New(sha256.New, key.hmacKey[:])
|
||||
mac.Write(encrypted[:len(encrypted)-sha256.Size])
|
||||
mac.Sum(macBytes[:0])
|
||||
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) {
|
||||
if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
keyName := encrypted[:ticketKeyNameLen]
|
||||
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
|
||||
macBytes := encrypted[len(encrypted)-sha256.Size:]
|
||||
ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
|
||||
|
||||
keyIndex := -1
|
||||
for i, candidateKey := range c.ticketKeys {
|
||||
if bytes.Equal(keyName, candidateKey.keyName[:]) {
|
||||
keyIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if keyIndex == -1 {
|
||||
return nil, false
|
||||
}
|
||||
key := &c.ticketKeys[keyIndex]
|
||||
|
||||
mac := hmac.New(sha256.New, key.hmacKey[:])
|
||||
mac.Write(encrypted[:len(encrypted)-sha256.Size])
|
||||
expected := mac.Sum(nil)
|
||||
|
||||
if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key.aesKey[:])
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
plaintext = make([]byte, len(ciphertext))
|
||||
cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
|
||||
|
||||
return plaintext, keyIndex > 0
|
||||
}
|
||||
|
||||
func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, error) {
|
||||
m := new(newSessionTicketMsgTLS13)
|
||||
|
||||
var certsFromClient [][]byte
|
||||
for _, cert := range c.peerCertificates {
|
||||
certsFromClient = append(certsFromClient, cert.Raw)
|
||||
}
|
||||
state := sessionStateTLS13{
|
||||
cipherSuite: c.cipherSuite,
|
||||
createdAt: uint64(c.config.time().Unix()),
|
||||
resumptionSecret: c.resumptionSecret,
|
||||
certificate: Certificate{
|
||||
Certificate: certsFromClient,
|
||||
OCSPStaple: c.ocspResponse,
|
||||
SignedCertificateTimestamps: c.scts,
|
||||
},
|
||||
appData: appData,
|
||||
alpn: c.clientProtocol,
|
||||
}
|
||||
if c.extraConfig != nil {
|
||||
state.maxEarlyData = c.extraConfig.MaxEarlyData
|
||||
}
|
||||
stateBytes, err := state.marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.label, err = c.encryptTicket(stateBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
|
||||
|
||||
// ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1
|
||||
// The value is not stored anywhere; we never need to check the ticket age
|
||||
// because 0-RTT is not supported.
|
||||
ageAdd := make([]byte, 4)
|
||||
_, err = c.config.rand().Read(ageAdd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.ageAdd = binary.LittleEndian.Uint32(ageAdd)
|
||||
|
||||
// ticket_nonce, which must be unique per connection, is always left at
|
||||
// zero because we only ever send one ticket per connection.
|
||||
|
||||
if c.extraConfig != nil {
|
||||
m.maxEarlyData = c.extraConfig.MaxEarlyData
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// GetSessionTicket generates a new session ticket.
|
||||
// It should only be called after the handshake completes.
|
||||
// It can only be used for servers, and only if the alternative record layer is set.
|
||||
// The ticket may be nil if config.SessionTicketsDisabled is set,
|
||||
// or if the client isn't able to receive session tickets.
|
||||
func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) {
|
||||
if c.isClient || !c.handshakeComplete() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil {
|
||||
return nil, errors.New("GetSessionTicket is only valid for servers after completion of the handshake, and if an alternative record layer is set.")
|
||||
}
|
||||
if c.config.SessionTicketsDisabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
m, err := c.getSessionTicketMsg(appData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.marshal()
|
||||
}
|
|
@ -1,362 +0,0 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// package qtls partially implements TLS 1.2, as specified in RFC 5246,
|
||||
// and TLS 1.3, as specified in RFC 8446.
|
||||
package qtls
|
||||
|
||||
// BUG(agl): The crypto/tls package only implements some countermeasures
|
||||
// against Lucky13 attacks on CBC-mode encryption, and only on SHA1
|
||||
// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and
|
||||
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Server returns a new TLS server side connection
|
||||
// using conn as the underlying transport.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
c := &Conn{
|
||||
conn: conn,
|
||||
config: fromConfig(config),
|
||||
extraConfig: extraConfig,
|
||||
}
|
||||
c.handshakeFn = c.serverHandshake
|
||||
return c
|
||||
}
|
||||
|
||||
// Client returns a new TLS client side connection
|
||||
// using conn as the underlying transport.
|
||||
// The config cannot be nil: users must set either ServerName or
|
||||
// InsecureSkipVerify in the config.
|
||||
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
c := &Conn{
|
||||
conn: conn,
|
||||
config: fromConfig(config),
|
||||
extraConfig: extraConfig,
|
||||
isClient: true,
|
||||
}
|
||||
c.handshakeFn = c.clientHandshake
|
||||
return c
|
||||
}
|
||||
|
||||
// A listener implements a network listener (net.Listener) for TLS connections.
|
||||
type listener struct {
|
||||
net.Listener
|
||||
config *Config
|
||||
extraConfig *ExtraConfig
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next incoming TLS connection.
|
||||
// The returned connection is of type *Conn.
|
||||
func (l *listener) Accept() (net.Conn, error) {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Server(c, l.config, l.extraConfig), nil
|
||||
}
|
||||
|
||||
// NewListener creates a Listener which accepts connections from an inner
|
||||
// Listener and wraps each connection with Server.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func NewListener(inner net.Listener, config *Config, extraConfig *ExtraConfig) net.Listener {
|
||||
l := new(listener)
|
||||
l.Listener = inner
|
||||
l.config = config
|
||||
l.extraConfig = extraConfig
|
||||
return l
|
||||
}
|
||||
|
||||
// Listen creates a TLS listener accepting connections on the
|
||||
// given network address using net.Listen.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func Listen(network, laddr string, config *Config, extraConfig *ExtraConfig) (net.Listener, error) {
|
||||
if config == nil || len(config.Certificates) == 0 &&
|
||||
config.GetCertificate == nil && config.GetConfigForClient == nil {
|
||||
return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
|
||||
}
|
||||
l, err := net.Listen(network, laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewListener(l, config, extraConfig), nil
|
||||
}
|
||||
|
||||
type timeoutError struct{}
|
||||
|
||||
func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
|
||||
func (timeoutError) Timeout() bool { return true }
|
||||
func (timeoutError) Temporary() bool { return true }
|
||||
|
||||
// DialWithDialer connects to the given network address using dialer.Dial and
|
||||
// then initiates a TLS handshake, returning the resulting TLS connection. Any
|
||||
// timeout or deadline given in the dialer apply to connection and TLS
|
||||
// handshake as a whole.
|
||||
//
|
||||
// DialWithDialer interprets a nil configuration as equivalent to the zero
|
||||
// configuration; see the documentation of Config for the defaults.
|
||||
//
|
||||
// DialWithDialer uses context.Background internally; to specify the context,
|
||||
// use Dialer.DialContext with NetDialer set to the desired dialer.
|
||||
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) {
|
||||
return dial(context.Background(), dialer, network, addr, config, extraConfig)
|
||||
}
|
||||
|
||||
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) {
|
||||
if netDialer.Timeout != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
if !netDialer.Deadline.IsZero() {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
rawConn, err := netDialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
colonPos := strings.LastIndex(addr, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(addr)
|
||||
}
|
||||
hostname := addr[:colonPos]
|
||||
|
||||
if config == nil {
|
||||
config = defaultConfig()
|
||||
}
|
||||
// If no ServerName is set, infer the ServerName
|
||||
// from the hostname we're connecting to.
|
||||
if config.ServerName == "" {
|
||||
// Make a copy to avoid polluting argument or default.
|
||||
c := config.Clone()
|
||||
c.ServerName = hostname
|
||||
config = c
|
||||
}
|
||||
|
||||
conn := Client(rawConn, config, extraConfig)
|
||||
if err := conn.HandshakeContext(ctx); err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Dial connects to the given network address using net.Dial
|
||||
// and then initiates a TLS handshake, returning the resulting
|
||||
// TLS connection.
|
||||
// Dial interprets a nil configuration as equivalent to
|
||||
// the zero configuration; see the documentation of Config
|
||||
// for the defaults.
|
||||
func Dial(network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) {
|
||||
return DialWithDialer(new(net.Dialer), network, addr, config, extraConfig)
|
||||
}
|
||||
|
||||
// Dialer dials TLS connections given a configuration and a Dialer for the
|
||||
// underlying connection.
|
||||
type Dialer struct {
|
||||
// NetDialer is the optional dialer to use for the TLS connections'
|
||||
// underlying TCP connections.
|
||||
// A nil NetDialer is equivalent to the net.Dialer zero value.
|
||||
NetDialer *net.Dialer
|
||||
|
||||
// Config is the TLS configuration to use for new connections.
|
||||
// A nil configuration is equivalent to the zero
|
||||
// configuration; see the documentation of Config for the
|
||||
// defaults.
|
||||
Config *Config
|
||||
|
||||
ExtraConfig *ExtraConfig
|
||||
}
|
||||
|
||||
// Dial connects to the given network address and initiates a TLS
|
||||
// handshake, returning the resulting TLS connection.
|
||||
//
|
||||
// The returned Conn, if any, will always be of type *Conn.
|
||||
//
|
||||
// Dial uses context.Background internally; to specify the context,
|
||||
// use DialContext.
|
||||
func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, addr)
|
||||
}
|
||||
|
||||
func (d *Dialer) netDialer() *net.Dialer {
|
||||
if d.NetDialer != nil {
|
||||
return d.NetDialer
|
||||
}
|
||||
return new(net.Dialer)
|
||||
}
|
||||
|
||||
// DialContext connects to the given network address and initiates a TLS
|
||||
// handshake, returning the resulting TLS connection.
|
||||
//
|
||||
// The provided Context must be non-nil. If the context expires before
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
//
|
||||
// The returned Conn, if any, will always be of type *Conn.
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
c, err := dial(ctx, d.netDialer(), network, addr, d.Config, d.ExtraConfig)
|
||||
if err != nil {
|
||||
// Don't return c (a typed nil) in an interface.
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// LoadX509KeyPair reads and parses a public/private key pair from a pair
|
||||
// of files. The files must contain PEM encoded data. The certificate file
|
||||
// may contain intermediate certificates following the leaf certificate to
|
||||
// form a certificate chain. On successful return, Certificate.Leaf will
|
||||
// be nil because the parsed form of the certificate is not retained.
|
||||
func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
|
||||
certPEMBlock, err := os.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return Certificate{}, err
|
||||
}
|
||||
keyPEMBlock, err := os.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return Certificate{}, err
|
||||
}
|
||||
return X509KeyPair(certPEMBlock, keyPEMBlock)
|
||||
}
|
||||
|
||||
// X509KeyPair parses a public/private key pair from a pair of
|
||||
// PEM encoded data. On successful return, Certificate.Leaf will be nil because
|
||||
// the parsed form of the certificate is not retained.
|
||||
func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
|
||||
fail := func(err error) (Certificate, error) { return Certificate{}, err }
|
||||
|
||||
var cert Certificate
|
||||
var skippedBlockTypes []string
|
||||
for {
|
||||
var certDERBlock *pem.Block
|
||||
certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
|
||||
if certDERBlock == nil {
|
||||
break
|
||||
}
|
||||
if certDERBlock.Type == "CERTIFICATE" {
|
||||
cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
|
||||
} else {
|
||||
skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
|
||||
}
|
||||
}
|
||||
|
||||
if len(cert.Certificate) == 0 {
|
||||
if len(skippedBlockTypes) == 0 {
|
||||
return fail(errors.New("tls: failed to find any PEM data in certificate input"))
|
||||
}
|
||||
if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
|
||||
return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
|
||||
}
|
||||
return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
|
||||
}
|
||||
|
||||
skippedBlockTypes = skippedBlockTypes[:0]
|
||||
var keyDERBlock *pem.Block
|
||||
for {
|
||||
keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
|
||||
if keyDERBlock == nil {
|
||||
if len(skippedBlockTypes) == 0 {
|
||||
return fail(errors.New("tls: failed to find any PEM data in key input"))
|
||||
}
|
||||
if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
|
||||
return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
|
||||
}
|
||||
return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
|
||||
}
|
||||
if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
|
||||
break
|
||||
}
|
||||
skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
|
||||
}
|
||||
|
||||
// We don't need to parse the public key for TLS, but we so do anyway
|
||||
// to check that it looks sane and matches the private key.
|
||||
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
return fail(err)
|
||||
}
|
||||
|
||||
cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
|
||||
if err != nil {
|
||||
return fail(err)
|
||||
}
|
||||
|
||||
switch pub := x509Cert.PublicKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return fail(errors.New("tls: private key type does not match public key type"))
|
||||
}
|
||||
if pub.N.Cmp(priv.N) != 0 {
|
||||
return fail(errors.New("tls: private key does not match public key"))
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return fail(errors.New("tls: private key type does not match public key type"))
|
||||
}
|
||||
if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
|
||||
return fail(errors.New("tls: private key does not match public key"))
|
||||
}
|
||||
case ed25519.PublicKey:
|
||||
priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
|
||||
if !ok {
|
||||
return fail(errors.New("tls: private key type does not match public key type"))
|
||||
}
|
||||
if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
|
||||
return fail(errors.New("tls: private key does not match public key"))
|
||||
}
|
||||
default:
|
||||
return fail(errors.New("tls: unknown public key algorithm"))
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
|
||||
// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys.
|
||||
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
|
||||
func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
|
||||
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
|
||||
return key, nil
|
||||
}
|
||||
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
|
||||
switch key := key.(type) {
|
||||
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
|
||||
return key, nil
|
||||
default:
|
||||
return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
|
||||
}
|
||||
}
|
||||
if key, err := x509.ParseECPrivateKey(der); err == nil {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("tls: failed to parse private key")
|
||||
}
|
|
@ -1,96 +0,0 @@
|
|||
package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !structsEqual(&tls.ConnectionState{}, &connectionState{}) {
|
||||
panic("qtls.ConnectionState doesn't match")
|
||||
}
|
||||
if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) {
|
||||
panic("qtls.ClientSessionState doesn't match")
|
||||
}
|
||||
if !structsEqual(&tls.CertificateRequestInfo{}, &certificateRequestInfo{}) {
|
||||
panic("qtls.CertificateRequestInfo doesn't match")
|
||||
}
|
||||
if !structsEqual(&tls.Config{}, &config{}) {
|
||||
panic("qtls.Config doesn't match")
|
||||
}
|
||||
if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) {
|
||||
panic("qtls.ClientHelloInfo doesn't match")
|
||||
}
|
||||
}
|
||||
|
||||
func toConnectionState(c connectionState) ConnectionState {
|
||||
return *(*ConnectionState)(unsafe.Pointer(&c))
|
||||
}
|
||||
|
||||
func toClientSessionState(s *clientSessionState) *ClientSessionState {
|
||||
return (*ClientSessionState)(unsafe.Pointer(s))
|
||||
}
|
||||
|
||||
func fromClientSessionState(s *ClientSessionState) *clientSessionState {
|
||||
return (*clientSessionState)(unsafe.Pointer(s))
|
||||
}
|
||||
|
||||
func toCertificateRequestInfo(i *certificateRequestInfo) *CertificateRequestInfo {
|
||||
return (*CertificateRequestInfo)(unsafe.Pointer(i))
|
||||
}
|
||||
|
||||
func toConfig(c *config) *Config {
|
||||
return (*Config)(unsafe.Pointer(c))
|
||||
}
|
||||
|
||||
func fromConfig(c *Config) *config {
|
||||
return (*config)(unsafe.Pointer(c))
|
||||
}
|
||||
|
||||
func toClientHelloInfo(chi *clientHelloInfo) *ClientHelloInfo {
|
||||
return (*ClientHelloInfo)(unsafe.Pointer(chi))
|
||||
}
|
||||
|
||||
func structsEqual(a, b interface{}) bool {
|
||||
return compare(reflect.ValueOf(a), reflect.ValueOf(b))
|
||||
}
|
||||
|
||||
func compare(a, b reflect.Value) bool {
|
||||
sa := a.Elem()
|
||||
sb := b.Elem()
|
||||
if sa.NumField() != sb.NumField() {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < sa.NumField(); i++ {
|
||||
fa := sa.Type().Field(i)
|
||||
fb := sb.Type().Field(i)
|
||||
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
|
||||
if fa.Type.Kind() != fb.Type.Kind() {
|
||||
return false
|
||||
}
|
||||
if fa.Type.Kind() == reflect.Slice {
|
||||
if !compareStruct(fa.Type.Elem(), fb.Type.Elem()) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func compareStruct(a, b reflect.Type) bool {
|
||||
if a.NumField() != b.NumField() {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < a.NumField(); i++ {
|
||||
fa := a.Field(i)
|
||||
fb := b.Field(i)
|
||||
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -6,10 +6,17 @@ package qtls
|
|||
|
||||
import "strconv"
|
||||
|
||||
type alert uint8
|
||||
// An AlertError is a TLS alert.
|
||||
//
|
||||
// When using a QUIC transport, QUICConn methods will return an error
|
||||
// which wraps AlertError rather than sending a TLS alert.
|
||||
type AlertError uint8
|
||||
|
||||
// Alert is a TLS alert
|
||||
type Alert = alert
|
||||
func (e AlertError) Error() string {
|
||||
return alert(e).String()
|
||||
}
|
||||
|
||||
type alert uint8
|
||||
|
||||
const (
|
||||
// alert level
|
||||
|
|
|
@ -15,8 +15,10 @@ import (
|
|||
"crypto/sha256"
|
||||
"fmt"
|
||||
"hash"
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
// CipherSuite is a TLS cipher suite. Note that most functions in this package
|
||||
|
@ -195,17 +197,6 @@ type cipherSuiteTLS13 struct {
|
|||
hash crypto.Hash
|
||||
}
|
||||
|
||||
type CipherSuiteTLS13 struct {
|
||||
ID uint16
|
||||
KeyLen int
|
||||
Hash crypto.Hash
|
||||
AEAD func(key, fixedNonce []byte) cipher.AEAD
|
||||
}
|
||||
|
||||
func (c *CipherSuiteTLS13) IVLen() int {
|
||||
return aeadNonceLength
|
||||
}
|
||||
|
||||
var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map.
|
||||
{TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256},
|
||||
{TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256},
|
||||
|
@ -362,6 +353,18 @@ var defaultCipherSuitesTLS13NoAES = []uint16{
|
|||
TLS_AES_256_GCM_SHA384,
|
||||
}
|
||||
|
||||
var (
|
||||
hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
|
||||
hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
|
||||
// Keep in sync with crypto/aes/cipher_s390x.go.
|
||||
hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR &&
|
||||
(cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
|
||||
|
||||
hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 ||
|
||||
runtime.GOARCH == "arm64" && hasGCMAsmARM64 ||
|
||||
runtime.GOARCH == "s390x" && hasGCMAsmS390X
|
||||
)
|
||||
|
||||
var aesgcmCiphers = map[uint16]bool{
|
||||
// TLS 1.2
|
||||
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true,
|
||||
|
@ -519,11 +522,6 @@ func aeadAESGCM(key, noncePrefix []byte) aead {
|
|||
return ret
|
||||
}
|
||||
|
||||
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
|
||||
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
|
||||
return aeadAESGCMTLS13(key, fixedNonce)
|
||||
}
|
||||
|
||||
func aeadAESGCMTLS13(key, nonceMask []byte) aead {
|
||||
if len(nonceMask) != aeadNonceLength {
|
||||
panic("tls: internal error: wrong nonce length")
|
||||
|
|
|
@ -82,11 +82,6 @@ const (
|
|||
compressionNone uint8 = 0
|
||||
)
|
||||
|
||||
type Extension struct {
|
||||
Type uint16
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// TLS extension numbers
|
||||
const (
|
||||
extensionServerName uint16 = 0
|
||||
|
@ -105,6 +100,7 @@ const (
|
|||
extensionCertificateAuthorities uint16 = 47
|
||||
extensionSignatureAlgorithmsCert uint16 = 50
|
||||
extensionKeyShare uint16 = 51
|
||||
extensionQUICTransportParameters uint16 = 57
|
||||
extensionRenegotiationInfo uint16 = 0xff01
|
||||
)
|
||||
|
||||
|
@ -113,14 +109,6 @@ const (
|
|||
scsvRenegotiation uint16 = 0x00ff
|
||||
)
|
||||
|
||||
type EncryptionLevel uint8
|
||||
|
||||
const (
|
||||
EncryptionHandshake EncryptionLevel = iota
|
||||
Encryption0RTT
|
||||
EncryptionApplication
|
||||
)
|
||||
|
||||
// CurveID is a tls.CurveID
|
||||
type CurveID = tls.CurveID
|
||||
|
||||
|
@ -294,12 +282,6 @@ type connectionState struct {
|
|||
ekm func(label string, context []byte, length int) ([]byte, error)
|
||||
}
|
||||
|
||||
type ConnectionStateWith0RTT struct {
|
||||
ConnectionState
|
||||
|
||||
Used0RTT bool // true if 0-RTT was both offered and accepted
|
||||
}
|
||||
|
||||
// ClientAuthType is tls.ClientAuthType
|
||||
type ClientAuthType = tls.ClientAuthType
|
||||
|
||||
|
@ -349,8 +331,6 @@ type clientSessionState struct {
|
|||
// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not
|
||||
// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which
|
||||
// are supported via this interface.
|
||||
//
|
||||
//go:generate sh -c "mockgen -package qtls -destination mock_client_session_cache_test.go github.com/quic-go/qtls-go1-20 ClientSessionCache"
|
||||
type ClientSessionCache = tls.ClientSessionCache
|
||||
|
||||
// SignatureScheme is a tls.SignatureScheme
|
||||
|
@ -736,64 +716,22 @@ type config struct {
|
|||
autoSessionTicketKeys []ticketKey
|
||||
}
|
||||
|
||||
// A RecordLayer handles encrypting and decrypting of TLS messages.
|
||||
type RecordLayer interface {
|
||||
SetReadKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte)
|
||||
SetWriteKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte)
|
||||
ReadHandshakeMessage() ([]byte, error)
|
||||
WriteRecord([]byte) (int, error)
|
||||
SendAlert(uint8)
|
||||
}
|
||||
|
||||
type ExtraConfig struct {
|
||||
// GetExtensions, if not nil, is called before a message that allows
|
||||
// sending of extensions is sent.
|
||||
// Currently only implemented for the ClientHello message (for the client)
|
||||
// and for the EncryptedExtensions message (for the server).
|
||||
// Only valid for TLS 1.3.
|
||||
GetExtensions func(handshakeMessageType uint8) []Extension
|
||||
|
||||
// ReceivedExtensions, if not nil, is called when a message that allows the
|
||||
// inclusion of extensions is received.
|
||||
// It is called with an empty slice of extensions, if the message didn't
|
||||
// contain any extensions.
|
||||
// Currently only implemented for the ClientHello message (sent by the
|
||||
// client) and for the EncryptedExtensions message (sent by the server).
|
||||
// Only valid for TLS 1.3.
|
||||
ReceivedExtensions func(handshakeMessageType uint8, exts []Extension)
|
||||
|
||||
// AlternativeRecordLayer is used by QUIC
|
||||
AlternativeRecordLayer RecordLayer
|
||||
|
||||
// Enforce the selection of a supported application protocol.
|
||||
// Only works for TLS 1.3.
|
||||
// If enabled, client and server have to agree on an application protocol.
|
||||
// Otherwise, connection establishment fails.
|
||||
EnforceNextProtoSelection bool
|
||||
|
||||
// If MaxEarlyData is greater than 0, the client will be allowed to send early
|
||||
// data when resuming a session.
|
||||
// Requires the AlternativeRecordLayer to be set.
|
||||
// If Enable0RTT is enabled, the client will be allowed to send early data when resuming a session.
|
||||
//
|
||||
// It has no meaning on the client.
|
||||
MaxEarlyData uint32
|
||||
Enable0RTT bool
|
||||
|
||||
// GetAppDataForSessionTicket requests application data to be sent with a session ticket.
|
||||
//
|
||||
// It has no meaning on the client.
|
||||
GetAppDataForSessionTicket func() []byte
|
||||
|
||||
// The Accept0RTT callback is called when the client offers 0-RTT.
|
||||
// The server then has to decide if it wants to accept or reject 0-RTT.
|
||||
// It is only used for servers.
|
||||
Accept0RTT func(appData []byte) bool
|
||||
|
||||
// 0RTTRejected is called when the server rejectes 0-RTT.
|
||||
// It is only used for clients.
|
||||
Rejected0RTT func()
|
||||
|
||||
// If set, the client will export the 0-RTT key when resuming a session that
|
||||
// allows sending of early data.
|
||||
// Requires the AlternativeRecordLayer to be set.
|
||||
//
|
||||
// It has no meaning to the server.
|
||||
Enable0RTT bool
|
||||
|
||||
// Is called when the client saves a session ticket to the session ticket.
|
||||
// This gives the application the opportunity to save some data along with the ticket,
|
||||
// which can be restored when the session ticket is used.
|
||||
|
@ -801,29 +739,20 @@ type ExtraConfig struct {
|
|||
|
||||
// Is called when the client uses a session ticket.
|
||||
// Restores the application data that was saved earlier on GetAppDataForSessionTicket.
|
||||
SetAppDataFromSessionState func([]byte)
|
||||
SetAppDataFromSessionState func([]byte) (allowEarlyData bool)
|
||||
}
|
||||
|
||||
// Clone clones.
|
||||
func (c *ExtraConfig) Clone() *ExtraConfig {
|
||||
return &ExtraConfig{
|
||||
GetExtensions: c.GetExtensions,
|
||||
ReceivedExtensions: c.ReceivedExtensions,
|
||||
AlternativeRecordLayer: c.AlternativeRecordLayer,
|
||||
EnforceNextProtoSelection: c.EnforceNextProtoSelection,
|
||||
MaxEarlyData: c.MaxEarlyData,
|
||||
Enable0RTT: c.Enable0RTT,
|
||||
GetAppDataForSessionTicket: c.GetAppDataForSessionTicket,
|
||||
Accept0RTT: c.Accept0RTT,
|
||||
Rejected0RTT: c.Rejected0RTT,
|
||||
GetAppDataForSessionState: c.GetAppDataForSessionState,
|
||||
SetAppDataFromSessionState: c.SetAppDataFromSessionState,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ExtraConfig) usesAlternativeRecordLayer() bool {
|
||||
return c != nil && c.AlternativeRecordLayer != nil
|
||||
}
|
||||
|
||||
const (
|
||||
// ticketKeyNameLen is the number of bytes of identifier that is prepended to
|
||||
// an encrypted session ticket in order to identify the key used to encrypt it.
|
||||
|
@ -1384,7 +1313,6 @@ func (c *config) BuildNameToCertificate() {
|
|||
|
||||
const (
|
||||
keyLogLabelTLS12 = "CLIENT_RANDOM"
|
||||
keyLogLabelEarlyTraffic = "CLIENT_EARLY_TRAFFIC_SECRET"
|
||||
keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET"
|
||||
keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET"
|
||||
keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0"
|
||||
|
@ -1523,16 +1451,4 @@ func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlg
|
|||
}
|
||||
|
||||
// CertificateVerificationError is returned when certificate verification fails during the handshake.
|
||||
type CertificateVerificationError struct {
|
||||
// UnverifiedCertificates and its contents should not be modified.
|
||||
UnverifiedCertificates []*x509.Certificate
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *CertificateVerificationError) Error() string {
|
||||
return fmt.Sprintf("tls: failed to verify certificate: %s", e.Err)
|
||||
}
|
||||
|
||||
func (e *CertificateVerificationError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
type CertificateVerificationError = tls.CertificateVerificationError
|
||||
|
|
|
@ -29,6 +29,7 @@ type Conn struct {
|
|||
conn net.Conn
|
||||
isClient bool
|
||||
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
|
||||
quic *quicState // nil for non-QUIC connections
|
||||
|
||||
// isHandshakeComplete is true if the connection is currently transferring
|
||||
// application data (i.e. is not currently processing a handshake).
|
||||
|
@ -40,11 +41,10 @@ type Conn struct {
|
|||
vers uint16 // TLS version
|
||||
haveVers bool // version has been negotiated
|
||||
config *config // configuration passed to constructor
|
||||
extraConfig *ExtraConfig
|
||||
// handshakes counts the number of handshakes performed on the
|
||||
// connection so far. If renegotiation is disabled then this is either
|
||||
// zero or one.
|
||||
extraConfig *ExtraConfig
|
||||
|
||||
handshakes int
|
||||
didResume bool // whether this connection was a session resumption
|
||||
cipherSuite uint16
|
||||
|
@ -65,13 +65,8 @@ type Conn struct {
|
|||
secureRenegotiation bool
|
||||
// ekm is a closure for exporting keying material.
|
||||
ekm func(label string, context []byte, length int) ([]byte, error)
|
||||
// For the client:
|
||||
// resumptionSecret is the resumption_master_secret for handling
|
||||
// NewSessionTicket messages. nil if config.SessionTicketsDisabled.
|
||||
// For the server:
|
||||
// resumptionSecret is the resumption_master_secret for generating
|
||||
// NewSessionTicket messages. Only used when the alternative record
|
||||
// layer is set. nil if config.SessionTicketsDisabled.
|
||||
// or sending NewSessionTicket messages.
|
||||
resumptionSecret []byte
|
||||
|
||||
// ticketKeys is the set of active session ticket keys for this
|
||||
|
@ -123,12 +118,7 @@ type Conn struct {
|
|||
// the rest of the bits are the number of goroutines in Conn.Write.
|
||||
activeCall atomic.Int32
|
||||
|
||||
used0RTT bool
|
||||
|
||||
tmp [16]byte
|
||||
|
||||
connStateMutex sync.Mutex
|
||||
connState ConnectionStateWith0RTT
|
||||
}
|
||||
|
||||
// Access to net.Conn methods.
|
||||
|
@ -188,9 +178,8 @@ type halfConn struct {
|
|||
nextCipher any // next encryption state
|
||||
nextMac hash.Hash // next MAC algorithm
|
||||
|
||||
level QUICEncryptionLevel // current QUIC encryption level
|
||||
trafficSecret []byte // current TLS 1.3 traffic secret
|
||||
|
||||
setKeyCallback func(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte)
|
||||
}
|
||||
|
||||
type permanentError struct {
|
||||
|
@ -235,20 +224,9 @@ func (hc *halfConn) changeCipherSpec() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (hc *halfConn) exportKey(encLevel EncryptionLevel, suite *cipherSuiteTLS13, trafficSecret []byte) {
|
||||
if hc.setKeyCallback != nil {
|
||||
s := &CipherSuiteTLS13{
|
||||
ID: suite.id,
|
||||
KeyLen: suite.keyLen,
|
||||
Hash: suite.hash,
|
||||
AEAD: func(key, fixedNonce []byte) cipher.AEAD { return suite.aead(key, fixedNonce) },
|
||||
}
|
||||
hc.setKeyCallback(encLevel, s, trafficSecret)
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
|
||||
func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
|
||||
hc.trafficSecret = secret
|
||||
hc.level = level
|
||||
key, iv := suite.trafficKey(secret)
|
||||
hc.cipher = suite.aead(key, iv)
|
||||
for i := range hc.seq {
|
||||
|
@ -481,13 +459,6 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
|
|||
return plaintext, typ, nil
|
||||
}
|
||||
|
||||
func (c *Conn) setAlternativeRecordLayer() {
|
||||
if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil {
|
||||
c.in.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetReadKey
|
||||
c.out.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetWriteKey
|
||||
}
|
||||
}
|
||||
|
||||
// sliceForAppend extends the input slice by n bytes. head is the full extended
|
||||
// slice, while tail is the appended part. If the original slice has sufficient
|
||||
// capacity no allocation is performed.
|
||||
|
@ -646,6 +617,10 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
|
|||
}
|
||||
c.input.Reset(nil)
|
||||
|
||||
if c.quic != nil {
|
||||
return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
|
||||
}
|
||||
|
||||
// Read header, payload.
|
||||
if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
|
||||
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
|
||||
|
@ -729,6 +704,9 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
|
|||
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
|
||||
case recordTypeAlert:
|
||||
if c.quic != nil {
|
||||
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
if len(data) != 2 {
|
||||
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
}
|
||||
|
@ -846,6 +824,9 @@ func (c *Conn) readFromUntil(r io.Reader, n int) error {
|
|||
|
||||
// sendAlert sends a TLS alert message.
|
||||
func (c *Conn) sendAlertLocked(err alert) error {
|
||||
if c.quic != nil {
|
||||
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
|
||||
}
|
||||
switch err {
|
||||
case alertNoRenegotiation, alertCloseNotify:
|
||||
c.tmp[0] = alertLevelWarning
|
||||
|
@ -865,11 +846,6 @@ func (c *Conn) sendAlertLocked(err alert) error {
|
|||
|
||||
// sendAlert sends a TLS alert message.
|
||||
func (c *Conn) sendAlert(err alert) error {
|
||||
if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil {
|
||||
c.extraConfig.AlternativeRecordLayer.SendAlert(uint8(err))
|
||||
return &net.OpError{Op: "local error", Err: err}
|
||||
}
|
||||
|
||||
c.out.Lock()
|
||||
defer c.out.Unlock()
|
||||
return c.sendAlertLocked(err)
|
||||
|
@ -985,6 +961,19 @@ var outBufPool = sync.Pool{
|
|||
// writeRecordLocked writes a TLS record with the given type and payload to the
|
||||
// connection and updates the record layer state.
|
||||
func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
|
||||
if c.quic != nil {
|
||||
if typ != recordTypeHandshake {
|
||||
return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
|
||||
}
|
||||
c.quicWriteCryptoData(c.out.level, data)
|
||||
if !c.buffering {
|
||||
if _, err := c.flush(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
outBufPtr := outBufPool.Get().(*[]byte)
|
||||
outBuf := *outBufPtr
|
||||
defer func() {
|
||||
|
@ -1046,69 +1035,63 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
|
|||
// the record layer state. If transcript is non-nil the marshalled message is
|
||||
// written to it.
|
||||
func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
|
||||
c.out.Lock()
|
||||
defer c.out.Unlock()
|
||||
|
||||
data, err := msg.marshal()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
c.out.Lock()
|
||||
defer c.out.Unlock()
|
||||
|
||||
if transcript != nil {
|
||||
transcript.Write(data)
|
||||
}
|
||||
|
||||
if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil {
|
||||
return c.extraConfig.AlternativeRecordLayer.WriteRecord(data)
|
||||
}
|
||||
|
||||
return c.writeRecordLocked(recordTypeHandshake, data)
|
||||
}
|
||||
|
||||
// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and
|
||||
// updates the record layer state.
|
||||
func (c *Conn) writeChangeCipherRecord() error {
|
||||
if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.out.Lock()
|
||||
defer c.out.Unlock()
|
||||
_, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
|
||||
return err
|
||||
}
|
||||
|
||||
// readHandshakeBytes reads handshake data until c.hand contains at least n bytes.
|
||||
func (c *Conn) readHandshakeBytes(n int) error {
|
||||
if c.quic != nil {
|
||||
return c.quicReadHandshakeBytes(n)
|
||||
}
|
||||
for c.hand.Len() < n {
|
||||
if err := c.readRecord(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readHandshake reads the next handshake message from
|
||||
// the record layer. If transcript is non-nil, the message
|
||||
// is written to the passed transcriptHash.
|
||||
func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
|
||||
var data []byte
|
||||
if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil {
|
||||
var err error
|
||||
data, err = c.extraConfig.AlternativeRecordLayer.ReadHandshakeMessage()
|
||||
if err != nil {
|
||||
if err := c.readHandshakeBytes(4); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
for c.hand.Len() < 4 {
|
||||
if err := c.readRecord(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
data = c.hand.Bytes()
|
||||
data := c.hand.Bytes()
|
||||
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
|
||||
if n > maxHandshake {
|
||||
c.sendAlertLocked(alertInternalError)
|
||||
return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
|
||||
}
|
||||
for c.hand.Len() < 4+n {
|
||||
if err := c.readRecord(); err != nil {
|
||||
if err := c.readHandshakeBytes(4 + n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
data = c.hand.Next(4 + n)
|
||||
}
|
||||
return c.unmarshalHandshakeMessage(data, transcript)
|
||||
}
|
||||
|
||||
func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
|
||||
var m handshakeMessage
|
||||
switch data[0] {
|
||||
case typeHelloRequest:
|
||||
|
@ -1288,10 +1271,6 @@ func (c *Conn) handleRenegotiation() error {
|
|||
return c.handshakeErr
|
||||
}
|
||||
|
||||
func (c *Conn) HandlePostHandshakeMessage() error {
|
||||
return c.handlePostHandshakeMessage()
|
||||
}
|
||||
|
||||
// handlePostHandshakeMessage processes a handshake message arrived after the
|
||||
// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation.
|
||||
func (c *Conn) handlePostHandshakeMessage() error {
|
||||
|
@ -1303,7 +1282,6 @@ func (c *Conn) handlePostHandshakeMessage() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.retryCount++
|
||||
if c.retryCount > maxUselessRecords {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
|
@ -1315,20 +1293,28 @@ func (c *Conn) handlePostHandshakeMessage() error {
|
|||
return c.handleNewSessionTicket(msg)
|
||||
case *keyUpdateMsg:
|
||||
return c.handleKeyUpdate(msg)
|
||||
default:
|
||||
}
|
||||
// The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
|
||||
// as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
|
||||
// unexpected_message alert here doesn't provide it with enough information to distinguish
|
||||
// this condition from other unexpected messages. This is probably fine.
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
|
||||
if c.quic != nil {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
|
||||
}
|
||||
|
||||
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
|
||||
if cipherSuite == nil {
|
||||
return c.in.setErrorLocked(c.sendAlert(alertInternalError))
|
||||
}
|
||||
|
||||
newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
|
||||
c.in.setTrafficSecret(cipherSuite, newSecret)
|
||||
c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
|
||||
|
||||
if keyUpdate.updateRequested {
|
||||
c.out.Lock()
|
||||
|
@ -1347,7 +1333,7 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
|
|||
}
|
||||
|
||||
newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
|
||||
c.out.setTrafficSecret(cipherSuite, newSecret)
|
||||
c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -1508,12 +1494,15 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
|
|||
// this cancellation. In the former case, we need to close the connection.
|
||||
defer cancel()
|
||||
|
||||
if c.quic != nil {
|
||||
c.quic.cancelc = handshakeCtx.Done()
|
||||
c.quic.cancel = cancel
|
||||
} else if ctx.Done() != nil {
|
||||
// Start the "interrupter" goroutine, if this context might be canceled.
|
||||
// (The background context cannot).
|
||||
//
|
||||
// The interrupter goroutine waits for the input context to be done and
|
||||
// closes the connection if this happens before the function returns.
|
||||
if ctx.Done() != nil {
|
||||
done := make(chan struct{})
|
||||
interruptRes := make(chan error, 1)
|
||||
defer func() {
|
||||
|
@ -1564,21 +1553,38 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
|
|||
panic("tls: internal error: handshake returned an error but is marked successful")
|
||||
}
|
||||
|
||||
if c.quic != nil {
|
||||
if c.handshakeErr == nil {
|
||||
c.quicHandshakeComplete()
|
||||
// Provide the 1-RTT read secret now that the handshake is complete.
|
||||
// The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing
|
||||
// the handshake (RFC 9001, Section 5.7).
|
||||
c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
|
||||
} else {
|
||||
var a alert
|
||||
c.out.Lock()
|
||||
if !errors.As(c.out.err, &a) {
|
||||
a = alertInternalError
|
||||
}
|
||||
c.out.Unlock()
|
||||
// Return an error which wraps both the handshake error and
|
||||
// any alert error we may have sent, or alertInternalError
|
||||
// if we didn't send an alert.
|
||||
// Truncate the text of the alert to 0 characters.
|
||||
c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
|
||||
}
|
||||
close(c.quic.blockedc)
|
||||
close(c.quic.signalc)
|
||||
}
|
||||
|
||||
return c.handshakeErr
|
||||
}
|
||||
|
||||
// ConnectionState returns basic TLS details about the connection.
|
||||
func (c *Conn) ConnectionState() ConnectionState {
|
||||
c.connStateMutex.Lock()
|
||||
defer c.connStateMutex.Unlock()
|
||||
return c.connState.ConnectionState
|
||||
}
|
||||
|
||||
// ConnectionStateWith0RTT returns basic TLS details (incl. 0-RTT status) about the connection.
|
||||
func (c *Conn) ConnectionStateWith0RTT() ConnectionStateWith0RTT {
|
||||
c.connStateMutex.Lock()
|
||||
defer c.connStateMutex.Unlock()
|
||||
return c.connState
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
return c.connectionStateLocked()
|
||||
}
|
||||
|
||||
func (c *Conn) connectionStateLocked() ConnectionState {
|
||||
|
@ -1609,15 +1615,6 @@ func (c *Conn) connectionStateLocked() ConnectionState {
|
|||
return toConnectionState(state)
|
||||
}
|
||||
|
||||
func (c *Conn) updateConnectionState() {
|
||||
c.connStateMutex.Lock()
|
||||
defer c.connStateMutex.Unlock()
|
||||
c.connState = ConnectionStateWith0RTT{
|
||||
Used0RTT: c.used0RTT,
|
||||
ConnectionState: c.connectionStateLocked(),
|
||||
}
|
||||
}
|
||||
|
||||
// OCSPResponse returns the stapled OCSP response from the TLS server, if
|
||||
// any. (Only valid for client connections.)
|
||||
func (c *Conn) OCSPResponse() []byte {
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
//go:build !js
|
||||
// +build !js
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
var (
|
||||
hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
|
||||
hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
|
||||
// Keep in sync with crypto/aes/cipher_s390x.go.
|
||||
hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR &&
|
||||
(cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
|
||||
|
||||
hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 ||
|
||||
runtime.GOARCH == "arm64" && hasGCMAsmARM64 ||
|
||||
runtime.GOARCH == "s390x" && hasGCMAsmS390X
|
||||
)
|
|
@ -1,12 +0,0 @@
|
|||
//go:build js
|
||||
// +build js
|
||||
|
||||
package qtls
|
||||
|
||||
var (
|
||||
hasGCMAsmAMD64 = false
|
||||
hasGCMAsmARM64 = false
|
||||
hasGCMAsmS390X = false
|
||||
|
||||
hasAESGCMHardwareSupport = false
|
||||
)
|
|
@ -57,23 +57,12 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
|
|||
return nil, nil, errors.New("tls: NextProtos values too large")
|
||||
}
|
||||
|
||||
var supportedVersions []uint16
|
||||
var clientHelloVersion uint16
|
||||
if c.extraConfig.usesAlternativeRecordLayer() {
|
||||
if config.maxSupportedVersion(roleClient) < VersionTLS13 {
|
||||
return nil, nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3")
|
||||
}
|
||||
// Only offer TLS 1.3 when QUIC is used.
|
||||
supportedVersions = []uint16{VersionTLS13}
|
||||
clientHelloVersion = VersionTLS13
|
||||
} else {
|
||||
supportedVersions = config.supportedVersions(roleClient)
|
||||
supportedVersions := config.supportedVersions(roleClient)
|
||||
if len(supportedVersions) == 0 {
|
||||
return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion")
|
||||
}
|
||||
clientHelloVersion = config.maxSupportedVersion(roleClient)
|
||||
}
|
||||
|
||||
clientHelloVersion := config.maxSupportedVersion(roleClient)
|
||||
// The version at the beginning of the ClientHello was capped at TLS 1.2
|
||||
// for compatibility reasons. The supported_versions extension is used
|
||||
// to negotiate versions now. See RFC 8446, Section 4.2.1.
|
||||
|
@ -127,7 +116,9 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
|
|||
// A random session ID is used to detect when the server accepted a ticket
|
||||
// and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as
|
||||
// a compatibility measure (see RFC 8446, Section 4.1.2).
|
||||
if c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil {
|
||||
//
|
||||
// The session ID is not set for QUIC connections (see RFC 9001, Section 8.4).
|
||||
if c.quic == nil {
|
||||
hello.sessionId = make([]byte, 32)
|
||||
if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
|
||||
return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
|
||||
|
@ -143,6 +134,9 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
|
|||
|
||||
var secret clientKeySharePrivate
|
||||
if hello.supportedVersions[0] == VersionTLS13 {
|
||||
if len(hello.supportedVersions) == 1 {
|
||||
hello.cipherSuites = hello.cipherSuites[:0]
|
||||
}
|
||||
if hasAESGCMHardwareSupport {
|
||||
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...)
|
||||
} else {
|
||||
|
@ -176,8 +170,15 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
|
|||
}
|
||||
}
|
||||
|
||||
if hello.supportedVersions[0] == VersionTLS13 && c.extraConfig != nil && c.extraConfig.GetExtensions != nil {
|
||||
hello.additionalExtensions = c.extraConfig.GetExtensions(typeClientHello)
|
||||
if c.quic != nil {
|
||||
p, err := c.quicGetTransportParameters()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if p == nil {
|
||||
p = []byte{}
|
||||
}
|
||||
hello.quicTransportParameters = p
|
||||
}
|
||||
|
||||
return hello, secret, nil
|
||||
|
@ -187,7 +188,6 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
|
|||
if c.config == nil {
|
||||
c.config = fromConfig(defaultConfig())
|
||||
}
|
||||
c.setAlternativeRecordLayer()
|
||||
|
||||
// This may be a renegotiation handshake, in which case some fields
|
||||
// need to be reset.
|
||||
|
@ -204,27 +204,6 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
|
|||
return err
|
||||
}
|
||||
if cacheKey != "" && session != nil {
|
||||
var deletedTicket bool
|
||||
if session.vers == VersionTLS13 && hello.earlyData && c.extraConfig != nil && c.extraConfig.Enable0RTT {
|
||||
// don't reuse a session ticket that enabled 0-RTT
|
||||
c.config.ClientSessionCache.Put(cacheKey, nil)
|
||||
deletedTicket = true
|
||||
|
||||
if suite := cipherSuiteTLS13ByID(session.cipherSuite); suite != nil {
|
||||
h := suite.hash.New()
|
||||
helloBytes, err := hello.marshal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.Write(helloBytes)
|
||||
clientEarlySecret := suite.deriveSecret(earlySecret, "c e traffic", h)
|
||||
c.out.exportKey(Encryption0RTT, suite, clientEarlySecret)
|
||||
if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hello.random, clientEarlySecret); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if !deletedTicket {
|
||||
defer func() {
|
||||
// If we got a handshake failure when resuming a session, throw away
|
||||
// the session ticket. See RFC 5077, Section 3.2.
|
||||
|
@ -237,12 +216,21 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
|
|||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := c.writeHandshakeRecord(hello, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hello.earlyData {
|
||||
suite := cipherSuiteTLS13ByID(session.cipherSuite)
|
||||
transcript := suite.hash.New()
|
||||
if err := transcriptMsg(hello, transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
earlyTrafficSecret := suite.deriveSecret(earlySecret, clientEarlyTrafficLabel, transcript)
|
||||
c.quicSetWriteSecret(QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret)
|
||||
}
|
||||
|
||||
// serverHelloMsg is not included in the transcript
|
||||
msg, err := c.readHandshake(nil)
|
||||
if err != nil {
|
||||
|
@ -305,7 +293,6 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
|
|||
c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(hs.session))
|
||||
}
|
||||
|
||||
c.updateConnectionState()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -358,7 +345,10 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
|
|||
}
|
||||
|
||||
// Try to resume a previously negotiated TLS session, if available.
|
||||
cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
|
||||
cacheKey = c.clientSessionCacheKey()
|
||||
if cacheKey == "" {
|
||||
return "", nil, nil, nil, nil
|
||||
}
|
||||
sess, ok := c.config.ClientSessionCache.Get(cacheKey)
|
||||
if !ok || sess == nil {
|
||||
return cacheKey, nil, nil, nil, nil
|
||||
|
@ -442,6 +432,17 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
|
|||
return cacheKey, nil, nil, nil, nil
|
||||
}
|
||||
|
||||
if c.quic != nil && maxEarlyData > 0 {
|
||||
var earlyData bool
|
||||
if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil {
|
||||
earlyData = c.extraConfig.SetAppDataFromSessionState(appData)
|
||||
}
|
||||
// For 0-RTT, the cipher suite has to match exactly.
|
||||
if earlyData && mutualCipherSuiteTLS13(hello.cipherSuites, session.cipherSuite) != nil {
|
||||
hello.earlyData = true
|
||||
}
|
||||
}
|
||||
|
||||
// Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1.
|
||||
ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond)
|
||||
identity := pskIdentity{
|
||||
|
@ -456,9 +457,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
|
|||
session.nonce, cipherSuite.hash.Size())
|
||||
earlySecret = cipherSuite.extract(psk, nil)
|
||||
binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil)
|
||||
if c.extraConfig != nil {
|
||||
hello.earlyData = c.extraConfig.Enable0RTT && maxEarlyData > 0
|
||||
}
|
||||
transcript := cipherSuite.hash.New()
|
||||
helloBytes, err := hello.marshalWithoutBinders()
|
||||
if err != nil {
|
||||
|
@ -470,9 +468,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
|
|||
return "", nil, nil, nil, err
|
||||
}
|
||||
|
||||
if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil {
|
||||
c.extraConfig.SetAppDataFromSessionState(appData)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -827,7 +822,7 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil {
|
||||
if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol, false); err != nil {
|
||||
c.sendAlert(alertUnsupportedExtension)
|
||||
return false, err
|
||||
}
|
||||
|
@ -865,8 +860,12 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
|
|||
|
||||
// checkALPN ensure that the server's choice of ALPN protocol is compatible with
|
||||
// the protocols that we advertised in the Client Hello.
|
||||
func checkALPN(clientProtos []string, serverProto string) error {
|
||||
func checkALPN(clientProtos []string, serverProto string, quic bool) error {
|
||||
if serverProto == "" {
|
||||
if quic && len(clientProtos) > 0 {
|
||||
// RFC 9001, Section 8.1
|
||||
return errors.New("tls: server did not select an ALPN protocol")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if len(clientProtos) == 0 {
|
||||
|
@ -962,6 +961,10 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// maxRSAKeySize is the maximum RSA key size in bits that we are willing
|
||||
// to verify the signatures of during a TLS handshake.
|
||||
const maxRSAKeySize = 8192
|
||||
|
||||
// verifyServerCertificate parses and verifies the provided chain, setting
|
||||
// c.verifiedChains and c.peerCertificates or sending the appropriate alert.
|
||||
func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
|
||||
|
@ -973,6 +976,10 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
|
|||
c.sendAlert(alertBadCertificate)
|
||||
return errors.New("tls: failed to parse certificate from server: " + err.Error())
|
||||
}
|
||||
if cert.cert.PublicKeyAlgorithm == x509.RSA && cert.cert.PublicKey.(*rsa.PublicKey).N.BitLen() > maxRSAKeySize {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return fmt.Errorf("tls: server sent certificate containing RSA key larger than %d bits", maxRSAKeySize)
|
||||
}
|
||||
activeHandles[i] = cert
|
||||
certs[i] = cert.cert
|
||||
}
|
||||
|
@ -1106,15 +1113,16 @@ func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate,
|
|||
return new(Certificate), nil
|
||||
}
|
||||
|
||||
const clientSessionCacheKeyPrefix = "qtls-"
|
||||
|
||||
// clientSessionCacheKey returns a key used to cache sessionTickets that could
|
||||
// be used to resume previously negotiated TLS sessions with a server.
|
||||
func clientSessionCacheKey(serverAddr net.Addr, config *config) string {
|
||||
if len(config.ServerName) > 0 {
|
||||
return clientSessionCacheKeyPrefix + config.ServerName
|
||||
func (c *Conn) clientSessionCacheKey() string {
|
||||
if len(c.config.ServerName) > 0 {
|
||||
return c.config.ServerName
|
||||
}
|
||||
return clientSessionCacheKeyPrefix + serverAddr.String()
|
||||
if c.conn != nil {
|
||||
return c.conn.RemoteAddr().String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// hostnameInSNI converts name into an appropriate hostname for SNI.
|
||||
|
|
|
@ -91,7 +91,6 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
|
|||
if err := hs.processServerHello(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.updateConnectionState()
|
||||
if err := hs.sendDummyChangeCipherSpec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -104,7 +103,6 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
|
|||
if err := hs.readServerCertificate(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.updateConnectionState()
|
||||
if err := hs.readServerFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -125,7 +123,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
|
|||
})
|
||||
|
||||
c.isHandshakeComplete.Store(true)
|
||||
c.updateConnectionState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -187,6 +185,9 @@ func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
|
|||
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
|
||||
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
|
||||
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
|
||||
if hs.c.quic != nil {
|
||||
return nil
|
||||
}
|
||||
if hs.sentDummyCCS {
|
||||
return nil
|
||||
}
|
||||
|
@ -293,7 +294,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
|
|||
transcript := hs.suite.hash.New()
|
||||
transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
|
||||
transcript.Write(chHash)
|
||||
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
|
||||
if err := transcriptMsg(hs.serverHello, transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
helloBytes, err := hs.hello.marshalWithoutBinders()
|
||||
|
@ -312,10 +313,11 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
|
|||
}
|
||||
}
|
||||
|
||||
if hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil {
|
||||
c.extraConfig.Rejected0RTT()
|
||||
if hs.hello.earlyData {
|
||||
hs.hello.earlyData = false
|
||||
c.quicRejectedEarlyData()
|
||||
}
|
||||
hs.hello.earlyData = false // disable 0-RTT
|
||||
|
||||
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -430,12 +432,18 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
|
|||
|
||||
clientSecret := hs.suite.deriveSecret(handshakeSecret,
|
||||
clientHandshakeTrafficLabel, hs.transcript)
|
||||
c.out.exportKey(EncryptionHandshake, hs.suite, clientSecret)
|
||||
c.out.setTrafficSecret(hs.suite, clientSecret)
|
||||
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
|
||||
serverSecret := hs.suite.deriveSecret(handshakeSecret,
|
||||
serverHandshakeTrafficLabel, hs.transcript)
|
||||
c.in.exportKey(EncryptionHandshake, hs.suite, serverSecret)
|
||||
c.in.setTrafficSecret(hs.suite, serverSecret)
|
||||
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
|
||||
|
||||
if c.quic != nil {
|
||||
if c.hand.Len() != 0 {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
|
||||
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
|
||||
}
|
||||
|
||||
err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
|
||||
if err != nil {
|
||||
|
@ -467,28 +475,35 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error {
|
|||
c.sendAlert(alertUnexpectedMessage)
|
||||
return unexpectedMessageError(encryptedExtensions, msg)
|
||||
}
|
||||
// Notify the caller if 0-RTT was rejected.
|
||||
if !encryptedExtensions.earlyData && hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil {
|
||||
c.extraConfig.Rejected0RTT()
|
||||
}
|
||||
c.used0RTT = encryptedExtensions.earlyData
|
||||
if hs.c.extraConfig != nil && hs.c.extraConfig.ReceivedExtensions != nil {
|
||||
hs.c.extraConfig.ReceivedExtensions(typeEncryptedExtensions, encryptedExtensions.additionalExtensions)
|
||||
}
|
||||
|
||||
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil {
|
||||
c.sendAlert(alertUnsupportedExtension)
|
||||
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol, c.quic != nil); err != nil {
|
||||
// RFC 8446 specifies that no_application_protocol is sent by servers, but
|
||||
// does not specify how clients handle the selection of an incompatible protocol.
|
||||
// RFC 9001 Section 8.1 specifies that QUIC clients send no_application_protocol
|
||||
// in this case. Always sending no_application_protocol seems reasonable.
|
||||
c.sendAlert(alertNoApplicationProtocol)
|
||||
return err
|
||||
}
|
||||
c.clientProtocol = encryptedExtensions.alpnProtocol
|
||||
|
||||
if c.extraConfig != nil && c.extraConfig.EnforceNextProtoSelection {
|
||||
if len(encryptedExtensions.alpnProtocol) == 0 {
|
||||
// the server didn't select an ALPN
|
||||
c.sendAlert(alertNoApplicationProtocol)
|
||||
return errors.New("ALPN negotiation failed. Server didn't offer any protocols")
|
||||
if c.quic != nil {
|
||||
if encryptedExtensions.quicTransportParameters == nil {
|
||||
// RFC 9001 Section 8.2.
|
||||
c.sendAlert(alertMissingExtension)
|
||||
return errors.New("tls: server did not send a quic_transport_parameters extension")
|
||||
}
|
||||
c.quicSetTransportParameters(encryptedExtensions.quicTransportParameters)
|
||||
} else {
|
||||
if encryptedExtensions.quicTransportParameters != nil {
|
||||
c.sendAlert(alertUnsupportedExtension)
|
||||
return errors.New("tls: server sent an unexpected quic_transport_parameters extension")
|
||||
}
|
||||
}
|
||||
|
||||
if hs.hello.earlyData && !encryptedExtensions.earlyData {
|
||||
c.quicRejectedEarlyData()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -616,8 +631,7 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error {
|
|||
clientApplicationTrafficLabel, hs.transcript)
|
||||
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
|
||||
serverApplicationTrafficLabel, hs.transcript)
|
||||
c.in.exportKey(EncryptionApplication, hs.suite, serverSecret)
|
||||
c.in.setTrafficSecret(hs.suite, serverSecret)
|
||||
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
|
||||
|
||||
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
|
||||
if err != nil {
|
||||
|
@ -713,14 +727,20 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
|
|||
return err
|
||||
}
|
||||
|
||||
c.out.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret)
|
||||
c.out.setTrafficSecret(hs.suite, hs.trafficSecret)
|
||||
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
|
||||
|
||||
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
|
||||
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
|
||||
resumptionLabel, hs.transcript)
|
||||
}
|
||||
|
||||
if c.quic != nil {
|
||||
if c.hand.Len() != 0 {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, hs.trafficSecret)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -791,8 +811,10 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
|
|||
scts: c.scts,
|
||||
}
|
||||
|
||||
cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
|
||||
cacheKey := c.clientSessionCacheKey()
|
||||
if cacheKey != "" {
|
||||
c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -93,7 +93,7 @@ type clientHelloMsg struct {
|
|||
pskModes []uint8
|
||||
pskIdentities []pskIdentity
|
||||
pskBinders [][]byte
|
||||
additionalExtensions []Extension
|
||||
quicTransportParameters []byte
|
||||
}
|
||||
|
||||
func (m *clientHelloMsg) marshal() ([]byte, error) {
|
||||
|
@ -247,10 +247,11 @@ func (m *clientHelloMsg) marshal() ([]byte, error) {
|
|||
})
|
||||
})
|
||||
}
|
||||
for _, ext := range m.additionalExtensions {
|
||||
exts.AddUint16(ext.Type)
|
||||
if m.quicTransportParameters != nil { // marshal zero-length parameters when present
|
||||
// RFC 9001, Section 8.2
|
||||
exts.AddUint16(extensionQUICTransportParameters)
|
||||
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
|
||||
exts.AddBytes(ext.Data)
|
||||
exts.AddBytes(m.quicTransportParameters)
|
||||
})
|
||||
}
|
||||
if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
|
||||
|
@ -567,6 +568,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
|||
if !readUint8LengthPrefixed(&extData, &m.pskModes) {
|
||||
return false
|
||||
}
|
||||
case extensionQUICTransportParameters:
|
||||
m.quicTransportParameters = make([]byte, len(extData))
|
||||
if !extData.CopyBytes(m.quicTransportParameters) {
|
||||
return false
|
||||
}
|
||||
case extensionPreSharedKey:
|
||||
// RFC 8446, Section 4.2.11
|
||||
if !extensions.Empty() {
|
||||
|
@ -598,7 +604,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
|||
m.pskBinders = append(m.pskBinders, binder)
|
||||
}
|
||||
default:
|
||||
m.additionalExtensions = append(m.additionalExtensions, Extension{Type: extension, Data: extData})
|
||||
// Ignore unknown extensions.
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -869,9 +875,8 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
|
|||
type encryptedExtensionsMsg struct {
|
||||
raw []byte
|
||||
alpnProtocol string
|
||||
quicTransportParameters []byte
|
||||
earlyData bool
|
||||
|
||||
additionalExtensions []Extension
|
||||
}
|
||||
|
||||
func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
|
||||
|
@ -893,17 +898,18 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
|
|||
})
|
||||
})
|
||||
}
|
||||
if m.quicTransportParameters != nil { // marshal zero-length parameters when present
|
||||
// draft-ietf-quic-tls-32, Section 8.2
|
||||
b.AddUint16(extensionQUICTransportParameters)
|
||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(m.quicTransportParameters)
|
||||
})
|
||||
}
|
||||
if m.earlyData {
|
||||
// RFC 8446, Section 4.2.10
|
||||
b.AddUint16(extensionEarlyData)
|
||||
b.AddUint16(0) // empty extension_data
|
||||
}
|
||||
for _, ext := range m.additionalExtensions {
|
||||
b.AddUint16(ext.Type)
|
||||
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
|
||||
b.AddBytes(ext.Data)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -923,14 +929,14 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
|
|||
}
|
||||
|
||||
for !extensions.Empty() {
|
||||
var ext uint16
|
||||
var extension uint16
|
||||
var extData cryptobyte.String
|
||||
if !extensions.ReadUint16(&ext) ||
|
||||
if !extensions.ReadUint16(&extension) ||
|
||||
!extensions.ReadUint16LengthPrefixed(&extData) {
|
||||
return false
|
||||
}
|
||||
|
||||
switch ext {
|
||||
switch extension {
|
||||
case extensionALPN:
|
||||
var protoList cryptobyte.String
|
||||
if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
|
||||
|
@ -942,10 +948,15 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
|
|||
return false
|
||||
}
|
||||
m.alpnProtocol = string(proto)
|
||||
case extensionQUICTransportParameters:
|
||||
m.quicTransportParameters = make([]byte, len(extData))
|
||||
if !extData.CopyBytes(m.quicTransportParameters) {
|
||||
return false
|
||||
}
|
||||
case extensionEarlyData:
|
||||
m.earlyData = true
|
||||
default:
|
||||
m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData})
|
||||
// Ignore unknown extensions.
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -39,8 +39,6 @@ type serverHandshakeState struct {
|
|||
|
||||
// serverHandshake performs a TLS handshake as a server.
|
||||
func (c *Conn) serverHandshake(ctx context.Context) error {
|
||||
c.setAlternativeRecordLayer()
|
||||
|
||||
clientHello, err := c.readClientHello(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -53,12 +51,6 @@ func (c *Conn) serverHandshake(ctx context.Context) error {
|
|||
clientHello: clientHello,
|
||||
}
|
||||
return hs.handshake()
|
||||
} else if c.extraConfig.usesAlternativeRecordLayer() {
|
||||
// This should already have been caught by the check that the ClientHello doesn't
|
||||
// offer any (supported) versions older than TLS 1.3.
|
||||
// Check again to make sure we can't be tricked into using an older version.
|
||||
c.sendAlert(alertProtocolVersion)
|
||||
return errors.New("tls: negotiated TLS < 1.3 when using QUIC")
|
||||
}
|
||||
|
||||
hs := serverHandshakeState{
|
||||
|
@ -131,7 +123,6 @@ func (hs *serverHandshakeState) handshake() error {
|
|||
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
|
||||
c.isHandshakeComplete.Store(true)
|
||||
|
||||
c.updateConnectionState()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -167,27 +158,6 @@ func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
|
|||
if len(clientHello.supportedVersions) == 0 {
|
||||
clientVersions = supportedVersionsFromMax(clientHello.vers)
|
||||
}
|
||||
if c.extraConfig.usesAlternativeRecordLayer() {
|
||||
// In QUIC, the client MUST NOT offer any old TLS versions.
|
||||
// Here, we can only check that none of the other supported versions of this library
|
||||
// (TLS 1.0 - TLS 1.2) is offered. We don't check for any SSL versions here.
|
||||
for _, ver := range clientVersions {
|
||||
if ver == VersionTLS13 {
|
||||
continue
|
||||
}
|
||||
for _, v := range supportedVersions {
|
||||
if ver == v {
|
||||
c.sendAlert(alertProtocolVersion)
|
||||
return nil, fmt.Errorf("tls: client offered old TLS version %#x", ver)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Make the config we're using allows us to use TLS 1.3.
|
||||
if c.config.maxSupportedVersion(roleServer) < VersionTLS13 {
|
||||
c.sendAlert(alertInternalError)
|
||||
return nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3")
|
||||
}
|
||||
}
|
||||
c.vers, ok = c.config.mutualVersion(roleServer, clientVersions)
|
||||
if !ok {
|
||||
c.sendAlert(alertProtocolVersion)
|
||||
|
@ -249,7 +219,7 @@ func (hs *serverHandshakeState) processClientHello() error {
|
|||
c.serverName = hs.clientHello.serverName
|
||||
}
|
||||
|
||||
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
|
||||
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, false)
|
||||
if err != nil {
|
||||
c.sendAlert(alertNoApplicationProtocol)
|
||||
return err
|
||||
|
@ -310,8 +280,12 @@ func (hs *serverHandshakeState) processClientHello() error {
|
|||
// negotiateALPN picks a shared ALPN protocol that both sides support in server
|
||||
// preference order. If ALPN is not configured or the peer doesn't support it,
|
||||
// it returns "" and no error.
|
||||
func negotiateALPN(serverProtos, clientProtos []string) (string, error) {
|
||||
func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, error) {
|
||||
if len(serverProtos) == 0 || len(clientProtos) == 0 {
|
||||
if quic && len(serverProtos) != 0 {
|
||||
// RFC 9001, Section 8.1
|
||||
return "", fmt.Errorf("tls: client did not request an application protocol")
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
var http11fallback bool
|
||||
|
@ -849,6 +823,10 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
|
|||
c.sendAlert(alertBadCertificate)
|
||||
return errors.New("tls: failed to parse client certificate: " + err.Error())
|
||||
}
|
||||
if certs[i].PublicKeyAlgorithm == x509.RSA && certs[i].PublicKey.(*rsa.PublicKey).N.BitLen() > maxRSAKeySize {
|
||||
c.sendAlert(alertBadCertificate)
|
||||
return fmt.Errorf("tls: client sent certificate containing RSA key larger than %d bits", maxRSAKeySize)
|
||||
}
|
||||
}
|
||||
|
||||
if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) {
|
||||
|
|
|
@ -41,6 +41,7 @@ type serverHandshakeStateTLS13 struct {
|
|||
trafficSecret []byte // client_application_traffic_secret_0
|
||||
transcript hash.Hash
|
||||
clientFinished []byte
|
||||
earlyData bool
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeStateTLS13) handshake() error {
|
||||
|
@ -59,7 +60,6 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
|
|||
if err := hs.checkForResumption(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.updateConnectionState()
|
||||
if err := hs.pickCertificate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -82,7 +82,6 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
|
|||
if err := hs.readClientCertificate(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.updateConnectionState()
|
||||
if err := hs.readClientFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -94,7 +93,7 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
|
|||
})
|
||||
|
||||
c.isHandshakeComplete.Store(true)
|
||||
c.updateConnectionState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -236,13 +235,23 @@ GroupSelection:
|
|||
return errors.New("tls: invalid client key share")
|
||||
}
|
||||
|
||||
c.serverName = hs.clientHello.serverName
|
||||
|
||||
if c.extraConfig != nil && c.extraConfig.ReceivedExtensions != nil {
|
||||
c.extraConfig.ReceivedExtensions(typeClientHello, hs.clientHello.additionalExtensions)
|
||||
if c.quic != nil {
|
||||
if hs.clientHello.quicTransportParameters == nil {
|
||||
// RFC 9001 Section 8.2.
|
||||
c.sendAlert(alertMissingExtension)
|
||||
return errors.New("tls: client did not send a quic_transport_parameters extension")
|
||||
}
|
||||
c.quicSetTransportParameters(hs.clientHello.quicTransportParameters)
|
||||
} else {
|
||||
if hs.clientHello.quicTransportParameters != nil {
|
||||
c.sendAlert(alertUnsupportedExtension)
|
||||
return errors.New("tls: client sent an unexpected quic_transport_parameters extension")
|
||||
}
|
||||
}
|
||||
|
||||
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
|
||||
c.serverName = hs.clientHello.serverName
|
||||
|
||||
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
|
||||
if err != nil {
|
||||
hs.alpnNegotiationErr = err
|
||||
}
|
||||
|
@ -299,10 +308,9 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
|
|||
}
|
||||
|
||||
if hs.alpnNegotiationErr == nil && sessionState.alpn == c.clientProtocol &&
|
||||
c.extraConfig != nil && c.extraConfig.MaxEarlyData > 0 &&
|
||||
c.extraConfig != nil && c.extraConfig.Enable0RTT &&
|
||||
c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) {
|
||||
hs.encryptedExtensions.earlyData = true
|
||||
c.used0RTT = true
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -354,27 +362,23 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
|
|||
return errors.New("tls: invalid PSK binder")
|
||||
}
|
||||
|
||||
if c.quic != nil && hs.clientHello.earlyData && hs.encryptedExtensions.earlyData && i == 0 &&
|
||||
sessionState.maxEarlyData > 0 && sessionState.cipherSuite == hs.suite.id {
|
||||
hs.earlyData = true
|
||||
|
||||
transcript := hs.suite.hash.New()
|
||||
if err := transcriptMsg(hs.clientHello, transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
earlyTrafficSecret := hs.suite.deriveSecret(hs.earlySecret, clientEarlyTrafficLabel, transcript)
|
||||
c.quicSetReadSecret(QUICEncryptionLevelEarly, hs.suite.id, earlyTrafficSecret)
|
||||
}
|
||||
|
||||
c.didResume = true
|
||||
if err := c.processCertsFromClient(sessionState.certificate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h := cloneHash(hs.transcript, hs.suite.hash)
|
||||
clientHelloWithBindersBytes, err := hs.clientHello.marshal()
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
h.Write(clientHelloWithBindersBytes)
|
||||
if hs.encryptedExtensions.earlyData {
|
||||
clientEarlySecret := hs.suite.deriveSecret(hs.earlySecret, "c e traffic", h)
|
||||
c.in.exportKey(Encryption0RTT, hs.suite, clientEarlySecret)
|
||||
if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hs.clientHello.random, clientEarlySecret); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
hs.hello.selectedIdentityPresent = true
|
||||
hs.hello.selectedIdentity = uint16(i)
|
||||
hs.usingPSK = true
|
||||
|
@ -449,6 +453,9 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error {
|
|||
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
|
||||
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
|
||||
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
|
||||
if hs.c.quic != nil {
|
||||
return nil
|
||||
}
|
||||
if hs.sentDummyCCS {
|
||||
return nil
|
||||
}
|
||||
|
@ -517,9 +524,9 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID)
|
|||
return errors.New("tls: client illegally modified second ClientHello")
|
||||
}
|
||||
|
||||
if clientHello.earlyData {
|
||||
if illegalClientHelloChange(clientHello, hs.clientHello) {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: client offered 0-RTT data in second ClientHello")
|
||||
return errors.New("tls: client illegally modified second ClientHello")
|
||||
}
|
||||
|
||||
hs.clientHello = clientHello
|
||||
|
@ -607,12 +614,18 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
|
|||
|
||||
clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
|
||||
clientHandshakeTrafficLabel, hs.transcript)
|
||||
c.in.exportKey(EncryptionHandshake, hs.suite, clientSecret)
|
||||
c.in.setTrafficSecret(hs.suite, clientSecret)
|
||||
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
|
||||
serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
|
||||
serverHandshakeTrafficLabel, hs.transcript)
|
||||
c.out.exportKey(EncryptionHandshake, hs.suite, serverSecret)
|
||||
c.out.setTrafficSecret(hs.suite, serverSecret)
|
||||
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
|
||||
|
||||
if c.quic != nil {
|
||||
if c.hand.Len() != 0 {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
|
||||
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
|
||||
}
|
||||
|
||||
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
|
||||
if err != nil {
|
||||
|
@ -625,12 +638,20 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
|
|||
return err
|
||||
}
|
||||
|
||||
if hs.alpnNegotiationErr != nil {
|
||||
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
|
||||
if err != nil {
|
||||
c.sendAlert(alertNoApplicationProtocol)
|
||||
return hs.alpnNegotiationErr
|
||||
return err
|
||||
}
|
||||
if hs.c.extraConfig != nil && hs.c.extraConfig.GetExtensions != nil {
|
||||
hs.encryptedExtensions.additionalExtensions = hs.c.extraConfig.GetExtensions(typeEncryptedExtensions)
|
||||
hs.encryptedExtensions.alpnProtocol = selectedProto
|
||||
c.clientProtocol = selectedProto
|
||||
|
||||
if c.quic != nil {
|
||||
p, err := c.quicGetTransportParameters()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hs.encryptedExtensions.quicTransportParameters = p
|
||||
}
|
||||
|
||||
if _, err := hs.c.writeHandshakeRecord(hs.encryptedExtensions, hs.transcript); err != nil {
|
||||
|
@ -731,8 +752,15 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
|
|||
clientApplicationTrafficLabel, hs.transcript)
|
||||
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
|
||||
serverApplicationTrafficLabel, hs.transcript)
|
||||
c.out.exportKey(EncryptionApplication, hs.suite, serverSecret)
|
||||
c.out.setTrafficSecret(hs.suite, serverSecret)
|
||||
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
|
||||
|
||||
if c.quic != nil {
|
||||
if c.hand.Len() != 0 {
|
||||
// TODO: Handle this in setTrafficSecret?
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, serverSecret)
|
||||
}
|
||||
|
||||
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
|
||||
if err != nil {
|
||||
|
@ -764,6 +792,10 @@ func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// QUIC tickets are sent by QUICConn.SendSessionTicket, not automatically.
|
||||
if hs.c.quic != nil {
|
||||
return false
|
||||
}
|
||||
// Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9.
|
||||
for _, pskMode := range hs.clientHello.pskModes {
|
||||
if pskMode == pskModeDHE {
|
||||
|
@ -783,25 +815,66 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
|
|||
if err := transcriptMsg(finishedMsg, hs.transcript); err != nil {
|
||||
return err
|
||||
}
|
||||
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
|
||||
resumptionLabel, hs.transcript)
|
||||
|
||||
if !hs.shouldSendSessionTickets() {
|
||||
return nil
|
||||
}
|
||||
return c.sendSessionTicket(false)
|
||||
}
|
||||
|
||||
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
|
||||
resumptionLabel, hs.transcript)
|
||||
|
||||
// Don't send session tickets when the alternative record layer is set.
|
||||
// Instead, save the resumption secret on the Conn.
|
||||
// Session tickets can then be generated by calling Conn.GetSessionTicket().
|
||||
if hs.c.extraConfig != nil && hs.c.extraConfig.AlternativeRecordLayer != nil {
|
||||
return nil
|
||||
func (c *Conn) sendSessionTicket(earlyData bool) error {
|
||||
suite := cipherSuiteTLS13ByID(c.cipherSuite)
|
||||
if suite == nil {
|
||||
return errors.New("tls: internal error: unknown cipher suite")
|
||||
}
|
||||
|
||||
m, err := hs.c.getSessionTicketMsg(nil)
|
||||
m := new(newSessionTicketMsgTLS13)
|
||||
|
||||
var certsFromClient [][]byte
|
||||
for _, cert := range c.peerCertificates {
|
||||
certsFromClient = append(certsFromClient, cert.Raw)
|
||||
}
|
||||
state := sessionStateTLS13{
|
||||
cipherSuite: suite.id,
|
||||
createdAt: uint64(c.config.time().Unix()),
|
||||
resumptionSecret: c.resumptionSecret,
|
||||
certificate: Certificate{
|
||||
Certificate: certsFromClient,
|
||||
OCSPStaple: c.ocspResponse,
|
||||
SignedCertificateTimestamps: c.scts,
|
||||
},
|
||||
alpn: c.clientProtocol,
|
||||
}
|
||||
if earlyData {
|
||||
state.maxEarlyData = 0xffffffff
|
||||
state.appData = c.extraConfig.GetAppDataForSessionTicket()
|
||||
}
|
||||
stateBytes, err := state.marshal()
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
m.label, err = c.encryptTicket(stateBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
|
||||
|
||||
// ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1
|
||||
// The value is not stored anywhere; we never need to check the ticket age
|
||||
// because 0-RTT is not supported.
|
||||
ageAdd := make([]byte, 4)
|
||||
_, err = c.config.rand().Read(ageAdd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if earlyData {
|
||||
// RFC 9001, Section 4.6.1
|
||||
m.maxEarlyData = 0xffffffff
|
||||
}
|
||||
|
||||
if _, err := c.writeHandshakeRecord(m, nil); err != nil {
|
||||
return err
|
||||
|
@ -919,8 +992,7 @@ func (hs *serverHandshakeStateTLS13) readClientFinished() error {
|
|||
return errors.New("tls: invalid client finished hash")
|
||||
}
|
||||
|
||||
c.in.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret)
|
||||
c.in.setTrafficSecret(hs.suite, hs.trafficSecret)
|
||||
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
const (
|
||||
resumptionBinderLabel = "res binder"
|
||||
clientEarlyTrafficLabel = "c e traffic"
|
||||
clientHandshakeTrafficLabel = "c hs traffic"
|
||||
serverHandshakeTrafficLabel = "s hs traffic"
|
||||
clientApplicationTrafficLabel = "c ap traffic"
|
||||
|
|
|
@ -0,0 +1,418 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// QUICEncryptionLevel represents a QUIC encryption level used to transmit
|
||||
// handshake messages.
|
||||
type QUICEncryptionLevel int
|
||||
|
||||
const (
|
||||
QUICEncryptionLevelInitial = QUICEncryptionLevel(iota)
|
||||
QUICEncryptionLevelEarly
|
||||
QUICEncryptionLevelHandshake
|
||||
QUICEncryptionLevelApplication
|
||||
)
|
||||
|
||||
func (l QUICEncryptionLevel) String() string {
|
||||
switch l {
|
||||
case QUICEncryptionLevelInitial:
|
||||
return "Initial"
|
||||
case QUICEncryptionLevelEarly:
|
||||
return "Early"
|
||||
case QUICEncryptionLevelHandshake:
|
||||
return "Handshake"
|
||||
case QUICEncryptionLevelApplication:
|
||||
return "Application"
|
||||
default:
|
||||
return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l))
|
||||
}
|
||||
}
|
||||
|
||||
// A QUICConn represents a connection which uses a QUIC implementation as the underlying
|
||||
// transport as described in RFC 9001.
|
||||
//
|
||||
// Methods of QUICConn are not safe for concurrent use.
|
||||
type QUICConn struct {
|
||||
conn *Conn
|
||||
|
||||
sessionTicketSent bool
|
||||
}
|
||||
|
||||
// A QUICConfig configures a QUICConn.
|
||||
type QUICConfig struct {
|
||||
TLSConfig *Config
|
||||
ExtraConfig *ExtraConfig
|
||||
}
|
||||
|
||||
// A QUICEventKind is a type of operation on a QUIC connection.
|
||||
type QUICEventKind int
|
||||
|
||||
const (
|
||||
// QUICNoEvent indicates that there are no events available.
|
||||
QUICNoEvent QUICEventKind = iota
|
||||
|
||||
// QUICSetReadSecret and QUICSetWriteSecret provide the read and write
|
||||
// secrets for a given encryption level.
|
||||
// QUICEvent.Level, QUICEvent.Data, and QUICEvent.Suite are set.
|
||||
//
|
||||
// Secrets for the Initial encryption level are derived from the initial
|
||||
// destination connection ID, and are not provided by the QUICConn.
|
||||
QUICSetReadSecret
|
||||
QUICSetWriteSecret
|
||||
|
||||
// QUICWriteData provides data to send to the peer in CRYPTO frames.
|
||||
// QUICEvent.Data is set.
|
||||
QUICWriteData
|
||||
|
||||
// QUICTransportParameters provides the peer's QUIC transport parameters.
|
||||
// QUICEvent.Data is set.
|
||||
QUICTransportParameters
|
||||
|
||||
// QUICTransportParametersRequired indicates that the caller must provide
|
||||
// QUIC transport parameters to send to the peer. The caller should set
|
||||
// the transport parameters with QUICConn.SetTransportParameters and call
|
||||
// QUICConn.NextEvent again.
|
||||
//
|
||||
// If transport parameters are set before calling QUICConn.Start, the
|
||||
// connection will never generate a QUICTransportParametersRequired event.
|
||||
QUICTransportParametersRequired
|
||||
|
||||
// QUICRejectedEarlyData indicates that the server rejected 0-RTT data even
|
||||
// if we offered it. It's returned before QUICEncryptionLevelApplication
|
||||
// keys are returned.
|
||||
QUICRejectedEarlyData
|
||||
|
||||
// QUICHandshakeDone indicates that the TLS handshake has completed.
|
||||
QUICHandshakeDone
|
||||
)
|
||||
|
||||
// A QUICEvent is an event occurring on a QUIC connection.
|
||||
//
|
||||
// The type of event is specified by the Kind field.
|
||||
// The contents of the other fields are kind-specific.
|
||||
type QUICEvent struct {
|
||||
Kind QUICEventKind
|
||||
|
||||
// Set for QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
|
||||
Level QUICEncryptionLevel
|
||||
|
||||
// Set for QUICTransportParameters, QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
|
||||
// The contents are owned by crypto/tls, and are valid until the next NextEvent call.
|
||||
Data []byte
|
||||
|
||||
// Set for QUICSetReadSecret and QUICSetWriteSecret.
|
||||
Suite uint16
|
||||
}
|
||||
|
||||
type quicState struct {
|
||||
events []QUICEvent
|
||||
nextEvent int
|
||||
|
||||
// eventArr is a statically allocated event array, large enough to handle
|
||||
// the usual maximum number of events resulting from a single call: transport
|
||||
// parameters, Initial data, Early read secret, Handshake write and read
|
||||
// secrets, Handshake data, Application write secret, Application data.
|
||||
eventArr [8]QUICEvent
|
||||
|
||||
started bool
|
||||
signalc chan struct{} // handshake data is available to be read
|
||||
blockedc chan struct{} // handshake is waiting for data, closed when done
|
||||
cancelc <-chan struct{} // handshake has been canceled
|
||||
cancel context.CancelFunc
|
||||
|
||||
// readbuf is shared between HandleData and the handshake goroutine.
|
||||
// HandshakeCryptoData passes ownership to the handshake goroutine by
|
||||
// reading from signalc, and reclaims ownership by reading from blockedc.
|
||||
readbuf []byte
|
||||
|
||||
transportParams []byte // to send to the peer
|
||||
}
|
||||
|
||||
// QUICClient returns a new TLS client side connection using QUICTransport as the
|
||||
// underlying transport. The config cannot be nil.
|
||||
//
|
||||
// The config's MinVersion must be at least TLS 1.3.
|
||||
func QUICClient(config *QUICConfig) *QUICConn {
|
||||
return newQUICConn(Client(nil, config.TLSConfig), config.ExtraConfig)
|
||||
}
|
||||
|
||||
// QUICServer returns a new TLS server side connection using QUICTransport as the
|
||||
// underlying transport. The config cannot be nil.
|
||||
//
|
||||
// The config's MinVersion must be at least TLS 1.3.
|
||||
func QUICServer(config *QUICConfig) *QUICConn {
|
||||
return newQUICConn(Server(nil, config.TLSConfig), config.ExtraConfig)
|
||||
}
|
||||
|
||||
func newQUICConn(conn *Conn, extraConfig *ExtraConfig) *QUICConn {
|
||||
conn.quic = &quicState{
|
||||
signalc: make(chan struct{}),
|
||||
blockedc: make(chan struct{}),
|
||||
}
|
||||
conn.quic.events = conn.quic.eventArr[:0]
|
||||
conn.extraConfig = extraConfig
|
||||
return &QUICConn{
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the client or server handshake protocol.
|
||||
// It may produce connection events, which may be read with NextEvent.
|
||||
//
|
||||
// Start must be called at most once.
|
||||
func (q *QUICConn) Start(ctx context.Context) error {
|
||||
if q.conn.quic.started {
|
||||
return quicError(errors.New("tls: Start called more than once"))
|
||||
}
|
||||
q.conn.quic.started = true
|
||||
if q.conn.config.MinVersion < VersionTLS13 {
|
||||
return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.13"))
|
||||
}
|
||||
go q.conn.HandshakeContext(ctx)
|
||||
if _, ok := <-q.conn.quic.blockedc; !ok {
|
||||
return q.conn.handshakeErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NextEvent returns the next event occurring on the connection.
|
||||
// It returns an event with a Kind of QUICNoEvent when no events are available.
|
||||
func (q *QUICConn) NextEvent() QUICEvent {
|
||||
qs := q.conn.quic
|
||||
if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
|
||||
// Write over some of the previous event's data,
|
||||
// to catch callers erroniously retaining it.
|
||||
qs.events[last].Data[0] = 0
|
||||
}
|
||||
if qs.nextEvent >= len(qs.events) {
|
||||
qs.events = qs.events[:0]
|
||||
qs.nextEvent = 0
|
||||
return QUICEvent{Kind: QUICNoEvent}
|
||||
}
|
||||
e := qs.events[qs.nextEvent]
|
||||
qs.events[qs.nextEvent] = QUICEvent{} // zero out references to data
|
||||
qs.nextEvent++
|
||||
return e
|
||||
}
|
||||
|
||||
// Close closes the connection and stops any in-progress handshake.
|
||||
func (q *QUICConn) Close() error {
|
||||
if q.conn.quic.cancel == nil {
|
||||
return nil // never started
|
||||
}
|
||||
q.conn.quic.cancel()
|
||||
for range q.conn.quic.blockedc {
|
||||
// Wait for the handshake goroutine to return.
|
||||
}
|
||||
return q.conn.handshakeErr
|
||||
}
|
||||
|
||||
// HandleData handles handshake bytes received from the peer.
|
||||
// It may produce connection events, which may be read with NextEvent.
|
||||
func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
|
||||
c := q.conn
|
||||
if c.in.level != level {
|
||||
return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level")))
|
||||
}
|
||||
c.quic.readbuf = data
|
||||
<-c.quic.signalc
|
||||
_, ok := <-c.quic.blockedc
|
||||
if ok {
|
||||
// The handshake goroutine is waiting for more data.
|
||||
return nil
|
||||
}
|
||||
// The handshake goroutine has exited.
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
c.hand.Write(c.quic.readbuf)
|
||||
c.quic.readbuf = nil
|
||||
for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
|
||||
b := q.conn.hand.Bytes()
|
||||
n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
|
||||
if n > maxHandshake {
|
||||
q.conn.handshakeErr = fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)
|
||||
break
|
||||
}
|
||||
if len(b) < 4+n {
|
||||
return nil
|
||||
}
|
||||
if err := q.conn.handlePostHandshakeMessage(); err != nil {
|
||||
q.conn.handshakeErr = err
|
||||
}
|
||||
}
|
||||
if q.conn.handshakeErr != nil {
|
||||
return quicError(q.conn.handshakeErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendSessionTicket sends a session ticket to the client.
|
||||
// It produces connection events, which may be read with NextEvent.
|
||||
// Currently, it can only be called once.
|
||||
func (q *QUICConn) SendSessionTicket(earlyData bool) error {
|
||||
c := q.conn
|
||||
if !c.isHandshakeComplete.Load() {
|
||||
return quicError(errors.New("tls: SendSessionTicket called before handshake completed"))
|
||||
}
|
||||
if c.isClient {
|
||||
return quicError(errors.New("tls: SendSessionTicket called on the client"))
|
||||
}
|
||||
if q.sessionTicketSent {
|
||||
return quicError(errors.New("tls: SendSessionTicket called multiple times"))
|
||||
}
|
||||
q.sessionTicketSent = true
|
||||
return quicError(c.sendSessionTicket(earlyData))
|
||||
}
|
||||
|
||||
// ConnectionState returns basic TLS details about the connection.
|
||||
func (q *QUICConn) ConnectionState() ConnectionState {
|
||||
return q.conn.ConnectionState()
|
||||
}
|
||||
|
||||
// SetTransportParameters sets the transport parameters to send to the peer.
|
||||
//
|
||||
// Server connections may delay setting the transport parameters until after
|
||||
// receiving the client's transport parameters. See QUICTransportParametersRequired.
|
||||
func (q *QUICConn) SetTransportParameters(params []byte) {
|
||||
if params == nil {
|
||||
params = []byte{}
|
||||
}
|
||||
q.conn.quic.transportParams = params
|
||||
if q.conn.quic.started {
|
||||
<-q.conn.quic.signalc
|
||||
<-q.conn.quic.blockedc
|
||||
}
|
||||
}
|
||||
|
||||
// quicError ensures err is an AlertError.
|
||||
// If err is not already, quicError wraps it with alertInternalError.
|
||||
func quicError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var ae AlertError
|
||||
if errors.As(err, &ae) {
|
||||
return err
|
||||
}
|
||||
var a alert
|
||||
if !errors.As(err, &a) {
|
||||
a = alertInternalError
|
||||
}
|
||||
// Return an error wrapping the original error and an AlertError.
|
||||
// Truncate the text of the alert to 0 characters.
|
||||
return fmt.Errorf("%w%.0w", err, AlertError(a))
|
||||
}
|
||||
|
||||
func (c *Conn) quicReadHandshakeBytes(n int) error {
|
||||
for c.hand.Len() < n {
|
||||
if err := c.quicWaitForSignal(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
|
||||
c.quic.events = append(c.quic.events, QUICEvent{
|
||||
Kind: QUICSetReadSecret,
|
||||
Level: level,
|
||||
Suite: suite,
|
||||
Data: secret,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
|
||||
c.quic.events = append(c.quic.events, QUICEvent{
|
||||
Kind: QUICSetWriteSecret,
|
||||
Level: level,
|
||||
Suite: suite,
|
||||
Data: secret,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
|
||||
var last *QUICEvent
|
||||
if len(c.quic.events) > 0 {
|
||||
last = &c.quic.events[len(c.quic.events)-1]
|
||||
}
|
||||
if last == nil || last.Kind != QUICWriteData || last.Level != level {
|
||||
c.quic.events = append(c.quic.events, QUICEvent{
|
||||
Kind: QUICWriteData,
|
||||
Level: level,
|
||||
})
|
||||
last = &c.quic.events[len(c.quic.events)-1]
|
||||
}
|
||||
last.Data = append(last.Data, data...)
|
||||
}
|
||||
|
||||
func (c *Conn) quicSetTransportParameters(params []byte) {
|
||||
c.quic.events = append(c.quic.events, QUICEvent{
|
||||
Kind: QUICTransportParameters,
|
||||
Data: params,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) quicGetTransportParameters() ([]byte, error) {
|
||||
if c.quic.transportParams == nil {
|
||||
c.quic.events = append(c.quic.events, QUICEvent{
|
||||
Kind: QUICTransportParametersRequired,
|
||||
})
|
||||
}
|
||||
for c.quic.transportParams == nil {
|
||||
if err := c.quicWaitForSignal(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c.quic.transportParams, nil
|
||||
}
|
||||
|
||||
func (c *Conn) quicHandshakeComplete() {
|
||||
c.quic.events = append(c.quic.events, QUICEvent{
|
||||
Kind: QUICHandshakeDone,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) quicRejectedEarlyData() {
|
||||
c.quic.events = append(c.quic.events, QUICEvent{
|
||||
Kind: QUICRejectedEarlyData,
|
||||
})
|
||||
}
|
||||
|
||||
// quicWaitForSignal notifies the QUICConn that handshake progress is blocked,
|
||||
// and waits for a signal that the handshake should proceed.
|
||||
//
|
||||
// The handshake may become blocked waiting for handshake bytes
|
||||
// or for the user to provide transport parameters.
|
||||
func (c *Conn) quicWaitForSignal() error {
|
||||
// Drop the handshake mutex while blocked to allow the user
|
||||
// to call ConnectionState before the handshake completes.
|
||||
c.handshakeMutex.Unlock()
|
||||
defer c.handshakeMutex.Lock()
|
||||
// Send on blockedc to notify the QUICConn that the handshake is blocked.
|
||||
// Exported methods of QUICConn wait for the handshake to become blocked
|
||||
// before returning to the user.
|
||||
select {
|
||||
case c.quic.blockedc <- struct{}{}:
|
||||
case <-c.quic.cancelc:
|
||||
return c.sendAlertLocked(alertCloseNotify)
|
||||
}
|
||||
// The QUICConn reads from signalc to notify us that the handshake may
|
||||
// be able to proceed. (The QUICConn reads, because we close signalc to
|
||||
// indicate that the handshake has completed.)
|
||||
select {
|
||||
case c.quic.signalc <- struct{}{}:
|
||||
c.hand.Write(c.quic.readbuf)
|
||||
c.quic.readbuf = nil
|
||||
case <-c.quic.cancelc:
|
||||
return c.sendAlertLocked(alertCloseNotify)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -11,12 +11,9 @@ import (
|
|||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
"io"
|
||||
)
|
||||
|
||||
// sessionState contains the information that is serialized into a session
|
||||
|
@ -204,74 +201,3 @@ func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey boo
|
|||
|
||||
return plaintext, keyIndex > 0
|
||||
}
|
||||
|
||||
func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, error) {
|
||||
m := new(newSessionTicketMsgTLS13)
|
||||
|
||||
var certsFromClient [][]byte
|
||||
for _, cert := range c.peerCertificates {
|
||||
certsFromClient = append(certsFromClient, cert.Raw)
|
||||
}
|
||||
state := sessionStateTLS13{
|
||||
cipherSuite: c.cipherSuite,
|
||||
createdAt: uint64(c.config.time().Unix()),
|
||||
resumptionSecret: c.resumptionSecret,
|
||||
certificate: Certificate{
|
||||
Certificate: certsFromClient,
|
||||
OCSPStaple: c.ocspResponse,
|
||||
SignedCertificateTimestamps: c.scts,
|
||||
},
|
||||
appData: appData,
|
||||
alpn: c.clientProtocol,
|
||||
}
|
||||
if c.extraConfig != nil {
|
||||
state.maxEarlyData = c.extraConfig.MaxEarlyData
|
||||
}
|
||||
stateBytes, err := state.marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.label, err = c.encryptTicket(stateBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
|
||||
|
||||
// ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1
|
||||
// The value is not stored anywhere; we never need to check the ticket age
|
||||
// because 0-RTT is not supported.
|
||||
ageAdd := make([]byte, 4)
|
||||
_, err = c.config.rand().Read(ageAdd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.ageAdd = binary.LittleEndian.Uint32(ageAdd)
|
||||
|
||||
// ticket_nonce, which must be unique per connection, is always left at
|
||||
// zero because we only ever send one ticket per connection.
|
||||
|
||||
if c.extraConfig != nil {
|
||||
m.maxEarlyData = c.extraConfig.MaxEarlyData
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// GetSessionTicket generates a new session ticket.
|
||||
// It should only be called after the handshake completes.
|
||||
// It can only be used for servers, and only if the alternative record layer is set.
|
||||
// The ticket may be nil if config.SessionTicketsDisabled is set,
|
||||
// or if the client isn't able to receive session tickets.
|
||||
func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) {
|
||||
if c.isClient || !c.isHandshakeComplete.Load() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil {
|
||||
return nil, errors.New("GetSessionTicket is only valid for servers after completion of the handshake, and if an alternative record layer is set.")
|
||||
}
|
||||
if c.config.SessionTicketsDisabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
m, err := c.getSessionTicketMsg(appData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.marshal()
|
||||
}
|
||||
|
|
|
@ -31,11 +31,10 @@ import (
|
|||
// using conn as the underlying transport.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
func Server(conn net.Conn, config *Config) *Conn {
|
||||
c := &Conn{
|
||||
conn: conn,
|
||||
config: fromConfig(config),
|
||||
extraConfig: extraConfig,
|
||||
}
|
||||
c.handshakeFn = c.serverHandshake
|
||||
return c
|
||||
|
@ -45,11 +44,10 @@ func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
|||
// using conn as the underlying transport.
|
||||
// The config cannot be nil: users must set either ServerName or
|
||||
// InsecureSkipVerify in the config.
|
||||
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
||||
func Client(conn net.Conn, config *Config) *Conn {
|
||||
c := &Conn{
|
||||
conn: conn,
|
||||
config: fromConfig(config),
|
||||
extraConfig: extraConfig,
|
||||
isClient: true,
|
||||
}
|
||||
c.handshakeFn = c.clientHandshake
|
||||
|
@ -60,7 +58,6 @@ func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
|
|||
type listener struct {
|
||||
net.Listener
|
||||
config *Config
|
||||
extraConfig *ExtraConfig
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next incoming TLS connection.
|
||||
|
@ -70,18 +67,17 @@ func (l *listener) Accept() (net.Conn, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Server(c, l.config, l.extraConfig), nil
|
||||
return Server(c, l.config), nil
|
||||
}
|
||||
|
||||
// NewListener creates a Listener which accepts connections from an inner
|
||||
// Listener and wraps each connection with Server.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func NewListener(inner net.Listener, config *Config, extraConfig *ExtraConfig) net.Listener {
|
||||
func NewListener(inner net.Listener, config *Config) net.Listener {
|
||||
l := new(listener)
|
||||
l.Listener = inner
|
||||
l.config = config
|
||||
l.extraConfig = extraConfig
|
||||
return l
|
||||
}
|
||||
|
||||
|
@ -89,7 +85,7 @@ func NewListener(inner net.Listener, config *Config, extraConfig *ExtraConfig) n
|
|||
// given network address using net.Listen.
|
||||
// The configuration config must be non-nil and must include
|
||||
// at least one certificate or else set GetCertificate.
|
||||
func Listen(network, laddr string, config *Config, extraConfig *ExtraConfig) (net.Listener, error) {
|
||||
func Listen(network, laddr string, config *Config) (net.Listener, error) {
|
||||
if config == nil || len(config.Certificates) == 0 &&
|
||||
config.GetCertificate == nil && config.GetConfigForClient == nil {
|
||||
return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
|
||||
|
@ -98,7 +94,7 @@ func Listen(network, laddr string, config *Config, extraConfig *ExtraConfig) (ne
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewListener(l, config, extraConfig), nil
|
||||
return NewListener(l, config), nil
|
||||
}
|
||||
|
||||
type timeoutError struct{}
|
||||
|
@ -117,11 +113,11 @@ func (timeoutError) Temporary() bool { return true }
|
|||
//
|
||||
// DialWithDialer uses context.Background internally; to specify the context,
|
||||
// use Dialer.DialContext with NetDialer set to the desired dialer.
|
||||
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) {
|
||||
return dial(context.Background(), dialer, network, addr, config, extraConfig)
|
||||
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||
return dial(context.Background(), dialer, network, addr, config)
|
||||
}
|
||||
|
||||
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) {
|
||||
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||
if netDialer.Timeout != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
|
||||
|
@ -157,7 +153,7 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf
|
|||
config = c
|
||||
}
|
||||
|
||||
conn := Client(rawConn, config, extraConfig)
|
||||
conn := Client(rawConn, config)
|
||||
if err := conn.HandshakeContext(ctx); err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
|
@ -171,8 +167,8 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf
|
|||
// Dial interprets a nil configuration as equivalent to
|
||||
// the zero configuration; see the documentation of Config
|
||||
// for the defaults.
|
||||
func Dial(network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) {
|
||||
return DialWithDialer(new(net.Dialer), network, addr, config, extraConfig)
|
||||
func Dial(network, addr string, config *Config) (*Conn, error) {
|
||||
return DialWithDialer(new(net.Dialer), network, addr, config)
|
||||
}
|
||||
|
||||
// Dialer dials TLS connections given a configuration and a Dialer for the
|
||||
|
@ -188,8 +184,6 @@ type Dialer struct {
|
|||
// configuration; see the documentation of Config for the
|
||||
// defaults.
|
||||
Config *Config
|
||||
|
||||
ExtraConfig *ExtraConfig
|
||||
}
|
||||
|
||||
// Dial connects to the given network address and initiates a TLS
|
||||
|
@ -220,7 +214,7 @@ func (d *Dialer) netDialer() *net.Dialer {
|
|||
//
|
||||
// The returned Conn, if any, will always be of type *Conn.
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
c, err := dial(ctx, d.netDialer(), network, addr, d.Config, d.ExtraConfig)
|
||||
c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
|
||||
if err != nil {
|
||||
// Don't return c (a typed nil) in an interface.
|
||||
return nil, err
|
||||
|
|
|
@ -94,3 +94,8 @@ func compareStruct(a, b reflect.Type) bool {
|
|||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// InitSessionTicketKeys triggers the initialization of session ticket keys.
|
||||
func InitSessionTicketKeys(conf *Config) {
|
||||
fromConfig(conf).ticketKeys(nil)
|
||||
}
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
run:
|
||||
skip-files:
|
||||
- internal/handshake/cipher_suite.go
|
||||
linters-settings:
|
||||
depguard:
|
||||
type: blacklist
|
||||
|
|
|
@ -4,29 +4,210 @@
|
|||
|
||||
[![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go)](https://pkg.go.dev/github.com/quic-go/quic-go)
|
||||
[![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/)
|
||||
[![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/quic-go.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:quic-go)
|
||||
|
||||
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go, including the Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) and Datagram Packetization Layer Path MTU
|
||||
Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899)). It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
|
||||
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
|
||||
|
||||
In addition to the RFCs listed above, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem.
|
||||
In addition to these base RFCs, it also implements the following RFCs:
|
||||
* Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221))
|
||||
* Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899))
|
||||
* QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369))
|
||||
* QUIC Event Logging using qlog ([draft-ietf-quic-qlog-main-schema](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-main-schema/) and [draft-ietf-quic-qlog-quic-events](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/))
|
||||
|
||||
## Guides
|
||||
Support for WebTransport over HTTP/3 ([draft-ietf-webtrans-http3](https://datatracker.ietf.org/doc/draft-ietf-webtrans-http3/)) is implemented in [webtransport-go](https://github.com/quic-go/webtransport-go).
|
||||
|
||||
*We currently support Go 1.19.x and Go 1.20.x*
|
||||
## Using QUIC
|
||||
|
||||
Running tests:
|
||||
### Running a Server
|
||||
|
||||
go test ./...
|
||||
The central entry point is the `quic.Transport`. A transport manages QUIC connections running on a single UDP socket. Since QUIC uses Connection IDs, it can demultiplex a listener (accepting incoming connections) and an arbitrary number of outgoing QUIC connections on the same UDP socket.
|
||||
|
||||
### QUIC without HTTP/3
|
||||
```go
|
||||
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 1234})
|
||||
// ... error handling
|
||||
tr := quic.Transport{
|
||||
Conn: udpConn,
|
||||
}
|
||||
ln, err := tr.Listen(tlsConf, quicConf)
|
||||
// ... error handling
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
// ... error handling
|
||||
// handle the connection, usually in a new Go routine
|
||||
}
|
||||
}()
|
||||
```
|
||||
|
||||
Take a look at [this echo example](example/echo/echo.go).
|
||||
The listener `ln` can now be used to accept incoming QUIC connections by (repeatedly) calling the `Accept` method (see below for more information on the `quic.Connection`).
|
||||
|
||||
## Usage
|
||||
As a shortcut, `quic.Listen` and `quic.ListenAddr` can be used without explicitly initializing a `quic.Transport`:
|
||||
|
||||
```
|
||||
ln, err := quic.Listen(udpConn, tlsConf, quicConf)
|
||||
```
|
||||
|
||||
When using the shortcut, it's not possible to reuse the same UDP socket for outgoing connections.
|
||||
|
||||
### Running a Client
|
||||
|
||||
As mentioned above, multiple outgoing connections can share a single UDP socket, since QUIC uses Connection IDs to demultiplex connections.
|
||||
|
||||
```go
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
|
||||
defer cancel()
|
||||
conn, err := tr.Dial(ctx, <server address>, <tls.Config>, <quic.Config>)
|
||||
// ... error handling
|
||||
```
|
||||
|
||||
As a shortcut, `quic.Dial` and `quic.DialAddr` can be used without explictly initializing a `quic.Transport`:
|
||||
|
||||
```go
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
|
||||
defer cancel()
|
||||
conn, err := quic.Dial(ctx, conn, <server address>, <tls.Config>, <quic.Config>)
|
||||
```
|
||||
|
||||
Just as we saw before when used a similar shortcut to run a server, it's also not possible to reuse the same UDP socket for other outgoing connections, or to listen for incoming connections.
|
||||
|
||||
### Using a QUIC Connection
|
||||
|
||||
#### Accepting Streams
|
||||
|
||||
QUIC is a stream-multiplexed transport. A `quic.Connection` fundamentally differs from the `net.Conn` and the `net.PacketConn` interface defined in the standard library. Data is sent and received on (unidirectional and bidirectional) streams (and, if supported, in [datagrams](#quic-datagrams)), not on the connection itself. The stream state machine is described in detail in [Section 3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-3).
|
||||
|
||||
Note: A unidirectional stream is a stream that the initiator can only write to (`quic.SendStream`), and the receiver can only read from (`quic.ReceiveStream`). A bidirectional stream (`quic.Stream`) allows reading from and writing to for both sides.
|
||||
|
||||
On the receiver side, streams are accepted using the `AcceptStream` (for bidirectional) and `AcceptUniStream` functions. For most user cases, it makes sense to call these functions in a loop:
|
||||
|
||||
```go
|
||||
for {
|
||||
str, err := conn.AcceptStream(context.Background()) // for bidirectional streams
|
||||
// ... error handling
|
||||
// handle the stream, usually in a new Go routine
|
||||
}
|
||||
```
|
||||
|
||||
These functions return an error when the underlying QUIC connection is closed.
|
||||
|
||||
#### Opening Streams
|
||||
|
||||
There are two slightly different ways to open streams, one synchronous and one (potentially) asynchronous. This API is necessary since the receiver grants us a certain number of streams that we're allowed to open. It may grant us additional streams later on (typically when existing streams are closed), but it means that at the time we want to open a new stream, we might not be able to do so.
|
||||
|
||||
Using the synchronous method `OpenStreamSync` for bidirectional streams, and `OpenUniStreamSync` for unidirectional streams, an application can block until the peer allows opening additional streams. In case that we're allowed to open a new stream, these methods return right away:
|
||||
|
||||
```go
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
str, err := conn.OpenStreamSync(ctx) // wait up to 5s to open a new bidirectional stream
|
||||
```
|
||||
|
||||
The asynchronous version never blocks. If it's currently not possible to open a new stream, it returns a `net.Error` timeout error:
|
||||
|
||||
```go
|
||||
str, err := conn.OpenStream()
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
||||
// It's currently not possible to open another stream,
|
||||
// but it might be possible later, once the peer allowed us to do so.
|
||||
}
|
||||
```
|
||||
|
||||
These functions return an error when the underlying QUIC connection is closed.
|
||||
|
||||
#### Using Streams
|
||||
|
||||
Using QUIC streams is pretty straightforward. The `quic.ReceiveStream` implements the `io.Reader` interface, and the `quic.SendStream` implements the `io.Writer` interface. A bidirectional stream (`quic.Stream`) implements both these interfaces. Conceptually, a bidirectional stream can be thought of as the composition of two unidirectional streams in opposite directions.
|
||||
|
||||
Calling `Close` on a `quic.SendStream` or a `quic.Stream` closes the send side of the stream. On the receiver side, this will be surfaced as an `io.EOF` returned from the `io.Reader` once all data has been consumed. Note that for bidirectional streams, `Close` _only_ closes the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
|
||||
|
||||
In case the application wishes to abort sending on a `quic.SendStream` or a `quic.Stream` , it can reset the send side by calling `CancelWrite` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Reader`. Note that for bidirectional streams, `CancelWrite` _only_ resets the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
|
||||
|
||||
Conversely, in case the application wishes to abort receiving from a `quic.ReceiveStream` or a `quic.Stream`, it can ask the sender to abort data transmission by calling `CancelRead` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Writer`. Note that for bidirectional streams, `CancelWrite` _only_ resets the receive side of the stream. It is still possible to write to the stream.
|
||||
|
||||
A bidirectional stream is only closed once both the read and the write side of the stream have been either closed or reset. Only then the peer is granted a new stream according to the maximum number of concurrent streams configured via `quic.Config.MaxIncomingStreams`.
|
||||
|
||||
### Configuring QUIC
|
||||
|
||||
The `quic.Config` struct passed to both the listen and dial calls (see above) contains a wide range of configuration options for QUIC connections, incl. the ability to fine-tune flow control limits, the number of streams that the peer is allowed to open concurrently, keep-alives, idle timeouts, and many more. Please refer to the documentation for the `quic.Config` for details.
|
||||
|
||||
The `quic.Transport` contains a few configuration options that don't apply to any single QUIC connection, but to all connections handled by that transport. It is highly recommend to set the `StatelessResetToken`, which allows endpoints to quickly recover from crashes / reboots of our node (see [Section 10.3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-10.3)).
|
||||
|
||||
### Closing a Connection
|
||||
|
||||
#### When the remote Peer closes the Connection
|
||||
|
||||
In case the peer closes the QUIC connection, all calls to open streams, accept streams, as well as all methods on streams immediately return an error. Additionally, it is set as cancellation cause of the connection context. Users can use errors assertions to find out what exactly went wrong:
|
||||
|
||||
* `quic.VersionNegotiationError`: Happens during the handshake, if there is no overlap between our and the remote's supported QUIC versions.
|
||||
* `quic.HandshakeTimeoutError`: Happens if the QUIC handshake doesn't complete within the time specified in `quic.Config.HandshakeTimeout`.
|
||||
* `quic.IdleTimeoutError`: Happens after completion of the handshake if the connection is idle for longer than the minimum of both peers idle timeouts (as configured by `quic.Config.IdleTimeout`). The connection is considered idle when no stream data (and datagrams, if applicable) are exchanged for that period. The QUIC connection can be instructed to regularly send a packet to prevent a connection from going idle by setting `quic.Config.KeepAlive`. However, this is no guarantee that the peer doesn't suddenly go away (e.g. by abruptly shutting down the node or by crashing), or by a NAT binding expiring, in which case this error might still occur.
|
||||
* `quic.StatelessResetError`: Happens when the remote peer lost the state required to decrypt the packet. This requires the `quic.Transport.StatelessResetToken` to be configured by the peer.
|
||||
* `quic.TransportError`: Happens if when the QUIC protocol is violated. Unless the error code is `APPLICATION_ERROR`, this will not happen unless one of the QUIC stacks involved is misbehaving. Please open an issue if you encounter this error.
|
||||
* `quic.ApplicationError`: Happens when the remote decides to close the connection, see below.
|
||||
|
||||
#### Initiated by the Application
|
||||
|
||||
A `quic.Connection` can be closed using `CloseWithError`:
|
||||
|
||||
```go
|
||||
conn.CloseWithError(0x42, "error 0x42 occurred")
|
||||
```
|
||||
|
||||
Applications can transmit both an error code (an unsigned 62-bit number) as well as a UTF-8 encoded human-readable reason. The error code allows the receiver to learn why the connection was closed, and the reason can be useful for debugging purposes.
|
||||
|
||||
On the receiver side, this is surfaced as a `quic.ApplicationError`.
|
||||
|
||||
### QUIC Datagrams
|
||||
|
||||
Unreliable datagrams are a QUIC extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) that is negotiated during the handshake. Support can be enabled by setting the `quic.Config.EnableDatagram` flag. Note that this doesn't guarantee that the peer also supports datagrams. Whether or not the feature negotiation succeeded can be learned from the `quic.ConnectionState.SupportsDatagrams` obtained from `quic.Connection.ConnectionState()`.
|
||||
|
||||
QUIC DATAGRAMs are a new QUIC frame type sent in QUIC 1-RTT packets (i.e. after completion of the handshake). Therefore, they're end-to-end encrypted and congestion-controlled. However, if a DATAGRAM frame is deemed lost by QUIC's loss detection mechanism, they are not retransmitted.
|
||||
|
||||
Datagrams are sent using the `SendDatagram` method on the `quic.Connection`:
|
||||
|
||||
```go
|
||||
conn.SendDatagram([]byte("foobar"))
|
||||
```
|
||||
|
||||
And received using `ReceiveDatagram`:
|
||||
|
||||
```go
|
||||
msg, err := conn.ReceiveDatagram()
|
||||
```
|
||||
|
||||
Note that this code path is currently not optimized. It works for datagrams that are sent occasionally, but it doesn't achieve the same throughput as writing data on a stream. Please get in touch on issue #3766 if your use case relies on high datagram throughput, or if you'd like to help fix this issue. There are also some restrictions regarding the maximum message size (see #3599).
|
||||
|
||||
### QUIC Event Logging using qlog
|
||||
|
||||
quic-go logs a wide range of events defined in [draft-ietf-quic-qlog-quic-events](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/), providing comprehensive insights in the internals of a QUIC connection.
|
||||
|
||||
qlog files can be processed by a number of 3rd-party tools. [qviz](https://qvis.quictools.info/) has proven very useful for debugging all kinds of QUIC connection failures.
|
||||
|
||||
qlog is activated by setting a `Tracer` callback on the `Config`. It is called as soon as quic-go decides to starts the QUIC handshake on a new connection.
|
||||
A useful implementation of this callback could look like this:
|
||||
```go
|
||||
quic.Config{
|
||||
Tracer: func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
|
||||
role := "server"
|
||||
if p == logging.PerspectiveClient {
|
||||
role = "client"
|
||||
}
|
||||
filename := fmt.Sprintf("./log_%s_%s.qlog", connID, role)
|
||||
f, err := os.Create(filename)
|
||||
// handle the error
|
||||
return qlog.NewConnectionTracer(f, p, connID)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This implementation of the callback creates a new qlog file in the current directory named `log_<client / server>_<QUIC connection ID>.qlog`.
|
||||
|
||||
|
||||
## Using HTTP/3
|
||||
|
||||
### As a server
|
||||
|
||||
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go:
|
||||
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard library http package in Go:
|
||||
|
||||
```go
|
||||
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
|
||||
|
@ -46,12 +227,13 @@ http.Client{
|
|||
## Projects using quic-go
|
||||
|
||||
| Project | Description | Stars |
|
||||
|-----------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
|
||||
| --------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
|
||||
| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) |
|
||||
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) |
|
||||
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) |
|
||||
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) |
|
||||
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) |
|
||||
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) |
|
||||
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
|
||||
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
|
||||
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) |
|
||||
|
@ -59,6 +241,17 @@ http.Client{
|
|||
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) |
|
||||
| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System | ![GitHub Repo stars](https://img.shields.io/github/stars/yomorun/yomo?style=flat-square) |
|
||||
|
||||
If you'd like to see your project added to this list, please send us a PR.
|
||||
|
||||
## Release Policy
|
||||
|
||||
quic-go always aims to support the latest two Go releases.
|
||||
|
||||
### Dependency on forked crypto/tls
|
||||
|
||||
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20).
|
||||
This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
|
||||
|
||||
## Contributing
|
||||
|
||||
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.
|
||||
|
|
|
@ -51,18 +51,22 @@ func (b *packetBuffer) Release() {
|
|||
}
|
||||
|
||||
// Len returns the length of Data
|
||||
func (b *packetBuffer) Len() protocol.ByteCount {
|
||||
return protocol.ByteCount(len(b.Data))
|
||||
}
|
||||
func (b *packetBuffer) Len() protocol.ByteCount { return protocol.ByteCount(len(b.Data)) }
|
||||
func (b *packetBuffer) Cap() protocol.ByteCount { return protocol.ByteCount(cap(b.Data)) }
|
||||
|
||||
func (b *packetBuffer) putBack() {
|
||||
if cap(b.Data) != int(protocol.MaxPacketBufferSize) {
|
||||
panic("putPacketBuffer called with packet of wrong size!")
|
||||
}
|
||||
if cap(b.Data) == protocol.MaxPacketBufferSize {
|
||||
bufferPool.Put(b)
|
||||
return
|
||||
}
|
||||
if cap(b.Data) == protocol.MaxLargePacketBufferSize {
|
||||
largeBufferPool.Put(b)
|
||||
return
|
||||
}
|
||||
panic("putPacketBuffer called with packet of wrong size!")
|
||||
}
|
||||
|
||||
var bufferPool sync.Pool
|
||||
var bufferPool, largeBufferPool sync.Pool
|
||||
|
||||
func getPacketBuffer() *packetBuffer {
|
||||
buf := bufferPool.Get().(*packetBuffer)
|
||||
|
@ -71,10 +75,18 @@ func getPacketBuffer() *packetBuffer {
|
|||
return buf
|
||||
}
|
||||
|
||||
func getLargePacketBuffer() *packetBuffer {
|
||||
buf := largeBufferPool.Get().(*packetBuffer)
|
||||
buf.refCount = 1
|
||||
buf.Data = buf.Data[:0]
|
||||
return buf
|
||||
}
|
||||
|
||||
func init() {
|
||||
bufferPool.New = func() interface{} {
|
||||
return &packetBuffer{
|
||||
Data: make([]byte, 0, protocol.MaxPacketBufferSize),
|
||||
bufferPool.New = func() any {
|
||||
return &packetBuffer{Data: make([]byte, 0, protocol.MaxPacketBufferSize)}
|
||||
}
|
||||
largeBufferPool.New = func() any {
|
||||
return &packetBuffer{Data: make([]byte, 0, protocol.MaxLargePacketBufferSize)}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ type client struct {
|
|||
|
||||
conn quicConn
|
||||
|
||||
tracer logging.ConnectionTracer
|
||||
tracer *logging.ConnectionTracer
|
||||
tracingID uint64
|
||||
logger utils.Logger
|
||||
}
|
||||
|
@ -43,7 +43,9 @@ type client struct {
|
|||
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
|
||||
|
||||
// DialAddr establishes a new QUIC connection to a server.
|
||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||
// It resolves the address, and then creates a new UDP connection to dial the QUIC server.
|
||||
// When the QUIC connection is closed, this UDP connection is closed.
|
||||
// See Dial for more details.
|
||||
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
|
@ -53,15 +55,15 @@ func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Confi
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl, err := setupTransport(udpConn, tlsConf, true)
|
||||
tr, err := setupTransport(udpConn, tlsConf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dl.Dial(ctx, udpAddr, tlsConf, conf)
|
||||
return tr.dial(ctx, udpAddr, addr, tlsConf, conf, false)
|
||||
}
|
||||
|
||||
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
|
||||
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
|
||||
// See DialAddr for more details.
|
||||
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
|
@ -71,20 +73,20 @@ func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl, err := setupTransport(udpConn, tlsConf, true)
|
||||
tr, err := setupTransport(udpConn, tlsConf, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf)
|
||||
conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true)
|
||||
if err != nil {
|
||||
dl.Close()
|
||||
tr.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
|
||||
// See DialEarly for details.
|
||||
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
|
||||
// See Dial for more details.
|
||||
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||
dl, err := setupTransport(c, tlsConf, false)
|
||||
if err != nil {
|
||||
|
@ -98,12 +100,15 @@ func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tl
|
|||
return conn, nil
|
||||
}
|
||||
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If
|
||||
// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
|
||||
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
|
||||
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
|
||||
// packets.
|
||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||
// If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
|
||||
// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
|
||||
// will be used instead of ReadFrom and WriteTo to read/write packets.
|
||||
// The tls.Config must define an application protocol (using NextProtos).
|
||||
//
|
||||
// This is a convenience function. More advanced use cases should instantiate a Transport,
|
||||
// which offers configuration options for a more fine-grained control of the connection establishment,
|
||||
// including reusing the underlying UDP socket for multiple QUIC connections.
|
||||
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||
dl, err := setupTransport(c, tlsConf, false)
|
||||
if err != nil {
|
||||
|
@ -148,7 +153,7 @@ func dial(
|
|||
if c.config.Tracer != nil {
|
||||
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
|
||||
}
|
||||
if c.tracer != nil {
|
||||
if c.tracer != nil && c.tracer.StartedConnection != nil {
|
||||
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
|
||||
}
|
||||
if err := c.dial(ctx); err != nil {
|
||||
|
@ -158,12 +163,6 @@ func dial(
|
|||
}
|
||||
|
||||
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{}
|
||||
} else {
|
||||
tlsConf = tlsConf.Clone()
|
||||
}
|
||||
|
||||
srcConnID, err := connIDGenerator.GenerateConnectionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -234,7 +233,7 @@ func (c *client) dial(ctx context.Context) error {
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
c.conn.shutdown()
|
||||
return ctx.Err()
|
||||
return context.Cause(ctx)
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
case recreateErr := <-recreateChan:
|
||||
|
|
|
@ -16,13 +16,13 @@ type closedLocalConn struct {
|
|||
perspective protocol.Perspective
|
||||
logger utils.Logger
|
||||
|
||||
sendPacket func(net.Addr, *packetInfo)
|
||||
sendPacket func(net.Addr, packetInfo)
|
||||
}
|
||||
|
||||
var _ packetHandler = &closedLocalConn{}
|
||||
|
||||
// newClosedLocalConn creates a new closedLocalConn and runs it.
|
||||
func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
|
||||
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
|
||||
return &closedLocalConn{
|
||||
sendPacket: sendPacket,
|
||||
perspective: pers,
|
||||
|
@ -30,7 +30,7 @@ func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Pe
|
|||
}
|
||||
}
|
||||
|
||||
func (c *closedLocalConn) handlePacket(p *receivedPacket) {
|
||||
func (c *closedLocalConn) handlePacket(p receivedPacket) {
|
||||
c.counter++
|
||||
// exponential backoff
|
||||
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
|
||||
|
@ -58,7 +58,7 @@ func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
|
|||
return &closedRemoteConn{perspective: pers}
|
||||
}
|
||||
|
||||
func (s *closedRemoteConn) handlePacket(*receivedPacket) {}
|
||||
func (s *closedRemoteConn) handlePacket(receivedPacket) {}
|
||||
func (s *closedRemoteConn) shutdown() {}
|
||||
func (s *closedRemoteConn) destroy(error) {}
|
||||
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }
|
||||
|
|
|
@ -1,18 +1,10 @@
|
|||
coverage:
|
||||
round: nearest
|
||||
ignore:
|
||||
- streams_map_incoming_bidi.go
|
||||
- streams_map_incoming_uni.go
|
||||
- streams_map_outgoing_bidi.go
|
||||
- streams_map_outgoing_uni.go
|
||||
- http3/gzip_reader.go
|
||||
- interop/
|
||||
- internal/ackhandler/packet_linkedlist.go
|
||||
- internal/utils/byteinterval_linkedlist.go
|
||||
- internal/utils/newconnectionid_linkedlist.go
|
||||
- internal/utils/packetinterval_linkedlist.go
|
||||
- internal/handshake/cipher_suite.go
|
||||
- internal/utils/linkedlist/linkedlist.go
|
||||
- logging/null_tracer.go
|
||||
- fuzzing/
|
||||
- metrics/
|
||||
status:
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/utils"
|
||||
"github.com/quic-go/quic-go/quicvarint"
|
||||
)
|
||||
|
||||
// Clone clones a Config
|
||||
|
@ -17,18 +16,29 @@ func (c *Config) Clone() *Config {
|
|||
}
|
||||
|
||||
func (c *Config) handshakeTimeout() time.Duration {
|
||||
return utils.Max(protocol.DefaultHandshakeTimeout, 2*c.HandshakeIdleTimeout)
|
||||
return 2 * c.HandshakeIdleTimeout
|
||||
}
|
||||
|
||||
func (c *Config) maxRetryTokenAge() time.Duration {
|
||||
return c.handshakeTimeout()
|
||||
}
|
||||
|
||||
func validateConfig(config *Config) error {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
if config.MaxIncomingStreams > 1<<60 {
|
||||
return errors.New("invalid value for Config.MaxIncomingStreams")
|
||||
const maxStreams = 1 << 60
|
||||
if config.MaxIncomingStreams > maxStreams {
|
||||
config.MaxIncomingStreams = maxStreams
|
||||
}
|
||||
if config.MaxIncomingUniStreams > 1<<60 {
|
||||
return errors.New("invalid value for Config.MaxIncomingUniStreams")
|
||||
if config.MaxIncomingUniStreams > maxStreams {
|
||||
config.MaxIncomingUniStreams = maxStreams
|
||||
}
|
||||
if config.MaxStreamReceiveWindow > quicvarint.Max {
|
||||
config.MaxStreamReceiveWindow = quicvarint.Max
|
||||
}
|
||||
if config.MaxConnectionReceiveWindow > quicvarint.Max {
|
||||
config.MaxConnectionReceiveWindow = quicvarint.Max
|
||||
}
|
||||
// check that all QUIC versions are actually supported
|
||||
for _, v := range config.Versions {
|
||||
|
@ -43,12 +53,6 @@ func validateConfig(config *Config) error {
|
|||
// it may be called with nil
|
||||
func populateServerConfig(config *Config) *Config {
|
||||
config = populateConfig(config)
|
||||
if config.MaxTokenAge == 0 {
|
||||
config.MaxTokenAge = protocol.TokenValidity
|
||||
}
|
||||
if config.MaxRetryTokenAge == 0 {
|
||||
config.MaxRetryTokenAge = protocol.RetryTokenValidity
|
||||
}
|
||||
if config.RequireAddressValidation == nil {
|
||||
config.RequireAddressValidation = func(net.Addr) bool { return false }
|
||||
}
|
||||
|
@ -101,18 +105,12 @@ func populateConfig(config *Config) *Config {
|
|||
} else if maxIncomingUniStreams < 0 {
|
||||
maxIncomingUniStreams = 0
|
||||
}
|
||||
maxDatagrameFrameSize := config.MaxDatagramFrameSize
|
||||
if maxDatagrameFrameSize == 0 {
|
||||
maxDatagrameFrameSize = int64(protocol.DefaultMaxDatagramFrameSize)
|
||||
}
|
||||
|
||||
return &Config{
|
||||
GetConfigForClient: config.GetConfigForClient,
|
||||
Versions: versions,
|
||||
HandshakeIdleTimeout: handshakeIdleTimeout,
|
||||
MaxIdleTimeout: idleTimeout,
|
||||
MaxTokenAge: config.MaxTokenAge,
|
||||
MaxRetryTokenAge: config.MaxRetryTokenAge,
|
||||
RequireAddressValidation: config.RequireAddressValidation,
|
||||
KeepAlivePeriod: config.KeepAlivePeriod,
|
||||
InitialStreamReceiveWindow: initialStreamReceiveWindow,
|
||||
|
@ -124,9 +122,7 @@ func populateConfig(config *Config) *Config {
|
|||
MaxIncomingUniStreams: maxIncomingUniStreams,
|
||||
TokenStore: config.TokenStore,
|
||||
EnableDatagrams: config.EnableDatagrams,
|
||||
MaxDatagramFrameSize: maxDatagrameFrameSize,
|
||||
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
|
||||
DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets,
|
||||
Allow0RTT: config.Allow0RTT,
|
||||
Tracer: config.Tracer,
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -71,17 +71,9 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
|
|||
|
||||
// GetCryptoData retrieves data that was received in CRYPTO frames
|
||||
func (s *cryptoStreamImpl) GetCryptoData() []byte {
|
||||
if len(s.msgBuf) < 4 {
|
||||
return nil
|
||||
}
|
||||
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
|
||||
if len(s.msgBuf) < msgLen {
|
||||
return nil
|
||||
}
|
||||
msg := make([]byte, msgLen)
|
||||
copy(msg, s.msgBuf[:msgLen])
|
||||
s.msgBuf = s.msgBuf[msgLen:]
|
||||
return msg
|
||||
b := s.msgBuf
|
||||
s.msgBuf = nil
|
||||
return b
|
||||
}
|
||||
|
||||
func (s *cryptoStreamImpl) Finish() error {
|
||||
|
|
|
@ -3,12 +3,14 @@ package quic
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/handshake"
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoDataHandler interface {
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) error
|
||||
NextEvent() handshake.Event
|
||||
}
|
||||
|
||||
type cryptoStreamManager struct {
|
||||
|
@ -33,7 +35,7 @@ func newCryptoStreamManager(
|
|||
}
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
|
||||
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
|
||||
var str cryptoStream
|
||||
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
|
||||
switch encLevel {
|
||||
|
@ -44,18 +46,37 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
|
|||
case protocol.Encryption1RTT:
|
||||
str = m.oneRTTStream
|
||||
default:
|
||||
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
|
||||
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
|
||||
}
|
||||
if err := str.HandleCryptoFrame(frame); err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
for {
|
||||
data := str.GetCryptoData()
|
||||
if data == nil {
|
||||
return false, nil
|
||||
return nil
|
||||
}
|
||||
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
|
||||
return true, str.Finish()
|
||||
if err := m.cryptoHandler.HandleMessage(data, encLevel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {
|
||||
if !m.oneRTTStream.HasData() {
|
||||
return nil
|
||||
}
|
||||
return m.oneRTTStream.PopCryptoFrame(maxSize)
|
||||
}
|
||||
|
||||
func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error {
|
||||
//nolint:exhaustive // 1-RTT keys should never get dropped.
|
||||
switch encLevel {
|
||||
case protocol.EncryptionInitial:
|
||||
return m.initialStream.Finish()
|
||||
case protocol.EncryptionHandshake:
|
||||
return m.handshakeStream.Finish()
|
||||
default:
|
||||
panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
@ -98,7 +99,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
|
|||
}
|
||||
|
||||
// Receive gets a received DATAGRAM frame.
|
||||
func (h *datagramQueue) Receive() ([]byte, error) {
|
||||
func (h *datagramQueue) Receive(ctx context.Context) ([]byte, error) {
|
||||
for {
|
||||
h.rcvMx.Lock()
|
||||
if len(h.rcvQueue) > 0 {
|
||||
|
@ -113,6 +114,8 @@ func (h *datagramQueue) Receive() ([]byte, error) {
|
|||
continue
|
||||
case <-h.closed:
|
||||
return nil, h.closeErr
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,3 +61,15 @@ func (e *StreamError) Error() string {
|
|||
}
|
||||
return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode)
|
||||
}
|
||||
|
||||
// DatagramTooLargeError is returned from Connection.SendDatagram if the payload is too large to be sent.
|
||||
type DatagramTooLargeError struct {
|
||||
PeerMaxDatagramFrameSize int64
|
||||
}
|
||||
|
||||
func (e *DatagramTooLargeError) Is(target error) bool {
|
||||
_, ok := target.(*DatagramTooLargeError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *DatagramTooLargeError) Error() string { return "DATAGRAM frame too large" }
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue