TUN-8006: Update quic-go to latest upstream

This commit is contained in:
Chung-Ting 2023-12-04 09:49:00 +00:00
parent 45236a1f7d
commit 8068cdebb6
219 changed files with 10032 additions and 17038 deletions

View File

@ -9,6 +9,7 @@ import (
"net"
"net/http"
"net/netip"
"runtime"
"strconv"
"strings"
"sync"
@ -623,9 +624,19 @@ func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, logger *zerolog.
localIP = net.IPv4zero
}
listenNetwork := "udp"
// https://github.com/quic-go/quic-go/issues/3793 DF bit cannot be set for dual stack listener on OSX
if runtime.GOOS == "darwin" {
if localIP.To4() != nil {
listenNetwork = "udp4"
} else {
listenNetwork = "udp6"
}
}
// if port was not set yet, it will be zero, so bind will randomly allocate one.
if port, ok := portForConnIndex[connIndex]; ok {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: localIP, Port: port})
udpConn, err := net.ListenUDP(listenNetwork, &net.UDPAddr{IP: localIP, Port: port})
// if there wasn't an error, or if port was 0 (independently of error or not, just return)
if err == nil {
return udpConn, nil
@ -635,7 +646,7 @@ func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, logger *zerolog.
}
// if we reached here, then there was an error or port as not been allocated it.
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: localIP, Port: 0})
udpConn, err := net.ListenUDP(listenNetwork, &net.UDPAddr{IP: localIP, Port: 0})
if err == nil {
udpAddr, ok := (udpConn.LocalAddr()).(*net.UDPAddr)
if !ok {

View File

@ -621,7 +621,7 @@ func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.
muxedPayload, err = quicpogs.SuffixType(muxedPayload, quicpogs.DatagramTypeUDP)
require.NoError(t, err)
err = edgeQUICSession.SendMessage(muxedPayload)
err = edgeQUICSession.SendDatagram(muxedPayload)
require.NoError(t, err)
readBuffer := make([]byte, len(payload)+1)

27
go.mod
View File

@ -3,7 +3,6 @@ module github.com/cloudflare/cloudflared
go 1.20
require (
github.com/cloudflare/golibs v0.0.0-20170913112048-333127dbecfc
github.com/coredns/coredns v1.10.0
github.com/coreos/go-oidc/v3 v3.6.0
github.com/coreos/go-systemd/v22 v22.5.0
@ -25,7 +24,8 @@ require (
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.13.0
github.com/prometheus/client_model v0.2.0
github.com/quic-go/quic-go v0.0.0-00010101000000-000000000000
github.com/quic-go/qtls-go1-20 v0.4.1
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55
github.com/rs/zerolog v1.20.0
github.com/stretchr/testify v1.8.1
github.com/urfave/cli/v2 v2.3.0
@ -38,7 +38,7 @@ require (
go.uber.org/automaxprocs v1.4.0
golang.org/x/crypto v0.11.0
golang.org/x/net v0.12.0
golang.org/x/sync v0.1.0
golang.org/x/sync v0.2.0
golang.org/x/sys v0.10.0
golang.org/x/term v0.10.0
google.golang.org/protobuf v1.28.1
@ -62,13 +62,12 @@ require (
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect
github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870 // indirect
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/logr v1.2.4 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/golang/mock v1.6.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 // indirect
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect
@ -79,20 +78,18 @@ require (
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/onsi/ginkgo/v2 v2.4.0 // indirect
github.com/onsi/gomega v1.23.0 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/opentracing/opentracing-go v1.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
go.uber.org/mock v0.3.0 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.8.0 // indirect
golang.org/x/mod v0.11.0 // indirect
golang.org/x/oauth2 v0.6.0 // indirect
golang.org/x/text v0.11.0 // indirect
golang.org/x/tools v0.6.0 // indirect
golang.org/x/tools v0.9.1 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20221202195650-67e5cbc046fd // indirect
google.golang.org/grpc v1.51.0 // indirect
@ -107,8 +104,6 @@ replace github.com/prometheus/golang_client => github.com/prometheus/golang_clie
replace gopkg.in/yaml.v3 => gopkg.in/yaml.v3 v3.0.1
replace github.com/quic-go/quic-go => github.com/devincarr/quic-go v0.0.0-20230502200822-d1f4edacbee7
// Post-quantum tunnel RTG-1339
// Branches go1.20 on github.com/cloudflare/qtls-pq
replace github.com/quic-go/qtls-go1-20 => github.com/cloudflare/qtls-pq v0.0.0-20230320122459-4ed280d0d633
replace github.com/quic-go/qtls-go1-20 => github.com/cloudflare/qtls-pq v0.0.0-20231024102457-5b458bcaf6d4

49
go.sum
View File

@ -61,10 +61,8 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudflare/circl v1.2.1-0.20220809205628-0a9554f37a47 h1:YzpECHxZ9TzO7LpnKmPxItSd79lLgrR5heIlnqU4dTU=
github.com/cloudflare/circl v1.2.1-0.20220809205628-0a9554f37a47/go.mod h1:qhx8gBILsYlbam7h09SvHDSkjpe3TfLA7b/z4rxJvkE=
github.com/cloudflare/golibs v0.0.0-20170913112048-333127dbecfc h1:Dvk3ySBsOm5EviLx6VCyILnafPcQinXGP5jbTdHUJgE=
github.com/cloudflare/golibs v0.0.0-20170913112048-333127dbecfc/go.mod h1:HlgKKR8V5a1wroIDDIz3/A+T+9Janfq+7n1P5sEFdi0=
github.com/cloudflare/qtls-pq v0.0.0-20230320122459-4ed280d0d633 h1:ZTub2XMOBpxyBiJf6Q+UKqAi07yt1rZmFitriHvFd8M=
github.com/cloudflare/qtls-pq v0.0.0-20230320122459-4ed280d0d633/go.mod h1:j/igSUc4PgBMayIsBGjAFu2i7g663rm6kZrKy4htb7E=
github.com/cloudflare/qtls-pq v0.0.0-20231024102457-5b458bcaf6d4 h1:gkjG0LZZTHDehVlLbY8pGcaCgeNHawJqV2IHgOjlGDM=
github.com/cloudflare/qtls-pq v0.0.0-20231024102457-5b458bcaf6d4/go.mod h1:bhkEYs+1JsfHM6xqs1h4eprhpmUk/UTjdmOZK3kLIpM=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI=
@ -88,14 +86,13 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/devincarr/quic-go v0.0.0-20230502200822-d1f4edacbee7 h1:qxyoKKPXmPsbvT7SZTcvhEgUaZhEttk4f6u8rIawKj0=
github.com/devincarr/quic-go v0.0.0-20230502200822-d1f4edacbee7/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0=
github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51 h1:0JZ+dUmQeA8IIVUMzysrX4/AKuQwWhV2dYQuPZdvdSQ=
github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51/go.mod h1:Yg+htXGokKKdzcwhuNDwVvN+uBxDGXJ7G/VN1d8fa64=
@ -137,8 +134,9 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
@ -149,8 +147,8 @@ github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/j
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 h1:YyrUZvJaU8Q0QsoVo+xLFBgWDTam29PKea6GYmwvSiQ=
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
@ -178,8 +176,6 @@ github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt
github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@ -195,8 +191,9 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
@ -295,10 +292,9 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/onsi/ginkgo/v2 v2.4.0 h1:+Ig9nvqgS5OBSACXNk15PLdp0U9XPYROt9CFzVdFGIs=
github.com/onsi/ginkgo/v2 v2.4.0/go.mod h1:iHkDK1fKGcBoEHT5W7YBq4RFWaQulw+caOMkAt4OrFo=
github.com/onsi/gomega v1.23.0 h1:/oxKu9c2HVap+F3PfKort2Hw5DEU+HGlW8n+tguWsys=
github.com/onsi/gomega v1.23.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg=
@ -335,8 +331,10 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI=
github.com/quic-go/quic-go v0.40.0 h1:GYd1iznlKm7dpHD7pOVpUvItgMPo/jrMgDWZhMCecqw=
github.com/quic-go/quic-go v0.40.0/go.mod h1:PeN7kuVJ4xZbxSv/4OX6S1USOX8MJvydwpTx31vx60c=
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55 h1:I4N3ZRnkZPbDN935Tg8QDf8fRpHp3bZ0U0/L42jBgNE=
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55/go.mod h1:PeN7kuVJ4xZbxSv/4OX6S1USOX8MJvydwpTx31vx60c=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
@ -396,6 +394,8 @@ go.opentelemetry.io/proto/otlp v0.15.0 h1:h0bKrvdrT/9sBwEJ6iWUqT/N/xPcS66bL4u3is
go.opentelemetry.io/proto/otlp v0.15.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U=
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo=
go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -439,8 +439,8 @@ golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -497,8 +497,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -605,10 +605,9 @@ golang.org/x/tools v0.0.0-20200618134242-20370b0cb4b2/go.mod h1:EkVYQZoAsY45+roY
golang.org/x/tools v0.0.0-20200729194436-6467de6f59a7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200804011535-6c149bb5ef0d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200825202427-b303f430e36d/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -53,7 +53,7 @@ func (dm *DatagramMuxer) SendToSession(session *packet.Session) error {
if err != nil {
return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped")
}
if err := dm.session.SendMessage(payloadWithMetadata); err != nil {
if err := dm.session.SendDatagram(payloadWithMetadata); err != nil {
return errors.Wrap(err, "Failed to send datagram back to edge")
}
return nil
@ -64,7 +64,7 @@ func (dm *DatagramMuxer) ServeReceive(ctx context.Context) error {
// Extracts datagram session ID, then sends the session ID and payload to receiver
// which determines how to proxy to the origin. It assumes the datagram session has already been
// registered with receiver through other side channel
msg, err := dm.session.ReceiveMessage()
msg, err := dm.session.ReceiveDatagram(ctx)
if err != nil {
return err
}

View File

@ -10,6 +10,7 @@ import (
"encoding/pem"
"fmt"
"math/big"
"net"
"net/netip"
"testing"
"time"
@ -114,9 +115,8 @@ func TestDatagram(t *testing.T) {
func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Session, packets []packet.ICMP) {
quicConfig := &quic.Config{
KeepAlivePeriod: 5 * time.Millisecond,
EnableDatagrams: true,
MaxDatagramFrameSize: MaxDatagramFrameSize,
KeepAlivePeriod: 5 * time.Millisecond,
EnableDatagrams: true,
}
quicListener := newQUICListener(t, quicConfig)
defer quicListener.Close()
@ -182,8 +182,12 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// https://github.com/quic-go/quic-go/issues/3793 MTU discovery is disabled on OSX for dual stack listeners
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
// Establish quic connection
quicSession, err := quic.DialAddrEarly(ctx, quicListener.Addr().String(), tlsClientConfig, quicConfig)
quicSession, err := quic.DialEarly(ctx, udpConn, quicListener.Addr(), tlsClientConfig, quicConfig)
require.NoError(t, err)
defer quicSession.CloseWithError(0, "")

View File

@ -86,7 +86,7 @@ func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error {
if err != nil {
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
}
if err := dm.session.SendMessage(msgWithIDAndType); err != nil {
if err := dm.session.SendDatagram(msgWithIDAndType); err != nil {
return errors.Wrap(err, "Failed to send datagram back to edge")
}
return nil
@ -104,7 +104,7 @@ func (dm *DatagramMuxerV2) SendPacket(pk Packet) error {
if err != nil {
return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped")
}
if err := dm.session.SendMessage(payloadWithMetadataAndType); err != nil {
if err := dm.session.SendDatagram(payloadWithMetadataAndType); err != nil {
return errors.Wrap(err, "Failed to send datagram back to edge")
}
return nil
@ -113,7 +113,7 @@ func (dm *DatagramMuxerV2) SendPacket(pk Packet) error {
// Demux reads datagrams from the QUIC connection and demuxes depending on whether it's a session or packet
func (dm *DatagramMuxerV2) ServeReceive(ctx context.Context) error {
for {
msg, err := dm.session.ReceiveMessage()
msg, err := dm.session.ReceiveDatagram(ctx)
if err != nil {
return err
}

View File

@ -21,7 +21,7 @@ type tracerConfig struct {
index uint8
}
func NewClientTracer(logger *zerolog.Logger, index uint8) func(context.Context, logging.Perspective, logging.ConnectionID) logging.ConnectionTracer {
func NewClientTracer(logger *zerolog.Logger, index uint8) func(context.Context, logging.Perspective, logging.ConnectionID) *logging.ConnectionTracer {
t := &tracer{
logger: logger,
config: &tracerConfig{
@ -32,42 +32,61 @@ func NewClientTracer(logger *zerolog.Logger, index uint8) func(context.Context,
return t.TracerForConnection
}
func NewServerTracer(logger *zerolog.Logger) logging.Tracer {
return &tracer{
logger: logger,
config: &tracerConfig{
isClient: false,
},
func NewServerTracer(logger *zerolog.Logger) *logging.Tracer {
return &logging.Tracer{
SentPacket: func(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {},
SentVersionNegotiationPacket: func(_ net.Addr, dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {},
DroppedPacket: func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) {},
}
}
func (t *tracer) TracerForConnection(_ctx context.Context, _p logging.Perspective, _odcid logging.ConnectionID) logging.ConnectionTracer {
func (t *tracer) TracerForConnection(_ctx context.Context, _p logging.Perspective, _odcid logging.ConnectionID) *logging.ConnectionTracer {
if t.config.isClient {
return newConnTracer(newClientCollector(t.config.index))
}
return newConnTracer(newServiceCollector())
}
func (*tracer) SentPacket(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {
}
func (*tracer) SentVersionNegotiationPacket(_ net.Addr, dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
}
func (*tracer) DroppedPacket(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) {
}
var _ logging.Tracer = (*tracer)(nil)
// connTracer collects connection level metrics
type connTracer struct {
metricsCollector MetricsCollector
}
var _ logging.ConnectionTracer = (*connTracer)(nil)
func newConnTracer(metricsCollector MetricsCollector) logging.ConnectionTracer {
return &connTracer{
func newConnTracer(metricsCollector MetricsCollector) *logging.ConnectionTracer {
tracer := connTracer{
metricsCollector: metricsCollector,
}
return &logging.ConnectionTracer{
StartedConnection: tracer.StartedConnection,
NegotiatedVersion: tracer.NegotiatedVersion,
ClosedConnection: tracer.ClosedConnection,
SentTransportParameters: tracer.SentTransportParameters,
ReceivedTransportParameters: tracer.ReceivedTransportParameters,
RestoredTransportParameters: tracer.RestoredTransportParameters,
SentLongHeaderPacket: tracer.SentLongHeaderPacket,
SentShortHeaderPacket: tracer.SentShortHeaderPacket,
ReceivedVersionNegotiationPacket: tracer.ReceivedVersionNegotiationPacket,
ReceivedRetry: tracer.ReceivedRetry,
ReceivedLongHeaderPacket: tracer.ReceivedLongHeaderPacket,
ReceivedShortHeaderPacket: tracer.ReceivedShortHeaderPacket,
BufferedPacket: tracer.BufferedPacket,
DroppedPacket: tracer.DroppedPacket,
UpdatedMetrics: tracer.UpdatedMetrics,
AcknowledgedPacket: tracer.AcknowledgedPacket,
LostPacket: tracer.LostPacket,
UpdatedCongestionState: tracer.UpdatedCongestionState,
UpdatedPTOCount: tracer.UpdatedPTOCount,
UpdatedKeyFromTLS: tracer.UpdatedKeyFromTLS,
UpdatedKey: tracer.UpdatedKey,
DroppedEncryptionLevel: tracer.DroppedEncryptionLevel,
DroppedKey: tracer.DroppedKey,
SetLossTimer: tracer.SetLossTimer,
LossTimerExpired: tracer.LossTimerExpired,
LossTimerCanceled: tracer.LossTimerCanceled,
ECNStateUpdated: tracer.ECNStateUpdated,
Close: tracer.Close,
Debug: tracer.Debug,
}
}
func (ct *connTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID logging.ConnectionID) {
@ -90,7 +109,7 @@ func (ct *connTracer) BufferedPacket(pt logging.PacketType, size logging.ByteCou
ct.metricsCollector.bufferedPackets(pt)
}
func (ct *connTracer) DroppedPacket(pt logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) {
func (ct *connTracer) DroppedPacket(pt logging.PacketType, number logging.PacketNumber, size logging.ByteCount, reason logging.PacketDropReason) {
ct.metricsCollector.droppedPackets(pt, size, reason)
}
@ -114,10 +133,10 @@ func (ct *connTracer) ReceivedTransportParameters(parameters *logging.TransportP
func (ct *connTracer) RestoredTransportParameters(parameters *logging.TransportParameters) {
}
func (ct *connTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
func (ct *connTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) {
}
func (ct *connTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
func (ct *connTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) {
}
func (ct *connTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
@ -126,10 +145,10 @@ func (ct *connTracer) ReceivedVersionNegotiationPacket(dest, src logging.Arbitra
func (ct *connTracer) ReceivedRetry(header *logging.Header) {
}
func (ct *connTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, frames []logging.Frame) {
func (ct *connTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) {
}
func (ct *connTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, frames []logging.Frame) {
func (ct *connTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) {
}
func (ct *connTracer) AcknowledgedPacket(level logging.EncryptionLevel, number logging.PacketNumber) {
@ -162,6 +181,10 @@ func (ct *connTracer) LossTimerExpired(timerType logging.TimerType, level loggin
func (ct *connTracer) LossTimerCanceled() {
}
func (ct *connTracer) ECNStateUpdated(state logging.ECNState, trigger logging.ECNStateTrigger) {
}
func (ct *connTracer) Close() {
}

View File

@ -597,7 +597,6 @@ func (e *EdgeTunnelServer) serveQUIC(
MaxIncomingStreams: quicpogs.MaxIncomingStreams,
MaxIncomingUniStreams: quicpogs.MaxIncomingStreams,
EnableDatagrams: true,
MaxDatagramFrameSize: quicpogs.MaxDatagramFrameSize,
Tracer: quicpogs.NewClientTracer(connLogger.Logger(), connIndex),
DisablePathMTUDiscovery: e.config.DisableQUICPathMTUDiscovery,
}

View File

@ -1,27 +0,0 @@
Copyright (c) 2013 CloudFlare, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of the CloudFlare, Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,3 +0,0 @@
cover.out~
benchmark/benchmark

View File

@ -1,41 +0,0 @@
# Copyright (c) 2013 CloudFlare, Inc.
RACE+=--race
PKGNAME=github.com/cloudflare/golibs/lrucache
SKIPCOVER=list.go|list_extension.go|priorityqueue.go
.PHONY: all test bench cover clean
all:
@echo "Targets:"
@echo " test: run tests with race detector"
@echo " cover: print test coverage"
@echo " bench: run basic benchmarks"
test:
@go test $(RACE) -bench=. -v $(PKGNAME)
COVEROUT=cover.out
cover:
@go test -coverprofile=$(COVEROUT) -v $(PKGNAME)
@cat $(COVEROUT) | egrep -v '$(SKIPCOVER)' > $(COVEROUT)~
@go tool cover -func=$(COVEROUT)~|sed 's|^.*/\([^/]*/[^/]*/[^/]*\)$$|\1|g'
bench:
@echo "[*] Scalability of cache/lrucache"
@echo "[ ] Operations in shared cache using one core"
@GOMAXPROCS=1 go test -run=- -bench='.*LRUCache.*' $(PKGNAME) \
| egrep -v "^PASS|^ok"
@echo "[*] Scalability of cache/multilru"
@echo "[ ] Operations in four caches using four cores "
@GOMAXPROCS=4 go test -run=- -bench='.*MultiLRU.*' $(PKGNAME) \
| egrep -v "^PASS|^ok"
@(cd benchmark; go build $(PKGNAME)/benchmark)
@./benchmark/benchmark
clean:
rm -rf $(COVEROUT) $(COVEROUT)~ benchmark/benchmark

View File

@ -1,40 +0,0 @@
LRU Cache
---------
A `golang` implementation of last recently used cache data structure.
To install:
go get github.com/cloudflare/golibs/lrucache
To test:
cd $GOPATH/src/github.com/cloudflare/golibs/lrucache
make test
For coverage:
make cover
Basic benchmarks:
$ make bench # As tested on my two core i5
[*] Scalability of cache/lrucache
[ ] Operations in shared cache using one core
BenchmarkConcurrentGetLRUCache 5000000 450 ns/op
BenchmarkConcurrentSetLRUCache 2000000 821 ns/op
BenchmarkConcurrentSetNXLRUCache 5000000 664 ns/op
[*] Scalability of cache/multilru
[ ] Operations in four caches using four cores
BenchmarkConcurrentGetMultiLRU-4 5000000 475 ns/op
BenchmarkConcurrentSetMultiLRU-4 2000000 809 ns/op
BenchmarkConcurrentSetNXMultiLRU-4 5000000 643 ns/op
[*] Capacity=4096 Keys=30000 KeySpace=15625
vitess LRUCache MultiLRUCache-4
create 1.709us 1.626374ms 343.54us
Get (miss) 144.266083ms 132.470397ms 177.277193ms
SetNX #1 338.637977ms 380.733302ms 411.709204ms
Get (hit) 195.896066ms 173.252112ms 234.109494ms
SetNX #2 349.785951ms 367.255624ms 419.129127ms

View File

@ -1,69 +0,0 @@
// Copyright (c) 2013 CloudFlare, Inc.
// Package lrucache implements a last recently used cache data structure.
//
// This code tries to avoid dynamic memory allocations - all required
// memory is allocated on creation. Access to the data structure is
// O(1). Modification O(log(n)) if expiry is used, O(1)
// otherwise.
//
// This package exports three things:
// LRUCache: is the main implementation. It supports multithreading by
// using guarding mutex lock.
//
// MultiLRUCache: is a sharded implementation. It supports the same
// API as LRUCache and uses it internally, but is not limited to
// a single CPU as every shard is separately locked. Use this
// data structure instead of LRUCache if you have have lock
// contention issues.
//
// Cache interface: Both implementations fulfill it.
package lrucache
import (
"time"
)
// Cache interface is fulfilled by the LRUCache and MultiLRUCache
// implementations.
type Cache interface {
// Methods not needing to know current time.
//
// Get a key from the cache, possibly stale. Update its LRU
// score.
Get(key string) (value interface{}, ok bool)
// Get a key from the cache, possibly stale. Don't modify its LRU score. O(1)
GetQuiet(key string) (value interface{}, ok bool)
// Get and remove a key from the cache.
Del(key string) (value interface{}, ok bool)
// Evict all items from the cache.
Clear() int
// Number of entries used in the LRU
Len() int
// Get the total capacity of the LRU
Capacity() int
// Methods use time.Now() when neccessary to determine expiry.
//
// Add an item to the cache overwriting existing one if it
// exists.
Set(key string, value interface{}, expire time.Time)
// Get a key from the cache, make sure it's not stale. Update
// its LRU score.
GetNotStale(key string) (value interface{}, ok bool)
// Evict all the expired items.
Expire() int
// Methods allowing to explicitly specify time used to
// determine if items are expired.
//
// Add an item to the cache overwriting existing one if it
// exists. Allows specifing current time required to expire an
// item when no more slots are used.
SetNow(key string, value interface{}, expire time.Time, now time.Time)
// Get a key from the cache, make sure it's not stale. Update
// its LRU score.
GetNotStaleNow(key string, now time.Time) (value interface{}, ok bool)
// Evict items that expire before Now.
ExpireNow(now time.Time) int
}

View File

@ -1,238 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/* This file is a slightly modified file from the go package sources
and is released on the following license:
Copyright (c) 2012 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
// Package list implements a doubly linked list.
//
// To iterate over a list (where l is a *List):
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.Value
// }
//
package lrucache
// Element is an element of a linked list.
type element struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *element
// The list to which this element belongs.
list *list
// The value stored with this element.
Value interface{}
}
// Next returns the next list element or nil.
func (e *element) Next() *element {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *element) Prev() *element {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// List represents a doubly linked list.
// The zero value for List is an empty list ready to use.
type list struct {
root element // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *list) Init() *list {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// New returns an initialized list.
// func New() *list { return new(list).Init() }
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *list) Len() int { return l.len }
// Front returns the first element of list l or nil
func (l *list) Front() *element {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil.
func (l *list) Back() *element {
if l.len == 0 {
return nil
}
return l.root.prev
}
// insert inserts e after at, increments l.len, and returns e.
func (l *list) insert(e, at *element) *element {
n := at.next
at.next = e
e.prev = at
e.next = n
n.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *list) insertValue(v interface{}, at *element) *element {
return l.insert(&element{Value: v}, at)
}
// remove removes e from its list, decrements l.len, and returns e.
func (l *list) remove(e *element) *element {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
func (l *list) Remove(e *element) interface{} {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *list) PushFront(v interface{}) *element {
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *list) PushBack(v interface{}) *element {
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
func (l *list) InsertBefore(v interface{}, mark *element) *element {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
func (l *list) InsertAfter(v interface{}, mark *element) *element {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
func (l *list) MoveToFront(e *element) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
func (l *list) MoveToBack(e *element) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e is not an element of l, or e == mark, the list is not modified.
func (l *list) MoveBefore(e, mark *element) {
if e.list != l || e == mark {
return
}
l.insert(l.remove(e), mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e is not an element of l, or e == mark, the list is not modified.
func (l *list) MoveAfter(e, mark *element) {
if e.list != l || e == mark {
return
}
l.insert(l.remove(e), mark)
}
// PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same.
func (l *list) PushBackList(other *list) {
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same.
func (l *list) PushFrontList(other *list) {
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}

View File

@ -1,25 +0,0 @@
// Copyright (c) 2013 CloudFlare, Inc.
// Extensions to "container/list" that allowing reuse of Elements.
package lrucache
func (l *list) PushElementFront(e *element) *element {
return l.insert(e, &l.root)
}
func (l *list) PushElementBack(e *element) *element {
return l.insert(e, l.root.prev)
}
func (l *list) PopElementFront() *element {
el := l.Front()
l.Remove(el)
return el
}
func (l *list) PopFront() interface{} {
el := l.Front()
l.Remove(el)
return el.Value
}

View File

@ -1,316 +0,0 @@
// Copyright (c) 2013 CloudFlare, Inc.
package lrucache
import (
"container/heap"
"sync"
"time"
)
// Every element in the cache is linked to three data structures:
// Table map, PriorityQueue heap ordered by expiry and a LruList list
// ordered by decreasing popularity.
type entry struct {
element element // list element. value is a pointer to this entry
key string // key is a key!
value interface{} //
expire time.Time // time when the item is expired. it's okay to be stale.
index int // index for priority queue needs. -1 if entry is free
}
// LRUCache data structure. Never dereference it or copy it by
// value. Always use it through a pointer.
type LRUCache struct {
lock sync.Mutex
table map[string]*entry // all entries in table must be in lruList
priorityQueue priorityQueue // some elements from table may be in priorityQueue
lruList list // every entry is either used and resides in lruList
freeList list // or free and is linked to freeList
ExpireGracePeriod time.Duration // time after an expired entry is purged from cache (unless pushed out of LRU)
}
// Initialize the LRU cache instance. O(capacity)
func (b *LRUCache) Init(capacity uint) {
b.table = make(map[string]*entry, capacity)
b.priorityQueue = make([]*entry, 0, capacity)
b.lruList.Init()
b.freeList.Init()
heap.Init(&b.priorityQueue)
// Reserve all the entries in one giant continous block of memory
arrayOfEntries := make([]entry, capacity)
for i := uint(0); i < capacity; i++ {
e := &arrayOfEntries[i]
e.element.Value = e
e.index = -1
b.freeList.PushElementBack(&e.element)
}
}
// Create new LRU cache instance. Allocate all the needed memory. O(capacity)
func NewLRUCache(capacity uint) *LRUCache {
b := &LRUCache{}
b.Init(capacity)
return b
}
// Give me the entry with lowest expiry field if it's before now.
func (b *LRUCache) expiredEntry(now time.Time) *entry {
if len(b.priorityQueue) == 0 {
return nil
}
if now.IsZero() {
// Fill it only when actually used.
now = time.Now()
}
if e := b.priorityQueue[0]; e.expire.Before(now) {
return e
}
return nil
}
// Give me the least used entry.
func (b *LRUCache) leastUsedEntry() *entry {
return b.lruList.Back().Value.(*entry)
}
func (b *LRUCache) freeSomeEntry(now time.Time) (e *entry, used bool) {
if b.freeList.Len() > 0 {
return b.freeList.Front().Value.(*entry), false
}
e = b.expiredEntry(now)
if e != nil {
return e, true
}
if b.lruList.Len() == 0 {
return nil, false
}
return b.leastUsedEntry(), true
}
// Move entry from used/lru list to a free list. Clear the entry as well.
func (b *LRUCache) removeEntry(e *entry) {
if e.element.list != &b.lruList {
panic("list lruList")
}
if e.index != -1 {
heap.Remove(&b.priorityQueue, e.index)
}
b.lruList.Remove(&e.element)
b.freeList.PushElementFront(&e.element)
delete(b.table, e.key)
e.key = ""
e.value = nil
}
func (b *LRUCache) insertEntry(e *entry) {
if e.element.list != &b.freeList {
panic("list freeList")
}
if !e.expire.IsZero() {
heap.Push(&b.priorityQueue, e)
}
b.freeList.Remove(&e.element)
b.lruList.PushElementFront(&e.element)
b.table[e.key] = e
}
func (b *LRUCache) touchEntry(e *entry) {
b.lruList.MoveToFront(&e.element)
}
// Add an item to the cache overwriting existing one if it
// exists. Allows specifing current time required to expire an item
// when no more slots are used. O(log(n)) if expiry is set, O(1) when
// clear.
func (b *LRUCache) SetNow(key string, value interface{}, expire time.Time, now time.Time) {
b.lock.Lock()
defer b.lock.Unlock()
var used bool
e := b.table[key]
if e != nil {
used = true
} else {
e, used = b.freeSomeEntry(now)
if e == nil {
return
}
}
if used {
b.removeEntry(e)
}
e.key = key
e.value = value
e.expire = expire
b.insertEntry(e)
}
// Add an item to the cache overwriting existing one if it
// exists. O(log(n)) if expiry is set, O(1) when clear.
func (b *LRUCache) Set(key string, value interface{}, expire time.Time) {
b.SetNow(key, value, expire, time.Time{})
}
// Get a key from the cache, possibly stale. Update its LRU score. O(1)
func (b *LRUCache) Get(key string) (v interface{}, ok bool) {
b.lock.Lock()
defer b.lock.Unlock()
e := b.table[key]
if e == nil {
return nil, false
}
b.touchEntry(e)
return e.value, true
}
// Get a key from the cache, possibly stale. Don't modify its LRU score. O(1)
func (b *LRUCache) GetQuiet(key string) (v interface{}, ok bool) {
b.lock.Lock()
defer b.lock.Unlock()
e := b.table[key]
if e == nil {
return nil, false
}
return e.value, true
}
// Get a key from the cache, make sure it's not stale. Update its
// LRU score. O(log(n)) if the item is expired.
func (b *LRUCache) GetNotStale(key string) (value interface{}, ok bool) {
return b.GetNotStaleNow(key, time.Now())
}
// Get a key from the cache, make sure it's not stale. Update its
// LRU score. O(log(n)) if the item is expired.
func (b *LRUCache) GetNotStaleNow(key string, now time.Time) (value interface{}, ok bool) {
b.lock.Lock()
defer b.lock.Unlock()
e := b.table[key]
if e == nil {
return nil, false
}
if e.expire.Before(now) {
// Remove entries expired for more than a graceful period
if b.ExpireGracePeriod == 0 || e.expire.Sub(now) > b.ExpireGracePeriod {
b.removeEntry(e)
}
return nil, false
}
b.touchEntry(e)
return e.value, true
}
// Get a key from the cache, possibly stale. Update its LRU
// score. O(1) always.
func (b *LRUCache) GetStale(key string) (value interface{}, ok, expired bool) {
return b.GetStaleNow(key, time.Now())
}
// Get a key from the cache, possibly stale. Update its LRU
// score. O(1) always.
func (b *LRUCache) GetStaleNow(key string, now time.Time) (value interface{}, ok, expired bool) {
b.lock.Lock()
defer b.lock.Unlock()
e := b.table[key]
if e == nil {
return nil, false, false
}
b.touchEntry(e)
return e.value, true, e.expire.Before(now)
}
// Get and remove a key from the cache. O(log(n)) if the item is using expiry, O(1) otherwise.
func (b *LRUCache) Del(key string) (v interface{}, ok bool) {
b.lock.Lock()
defer b.lock.Unlock()
e := b.table[key]
if e == nil {
return nil, false
}
value := e.value
b.removeEntry(e)
return value, true
}
// Evict all items from the cache. O(n*log(n))
func (b *LRUCache) Clear() int {
b.lock.Lock()
defer b.lock.Unlock()
// First, remove entries that have expiry set
l := len(b.priorityQueue)
for i := 0; i < l; i++ {
// This could be reduced to O(n).
b.removeEntry(b.priorityQueue[0])
}
// Second, remove all remaining entries
r := b.lruList.Len()
for i := 0; i < r; i++ {
b.removeEntry(b.leastUsedEntry())
}
return l + r
}
// Evict all the expired items. O(n*log(n))
func (b *LRUCache) Expire() int {
return b.ExpireNow(time.Now())
}
// Evict items that expire before `now`. O(n*log(n))
func (b *LRUCache) ExpireNow(now time.Time) int {
b.lock.Lock()
defer b.lock.Unlock()
i := 0
for {
e := b.expiredEntry(now)
if e == nil {
break
}
b.removeEntry(e)
i += 1
}
return i
}
// Number of entries used in the LRU
func (b *LRUCache) Len() int {
// yes. this stupid thing requires locking
b.lock.Lock()
defer b.lock.Unlock()
return b.lruList.Len()
}
// Get the total capacity of the LRU
func (b *LRUCache) Capacity() int {
// yes. this stupid thing requires locking
b.lock.Lock()
defer b.lock.Unlock()
return b.lruList.Len() + b.freeList.Len()
}

View File

@ -1,118 +0,0 @@
// Copyright (c) 2013 CloudFlare, Inc.
package lrucache
import (
"hash/crc32"
"time"
)
// MultiLRUCache data structure. Never dereference it or copy it by
// value. Always use it through a pointer.
type MultiLRUCache struct {
buckets uint
cache []*LRUCache
}
// Using this constructor is almost always wrong. Use NewMultiLRUCache instead.
func (m *MultiLRUCache) Init(buckets, bucket_capacity uint) {
m.buckets = buckets
m.cache = make([]*LRUCache, buckets)
for i := uint(0); i < buckets; i++ {
m.cache[i] = NewLRUCache(bucket_capacity)
}
}
// Set the stale expiry grace period for each cache in the multicache instance.
func (m *MultiLRUCache) SetExpireGracePeriod(p time.Duration) {
for _, c := range m.cache {
c.ExpireGracePeriod = p
}
}
func NewMultiLRUCache(buckets, bucket_capacity uint) *MultiLRUCache {
m := &MultiLRUCache{}
m.Init(buckets, bucket_capacity)
return m
}
func (m *MultiLRUCache) bucketNo(key string) uint {
// Arbitrary choice. Any fast hash will do.
return uint(crc32.ChecksumIEEE([]byte(key))) % m.buckets
}
func (m *MultiLRUCache) Set(key string, value interface{}, expire time.Time) {
m.cache[m.bucketNo(key)].Set(key, value, expire)
}
func (m *MultiLRUCache) SetNow(key string, value interface{}, expire time.Time, now time.Time) {
m.cache[m.bucketNo(key)].SetNow(key, value, expire, now)
}
func (m *MultiLRUCache) Get(key string) (value interface{}, ok bool) {
return m.cache[m.bucketNo(key)].Get(key)
}
func (m *MultiLRUCache) GetQuiet(key string) (value interface{}, ok bool) {
return m.cache[m.bucketNo(key)].Get(key)
}
func (m *MultiLRUCache) GetNotStale(key string) (value interface{}, ok bool) {
return m.cache[m.bucketNo(key)].GetNotStale(key)
}
func (m *MultiLRUCache) GetNotStaleNow(key string, now time.Time) (value interface{}, ok bool) {
return m.cache[m.bucketNo(key)].GetNotStaleNow(key, now)
}
func (m *MultiLRUCache) GetStale(key string) (value interface{}, ok, expired bool) {
return m.cache[m.bucketNo(key)].GetStale(key)
}
func (m *MultiLRUCache) GetStaleNow(key string, now time.Time) (value interface{}, ok, expired bool) {
return m.cache[m.bucketNo(key)].GetStaleNow(key, now)
}
func (m *MultiLRUCache) Del(key string) (value interface{}, ok bool) {
return m.cache[m.bucketNo(key)].Del(key)
}
func (m *MultiLRUCache) Clear() int {
var s int
for _, c := range m.cache {
s += c.Clear()
}
return s
}
func (m *MultiLRUCache) Len() int {
var s int
for _, c := range m.cache {
s += c.Len()
}
return s
}
func (m *MultiLRUCache) Capacity() int {
var s int
for _, c := range m.cache {
s += c.Capacity()
}
return s
}
func (m *MultiLRUCache) Expire() int {
var s int
for _, c := range m.cache {
s += c.Expire()
}
return s
}
func (m *MultiLRUCache) ExpireNow(now time.Time) int {
var s int
for _, c := range m.cache {
s += c.ExpireNow(now)
}
return s
}

View File

@ -1,37 +0,0 @@
// Copyright (c) 2013 CloudFlare, Inc.
// This code is based on golang example from "container/heap" package.
package lrucache
type priorityQueue []*entry
func (pq priorityQueue) Len() int {
return len(pq)
}
func (pq priorityQueue) Less(i, j int) bool {
return pq[i].expire.Before(pq[j].expire)
}
func (pq priorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i
pq[j].index = j
}
func (pq *priorityQueue) Push(e interface{}) {
n := len(*pq)
item := e.(*entry)
item.index = n
*pq = append(*pq, item)
}
func (pq *priorityQueue) Pop() interface{} {
old := *pq
n := len(old)
item := old[n-1]
item.index = -1
*pq = old[0 : n-1]
return item
}

View File

@ -6,7 +6,6 @@ linters:
disable-all: true
enable:
- asciicheck
- deadcode
- errcheck
- forcetypeassert
- gocritic
@ -18,10 +17,8 @@ linters:
- misspell
- revive
- staticcheck
- structcheck
- typecheck
- unused
- varcheck
issues:
exclude-use-default: false

View File

@ -20,35 +20,5 @@ package logr
// used whenever the caller is not interested in the logs. Logger instances
// produced by this function always compare as equal.
func Discard() Logger {
return Logger{
level: 0,
sink: discardLogSink{},
}
}
// discardLogSink is a LogSink that discards all messages.
type discardLogSink struct{}
// Verify that it actually implements the interface
var _ LogSink = discardLogSink{}
func (l discardLogSink) Init(RuntimeInfo) {
}
func (l discardLogSink) Enabled(int) bool {
return false
}
func (l discardLogSink) Info(int, string, ...interface{}) {
}
func (l discardLogSink) Error(error, string, ...interface{}) {
}
func (l discardLogSink) WithValues(...interface{}) LogSink {
return l
}
func (l discardLogSink) WithName(string) LogSink {
return l
return New(nil)
}

View File

@ -21,13 +21,13 @@ limitations under the License.
// github.com/go-logr/logr.LogSink with output through an arbitrary
// "write" function. See New and NewJSON for details.
//
// Custom LogSinks
// # Custom LogSinks
//
// For users who need more control, a funcr.Formatter can be embedded inside
// your own custom LogSink implementation. This is useful when the LogSink
// needs to implement additional methods, for example.
//
// Formatting
// # Formatting
//
// This will respect logr.Marshaler, fmt.Stringer, and error interfaces for
// values which are being logged. When rendering a struct, funcr will use Go's
@ -37,6 +37,7 @@ package funcr
import (
"bytes"
"encoding"
"encoding/json"
"fmt"
"path/filepath"
"reflect"
@ -217,7 +218,7 @@ func newFormatter(opts Options, outfmt outputFormat) Formatter {
prefix: "",
values: nil,
depth: 0,
opts: opts,
opts: &opts,
}
return f
}
@ -231,7 +232,7 @@ type Formatter struct {
values []interface{}
valuesStr string
depth int
opts Options
opts *Options
}
// outputFormat indicates which outputFormat to use.
@ -447,6 +448,7 @@ func (f Formatter) prettyWithFlags(value interface{}, flags uint32, depth int) s
if flags&flagRawStruct == 0 {
buf.WriteByte('{')
}
printComma := false // testing i>0 is not enough because of JSON omitted fields
for i := 0; i < t.NumField(); i++ {
fld := t.Field(i)
if fld.PkgPath != "" {
@ -478,9 +480,10 @@ func (f Formatter) prettyWithFlags(value interface{}, flags uint32, depth int) s
if omitempty && isEmpty(v.Field(i)) {
continue
}
if i > 0 {
if printComma {
buf.WriteByte(',')
}
printComma = true // if we got here, we are rendering a field
if fld.Anonymous && fld.Type.Kind() == reflect.Struct && name == "" {
buf.WriteString(f.prettyWithFlags(v.Field(i).Interface(), flags|flagRawStruct, depth+1))
continue
@ -500,6 +503,20 @@ func (f Formatter) prettyWithFlags(value interface{}, flags uint32, depth int) s
}
return buf.String()
case reflect.Slice, reflect.Array:
// If this is outputing as JSON make sure this isn't really a json.RawMessage.
// If so just emit "as-is" and don't pretty it as that will just print
// it as [X,Y,Z,...] which isn't terribly useful vs the string form you really want.
if f.outputFormat == outputJSON {
if rm, ok := value.(json.RawMessage); ok {
// If it's empty make sure we emit an empty value as the array style would below.
if len(rm) > 0 {
buf.Write(rm)
} else {
buf.WriteString("null")
}
return buf.String()
}
}
buf.WriteByte('[')
for i := 0; i < v.Len(); i++ {
if i > 0 {

View File

@ -21,7 +21,7 @@ limitations under the License.
// to back that API. Packages in the Go ecosystem can depend on this package,
// while callers can implement logging with whatever backend is appropriate.
//
// Usage
// # Usage
//
// Logging is done using a Logger instance. Logger is a concrete type with
// methods, which defers the actual logging to a LogSink interface. The main
@ -30,16 +30,20 @@ limitations under the License.
// "structured logging".
//
// With Go's standard log package, we might write:
// log.Printf("setting target value %s", targetValue)
//
// log.Printf("setting target value %s", targetValue)
//
// With logr's structured logging, we'd write:
// logger.Info("setting target", "value", targetValue)
//
// logger.Info("setting target", "value", targetValue)
//
// Errors are much the same. Instead of:
// log.Printf("failed to open the pod bay door for user %s: %v", user, err)
//
// log.Printf("failed to open the pod bay door for user %s: %v", user, err)
//
// We'd write:
// logger.Error(err, "failed to open the pod bay door", "user", user)
//
// logger.Error(err, "failed to open the pod bay door", "user", user)
//
// Info() and Error() are very similar, but they are separate methods so that
// LogSink implementations can choose to do things like attach additional
@ -47,7 +51,7 @@ limitations under the License.
// always logged, regardless of the current verbosity. If there is no error
// instance available, passing nil is valid.
//
// Verbosity
// # Verbosity
//
// Often we want to log information only when the application in "verbose
// mode". To write log lines that are more verbose, Logger has a V() method.
@ -58,20 +62,22 @@ limitations under the License.
// Error messages do not have a verbosity level and are always logged.
//
// Where we might have written:
// if flVerbose >= 2 {
// log.Printf("an unusual thing happened")
// }
//
// if flVerbose >= 2 {
// log.Printf("an unusual thing happened")
// }
//
// We can write:
// logger.V(2).Info("an unusual thing happened")
//
// Logger Names
// logger.V(2).Info("an unusual thing happened")
//
// # Logger Names
//
// Logger instances can have name strings so that all messages logged through
// that instance have additional context. For example, you might want to add
// a subsystem name:
//
// logger.WithName("compactor").Info("started", "time", time.Now())
// logger.WithName("compactor").Info("started", "time", time.Now())
//
// The WithName() method returns a new Logger, which can be passed to
// constructors or other functions for further use. Repeated use of WithName()
@ -82,25 +88,27 @@ limitations under the License.
// joining operation (e.g. whitespace, commas, periods, slashes, brackets,
// quotes, etc).
//
// Saved Values
// # Saved Values
//
// Logger instances can store any number of key/value pairs, which will be
// logged alongside all messages logged through that instance. For example,
// you might want to create a Logger instance per managed object:
//
// With the standard log package, we might write:
// log.Printf("decided to set field foo to value %q for object %s/%s",
// targetValue, object.Namespace, object.Name)
//
// log.Printf("decided to set field foo to value %q for object %s/%s",
// targetValue, object.Namespace, object.Name)
//
// With logr we'd write:
// // Elsewhere: set up the logger to log the object name.
// obj.logger = mainLogger.WithValues(
// "name", obj.name, "namespace", obj.namespace)
//
// // later on...
// obj.logger.Info("setting foo", "value", targetValue)
// // Elsewhere: set up the logger to log the object name.
// obj.logger = mainLogger.WithValues(
// "name", obj.name, "namespace", obj.namespace)
//
// Best Practices
// // later on...
// obj.logger.Info("setting foo", "value", targetValue)
//
// # Best Practices
//
// Logger has very few hard rules, with the goal that LogSink implementations
// might have a lot of freedom to differentiate. There are, however, some
@ -124,15 +132,15 @@ limitations under the License.
// around. For cases where passing a logger is optional, a pointer to Logger
// should be used.
//
// Key Naming Conventions
// # Key Naming Conventions
//
// Keys are not strictly required to conform to any specification or regex, but
// it is recommended that they:
// * be human-readable and meaningful (not auto-generated or simple ordinals)
// * be constant (not dependent on input data)
// * contain only printable characters
// * not contain whitespace or punctuation
// * use lower case for simple keys and lowerCamelCase for more complex ones
// - be human-readable and meaningful (not auto-generated or simple ordinals)
// - be constant (not dependent on input data)
// - contain only printable characters
// - not contain whitespace or punctuation
// - use lower case for simple keys and lowerCamelCase for more complex ones
//
// These guidelines help ensure that log data is processed properly regardless
// of the log implementation. For example, log implementations will try to
@ -141,51 +149,54 @@ limitations under the License.
// While users are generally free to use key names of their choice, it's
// generally best to avoid using the following keys, as they're frequently used
// by implementations:
// * "caller": the calling information (file/line) of a particular log line
// * "error": the underlying error value in the `Error` method
// * "level": the log level
// * "logger": the name of the associated logger
// * "msg": the log message
// * "stacktrace": the stack trace associated with a particular log line or
// error (often from the `Error` message)
// * "ts": the timestamp for a log line
// - "caller": the calling information (file/line) of a particular log line
// - "error": the underlying error value in the `Error` method
// - "level": the log level
// - "logger": the name of the associated logger
// - "msg": the log message
// - "stacktrace": the stack trace associated with a particular log line or
// error (often from the `Error` message)
// - "ts": the timestamp for a log line
//
// Implementations are encouraged to make use of these keys to represent the
// above concepts, when necessary (for example, in a pure-JSON output form, it
// would be necessary to represent at least message and timestamp as ordinary
// named values).
//
// Break Glass
// # Break Glass
//
// Implementations may choose to give callers access to the underlying
// logging implementation. The recommended pattern for this is:
// // Underlier exposes access to the underlying logging implementation.
// // Since callers only have a logr.Logger, they have to know which
// // implementation is in use, so this interface is less of an abstraction
// // and more of way to test type conversion.
// type Underlier interface {
// GetUnderlying() <underlying-type>
// }
//
// // Underlier exposes access to the underlying logging implementation.
// // Since callers only have a logr.Logger, they have to know which
// // implementation is in use, so this interface is less of an abstraction
// // and more of way to test type conversion.
// type Underlier interface {
// GetUnderlying() <underlying-type>
// }
//
// Logger grants access to the sink to enable type assertions like this:
// func DoSomethingWithImpl(log logr.Logger) {
// if underlier, ok := log.GetSink()(impl.Underlier) {
// implLogger := underlier.GetUnderlying()
// ...
// }
// }
//
// func DoSomethingWithImpl(log logr.Logger) {
// if underlier, ok := log.GetSink().(impl.Underlier); ok {
// implLogger := underlier.GetUnderlying()
// ...
// }
// }
//
// Custom `With*` functions can be implemented by copying the complete
// Logger struct and replacing the sink in the copy:
// // WithFooBar changes the foobar parameter in the log sink and returns a
// // new logger with that modified sink. It does nothing for loggers where
// // the sink doesn't support that parameter.
// func WithFoobar(log logr.Logger, foobar int) logr.Logger {
// if foobarLogSink, ok := log.GetSink()(FoobarSink); ok {
// log = log.WithSink(foobarLogSink.WithFooBar(foobar))
// }
// return log
// }
//
// // WithFooBar changes the foobar parameter in the log sink and returns a
// // new logger with that modified sink. It does nothing for loggers where
// // the sink doesn't support that parameter.
// func WithFoobar(log logr.Logger, foobar int) logr.Logger {
// if foobarLogSink, ok := log.GetSink().(FoobarSink); ok {
// log = log.WithSink(foobarLogSink.WithFooBar(foobar))
// }
// return log
// }
//
// Don't use New to construct a new Logger with a LogSink retrieved from an
// existing Logger. Source code attribution might not work correctly and
@ -201,11 +212,14 @@ import (
)
// New returns a new Logger instance. This is primarily used by libraries
// implementing LogSink, rather than end users.
// implementing LogSink, rather than end users. Passing a nil sink will create
// a Logger which discards all log lines.
func New(sink LogSink) Logger {
logger := Logger{}
logger.setSink(sink)
sink.Init(runtimeInfo)
if sink != nil {
sink.Init(runtimeInfo)
}
return logger
}
@ -244,7 +258,7 @@ type Logger struct {
// Enabled tests whether this Logger is enabled. For example, commandline
// flags might be used to set the logging verbosity and disable some info logs.
func (l Logger) Enabled() bool {
return l.sink.Enabled(l.level)
return l.sink != nil && l.sink.Enabled(l.level)
}
// Info logs a non-error message with the given key/value pairs as context.
@ -254,6 +268,9 @@ func (l Logger) Enabled() bool {
// information. The key/value pairs must alternate string keys and arbitrary
// values.
func (l Logger) Info(msg string, keysAndValues ...interface{}) {
if l.sink == nil {
return
}
if l.Enabled() {
if withHelper, ok := l.sink.(CallStackHelperLogSink); ok {
withHelper.GetCallStackHelper()()
@ -273,6 +290,9 @@ func (l Logger) Info(msg string, keysAndValues ...interface{}) {
// triggered this log line, if present. The err parameter is optional
// and nil may be passed instead of an error instance.
func (l Logger) Error(err error, msg string, keysAndValues ...interface{}) {
if l.sink == nil {
return
}
if withHelper, ok := l.sink.(CallStackHelperLogSink); ok {
withHelper.GetCallStackHelper()()
}
@ -284,6 +304,9 @@ func (l Logger) Error(err error, msg string, keysAndValues ...interface{}) {
// level means a log message is less important. Negative V-levels are treated
// as 0.
func (l Logger) V(level int) Logger {
if l.sink == nil {
return l
}
if level < 0 {
level = 0
}
@ -294,6 +317,9 @@ func (l Logger) V(level int) Logger {
// WithValues returns a new Logger instance with additional key/value pairs.
// See Info for documentation on how key/value pairs work.
func (l Logger) WithValues(keysAndValues ...interface{}) Logger {
if l.sink == nil {
return l
}
l.setSink(l.sink.WithValues(keysAndValues...))
return l
}
@ -304,6 +330,9 @@ func (l Logger) WithValues(keysAndValues ...interface{}) Logger {
// contain only letters, digits, and hyphens (see the package documentation for
// more information).
func (l Logger) WithName(name string) Logger {
if l.sink == nil {
return l
}
l.setSink(l.sink.WithName(name))
return l
}
@ -324,6 +353,9 @@ func (l Logger) WithName(name string) Logger {
// WithCallDepth(1) because it works with implementions that support the
// CallDepthLogSink and/or CallStackHelperLogSink interfaces.
func (l Logger) WithCallDepth(depth int) Logger {
if l.sink == nil {
return l
}
if withCallDepth, ok := l.sink.(CallDepthLogSink); ok {
l.setSink(withCallDepth.WithCallDepth(depth))
}
@ -345,6 +377,9 @@ func (l Logger) WithCallDepth(depth int) Logger {
// implementation does not support either of these, the original Logger will be
// returned.
func (l Logger) WithCallStackHelper() (func(), Logger) {
if l.sink == nil {
return func() {}, l
}
var helper func()
if withCallDepth, ok := l.sink.(CallDepthLogSink); ok {
l.setSink(withCallDepth.WithCallDepth(1))
@ -357,6 +392,11 @@ func (l Logger) WithCallStackHelper() (func(), Logger) {
return helper, l
}
// IsZero returns true if this logger is an uninitialized zero value
func (l Logger) IsZero() bool {
return l.sink == nil
}
// contextKey is how we find Loggers in a context.Context.
type contextKey struct{}
@ -442,7 +482,7 @@ type LogSink interface {
WithName(name string) LogSink
}
// CallDepthLogSink represents a Logger that knows how to climb the call stack
// CallDepthLogSink represents a LogSink that knows how to climb the call stack
// to identify the original call site and can offset the depth by a specified
// number of frames. This is useful for users who have helper functions
// between the "real" call site and the actual calls to Logger methods.
@ -467,7 +507,7 @@ type CallDepthLogSink interface {
WithCallDepth(depth int) LogSink
}
// CallStackHelperLogSink represents a Logger that knows how to climb
// CallStackHelperLogSink represents a LogSink that knows how to climb
// the call stack to identify the original call site and can skip
// intermediate helper functions if they mark themselves as
// helper. Go's testing package uses that approach.

View File

@ -1,26 +0,0 @@
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !go1.12
package main
import (
"log"
)
func printModuleVersion() {
log.Printf("No version information is available for Mockgen compiled with " +
"version 1.11")
}

View File

@ -386,8 +386,14 @@ func (u *Unmarshaler) unmarshalMessage(m protoreflect.Message, in []byte) error
}
func isSingularWellKnownValue(fd protoreflect.FieldDescriptor) bool {
if fd.Cardinality() == protoreflect.Repeated {
return false
}
if md := fd.Message(); md != nil {
return md.FullName() == "google.protobuf.Value" && fd.Cardinality() != protoreflect.Repeated
return md.FullName() == "google.protobuf.Value"
}
if ed := fd.Enum(); ed != nil {
return ed.FullName() == "google.protobuf.NullValue"
}
return false
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"os"
"regexp"
"strconv"
"strings"
)
@ -50,6 +51,37 @@ func NewWithNoColorBool(noColor bool) Formatter {
}
func New(colorMode ColorMode) Formatter {
colorAliases := map[string]int{
"black": 0,
"red": 1,
"green": 2,
"yellow": 3,
"blue": 4,
"magenta": 5,
"cyan": 6,
"white": 7,
}
for colorAlias, n := range colorAliases {
colorAliases[fmt.Sprintf("bright-%s", colorAlias)] = n + 8
}
getColor := func(color, defaultEscapeCode string) string {
color = strings.ToUpper(strings.ReplaceAll(color, "-", "_"))
envVar := fmt.Sprintf("GINKGO_CLI_COLOR_%s", color)
envVarColor := os.Getenv(envVar)
if envVarColor == "" {
return defaultEscapeCode
}
if colorCode, ok := colorAliases[envVarColor]; ok {
return fmt.Sprintf("\x1b[38;5;%dm", colorCode)
}
colorCode, err := strconv.Atoi(envVarColor)
if err != nil || colorCode < 0 || colorCode > 255 {
return defaultEscapeCode
}
return fmt.Sprintf("\x1b[38;5;%dm", colorCode)
}
f := Formatter{
ColorMode: colorMode,
colors: map[string]string{
@ -57,18 +89,18 @@ func New(colorMode ColorMode) Formatter {
"bold": "\x1b[1m",
"underline": "\x1b[4m",
"red": "\x1b[38;5;9m",
"orange": "\x1b[38;5;214m",
"coral": "\x1b[38;5;204m",
"magenta": "\x1b[38;5;13m",
"green": "\x1b[38;5;10m",
"dark-green": "\x1b[38;5;28m",
"yellow": "\x1b[38;5;11m",
"light-yellow": "\x1b[38;5;228m",
"cyan": "\x1b[38;5;14m",
"gray": "\x1b[38;5;243m",
"light-gray": "\x1b[38;5;246m",
"blue": "\x1b[38;5;12m",
"red": getColor("red", "\x1b[38;5;9m"),
"orange": getColor("orange", "\x1b[38;5;214m"),
"coral": getColor("coral", "\x1b[38;5;204m"),
"magenta": getColor("magenta", "\x1b[38;5;13m"),
"green": getColor("green", "\x1b[38;5;10m"),
"dark-green": getColor("dark-green", "\x1b[38;5;28m"),
"yellow": getColor("yellow", "\x1b[38;5;11m"),
"light-yellow": getColor("light-yellow", "\x1b[38;5;228m"),
"cyan": getColor("cyan", "\x1b[38;5;14m"),
"gray": getColor("gray", "\x1b[38;5;243m"),
"light-gray": getColor("light-gray", "\x1b[38;5;246m"),
"blue": getColor("blue", "\x1b[38;5;12m"),
},
}
colors := []string{}
@ -88,7 +120,10 @@ func (f Formatter) Fi(indentation uint, format string, args ...interface{}) stri
}
func (f Formatter) Fiw(indentation uint, maxWidth uint, format string, args ...interface{}) string {
out := fmt.Sprintf(f.style(format), args...)
out := f.style(format)
if len(args) > 0 {
out = fmt.Sprintf(out, args...)
}
if indentation == 0 && maxWidth == 0 {
return out

View File

@ -2,6 +2,7 @@ package generators
import (
"bytes"
"encoding/json"
"fmt"
"os"
"text/template"
@ -25,6 +26,9 @@ func BuildBootstrapCommand() command.Command {
{Name: "template", KeyPath: "CustomTemplate",
UsageArgument: "template-file",
Usage: "If specified, generate will use the contents of the file passed as the bootstrap template"},
{Name: "template-data", KeyPath: "CustomTemplateData",
UsageArgument: "template-data-file",
Usage: "If specified, generate will use the contents of the file passed as data to be rendered in the bootstrap template"},
},
&conf,
types.GinkgoFlagSections{},
@ -57,6 +61,7 @@ type bootstrapData struct {
GomegaImport string
GinkgoPackage string
GomegaPackage string
CustomData map[string]any
}
func generateBootstrap(conf GeneratorsConfig) {
@ -95,17 +100,32 @@ func generateBootstrap(conf GeneratorsConfig) {
tpl, err := os.ReadFile(conf.CustomTemplate)
command.AbortIfError("Failed to read custom bootstrap file:", err)
templateText = string(tpl)
if conf.CustomTemplateData != "" {
var tplCustomDataMap map[string]any
tplCustomData, err := os.ReadFile(conf.CustomTemplateData)
command.AbortIfError("Failed to read custom boostrap data file:", err)
if !json.Valid([]byte(tplCustomData)) {
command.AbortWith("Invalid JSON object in custom data file.")
}
//create map from the custom template data
json.Unmarshal(tplCustomData, &tplCustomDataMap)
data.CustomData = tplCustomDataMap
}
} else if conf.Agouti {
templateText = agoutiBootstrapText
} else {
templateText = bootstrapText
}
bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Parse(templateText)
//Setting the option to explicitly fail if template is rendered trying to access missing key
bootstrapTemplate, err := template.New("bootstrap").Funcs(sprig.TxtFuncMap()).Option("missingkey=error").Parse(templateText)
command.AbortIfError("Failed to parse bootstrap template:", err)
buf := &bytes.Buffer{}
bootstrapTemplate.Execute(buf, data)
//Being explicit about failing sooner during template rendering
//when accessing custom data rather than during the go fmt command
err = bootstrapTemplate.Execute(buf, data)
command.AbortIfError("Failed to render bootstrap template:", err)
buf.WriteTo(f)

View File

@ -2,6 +2,7 @@ package generators
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
@ -28,6 +29,9 @@ func BuildGenerateCommand() command.Command {
{Name: "template", KeyPath: "CustomTemplate",
UsageArgument: "template-file",
Usage: "If specified, generate will use the contents of the file passed as the test file template"},
{Name: "template-data", KeyPath: "CustomTemplateData",
UsageArgument: "template-data-file",
Usage: "If specified, generate will use the contents of the file passed as data to be rendered in the test file template"},
},
&conf,
types.GinkgoFlagSections{},
@ -64,6 +68,7 @@ type specData struct {
GomegaImport string
GinkgoPackage string
GomegaPackage string
CustomData map[string]any
}
func generateTestFiles(conf GeneratorsConfig, args []string) {
@ -122,16 +127,31 @@ func generateTestFileForSubject(subject string, conf GeneratorsConfig) {
tpl, err := os.ReadFile(conf.CustomTemplate)
command.AbortIfError("Failed to read custom template file:", err)
templateText = string(tpl)
if conf.CustomTemplateData != "" {
var tplCustomDataMap map[string]any
tplCustomData, err := os.ReadFile(conf.CustomTemplateData)
command.AbortIfError("Failed to read custom template data file:", err)
if !json.Valid([]byte(tplCustomData)) {
command.AbortWith("Invalid JSON object in custom data file.")
}
//create map from the custom template data
json.Unmarshal(tplCustomData, &tplCustomDataMap)
data.CustomData = tplCustomDataMap
}
} else if conf.Agouti {
templateText = agoutiSpecText
} else {
templateText = specText
}
specTemplate, err := template.New("spec").Funcs(sprig.TxtFuncMap()).Parse(templateText)
//Setting the option to explicitly fail if template is rendered trying to access missing key
specTemplate, err := template.New("spec").Funcs(sprig.TxtFuncMap()).Option("missingkey=error").Parse(templateText)
command.AbortIfError("Failed to read parse test template:", err)
specTemplate.Execute(f, data)
//Being explicit about failing sooner during template rendering
//when accessing custom data rather than during the go fmt command
err = specTemplate.Execute(f, data)
command.AbortIfError("Failed to render bootstrap template:", err)
internal.GoFmt(targetFile)
}

View File

@ -13,6 +13,7 @@ import (
type GeneratorsConfig struct {
Agouti, NoDot, Internal bool
CustomTemplate string
CustomTemplateData string
}
func getPackageAndFormattedName() (string, string, string) {

View File

@ -25,7 +25,16 @@ func CompileSuite(suite TestSuite, goFlagsConfig types.GoFlagsConfig) TestSuite
return suite
}
args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, path, "./")
ginkgoInvocationPath, _ := os.Getwd()
ginkgoInvocationPath, _ = filepath.Abs(ginkgoInvocationPath)
packagePath := suite.AbsPath()
pathToInvocationPath, err := filepath.Rel(packagePath, ginkgoInvocationPath)
if err != nil {
suite.State = TestSuiteStateFailedToCompile
suite.CompilationError = fmt.Errorf("Failed to get relative path from package to the current working directory:\n%s", err.Error())
return suite
}
args, err := types.GenerateGoTestCompileArgs(goFlagsConfig, path, "./", pathToInvocationPath)
if err != nil {
suite.State = TestSuiteStateFailedToCompile
suite.CompilationError = fmt.Errorf("Failed to generate go test compile flags:\n%s", err.Error())

View File

@ -6,6 +6,7 @@ import (
"io"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"syscall"
@ -63,6 +64,12 @@ func checkForNoTestsWarning(buf *bytes.Buffer) bool {
}
func runGoTest(suite TestSuite, cliConfig types.CLIConfig, goFlagsConfig types.GoFlagsConfig) TestSuite {
// As we run the go test from the suite directory, make sure the cover profile is absolute
// and placed into the expected output directory when one is configured.
if goFlagsConfig.Cover && !filepath.IsAbs(goFlagsConfig.CoverProfile) {
goFlagsConfig.CoverProfile = AbsPathForGeneratedAsset(goFlagsConfig.CoverProfile, suite, cliConfig, 0)
}
args, err := types.GenerateGoTestRunArgs(goFlagsConfig)
command.AbortIfError("Failed to generate test run arguments", err)
cmd, buf := buildAndStartCommand(suite, args, true)

View File

@ -1,6 +1,7 @@
package outline
import (
"github.com/onsi/ginkgo/v2/types"
"go/ast"
"go/token"
"strconv"
@ -25,9 +26,10 @@ type ginkgoMetadata struct {
// End is the position of first character immediately after the spec or container block
End int `json:"end"`
Spec bool `json:"spec"`
Focused bool `json:"focused"`
Pending bool `json:"pending"`
Spec bool `json:"spec"`
Focused bool `json:"focused"`
Pending bool `json:"pending"`
Labels []string `json:"labels"`
}
// ginkgoNode is used to construct the outline as a tree
@ -145,27 +147,35 @@ func ginkgoNodeFromCallExpr(fset *token.FileSet, ce *ast.CallExpr, ginkgoPackage
case "It", "Specify", "Entry":
n.Spec = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
n.Pending = pendingFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "FIt", "FSpecify", "FEntry":
n.Spec = true
n.Focused = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "PIt", "PSpecify", "XIt", "XSpecify", "PEntry", "XEntry":
n.Spec = true
n.Pending = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "Context", "Describe", "When", "DescribeTable":
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
n.Pending = pendingFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "FContext", "FDescribe", "FWhen", "FDescribeTable":
n.Focused = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "PContext", "PDescribe", "PWhen", "XContext", "XDescribe", "XWhen", "PDescribeTable", "XDescribeTable":
n.Pending = true
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
n.Labels = labelFromCallExpr(ce)
return &n, ginkgoPackageName != nil && *ginkgoPackageName == packageName
case "By":
n.Text = textOrAltFromCallExpr(ce, undefinedTextAlt)
@ -216,3 +226,77 @@ func textFromCallExpr(ce *ast.CallExpr) (string, bool) {
return text.Value, true
}
}
func labelFromCallExpr(ce *ast.CallExpr) []string {
labels := []string{}
if len(ce.Args) < 2 {
return labels
}
for _, arg := range ce.Args[1:] {
switch expr := arg.(type) {
case *ast.CallExpr:
id, ok := expr.Fun.(*ast.Ident)
if !ok {
// to skip over cases where the expr.Fun. is actually *ast.SelectorExpr
continue
}
if id.Name == "Label" {
ls := extractLabels(expr)
for _, label := range ls {
labels = append(labels, label)
}
}
}
}
return labels
}
func extractLabels(expr *ast.CallExpr) []string {
out := []string{}
for _, arg := range expr.Args {
switch expr := arg.(type) {
case *ast.BasicLit:
if expr.Kind == token.STRING {
unquoted, err := strconv.Unquote(expr.Value)
if err != nil {
unquoted = expr.Value
}
validated, err := types.ValidateAndCleanupLabel(unquoted, types.CodeLocation{})
if err == nil {
out = append(out, validated)
}
}
}
}
return out
}
func pendingFromCallExpr(ce *ast.CallExpr) bool {
pending := false
if len(ce.Args) < 2 {
return pending
}
for _, arg := range ce.Args[1:] {
switch expr := arg.(type) {
case *ast.CallExpr:
id, ok := expr.Fun.(*ast.Ident)
if !ok {
// to skip over cases where the expr.Fun. is actually *ast.SelectorExpr
continue
}
if id.Name == "Pending" {
pending = true
}
case *ast.Ident:
if expr.Name == "Pending" {
pending = true
}
}
}
return pending
}

View File

@ -85,12 +85,19 @@ func (o *outline) String() string {
// one 'width' of spaces for every level of nesting.
func (o *outline) StringIndent(width int) string {
var b strings.Builder
b.WriteString("Name,Text,Start,End,Spec,Focused,Pending\n")
b.WriteString("Name,Text,Start,End,Spec,Focused,Pending,Labels\n")
currentIndent := 0
pre := func(n *ginkgoNode) {
b.WriteString(fmt.Sprintf("%*s", currentIndent, ""))
b.WriteString(fmt.Sprintf("%s,%s,%d,%d,%t,%t,%t\n", n.Name, n.Text, n.Start, n.End, n.Spec, n.Focused, n.Pending))
var labels string
if len(n.Labels) == 1 {
labels = n.Labels[0]
} else {
labels = strings.Join(n.Labels, ", ")
}
//enclosing labels in a double quoted comma separate listed so that when inmported into a CSV app the Labels column has comma separate strings
b.WriteString(fmt.Sprintf("%s,%s,%d,%d,%t,%t,%t,\"%s\"\n", n.Name, n.Text, n.Start, n.End, n.Spec, n.Focused, n.Pending, labels))
currentIndent += width
}
post := func(n *ginkgoNode) {

View File

@ -10,7 +10,7 @@ import (
"github.com/onsi/ginkgo/v2/internal/parallel_support"
)
const ABORT_POLLING_INTERVAL = 500 * time.Millisecond
var ABORT_POLLING_INTERVAL = 500 * time.Millisecond
type InterruptCause uint
@ -62,13 +62,14 @@ type InterruptHandlerInterface interface {
}
type InterruptHandler struct {
c chan interface{}
lock *sync.Mutex
level InterruptLevel
cause InterruptCause
client parallel_support.Client
stop chan interface{}
signals []os.Signal
c chan interface{}
lock *sync.Mutex
level InterruptLevel
cause InterruptCause
client parallel_support.Client
stop chan interface{}
signals []os.Signal
requestAbortCheck chan interface{}
}
func NewInterruptHandler(client parallel_support.Client, signals ...os.Signal) *InterruptHandler {
@ -76,11 +77,12 @@ func NewInterruptHandler(client parallel_support.Client, signals ...os.Signal) *
signals = []os.Signal{os.Interrupt, syscall.SIGTERM}
}
handler := &InterruptHandler{
c: make(chan interface{}),
lock: &sync.Mutex{},
stop: make(chan interface{}),
client: client,
signals: signals,
c: make(chan interface{}),
lock: &sync.Mutex{},
stop: make(chan interface{}),
requestAbortCheck: make(chan interface{}),
client: client,
signals: signals,
}
handler.registerForInterrupts()
return handler
@ -109,6 +111,12 @@ func (handler *InterruptHandler) registerForInterrupts() {
pollTicker.Stop()
return
}
case <-handler.requestAbortCheck:
if handler.client.ShouldAbort() {
close(abortChannel)
pollTicker.Stop()
return
}
case <-handler.stop:
pollTicker.Stop()
return
@ -152,11 +160,18 @@ func (handler *InterruptHandler) registerForInterrupts() {
func (handler *InterruptHandler) Status() InterruptStatus {
handler.lock.Lock()
defer handler.lock.Unlock()
return InterruptStatus{
status := InterruptStatus{
Level: handler.level,
Channel: handler.c,
Cause: handler.cause,
}
handler.lock.Unlock()
if handler.client != nil && handler.client.ShouldAbort() && !status.Interrupted() {
close(handler.requestAbortCheck)
<-status.Channel
return handler.Status()
}
return status
}

View File

@ -42,6 +42,8 @@ type Client interface {
PostSuiteWillBegin(report types.Report) error
PostDidRun(report types.SpecReport) error
PostSuiteDidEnd(report types.Report) error
PostReportBeforeSuiteCompleted(state types.SpecState) error
BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error)
PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error
BlockUntilSynchronizedBeforeSuiteData() (types.SpecState, []byte, error)
BlockUntilNonprimaryProcsHaveFinished() error

View File

@ -98,6 +98,19 @@ func (client *httpClient) PostEmitProgressReport(report types.ProgressReport) er
return client.post("/progress-report", report)
}
func (client *httpClient) PostReportBeforeSuiteCompleted(state types.SpecState) error {
return client.post("/report-before-suite-completed", state)
}
func (client *httpClient) BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error) {
var state types.SpecState
err := client.poll("/report-before-suite-state", &state)
if err == ErrorGone {
return types.SpecStateFailed, nil
}
return state, err
}
func (client *httpClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error {
beforeSuiteState := BeforeSuiteState{
State: state,

View File

@ -26,7 +26,7 @@ type httpServer struct {
handler *ServerHandler
}
//Create a new server, automatically selecting a port
// Create a new server, automatically selecting a port
func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
@ -38,7 +38,7 @@ func newHttpServer(parallelTotal int, reporter reporters.Reporter) (*httpServer,
}, nil
}
//Start the server. You don't need to `go s.Start()`, just `s.Start()`
// Start the server. You don't need to `go s.Start()`, just `s.Start()`
func (server *httpServer) Start() {
httpServer := &http.Server{}
mux := http.NewServeMux()
@ -52,6 +52,8 @@ func (server *httpServer) Start() {
mux.HandleFunc("/progress-report", server.emitProgressReport)
//synchronization endpoints
mux.HandleFunc("/report-before-suite-completed", server.handleReportBeforeSuiteCompleted)
mux.HandleFunc("/report-before-suite-state", server.handleReportBeforeSuiteState)
mux.HandleFunc("/before-suite-completed", server.handleBeforeSuiteCompleted)
mux.HandleFunc("/before-suite-state", server.handleBeforeSuiteState)
mux.HandleFunc("/have-nonprimary-procs-finished", server.handleHaveNonprimaryProcsFinished)
@ -63,12 +65,12 @@ func (server *httpServer) Start() {
go httpServer.Serve(server.listener)
}
//Stop the server
// Stop the server
func (server *httpServer) Close() {
server.listener.Close()
}
//The address the server can be reached it. Pass this into the `ForwardingReporter`.
// The address the server can be reached it. Pass this into the `ForwardingReporter`.
func (server *httpServer) Address() string {
return "http://" + server.listener.Addr().String()
}
@ -93,7 +95,7 @@ func (server *httpServer) RegisterAlive(node int, alive func() bool) {
// Streaming Endpoints
//
//The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters`
// The server will forward all received messages to Ginkgo reporters registered with `RegisterReporters`
func (server *httpServer) decode(writer http.ResponseWriter, request *http.Request, object interface{}) bool {
defer request.Body.Close()
if json.NewDecoder(request.Body).Decode(object) != nil {
@ -164,6 +166,23 @@ func (server *httpServer) emitProgressReport(writer http.ResponseWriter, request
server.handleError(server.handler.EmitProgressReport(report, voidReceiver), writer)
}
func (server *httpServer) handleReportBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) {
var state types.SpecState
if !server.decode(writer, request, &state) {
return
}
server.handleError(server.handler.ReportBeforeSuiteCompleted(state, voidReceiver), writer)
}
func (server *httpServer) handleReportBeforeSuiteState(writer http.ResponseWriter, request *http.Request) {
var state types.SpecState
if server.handleError(server.handler.ReportBeforeSuiteState(voidSender, &state), writer) {
return
}
json.NewEncoder(writer).Encode(state)
}
func (server *httpServer) handleBeforeSuiteCompleted(writer http.ResponseWriter, request *http.Request) {
var beforeSuiteState BeforeSuiteState
if !server.decode(writer, request, &beforeSuiteState) {

View File

@ -76,6 +76,19 @@ func (client *rpcClient) PostEmitProgressReport(report types.ProgressReport) err
return client.client.Call("Server.EmitProgressReport", report, voidReceiver)
}
func (client *rpcClient) PostReportBeforeSuiteCompleted(state types.SpecState) error {
return client.client.Call("Server.ReportBeforeSuiteCompleted", state, voidReceiver)
}
func (client *rpcClient) BlockUntilReportBeforeSuiteCompleted() (types.SpecState, error) {
var state types.SpecState
err := client.poll("Server.ReportBeforeSuiteState", &state)
if err == ErrorGone {
return types.SpecStateFailed, nil
}
return state, err
}
func (client *rpcClient) PostSynchronizedBeforeSuiteCompleted(state types.SpecState, data []byte) error {
beforeSuiteState := BeforeSuiteState{
State: state,

View File

@ -18,16 +18,17 @@ var voidSender Void
// It handles all the business logic to avoid duplication between the two servers
type ServerHandler struct {
done chan interface{}
outputDestination io.Writer
reporter reporters.Reporter
alives []func() bool
lock *sync.Mutex
beforeSuiteState BeforeSuiteState
parallelTotal int
counter int
counterLock *sync.Mutex
shouldAbort bool
done chan interface{}
outputDestination io.Writer
reporter reporters.Reporter
alives []func() bool
lock *sync.Mutex
beforeSuiteState BeforeSuiteState
reportBeforeSuiteState types.SpecState
parallelTotal int
counter int
counterLock *sync.Mutex
shouldAbort bool
numSuiteDidBegins int
numSuiteDidEnds int
@ -37,11 +38,12 @@ type ServerHandler struct {
func newServerHandler(parallelTotal int, reporter reporters.Reporter) *ServerHandler {
return &ServerHandler{
reporter: reporter,
lock: &sync.Mutex{},
counterLock: &sync.Mutex{},
alives: make([]func() bool, parallelTotal),
beforeSuiteState: BeforeSuiteState{Data: nil, State: types.SpecStateInvalid},
reporter: reporter,
lock: &sync.Mutex{},
counterLock: &sync.Mutex{},
alives: make([]func() bool, parallelTotal),
beforeSuiteState: BeforeSuiteState{Data: nil, State: types.SpecStateInvalid},
parallelTotal: parallelTotal,
outputDestination: os.Stdout,
done: make(chan interface{}),
@ -140,6 +142,29 @@ func (handler *ServerHandler) haveNonprimaryProcsFinished() bool {
return true
}
func (handler *ServerHandler) ReportBeforeSuiteCompleted(reportBeforeSuiteState types.SpecState, _ *Void) error {
handler.lock.Lock()
defer handler.lock.Unlock()
handler.reportBeforeSuiteState = reportBeforeSuiteState
return nil
}
func (handler *ServerHandler) ReportBeforeSuiteState(_ Void, reportBeforeSuiteState *types.SpecState) error {
proc1IsAlive := handler.procIsAlive(1)
handler.lock.Lock()
defer handler.lock.Unlock()
if handler.reportBeforeSuiteState == types.SpecStateInvalid {
if proc1IsAlive {
return ErrorEarly
} else {
return ErrorGone
}
}
*reportBeforeSuiteState = handler.reportBeforeSuiteState
return nil
}
func (handler *ServerHandler) BeforeSuiteCompleted(beforeSuiteState BeforeSuiteState, _ *Void) error {
handler.lock.Lock()
defer handler.lock.Unlock()

View File

@ -12,6 +12,7 @@ import (
"io"
"runtime"
"strings"
"sync"
"time"
"github.com/onsi/ginkgo/v2/formatter"
@ -23,13 +24,16 @@ type DefaultReporter struct {
writer io.Writer
// managing the emission stream
lastChar string
lastCharWasNewline bool
lastEmissionWasDelimiter bool
// rendering
specDenoter string
retryDenoter string
formatter formatter.Formatter
runningInParallel bool
lock *sync.Mutex
}
func NewDefaultReporterUnderTest(conf types.ReporterConfig, writer io.Writer) *DefaultReporter {
@ -44,12 +48,13 @@ func NewDefaultReporter(conf types.ReporterConfig, writer io.Writer) *DefaultRep
conf: conf,
writer: writer,
lastChar: "\n",
lastCharWasNewline: true,
lastEmissionWasDelimiter: false,
specDenoter: "•",
retryDenoter: "↺",
formatter: formatter.NewWithNoColorBool(conf.NoColor),
lock: &sync.Mutex{},
}
if runtime.GOOS == "windows" {
reporter.specDenoter = "+"
@ -97,230 +102,10 @@ func (r *DefaultReporter) SuiteWillBegin(report types.Report) {
}
}
func (r *DefaultReporter) WillRun(report types.SpecReport) {
if r.conf.Verbosity().LT(types.VerbosityLevelVerbose) || report.State.Is(types.SpecStatePending|types.SpecStateSkipped) {
return
}
r.emitDelimiter()
indentation := uint(0)
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
r.emitBlock(r.f("{{bold}}[%s] %s{{/}}", report.LeafNodeType.String(), report.LeafNodeText))
} else {
if len(report.ContainerHierarchyTexts) > 0 {
r.emitBlock(r.cycleJoin(report.ContainerHierarchyTexts, " "))
indentation = 1
}
line := r.fi(indentation, "{{bold}}%s{{/}}", report.LeafNodeText)
labels := report.Labels()
if len(labels) > 0 {
line += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels, ", "))
}
r.emitBlock(line)
}
r.emitBlock(r.fi(indentation, "{{gray}}%s{{/}}", report.LeafNodeLocation))
}
func (r *DefaultReporter) DidRun(report types.SpecReport) {
v := r.conf.Verbosity()
var header, highlightColor string
includeRuntime, emitGinkgoWriterOutput, stream, denoter := true, true, false, r.specDenoter
succinctLocationBlock := v.Is(types.VerbosityLevelSuccinct)
hasGW := report.CapturedGinkgoWriterOutput != ""
hasStd := report.CapturedStdOutErr != ""
hasEmittableReports := report.ReportEntries.HasVisibility(types.ReportEntryVisibilityAlways) || (report.ReportEntries.HasVisibility(types.ReportEntryVisibilityFailureOrVerbose) && (!report.Failure.IsZero() || v.GTE(types.VerbosityLevelVerbose)))
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
denoter = fmt.Sprintf("[%s]", report.LeafNodeType)
}
highlightColor = r.highlightColorForState(report.State)
switch report.State {
case types.SpecStatePassed:
succinctLocationBlock = v.LT(types.VerbosityLevelVerbose)
emitGinkgoWriterOutput = (r.conf.AlwaysEmitGinkgoWriter || v.GTE(types.VerbosityLevelVerbose)) && hasGW
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
if v.GTE(types.VerbosityLevelVerbose) || hasStd || hasEmittableReports {
header = fmt.Sprintf("%s PASSED", denoter)
} else {
return
}
} else {
header, stream = denoter, true
if report.NumAttempts > 1 && report.MaxFlakeAttempts > 1 {
header, stream = fmt.Sprintf("%s [FLAKEY TEST - TOOK %d ATTEMPTS TO PASS]", r.retryDenoter, report.NumAttempts), false
}
if report.RunTime > r.conf.SlowSpecThreshold {
header, stream = fmt.Sprintf("%s [SLOW TEST]", header), false
}
}
if hasStd || emitGinkgoWriterOutput || hasEmittableReports {
stream = false
}
case types.SpecStatePending:
includeRuntime, emitGinkgoWriterOutput = false, false
if v.Is(types.VerbosityLevelSuccinct) {
header, stream = "P", true
} else {
header, succinctLocationBlock = "P [PENDING]", v.LT(types.VerbosityLevelVeryVerbose)
}
case types.SpecStateSkipped:
if report.Failure.Message != "" || v.Is(types.VerbosityLevelVeryVerbose) {
header = "S [SKIPPED]"
} else {
header, stream = "S", true
}
case types.SpecStateFailed:
header = fmt.Sprintf("%s [FAILED]", denoter)
case types.SpecStateTimedout:
header = fmt.Sprintf("%s [TIMEDOUT]", denoter)
case types.SpecStatePanicked:
header = fmt.Sprintf("%s! [PANICKED]", denoter)
case types.SpecStateInterrupted:
header = fmt.Sprintf("%s! [INTERRUPTED]", denoter)
case types.SpecStateAborted:
header = fmt.Sprintf("%s! [ABORTED]", denoter)
}
if report.State.Is(types.SpecStateFailureStates) && report.MaxMustPassRepeatedly > 1 {
header, stream = fmt.Sprintf("%s DURING REPETITION #%d", header, report.NumAttempts), false
}
// Emit stream and return
if stream {
r.emit(r.f(highlightColor + header + "{{/}}"))
return
}
// Emit header
r.emitDelimiter()
if includeRuntime {
header = r.f("%s [%.3f seconds]", header, report.RunTime.Seconds())
}
r.emitBlock(r.f(highlightColor + header + "{{/}}"))
// Emit Code Location Block
r.emitBlock(r.codeLocationBlock(report, highlightColor, succinctLocationBlock, false))
//Emit Stdout/Stderr Output
if hasStd {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Begin Captured StdOut/StdErr Output >>{{/}}"))
r.emitBlock(r.fi(2, "%s", report.CapturedStdOutErr))
r.emitBlock(r.fi(1, "{{gray}}<< End Captured StdOut/StdErr Output{{/}}"))
}
//Emit Captured GinkgoWriter Output
if emitGinkgoWriterOutput && hasGW {
r.emitBlock("\n")
r.emitGinkgoWriterOutput(1, report.CapturedGinkgoWriterOutput, 0)
}
if hasEmittableReports {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Begin Report Entries >>{{/}}"))
reportEntries := report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways)
if !report.Failure.IsZero() || v.GTE(types.VerbosityLevelVerbose) {
reportEntries = report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways, types.ReportEntryVisibilityFailureOrVerbose)
}
for _, entry := range reportEntries {
r.emitBlock(r.fi(2, "{{bold}}"+entry.Name+"{{gray}} - %s @ %s{{/}}", entry.Location, entry.Time.Format(types.GINKGO_TIME_FORMAT)))
if representation := entry.StringRepresentation(); representation != "" {
r.emitBlock(r.fi(3, representation))
}
}
r.emitBlock(r.fi(1, "{{gray}}<< End Report Entries{{/}}"))
}
// Emit Failure Message
if !report.Failure.IsZero() {
r.emitBlock("\n")
r.EmitFailure(1, report.State, report.Failure, false)
}
if len(report.AdditionalFailures) > 0 {
if v.GTE(types.VerbosityLevelVerbose) {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{bold}}There were additional failures detected after the initial failure:{{/}}"))
for i, additionalFailure := range report.AdditionalFailures {
r.EmitFailure(2, additionalFailure.State, additionalFailure.Failure, true)
if i < len(report.AdditionalFailures)-1 {
r.emitBlock(r.fi(2, "{{gray}}%s{{/}}", strings.Repeat("-", 10)))
}
}
} else {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{bold}}There were additional failures detected after the initial failure. Here's a summary - for full details run Ginkgo in verbose mode:{{/}}"))
for _, additionalFailure := range report.AdditionalFailures {
r.emitBlock(r.fi(2, r.highlightColorForState(additionalFailure.State)+"[%s]{{/}} in [%s] at %s",
r.humanReadableState(additionalFailure.State),
additionalFailure.Failure.FailureNodeType,
additionalFailure.Failure.Location,
))
}
}
}
r.emitDelimiter()
}
func (r *DefaultReporter) highlightColorForState(state types.SpecState) string {
switch state {
case types.SpecStatePassed:
return "{{green}}"
case types.SpecStatePending:
return "{{yellow}}"
case types.SpecStateSkipped:
return "{{cyan}}"
case types.SpecStateFailed:
return "{{red}}"
case types.SpecStateTimedout:
return "{{orange}}"
case types.SpecStatePanicked:
return "{{magenta}}"
case types.SpecStateInterrupted:
return "{{orange}}"
case types.SpecStateAborted:
return "{{coral}}"
default:
return "{{gray}}"
}
}
func (r *DefaultReporter) humanReadableState(state types.SpecState) string {
return strings.ToUpper(state.String())
}
func (r *DefaultReporter) EmitFailure(indent uint, state types.SpecState, failure types.Failure, includeState bool) {
highlightColor := r.highlightColorForState(state)
if includeState {
r.emitBlock(r.fi(indent, highlightColor+"[%s]{{/}}", r.humanReadableState(state)))
}
r.emitBlock(r.fi(indent, highlightColor+"%s{{/}}", failure.Message))
r.emitBlock(r.fi(indent, highlightColor+"In {{bold}}[%s]{{/}}"+highlightColor+" at: {{bold}}%s{{/}}\n", failure.FailureNodeType, failure.Location))
if failure.ForwardedPanic != "" {
r.emitBlock("\n")
r.emitBlock(r.fi(indent, highlightColor+"%s{{/}}", failure.ForwardedPanic))
}
if r.conf.FullTrace || failure.ForwardedPanic != "" {
r.emitBlock("\n")
r.emitBlock(r.fi(indent, highlightColor+"Full Stack Trace{{/}}"))
r.emitBlock(r.fi(indent+1, "%s", failure.Location.FullStackTrace))
}
if !failure.ProgressReport.IsZero() {
r.emitBlock("\n")
r.emitProgressReport(indent, false, failure.ProgressReport)
}
}
func (r *DefaultReporter) SuiteDidEnd(report types.Report) {
failures := report.SpecReports.WithState(types.SpecStateFailureStates)
if len(failures) > 0 {
r.emitBlock("\n\n")
r.emitBlock("\n")
if len(failures) > 1 {
r.emitBlock(r.f("{{red}}{{bold}}Summarizing %d Failures:{{/}}", len(failures)))
} else {
@ -338,7 +123,7 @@ func (r *DefaultReporter) SuiteDidEnd(report types.Report) {
case types.SpecStateInterrupted:
highlightColor, heading = "{{orange}}", "[INTERRUPTED]"
}
locationBlock := r.codeLocationBlock(specReport, highlightColor, true, true)
locationBlock := r.codeLocationBlock(specReport, highlightColor, false, true)
r.emitBlock(r.fi(1, highlightColor+"%s{{/}} %s", heading, locationBlock))
}
}
@ -387,14 +172,271 @@ func (r *DefaultReporter) SuiteDidEnd(report types.Report) {
}
}
func (r *DefaultReporter) WillRun(report types.SpecReport) {
v := r.conf.Verbosity()
if v.LT(types.VerbosityLevelVerbose) || report.State.Is(types.SpecStatePending|types.SpecStateSkipped) || report.RunningInParallel {
return
}
r.emitDelimiter(0)
r.emitBlock(r.f(r.codeLocationBlock(report, "{{/}}", v.Is(types.VerbosityLevelVeryVerbose), false)))
}
func (r *DefaultReporter) DidRun(report types.SpecReport) {
v := r.conf.Verbosity()
inParallel := report.RunningInParallel
header := r.specDenoter
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
header = fmt.Sprintf("[%s]", report.LeafNodeType)
}
highlightColor := r.highlightColorForState(report.State)
// have we already been streaming the timeline?
timelineHasBeenStreaming := v.GTE(types.VerbosityLevelVerbose) && !inParallel
// should we show the timeline?
var timeline types.Timeline
showTimeline := !timelineHasBeenStreaming && (v.GTE(types.VerbosityLevelVerbose) || report.Failed())
if showTimeline {
timeline = report.Timeline().WithoutHiddenReportEntries()
keepVeryVerboseSpecEvents := v.Is(types.VerbosityLevelVeryVerbose) ||
(v.Is(types.VerbosityLevelVerbose) && r.conf.ShowNodeEvents) ||
(report.Failed() && r.conf.ShowNodeEvents)
if !keepVeryVerboseSpecEvents {
timeline = timeline.WithoutVeryVerboseSpecEvents()
}
if len(timeline) == 0 && report.CapturedGinkgoWriterOutput == "" {
// the timeline is completely empty - don't show it
showTimeline = false
}
if v.LT(types.VerbosityLevelVeryVerbose) && report.CapturedGinkgoWriterOutput == "" && len(timeline) > 0 {
//if we aren't -vv and the timeline only has a single failure, don't show it as it will appear at the end of the report
failure, isFailure := timeline[0].(types.Failure)
if isFailure && (len(timeline) == 1 || (len(timeline) == 2 && failure.AdditionalFailure != nil)) {
showTimeline = false
}
}
}
// should we have a separate section for always-visible reports?
showSeparateVisibilityAlwaysReportsSection := !timelineHasBeenStreaming && !showTimeline && report.ReportEntries.HasVisibility(types.ReportEntryVisibilityAlways)
// should we have a separate section for captured stdout/stderr
showSeparateStdSection := inParallel && (report.CapturedStdOutErr != "")
// given all that - do we have any actual content to show? or are we a single denoter in a stream?
reportHasContent := v.Is(types.VerbosityLevelVeryVerbose) || showTimeline || showSeparateVisibilityAlwaysReportsSection || showSeparateStdSection || report.Failed() || (v.Is(types.VerbosityLevelVerbose) && !report.State.Is(types.SpecStateSkipped))
// should we show a runtime?
includeRuntime := !report.State.Is(types.SpecStateSkipped|types.SpecStatePending) || (report.State.Is(types.SpecStateSkipped) && report.Failure.Message != "")
// should we show the codelocation block?
showCodeLocation := !timelineHasBeenStreaming || !report.State.Is(types.SpecStatePassed)
switch report.State {
case types.SpecStatePassed:
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) && !reportHasContent {
return
}
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
header = fmt.Sprintf("%s PASSED", header)
}
if report.NumAttempts > 1 && report.MaxFlakeAttempts > 1 {
header, reportHasContent = fmt.Sprintf("%s [FLAKEY TEST - TOOK %d ATTEMPTS TO PASS]", r.retryDenoter, report.NumAttempts), true
}
case types.SpecStatePending:
header = "P"
if v.GT(types.VerbosityLevelSuccinct) {
header, reportHasContent = "P [PENDING]", true
}
case types.SpecStateSkipped:
header = "S"
if v.Is(types.VerbosityLevelVeryVerbose) || (v.Is(types.VerbosityLevelVerbose) && report.Failure.Message != "") {
header, reportHasContent = "S [SKIPPED]", true
}
default:
header = fmt.Sprintf("%s [%s]", header, r.humanReadableState(report.State))
if report.MaxMustPassRepeatedly > 1 {
header = fmt.Sprintf("%s DURING REPETITION #%d", header, report.NumAttempts)
}
}
// If we have no content to show, jsut emit the header and return
if !reportHasContent {
r.emit(r.f(highlightColor + header + "{{/}}"))
return
}
if includeRuntime {
header = r.f("%s [%.3f seconds]", header, report.RunTime.Seconds())
}
// Emit header
if !timelineHasBeenStreaming {
r.emitDelimiter(0)
}
r.emitBlock(r.f(highlightColor + header + "{{/}}"))
if showCodeLocation {
r.emitBlock(r.codeLocationBlock(report, highlightColor, v.Is(types.VerbosityLevelVeryVerbose), false))
}
//Emit Stdout/Stderr Output
if showSeparateStdSection {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Captured StdOut/StdErr Output >>{{/}}"))
r.emitBlock(r.fi(1, "%s", report.CapturedStdOutErr))
r.emitBlock(r.fi(1, "{{gray}}<< Captured StdOut/StdErr Output{{/}}"))
}
if showSeparateVisibilityAlwaysReportsSection {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Report Entries >>{{/}}"))
for _, entry := range report.ReportEntries.WithVisibility(types.ReportEntryVisibilityAlways) {
r.emitReportEntry(1, entry)
}
r.emitBlock(r.fi(1, "{{gray}}<< Report Entries{{/}}"))
}
if showTimeline {
r.emitBlock("\n")
r.emitBlock(r.fi(1, "{{gray}}Timeline >>{{/}}"))
r.emitTimeline(1, report, timeline)
r.emitBlock(r.fi(1, "{{gray}}<< Timeline{{/}}"))
}
// Emit Failure Message
if !report.Failure.IsZero() && !v.Is(types.VerbosityLevelVeryVerbose) {
r.emitBlock("\n")
r.emitFailure(1, report.State, report.Failure, true)
if len(report.AdditionalFailures) > 0 {
r.emitBlock(r.fi(1, "\nThere were {{bold}}{{red}}additional failures{{/}} detected. To view them in detail run {{bold}}ginkgo -vv{{/}}"))
}
}
r.emitDelimiter(0)
}
func (r *DefaultReporter) highlightColorForState(state types.SpecState) string {
switch state {
case types.SpecStatePassed:
return "{{green}}"
case types.SpecStatePending:
return "{{yellow}}"
case types.SpecStateSkipped:
return "{{cyan}}"
case types.SpecStateFailed:
return "{{red}}"
case types.SpecStateTimedout:
return "{{orange}}"
case types.SpecStatePanicked:
return "{{magenta}}"
case types.SpecStateInterrupted:
return "{{orange}}"
case types.SpecStateAborted:
return "{{coral}}"
default:
return "{{gray}}"
}
}
func (r *DefaultReporter) humanReadableState(state types.SpecState) string {
return strings.ToUpper(state.String())
}
func (r *DefaultReporter) emitTimeline(indent uint, report types.SpecReport, timeline types.Timeline) {
isVeryVerbose := r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose)
gw := report.CapturedGinkgoWriterOutput
cursor := 0
for _, entry := range timeline {
tl := entry.GetTimelineLocation()
if tl.Offset < len(gw) {
r.emit(r.fi(indent, "%s", gw[cursor:tl.Offset]))
cursor = tl.Offset
} else if cursor < len(gw) {
r.emit(r.fi(indent, "%s", gw[cursor:]))
cursor = len(gw)
}
switch x := entry.(type) {
case types.Failure:
if isVeryVerbose {
r.emitFailure(indent, report.State, x, false)
} else {
r.emitShortFailure(indent, report.State, x)
}
case types.AdditionalFailure:
if isVeryVerbose {
r.emitFailure(indent, x.State, x.Failure, true)
} else {
r.emitShortFailure(indent, x.State, x.Failure)
}
case types.ReportEntry:
r.emitReportEntry(indent, x)
case types.ProgressReport:
r.emitProgressReport(indent, false, x)
case types.SpecEvent:
if isVeryVerbose || !x.IsOnlyVisibleAtVeryVerbose() || r.conf.ShowNodeEvents {
r.emitSpecEvent(indent, x, isVeryVerbose)
}
}
}
if cursor < len(gw) {
r.emit(r.fi(indent, "%s", gw[cursor:]))
}
}
func (r *DefaultReporter) EmitFailure(state types.SpecState, failure types.Failure) {
if r.conf.Verbosity().Is(types.VerbosityLevelVerbose) {
r.emitShortFailure(1, state, failure)
} else if r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose) {
r.emitFailure(1, state, failure, true)
}
}
func (r *DefaultReporter) emitShortFailure(indent uint, state types.SpecState, failure types.Failure) {
r.emitBlock(r.fi(indent, r.highlightColorForState(state)+"[%s]{{/}} in [%s] - %s {{gray}}@ %s{{/}}",
r.humanReadableState(state),
failure.FailureNodeType,
failure.Location,
failure.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT),
))
}
func (r *DefaultReporter) emitFailure(indent uint, state types.SpecState, failure types.Failure, includeAdditionalFailure bool) {
highlightColor := r.highlightColorForState(state)
r.emitBlock(r.fi(indent, highlightColor+"[%s] %s{{/}}", r.humanReadableState(state), failure.Message))
r.emitBlock(r.fi(indent, highlightColor+"In {{bold}}[%s]{{/}}"+highlightColor+" at: {{bold}}%s{{/}} {{gray}}@ %s{{/}}\n", failure.FailureNodeType, failure.Location, failure.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
if failure.ForwardedPanic != "" {
r.emitBlock("\n")
r.emitBlock(r.fi(indent, highlightColor+"%s{{/}}", failure.ForwardedPanic))
}
if r.conf.FullTrace || failure.ForwardedPanic != "" {
r.emitBlock("\n")
r.emitBlock(r.fi(indent, highlightColor+"Full Stack Trace{{/}}"))
r.emitBlock(r.fi(indent+1, "%s", failure.Location.FullStackTrace))
}
if !failure.ProgressReport.IsZero() {
r.emitBlock("\n")
r.emitProgressReport(indent, false, failure.ProgressReport)
}
if failure.AdditionalFailure != nil && includeAdditionalFailure {
r.emitBlock("\n")
r.emitFailure(indent, failure.AdditionalFailure.State, failure.AdditionalFailure.Failure, true)
}
}
func (r *DefaultReporter) EmitProgressReport(report types.ProgressReport) {
r.emitDelimiter()
r.emitDelimiter(1)
if report.RunningInParallel {
r.emit(r.f("{{coral}}Progress Report for Ginkgo Process #{{bold}}%d{{/}}\n", report.ParallelProcess))
r.emit(r.fi(1, "{{coral}}Progress Report for Ginkgo Process #{{bold}}%d{{/}}\n", report.ParallelProcess))
}
r.emitProgressReport(0, true, report)
r.emitDelimiter()
shouldEmitGW := report.RunningInParallel || r.conf.Verbosity().LT(types.VerbosityLevelVerbose)
r.emitProgressReport(1, shouldEmitGW, report)
r.emitDelimiter(1)
}
func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput bool, report types.ProgressReport) {
@ -409,7 +451,7 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
r.emit(" ")
subjectIndent = 0
}
r.emit(r.fi(subjectIndent, "{{bold}}{{orange}}%s{{/}} (Spec Runtime: %s)\n", report.LeafNodeText, report.Time.Sub(report.SpecStartTime).Round(time.Millisecond)))
r.emit(r.fi(subjectIndent, "{{bold}}{{orange}}%s{{/}} (Spec Runtime: %s)\n", report.LeafNodeText, report.Time().Sub(report.SpecStartTime).Round(time.Millisecond)))
r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.LeafNodeLocation))
indent += 1
}
@ -419,12 +461,12 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
r.emit(r.f(" {{bold}}{{orange}}%s{{/}}", report.CurrentNodeText))
}
r.emit(r.f(" (Node Runtime: %s)\n", report.Time.Sub(report.CurrentNodeStartTime).Round(time.Millisecond)))
r.emit(r.f(" (Node Runtime: %s)\n", report.Time().Sub(report.CurrentNodeStartTime).Round(time.Millisecond)))
r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentNodeLocation))
indent += 1
}
if report.CurrentStepText != "" {
r.emit(r.fi(indent, "At {{bold}}{{orange}}[By Step] %s{{/}} (Step Runtime: %s)\n", report.CurrentStepText, report.Time.Sub(report.CurrentStepStartTime).Round(time.Millisecond)))
r.emit(r.fi(indent, "At {{bold}}{{orange}}[By Step] %s{{/}} (Step Runtime: %s)\n", report.CurrentStepText, report.Time().Sub(report.CurrentStepStartTime).Round(time.Millisecond)))
r.emit(r.fi(indent+1, "{{gray}}%s{{/}}\n", report.CurrentStepLocation))
indent += 1
}
@ -433,9 +475,19 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
indent -= 1
}
if emitGinkgoWriterOutput && report.CapturedGinkgoWriterOutput != "" && (report.RunningInParallel || r.conf.Verbosity().LT(types.VerbosityLevelVerbose)) {
if emitGinkgoWriterOutput && report.CapturedGinkgoWriterOutput != "" {
r.emit("\n")
r.emitGinkgoWriterOutput(indent, report.CapturedGinkgoWriterOutput, 10)
r.emitBlock(r.fi(indent, "{{gray}}Begin Captured GinkgoWriter Output >>{{/}}"))
limit, lines := 10, strings.Split(report.CapturedGinkgoWriterOutput, "\n")
if len(lines) <= limit {
r.emitBlock(r.fi(indent+1, "%s", report.CapturedGinkgoWriterOutput))
} else {
r.emitBlock(r.fi(indent+1, "{{gray}}...{{/}}"))
for _, line := range lines[len(lines)-limit-1:] {
r.emitBlock(r.fi(indent+1, "%s", line))
}
}
r.emitBlock(r.fi(indent, "{{gray}}<< End Captured GinkgoWriter Output{{/}}"))
}
if !report.SpecGoroutine().IsZero() {
@ -471,22 +523,48 @@ func (r *DefaultReporter) emitProgressReport(indent uint, emitGinkgoWriterOutput
}
}
func (r *DefaultReporter) emitGinkgoWriterOutput(indent uint, output string, limit int) {
r.emitBlock(r.fi(indent, "{{gray}}Begin Captured GinkgoWriter Output >>{{/}}"))
if limit == 0 {
r.emitBlock(r.fi(indent+1, "%s", output))
} else {
lines := strings.Split(output, "\n")
if len(lines) <= limit {
r.emitBlock(r.fi(indent+1, "%s", output))
} else {
r.emitBlock(r.fi(indent+1, "{{gray}}...{{/}}"))
for _, line := range lines[len(lines)-limit-1:] {
r.emitBlock(r.fi(indent+1, "%s", line))
}
}
func (r *DefaultReporter) EmitReportEntry(entry types.ReportEntry) {
if r.conf.Verbosity().LT(types.VerbosityLevelVerbose) || entry.Visibility == types.ReportEntryVisibilityNever {
return
}
r.emitReportEntry(1, entry)
}
func (r *DefaultReporter) emitReportEntry(indent uint, entry types.ReportEntry) {
r.emitBlock(r.fi(indent, "{{bold}}"+entry.Name+"{{gray}} "+fmt.Sprintf("- %s @ %s{{/}}", entry.Location, entry.Time.Format(types.GINKGO_TIME_FORMAT))))
if representation := entry.StringRepresentation(); representation != "" {
r.emitBlock(r.fi(indent+1, representation))
}
}
func (r *DefaultReporter) EmitSpecEvent(event types.SpecEvent) {
v := r.conf.Verbosity()
if v.Is(types.VerbosityLevelVeryVerbose) || (v.Is(types.VerbosityLevelVerbose) && (r.conf.ShowNodeEvents || !event.IsOnlyVisibleAtVeryVerbose())) {
r.emitSpecEvent(1, event, r.conf.Verbosity().Is(types.VerbosityLevelVeryVerbose))
}
}
func (r *DefaultReporter) emitSpecEvent(indent uint, event types.SpecEvent, includeLocation bool) {
location := ""
if includeLocation {
location = fmt.Sprintf("- %s ", event.CodeLocation.String())
}
switch event.SpecEventType {
case types.SpecEventInvalid:
return
case types.SpecEventByStart:
r.emitBlock(r.fi(indent, "{{bold}}STEP:{{/}} %s {{gray}}%s@ %s{{/}}", event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
case types.SpecEventByEnd:
r.emitBlock(r.fi(indent, "{{bold}}END STEP:{{/}} %s {{gray}}%s@ %s (%s){{/}}", event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT), event.Duration.Round(time.Millisecond)))
case types.SpecEventNodeStart:
r.emitBlock(r.fi(indent, "> Enter {{bold}}[%s]{{/}} %s {{gray}}%s@ %s{{/}}", event.NodeType.String(), event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
case types.SpecEventNodeEnd:
r.emitBlock(r.fi(indent, "< Exit {{bold}}[%s]{{/}} %s {{gray}}%s@ %s (%s){{/}}", event.NodeType.String(), event.Message, location, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT), event.Duration.Round(time.Millisecond)))
case types.SpecEventSpecRepeat:
r.emitBlock(r.fi(indent, "\n{{bold}}Attempt #%d {{green}}Passed{{/}}{{bold}}. Repeating %s{{/}} {{gray}}@ %s{{/}}\n\n", event.Attempt, r.retryDenoter, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
case types.SpecEventSpecRetry:
r.emitBlock(r.fi(indent, "\n{{bold}}Attempt #%d {{red}}Failed{{/}}{{bold}}. Retrying %s{{/}} {{gray}}@ %s{{/}}\n\n", event.Attempt, r.retryDenoter, event.TimelineLocation.Time.Format(types.GINKGO_TIME_FORMAT)))
}
r.emitBlock(r.fi(indent, "{{gray}}<< End Captured GinkgoWriter Output{{/}}"))
}
func (r *DefaultReporter) emitGoroutines(indent uint, goroutines ...types.Goroutine) {
@ -544,31 +622,37 @@ func (r *DefaultReporter) emitSource(indent uint, fc types.FunctionCall) {
/* Emitting to the writer */
func (r *DefaultReporter) emit(s string) {
if len(s) > 0 {
r.lastChar = s[len(s)-1:]
r.lastEmissionWasDelimiter = false
r.writer.Write([]byte(s))
}
r._emit(s, false, false)
}
func (r *DefaultReporter) emitBlock(s string) {
if len(s) > 0 {
if r.lastChar != "\n" {
r.emit("\n")
}
r.emit(s)
if r.lastChar != "\n" {
r.emit("\n")
}
}
r._emit(s, true, false)
}
func (r *DefaultReporter) emitDelimiter() {
if r.lastEmissionWasDelimiter {
func (r *DefaultReporter) emitDelimiter(indent uint) {
r._emit(r.fi(indent, "{{gray}}%s{{/}}", strings.Repeat("-", 30)), true, true)
}
// a bit ugly - but we're trying to minimize locking on this hot codepath
func (r *DefaultReporter) _emit(s string, block bool, isDelimiter bool) {
if len(s) == 0 {
return
}
r.emitBlock(r.f("{{gray}}%s{{/}}", strings.Repeat("-", 30)))
r.lastEmissionWasDelimiter = true
r.lock.Lock()
defer r.lock.Unlock()
if isDelimiter && r.lastEmissionWasDelimiter {
return
}
if block && !r.lastCharWasNewline {
r.writer.Write([]byte("\n"))
}
r.lastCharWasNewline = (s[len(s)-1:] == "\n")
r.writer.Write([]byte(s))
if block && !r.lastCharWasNewline {
r.writer.Write([]byte("\n"))
r.lastCharWasNewline = true
}
r.lastEmissionWasDelimiter = isDelimiter
}
/* Rendering text */
@ -584,13 +668,14 @@ func (r *DefaultReporter) cycleJoin(elements []string, joiner string) string {
return r.formatter.CycleJoin(elements, joiner, []string{"{{/}}", "{{gray}}"})
}
func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightColor string, succinct bool, usePreciseFailureLocation bool) string {
func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightColor string, veryVerbose bool, usePreciseFailureLocation bool) string {
texts, locations, labels := []string{}, []types.CodeLocation{}, [][]string{}
texts, locations, labels = append(texts, report.ContainerHierarchyTexts...), append(locations, report.ContainerHierarchyLocations...), append(labels, report.ContainerHierarchyLabels...)
if report.LeafNodeType.Is(types.NodeTypesForSuiteLevelNodes) {
texts = append(texts, r.f("[%s] %s", report.LeafNodeType, report.LeafNodeText))
} else {
texts = append(texts, report.LeafNodeText)
texts = append(texts, r.f(report.LeafNodeText))
}
labels = append(labels, report.LeafNodeLabels)
locations = append(locations, report.LeafNodeLocation)
@ -600,24 +685,58 @@ func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightCo
failureLocation = report.Failure.Location
}
highlightIndex := -1
switch report.Failure.FailureNodeContext {
case types.FailureNodeAtTopLevel:
texts = append([]string{r.f(highlightColor+"{{bold}}TOP-LEVEL [%s]{{/}}", report.Failure.FailureNodeType)}, texts...)
texts = append([]string{fmt.Sprintf("TOP-LEVEL [%s]", report.Failure.FailureNodeType)}, texts...)
locations = append([]types.CodeLocation{failureLocation}, locations...)
labels = append([][]string{{}}, labels...)
highlightIndex = 0
case types.FailureNodeInContainer:
i := report.Failure.FailureNodeContainerIndex
texts[i] = r.f(highlightColor+"{{bold}}%s [%s]{{/}}", texts[i], report.Failure.FailureNodeType)
texts[i] = fmt.Sprintf("%s [%s]", texts[i], report.Failure.FailureNodeType)
locations[i] = failureLocation
highlightIndex = i
case types.FailureNodeIsLeafNode:
i := len(texts) - 1
texts[i] = r.f(highlightColor+"{{bold}}[%s] %s{{/}}", report.LeafNodeType, report.LeafNodeText)
texts[i] = fmt.Sprintf("[%s] %s", report.LeafNodeType, report.LeafNodeText)
locations[i] = failureLocation
highlightIndex = i
default:
//there is no failure, so we highlight the leaf ndoe
highlightIndex = len(texts) - 1
}
out := ""
if succinct {
out += r.f("%s", r.cycleJoin(texts, " "))
if veryVerbose {
for i := range texts {
if i == highlightIndex {
out += r.fi(uint(i), highlightColor+"{{bold}}%s{{/}}", texts[i])
} else {
out += r.fi(uint(i), "%s", texts[i])
}
if len(labels[i]) > 0 {
out += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels[i], ", "))
}
out += "\n"
out += r.fi(uint(i), "{{gray}}%s{{/}}\n", locations[i])
}
} else {
for i := range texts {
style := "{{/}}"
if i%2 == 1 {
style = "{{gray}}"
}
if i == highlightIndex {
style = highlightColor + "{{bold}}"
}
out += r.f(style+"%s", texts[i])
if i < len(texts)-1 {
out += " "
} else {
out += r.f("{{/}}")
}
}
flattenedLabels := report.Labels()
if len(flattenedLabels) > 0 {
out += r.f(" {{coral}}[%s]{{/}}", strings.Join(flattenedLabels, ", "))
@ -626,17 +745,15 @@ func (r *DefaultReporter) codeLocationBlock(report types.SpecReport, highlightCo
if usePreciseFailureLocation {
out += r.f("{{gray}}%s{{/}}", failureLocation)
} else {
out += r.f("{{gray}}%s{{/}}", locations[len(locations)-1])
}
} else {
for i := range texts {
out += r.fi(uint(i), "%s", texts[i])
if len(labels[i]) > 0 {
out += r.f(" {{coral}}[%s]{{/}}", strings.Join(labels[i], ", "))
leafLocation := locations[len(locations)-1]
if (report.Failure.FailureNodeLocation != types.CodeLocation{}) && (report.Failure.FailureNodeLocation != leafLocation) {
out += r.fi(1, highlightColor+"[%s]{{/}} {{gray}}%s{{/}}\n", report.Failure.FailureNodeType, report.Failure.FailureNodeLocation)
out += r.fi(1, "{{gray}}[%s] %s{{/}}", report.LeafNodeType, leafLocation)
} else {
out += r.f("{{gray}}%s{{/}}", leafLocation)
}
out += "\n"
out += r.fi(uint(i), "{{gray}}%s{{/}}\n", locations[i])
}
}
return out
}

View File

@ -35,7 +35,7 @@ func ReportViaDeprecatedReporter(reporter DeprecatedReporter, report types.Repor
FailOnPending: report.SuiteConfig.FailOnPending,
FailFast: report.SuiteConfig.FailFast,
FlakeAttempts: report.SuiteConfig.FlakeAttempts,
EmitSpecProgress: report.SuiteConfig.EmitSpecProgress,
EmitSpecProgress: false,
DryRun: report.SuiteConfig.DryRun,
ParallelNode: report.SuiteConfig.ParallelProcess,
ParallelTotal: report.SuiteConfig.ParallelTotal,

View File

@ -15,12 +15,32 @@ import (
"fmt"
"os"
"strings"
"time"
"github.com/onsi/ginkgo/v2/config"
"github.com/onsi/ginkgo/v2/types"
)
type JunitReportConfig struct {
// Spec States for which no timeline should be emitted for system-err
// set this to types.SpecStatePassed|types.SpecStateSkipped|types.SpecStatePending to only match failing specs
OmitTimelinesForSpecState types.SpecState
// Enable OmitFailureMessageAttr to prevent failure messages appearing in the "message" attribute of the Failure and Error tags
OmitFailureMessageAttr bool
//Enable OmitCapturedStdOutErr to prevent captured stdout/stderr appearing in system-out
OmitCapturedStdOutErr bool
// Enable OmitSpecLabels to prevent labels from appearing in the spec name
OmitSpecLabels bool
// Enable OmitLeafNodeType to prevent the spec leaf node type from appearing in the spec name
OmitLeafNodeType bool
// Enable OmitSuiteSetupNodes to prevent the creation of testcase entries for setup nodes
OmitSuiteSetupNodes bool
}
type JUnitTestSuites struct {
XMLName xml.Name `xml:"testsuites"`
// Tests maps onto the total number of specs in all test suites (this includes any suite nodes such as BeforeSuite)
@ -128,6 +148,10 @@ type JUnitFailure struct {
}
func GenerateJUnitReport(report types.Report, dst string) error {
return GenerateJUnitReportWithConfig(report, dst, JunitReportConfig{})
}
func GenerateJUnitReportWithConfig(report types.Report, dst string, config JunitReportConfig) error {
suite := JUnitTestSuite{
Name: report.SuiteDescription,
Package: report.SuitePath,
@ -149,7 +173,6 @@ func GenerateJUnitReport(report types.Report, dst string) error {
{"FailOnPending", fmt.Sprintf("%t", report.SuiteConfig.FailOnPending)},
{"FailFast", fmt.Sprintf("%t", report.SuiteConfig.FailFast)},
{"FlakeAttempts", fmt.Sprintf("%d", report.SuiteConfig.FlakeAttempts)},
{"EmitSpecProgress", fmt.Sprintf("%t", report.SuiteConfig.EmitSpecProgress)},
{"DryRun", fmt.Sprintf("%t", report.SuiteConfig.DryRun)},
{"ParallelTotal", fmt.Sprintf("%d", report.SuiteConfig.ParallelTotal)},
{"OutputInterceptorMode", report.SuiteConfig.OutputInterceptorMode},
@ -157,22 +180,33 @@ func GenerateJUnitReport(report types.Report, dst string) error {
},
}
for _, spec := range report.SpecReports {
if config.OmitSuiteSetupNodes && spec.LeafNodeType != types.NodeTypeIt {
continue
}
name := fmt.Sprintf("[%s]", spec.LeafNodeType)
if config.OmitLeafNodeType {
name = ""
}
if spec.FullText() != "" {
name = name + " " + spec.FullText()
}
labels := spec.Labels()
if len(labels) > 0 {
if len(labels) > 0 && !config.OmitSpecLabels {
name = name + " [" + strings.Join(labels, ", ") + "]"
}
name = strings.TrimSpace(name)
test := JUnitTestCase{
Name: name,
Classname: report.SuiteDescription,
Status: spec.State.String(),
Time: spec.RunTime.Seconds(),
SystemOut: systemOutForUnstructuredReporters(spec),
SystemErr: systemErrForUnstructuredReporters(spec),
}
if !spec.State.Is(config.OmitTimelinesForSpecState) {
test.SystemErr = systemErrForUnstructuredReporters(spec)
}
if !config.OmitCapturedStdOutErr {
test.SystemOut = systemOutForUnstructuredReporters(spec)
}
suite.Tests += 1
@ -193,6 +227,9 @@ func GenerateJUnitReport(report types.Report, dst string) error {
Type: "failed",
Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Failure.Message = ""
}
suite.Failures += 1
case types.SpecStateTimedout:
test.Failure = &JUnitFailure{
@ -200,6 +237,9 @@ func GenerateJUnitReport(report types.Report, dst string) error {
Type: "timedout",
Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Failure.Message = ""
}
suite.Failures += 1
case types.SpecStateInterrupted:
test.Error = &JUnitError{
@ -207,6 +247,9 @@ func GenerateJUnitReport(report types.Report, dst string) error {
Type: "interrupted",
Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Error.Message = ""
}
suite.Errors += 1
case types.SpecStateAborted:
test.Failure = &JUnitFailure{
@ -214,6 +257,9 @@ func GenerateJUnitReport(report types.Report, dst string) error {
Type: "aborted",
Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Failure.Message = ""
}
suite.Errors += 1
case types.SpecStatePanicked:
test.Error = &JUnitError{
@ -221,6 +267,9 @@ func GenerateJUnitReport(report types.Report, dst string) error {
Type: "panicked",
Description: failureDescriptionForUnstructuredReporters(spec),
}
if config.OmitFailureMessageAttr {
test.Error.Message = ""
}
suite.Errors += 1
}
@ -287,63 +336,25 @@ func MergeAndCleanupJUnitReports(sources []string, dst string) ([]string, error)
func failureDescriptionForUnstructuredReporters(spec types.SpecReport) string {
out := &strings.Builder{}
out.WriteString(spec.Failure.Location.String() + "\n")
out.WriteString(spec.Failure.Location.FullStackTrace)
if !spec.Failure.ProgressReport.IsZero() {
out.WriteString("\n")
NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitProgressReport(spec.Failure.ProgressReport)
}
NewDefaultReporter(types.ReporterConfig{NoColor: true, VeryVerbose: true}, out).emitFailure(0, spec.State, spec.Failure, true)
if len(spec.AdditionalFailures) > 0 {
out.WriteString("\nThere were additional failures detected after the initial failure:\n")
for i, additionalFailure := range spec.AdditionalFailures {
NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitFailure(0, additionalFailure.State, additionalFailure.Failure, true)
if i < len(spec.AdditionalFailures)-1 {
out.WriteString("----------\n")
}
}
out.WriteString("\nThere were additional failures detected after the initial failure. These are visible in the timeline\n")
}
return out.String()
}
func systemErrForUnstructuredReporters(spec types.SpecReport) string {
return RenderTimeline(spec, true)
}
func RenderTimeline(spec types.SpecReport, noColor bool) string {
out := &strings.Builder{}
gw := spec.CapturedGinkgoWriterOutput
cursor := 0
for _, pr := range spec.ProgressReports {
if cursor < pr.GinkgoWriterOffset {
if pr.GinkgoWriterOffset < len(gw) {
out.WriteString(gw[cursor:pr.GinkgoWriterOffset])
cursor = pr.GinkgoWriterOffset
} else if cursor < len(gw) {
out.WriteString(gw[cursor:])
cursor = len(gw)
}
}
NewDefaultReporter(types.ReporterConfig{NoColor: true}, out).EmitProgressReport(pr)
}
if cursor < len(gw) {
out.WriteString(gw[cursor:])
}
NewDefaultReporter(types.ReporterConfig{NoColor: noColor, VeryVerbose: true}, out).emitTimeline(0, spec, spec.Timeline())
return out.String()
}
func systemOutForUnstructuredReporters(spec types.SpecReport) string {
systemOut := spec.CapturedStdOutErr
if len(spec.ReportEntries) > 0 {
systemOut += "\nReport Entries:\n"
for i, entry := range spec.ReportEntries {
systemOut += fmt.Sprintf("%s\n%s\n%s\n", entry.Name, entry.Location, entry.Time.Format(time.RFC3339Nano))
if representation := entry.StringRepresentation(); representation != "" {
systemOut += representation + "\n"
}
if i+1 < len(spec.ReportEntries) {
systemOut += "--\n"
}
}
}
return systemOut
return spec.CapturedStdOutErr
}
// Deprecated JUnitReporter (so folks can still compile their suites)

View File

@ -9,13 +9,21 @@ type Reporter interface {
WillRun(report types.SpecReport)
DidRun(report types.SpecReport)
SuiteDidEnd(report types.Report)
//Timeline emission
EmitFailure(state types.SpecState, failure types.Failure)
EmitProgressReport(progressReport types.ProgressReport)
EmitReportEntry(entry types.ReportEntry)
EmitSpecEvent(event types.SpecEvent)
}
type NoopReporter struct{}
func (n NoopReporter) SuiteWillBegin(report types.Report) {}
func (n NoopReporter) WillRun(report types.SpecReport) {}
func (n NoopReporter) DidRun(report types.SpecReport) {}
func (n NoopReporter) SuiteDidEnd(report types.Report) {}
func (n NoopReporter) EmitProgressReport(progressReport types.ProgressReport) {}
func (n NoopReporter) SuiteWillBegin(report types.Report) {}
func (n NoopReporter) WillRun(report types.SpecReport) {}
func (n NoopReporter) DidRun(report types.SpecReport) {}
func (n NoopReporter) SuiteDidEnd(report types.Report) {}
func (n NoopReporter) EmitFailure(state types.SpecState, failure types.Failure) {}
func (n NoopReporter) EmitProgressReport(progressReport types.ProgressReport) {}
func (n NoopReporter) EmitReportEntry(entry types.ReportEntry) {}
func (n NoopReporter) EmitSpecEvent(event types.SpecEvent) {}

View File

@ -1,4 +1,5 @@
package types
import (
"fmt"
"os"
@ -6,6 +7,7 @@ import (
"runtime"
"runtime/debug"
"strings"
"sync"
)
type CodeLocation struct {
@ -37,6 +39,73 @@ func (codeLocation CodeLocation) ContentsOfLine() string {
return lines[codeLocation.LineNumber-1]
}
type codeLocationLocator struct {
pcs map[uintptr]bool
helpers map[string]bool
lock *sync.Mutex
}
func (c *codeLocationLocator) addHelper(pc uintptr) {
c.lock.Lock()
defer c.lock.Unlock()
if c.pcs[pc] {
return
}
c.lock.Unlock()
f := runtime.FuncForPC(pc)
c.lock.Lock()
if f == nil {
return
}
c.helpers[f.Name()] = true
c.pcs[pc] = true
}
func (c *codeLocationLocator) hasHelper(name string) bool {
c.lock.Lock()
defer c.lock.Unlock()
return c.helpers[name]
}
func (c *codeLocationLocator) getCodeLocation(skip int) CodeLocation {
pc := make([]uintptr, 40)
n := runtime.Callers(skip+2, pc)
if n == 0 {
return CodeLocation{}
}
pc = pc[:n]
frames := runtime.CallersFrames(pc)
for {
frame, more := frames.Next()
if !c.hasHelper(frame.Function) {
return CodeLocation{FileName: frame.File, LineNumber: frame.Line}
}
if !more {
break
}
}
return CodeLocation{}
}
var clLocator = &codeLocationLocator{
pcs: map[uintptr]bool{},
helpers: map[string]bool{},
lock: &sync.Mutex{},
}
// MarkAsHelper is used by GinkgoHelper to mark the caller (appropriately offset by skip)as a helper. You can use this directly if you need to provide an optional `skip` to mark functions further up the call stack as helpers.
func MarkAsHelper(optionalSkip ...int) {
skip := 1
if len(optionalSkip) > 0 {
skip += optionalSkip[0]
}
pc, _, _, ok := runtime.Caller(skip)
if ok {
clLocator.addHelper(pc)
}
}
func NewCustomCodeLocation(message string) CodeLocation {
return CodeLocation{
CustomMessage: message,
@ -44,14 +113,13 @@ func NewCustomCodeLocation(message string) CodeLocation {
}
func NewCodeLocation(skip int) CodeLocation {
_, file, line, _ := runtime.Caller(skip + 1)
return CodeLocation{FileName: file, LineNumber: line}
return clLocator.getCodeLocation(skip + 1)
}
func NewCodeLocationWithStackTrace(skip int) CodeLocation {
_, file, line, _ := runtime.Caller(skip + 1)
stackTrace := PruneStack(string(debug.Stack()), skip+1)
return CodeLocation{FileName: file, LineNumber: line, FullStackTrace: stackTrace}
cl := clLocator.getCodeLocation(skip + 1)
cl.FullStackTrace = PruneStack(string(debug.Stack()), skip+1)
return cl
}
// PruneStack removes references to functions that are internal to Ginkgo

View File

@ -8,6 +8,7 @@ package types
import (
"flag"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
@ -26,11 +27,11 @@ type SuiteConfig struct {
FailOnPending bool
FailFast bool
FlakeAttempts int
EmitSpecProgress bool
DryRun bool
PollProgressAfter time.Duration
PollProgressInterval time.Duration
Timeout time.Duration
EmitSpecProgress bool // this is deprecated but its removal is causing compile issue for some users that were setting it manually
OutputInterceptorMode string
SourceRoots []string
GracePeriod time.Duration
@ -81,13 +82,12 @@ func (vl VerbosityLevel) LT(comp VerbosityLevel) bool {
// Configuration for Ginkgo's reporter
type ReporterConfig struct {
NoColor bool
SlowSpecThreshold time.Duration
Succinct bool
Verbose bool
VeryVerbose bool
FullTrace bool
AlwaysEmitGinkgoWriter bool
NoColor bool
Succinct bool
Verbose bool
VeryVerbose bool
FullTrace bool
ShowNodeEvents bool
JSONReport string
JUnitReport string
@ -110,9 +110,7 @@ func (rc ReporterConfig) WillGenerateReport() bool {
}
func NewDefaultReporterConfig() ReporterConfig {
return ReporterConfig{
SlowSpecThreshold: 5 * time.Second,
}
return ReporterConfig{}
}
// Configuration for the Ginkgo CLI
@ -235,6 +233,9 @@ type deprecatedConfig struct {
SlowSpecThresholdWithFLoatUnits float64
Stream bool
Notify bool
EmitSpecProgress bool
SlowSpecThreshold time.Duration
AlwaysEmitGinkgoWriter bool
}
// Flags
@ -275,8 +276,6 @@ var SuiteConfigFlags = GinkgoFlags{
{KeyPath: "S.DryRun", Name: "dry-run", SectionKey: "debug", DeprecatedName: "dryRun", DeprecatedDocLink: "changed-command-line-flags",
Usage: "If set, ginkgo will walk the test hierarchy without actually running anything. Best paired with -v."},
{KeyPath: "S.EmitSpecProgress", Name: "progress", SectionKey: "debug",
Usage: "If set, ginkgo will emit progress information as each spec runs to the GinkgoWriter."},
{KeyPath: "S.PollProgressAfter", Name: "poll-progress-after", SectionKey: "debug", UsageDefaultValue: "0",
Usage: "Emit node progress reports periodically if node hasn't completed after this duration."},
{KeyPath: "S.PollProgressInterval", Name: "poll-progress-interval", SectionKey: "debug", UsageDefaultValue: "10s",
@ -303,6 +302,8 @@ var SuiteConfigFlags = GinkgoFlags{
{KeyPath: "D.RegexScansFilePath", DeprecatedName: "regexScansFilePath", DeprecatedDocLink: "removed--regexscansfilepath", DeprecatedVersion: "2.0.0"},
{KeyPath: "D.DebugParallel", DeprecatedName: "debug", DeprecatedDocLink: "removed--debug", DeprecatedVersion: "2.0.0"},
{KeyPath: "D.EmitSpecProgress", DeprecatedName: "progress", SectionKey: "debug",
DeprecatedVersion: "2.5.0", Usage: ". The functionality provided by --progress was confusing and is no longer needed. Use --show-node-events instead to see node entry and exit events included in the timeline of failed and verbose specs. Or you can run with -vv to always see all node events. Lastly, --poll-progress-after and the PollProgressAfter decorator now provide a better mechanism for debugging specs that tend to get stuck."},
}
// ParallelConfigFlags provides flags for the Ginkgo test process (not the CLI)
@ -319,8 +320,6 @@ var ParallelConfigFlags = GinkgoFlags{
var ReporterConfigFlags = GinkgoFlags{
{KeyPath: "R.NoColor", Name: "no-color", SectionKey: "output", DeprecatedName: "noColor", DeprecatedDocLink: "changed-command-line-flags",
Usage: "If set, suppress color output in default reporter."},
{KeyPath: "R.SlowSpecThreshold", Name: "slow-spec-threshold", SectionKey: "output", UsageArgument: "duration", UsageDefaultValue: "5s",
Usage: "Specs that take longer to run than this threshold are flagged as slow by the default reporter."},
{KeyPath: "R.Verbose", Name: "v", SectionKey: "output",
Usage: "If set, emits more output including GinkgoWriter contents."},
{KeyPath: "R.VeryVerbose", Name: "vv", SectionKey: "output",
@ -329,8 +328,8 @@ var ReporterConfigFlags = GinkgoFlags{
Usage: "If set, default reporter prints out a very succinct report"},
{KeyPath: "R.FullTrace", Name: "trace", SectionKey: "output",
Usage: "If set, default reporter prints out the full stack trace when a failure occurs"},
{KeyPath: "R.AlwaysEmitGinkgoWriter", Name: "always-emit-ginkgo-writer", SectionKey: "output", DeprecatedName: "reportPassed", DeprecatedDocLink: "renamed--reportpassed",
Usage: "If set, default reporter prints out captured output of passed tests."},
{KeyPath: "R.ShowNodeEvents", Name: "show-node-events", SectionKey: "output",
Usage: "If set, default reporter prints node > Enter and < Exit events when specs fail"},
{KeyPath: "R.JSONReport", Name: "json-report", UsageArgument: "filename.json", SectionKey: "output",
Usage: "If set, Ginkgo will generate a JSON-formatted test report at the specified location."},
@ -343,6 +342,8 @@ var ReporterConfigFlags = GinkgoFlags{
Usage: "use --slow-spec-threshold instead and pass in a duration string (e.g. '5s', not '5.0')"},
{KeyPath: "D.NoisyPendings", DeprecatedName: "noisyPendings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"},
{KeyPath: "D.NoisySkippings", DeprecatedName: "noisySkippings", DeprecatedDocLink: "removed--noisypendings-and--noisyskippings", DeprecatedVersion: "2.0.0"},
{KeyPath: "D.SlowSpecThreshold", DeprecatedName: "slow-spec-threshold", SectionKey: "output", Usage: "--slow-spec-threshold has been deprecated and will be removed in a future version of Ginkgo. This feature has proved to be more noisy than useful. You can use --poll-progress-after, instead, to get more actionable feedback about potentially slow specs and understand where they might be getting stuck.", DeprecatedVersion: "2.5.0"},
{KeyPath: "D.AlwaysEmitGinkgoWriter", DeprecatedName: "always-emit-ginkgo-writer", SectionKey: "output", Usage: " - use -v instead, or one of Ginkgo's machine-readable report formats to get GinkgoWriter output for passing specs."},
}
// BuildTestSuiteFlagSet attaches to the CommandLine flagset and provides flags for the Ginkgo test process
@ -600,13 +601,29 @@ func VetAndInitializeCLIAndGoConfig(cliConfig CLIConfig, goFlagsConfig GoFlagsCo
}
// GenerateGoTestCompileArgs is used by the Ginkgo CLI to generate command line arguments to pass to the go test -c command when compiling the test
func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, destination string, packageToBuild string) ([]string, error) {
func GenerateGoTestCompileArgs(goFlagsConfig GoFlagsConfig, destination string, packageToBuild string, pathToInvocationPath string) ([]string, error) {
// if the user has set the CoverProfile run-time flag make sure to set the build-time cover flag to make sure
// the built test binary can generate a coverprofile
if goFlagsConfig.CoverProfile != "" {
goFlagsConfig.Cover = true
}
if goFlagsConfig.CoverPkg != "" {
coverPkgs := strings.Split(goFlagsConfig.CoverPkg, ",")
adjustedCoverPkgs := make([]string, len(coverPkgs))
for i, coverPkg := range coverPkgs {
coverPkg = strings.Trim(coverPkg, " ")
if strings.HasPrefix(coverPkg, "./") {
// this is a relative coverPkg - we need to reroot it
adjustedCoverPkgs[i] = "./" + filepath.Join(pathToInvocationPath, strings.TrimPrefix(coverPkg, "./"))
} else {
// this is a package name - don't touch it
adjustedCoverPkgs[i] = coverPkg
}
}
goFlagsConfig.CoverPkg = strings.Join(adjustedCoverPkgs, ",")
}
args := []string{"test", "-c", "-o", destination, packageToBuild}
goArgs, err := GenerateFlagArgs(
GoBuildFlags,

View File

@ -38,7 +38,7 @@ func (d deprecations) Async() Deprecation {
func (d deprecations) Measure() Deprecation {
return Deprecation{
Message: "Measure is deprecated and will be removed in Ginkgo V2. Please migrate to gomega/gmeasure.",
Message: "Measure is deprecated and has been removed from Ginkgo V2. Any Measure tests in your spec will not run. Please migrate to gomega/gmeasure.",
DocLink: "removed-measure",
Version: "1.16.3",
}
@ -83,6 +83,13 @@ func (d deprecations) Nodot() Deprecation {
}
}
func (d deprecations) SuppressProgressReporting() Deprecation {
return Deprecation{
Message: "Improvements to how reporters emit timeline information means that SuppressProgressReporting is no longer necessary and has been deprecated.",
Version: "2.5.0",
}
}
type DeprecationTracker struct {
deprecations map[Deprecation][]CodeLocation
lock *sync.Mutex

View File

@ -108,8 +108,8 @@ Please ensure all assertions are inside leaf nodes such as {{bold}}BeforeEach{{/
func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocation) error {
docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite"
if nodeType.Is(NodeTypeReportAfterSuite) {
docLink = "reporting-nodes---reportaftersuite"
if nodeType.Is(NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite) {
docLink = "reporting-nodes---reportbeforesuite-and-reportaftersuite"
}
return GinkgoError{
@ -125,8 +125,8 @@ func (g ginkgoErrors) SuiteNodeInNestedContext(nodeType NodeType, cl CodeLocatio
func (g ginkgoErrors) SuiteNodeDuringRunPhase(nodeType NodeType, cl CodeLocation) error {
docLink := "suite-setup-and-cleanup-beforesuite-and-aftersuite"
if nodeType.Is(NodeTypeReportAfterSuite) {
docLink = "reporting-nodes---reportaftersuite"
if nodeType.Is(NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite) {
docLink = "reporting-nodes---reportbeforesuite-and-reportaftersuite"
}
return GinkgoError{
@ -298,6 +298,15 @@ func (g ginkgoErrors) SetupNodeNotInOrderedContainer(cl CodeLocation, nodeType N
}
}
func (g ginkgoErrors) InvalidContinueOnFailureDecoration(cl CodeLocation) error {
return GinkgoError{
Heading: "ContinueOnFailure not decorating an outermost Ordered Container",
Message: "ContinueOnFailure can only decorate an Ordered container, and this Ordered container must be the outermost Ordered container.",
CodeLocation: cl,
DocLink: "ordered-containers",
}
}
/* DeferCleanup errors */
func (g ginkgoErrors) DeferCleanupInvalidFunction(cl CodeLocation) error {
return GinkgoError{
@ -320,7 +329,7 @@ func (g ginkgoErrors) PushingCleanupNodeDuringTreeConstruction(cl CodeLocation)
func (g ginkgoErrors) PushingCleanupInReportingNode(cl CodeLocation, nodeType NodeType) error {
return GinkgoError{
Heading: fmt.Sprintf("DeferCleanup cannot be called in %s", nodeType),
Message: "Please inline your cleanup code - Ginkgo won't run cleanup code after a ReportAfterEach or ReportAfterSuite.",
Message: "Please inline your cleanup code - Ginkgo won't run cleanup code after a Reporting node.",
CodeLocation: cl,
DocLink: "cleaning-up-our-cleanup-code-defercleanup",
}

View File

@ -272,12 +272,23 @@ func tokenize(input string) func() (*treeNode, error) {
}
}
func MustParseLabelFilter(input string) LabelFilter {
filter, err := ParseLabelFilter(input)
if err != nil {
panic(err)
}
return filter
}
func ParseLabelFilter(input string) (LabelFilter, error) {
if DEBUG_LABEL_FILTER_PARSING {
fmt.Println("\n==============")
fmt.Println("Input: ", input)
fmt.Print("Tokens: ")
}
if input == "" {
return func(_ []string) bool { return true }, nil
}
nextToken := tokenize(input)
root := &treeNode{token: lfTokenRoot}

View File

@ -6,8 +6,8 @@ import (
"time"
)
//ReportEntryValue wraps a report entry's value ensuring it can be encoded and decoded safely into reports
//and across the network connection when running in parallel
// ReportEntryValue wraps a report entry's value ensuring it can be encoded and decoded safely into reports
// and across the network connection when running in parallel
type ReportEntryValue struct {
raw interface{} //unexported to prevent gob from freaking out about unregistered structs
AsJSON string
@ -85,10 +85,12 @@ func (rev *ReportEntryValue) GobDecode(data []byte) error {
type ReportEntry struct {
// Visibility captures the visibility policy for this ReportEntry
Visibility ReportEntryVisibility
// Time captures the time the AddReportEntry was called
Time time.Time
// Location captures the location of the AddReportEntry call
Location CodeLocation
Time time.Time //need this for backwards compatibility
TimelineLocation TimelineLocation
// Name captures the name of this report
Name string
// Value captures the (optional) object passed into AddReportEntry - this can be
@ -120,7 +122,9 @@ func (entry ReportEntry) GetRawValue() interface{} {
return entry.Value.GetRawValue()
}
func (entry ReportEntry) GetTimelineLocation() TimelineLocation {
return entry.TimelineLocation
}
type ReportEntries []ReportEntry

View File

@ -2,6 +2,8 @@ package types
import (
"encoding/json"
"fmt"
"sort"
"strings"
"time"
)
@ -56,19 +58,20 @@ type Report struct {
SuiteConfig SuiteConfig
//SpecReports is a list of all SpecReports generated by this test run
//It is empty when the SuiteReport is provided to ReportBeforeSuite
SpecReports SpecReports
}
//PreRunStats contains a set of stats captured before the test run begins. This is primarily used
//by Ginkgo's reporter to tell the user how many specs are in the current suite (PreRunStats.TotalSpecs)
//and how many it intends to run (PreRunStats.SpecsThatWillRun) after applying any relevant focus or skip filters.
// PreRunStats contains a set of stats captured before the test run begins. This is primarily used
// by Ginkgo's reporter to tell the user how many specs are in the current suite (PreRunStats.TotalSpecs)
// and how many it intends to run (PreRunStats.SpecsThatWillRun) after applying any relevant focus or skip filters.
type PreRunStats struct {
TotalSpecs int
SpecsThatWillRun int
}
//Add is used by Ginkgo's parallel aggregation mechanisms to combine test run reports form individual parallel processes
//to form a complete final report.
// Add is used by Ginkgo's parallel aggregation mechanisms to combine test run reports form individual parallel processes
// to form a complete final report.
func (report Report) Add(other Report) Report {
report.SuiteSucceeded = report.SuiteSucceeded && other.SuiteSucceeded
@ -147,6 +150,9 @@ type SpecReport struct {
// ParallelProcess captures the parallel process that this spec ran on
ParallelProcess int
// RunningInParallel captures whether this spec is part of a suite that ran in parallel
RunningInParallel bool
//Failure is populated if a spec has failed, panicked, been interrupted, or skipped by the user (e.g. calling Skip())
//It includes detailed information about the Failure
Failure Failure
@ -178,6 +184,9 @@ type SpecReport struct {
// AdditionalFailures contains any failures that occurred after the initial spec failure. These typically occur in cleanup nodes after the initial failure and are only emitted when running in verbose mode.
AdditionalFailures []AdditionalFailure
// SpecEvents capture additional events that occur during the spec run
SpecEvents SpecEvents
}
func (report SpecReport) MarshalJSON() ([]byte, error) {
@ -204,6 +213,7 @@ func (report SpecReport) MarshalJSON() ([]byte, error) {
ReportEntries ReportEntries `json:",omitempty"`
ProgressReports []ProgressReport `json:",omitempty"`
AdditionalFailures []AdditionalFailure `json:",omitempty"`
SpecEvents SpecEvents `json:",omitempty"`
}{
ContainerHierarchyTexts: report.ContainerHierarchyTexts,
ContainerHierarchyLocations: report.ContainerHierarchyLocations,
@ -238,6 +248,9 @@ func (report SpecReport) MarshalJSON() ([]byte, error) {
if len(report.AdditionalFailures) > 0 {
out.AdditionalFailures = report.AdditionalFailures
}
if len(report.SpecEvents) > 0 {
out.SpecEvents = report.SpecEvents
}
return json.Marshal(out)
}
@ -255,13 +268,13 @@ func (report SpecReport) CombinedOutput() string {
return report.CapturedStdOutErr + "\n" + report.CapturedGinkgoWriterOutput
}
//Failed returns true if report.State is one of the SpecStateFailureStates
// Failed returns true if report.State is one of the SpecStateFailureStates
// (SpecStateFailed, SpecStatePanicked, SpecStateinterrupted, SpecStateAborted)
func (report SpecReport) Failed() bool {
return report.State.Is(SpecStateFailureStates)
}
//FullText returns a concatenation of all the report.ContainerHierarchyTexts and report.LeafNodeText
// FullText returns a concatenation of all the report.ContainerHierarchyTexts and report.LeafNodeText
func (report SpecReport) FullText() string {
texts := []string{}
texts = append(texts, report.ContainerHierarchyTexts...)
@ -271,7 +284,7 @@ func (report SpecReport) FullText() string {
return strings.Join(texts, " ")
}
//Labels returns a deduped set of all the spec's Labels.
// Labels returns a deduped set of all the spec's Labels.
func (report SpecReport) Labels() []string {
out := []string{}
seen := map[string]bool{}
@ -293,7 +306,7 @@ func (report SpecReport) Labels() []string {
return out
}
//MatchesLabelFilter returns true if the spec satisfies the passed in label filter query
// MatchesLabelFilter returns true if the spec satisfies the passed in label filter query
func (report SpecReport) MatchesLabelFilter(query string) (bool, error) {
filter, err := ParseLabelFilter(query)
if err != nil {
@ -302,29 +315,54 @@ func (report SpecReport) MatchesLabelFilter(query string) (bool, error) {
return filter(report.Labels()), nil
}
//FileName() returns the name of the file containing the spec
// FileName() returns the name of the file containing the spec
func (report SpecReport) FileName() string {
return report.LeafNodeLocation.FileName
}
//LineNumber() returns the line number of the leaf node
// LineNumber() returns the line number of the leaf node
func (report SpecReport) LineNumber() int {
return report.LeafNodeLocation.LineNumber
}
//FailureMessage() returns the failure message (or empty string if the test hasn't failed)
// FailureMessage() returns the failure message (or empty string if the test hasn't failed)
func (report SpecReport) FailureMessage() string {
return report.Failure.Message
}
//FailureLocation() returns the location of the failure (or an empty CodeLocation if the test hasn't failed)
// FailureLocation() returns the location of the failure (or an empty CodeLocation if the test hasn't failed)
func (report SpecReport) FailureLocation() CodeLocation {
return report.Failure.Location
}
// Timeline() returns a timeline view of the report
func (report SpecReport) Timeline() Timeline {
timeline := Timeline{}
if !report.Failure.IsZero() {
timeline = append(timeline, report.Failure)
if report.Failure.AdditionalFailure != nil {
timeline = append(timeline, *(report.Failure.AdditionalFailure))
}
}
for _, additionalFailure := range report.AdditionalFailures {
timeline = append(timeline, additionalFailure)
}
for _, reportEntry := range report.ReportEntries {
timeline = append(timeline, reportEntry)
}
for _, progressReport := range report.ProgressReports {
timeline = append(timeline, progressReport)
}
for _, specEvent := range report.SpecEvents {
timeline = append(timeline, specEvent)
}
sort.Sort(timeline)
return timeline
}
type SpecReports []SpecReport
//WithLeafNodeType returns the subset of SpecReports with LeafNodeType matching one of the requested NodeTypes
// WithLeafNodeType returns the subset of SpecReports with LeafNodeType matching one of the requested NodeTypes
func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports {
count := 0
for i := range reports {
@ -344,7 +382,7 @@ func (reports SpecReports) WithLeafNodeType(nodeTypes NodeType) SpecReports {
return out
}
//WithState returns the subset of SpecReports with State matching one of the requested SpecStates
// WithState returns the subset of SpecReports with State matching one of the requested SpecStates
func (reports SpecReports) WithState(states SpecState) SpecReports {
count := 0
for i := range reports {
@ -363,7 +401,7 @@ func (reports SpecReports) WithState(states SpecState) SpecReports {
return out
}
//CountWithState returns the number of SpecReports with State matching one of the requested SpecStates
// CountWithState returns the number of SpecReports with State matching one of the requested SpecStates
func (reports SpecReports) CountWithState(states SpecState) int {
n := 0
for i := range reports {
@ -374,7 +412,7 @@ func (reports SpecReports) CountWithState(states SpecState) int {
return n
}
//If the Spec passes, CountOfFlakedSpecs returns the number of SpecReports that failed after multiple attempts.
// If the Spec passes, CountOfFlakedSpecs returns the number of SpecReports that failed after multiple attempts.
func (reports SpecReports) CountOfFlakedSpecs() int {
n := 0
for i := range reports {
@ -385,7 +423,7 @@ func (reports SpecReports) CountOfFlakedSpecs() int {
return n
}
//If the Spec fails, CountOfRepeatedSpecs returns the number of SpecReports that passed after multiple attempts
// If the Spec fails, CountOfRepeatedSpecs returns the number of SpecReports that passed after multiple attempts
func (reports SpecReports) CountOfRepeatedSpecs() int {
n := 0
for i := range reports {
@ -396,6 +434,53 @@ func (reports SpecReports) CountOfRepeatedSpecs() int {
return n
}
// TimelineLocation captures the location of an event in the spec's timeline
type TimelineLocation struct {
//Offset is the offset (in bytes) of the event relative to the GinkgoWriter stream
Offset int `json:",omitempty"`
//Order is the order of the event with respect to other events. The absolute value of Order
//is irrelevant. All that matters is that an event with a lower Order occurs before ane vent with a higher Order
Order int `json:",omitempty"`
Time time.Time
}
// TimelineEvent represent an event on the timeline
// consumers of Timeline will need to check the concrete type of each entry to determine how to handle it
type TimelineEvent interface {
GetTimelineLocation() TimelineLocation
}
type Timeline []TimelineEvent
func (t Timeline) Len() int { return len(t) }
func (t Timeline) Less(i, j int) bool {
return t[i].GetTimelineLocation().Order < t[j].GetTimelineLocation().Order
}
func (t Timeline) Swap(i, j int) { t[i], t[j] = t[j], t[i] }
func (t Timeline) WithoutHiddenReportEntries() Timeline {
out := Timeline{}
for _, event := range t {
if reportEntry, isReportEntry := event.(ReportEntry); isReportEntry && reportEntry.Visibility == ReportEntryVisibilityNever {
continue
}
out = append(out, event)
}
return out
}
func (t Timeline) WithoutVeryVerboseSpecEvents() Timeline {
out := Timeline{}
for _, event := range t {
if specEvent, isSpecEvent := event.(SpecEvent); isSpecEvent && specEvent.IsOnlyVisibleAtVeryVerbose() {
continue
}
out = append(out, event)
}
return out
}
// Failure captures failure information for an individual test
type Failure struct {
// Message - the failure message passed into Fail(...). When using a matcher library
@ -408,6 +493,8 @@ type Failure struct {
// This CodeLocation will include a fully-populated StackTrace
Location CodeLocation
TimelineLocation TimelineLocation
// ForwardedPanic - if the failure represents a captured panic (i.e. Summary.State == SpecStatePanicked)
// then ForwardedPanic will be populated with a string representation of the captured panic.
ForwardedPanic string `json:",omitempty"`
@ -420,19 +507,29 @@ type Failure struct {
// FailureNodeType will contain the NodeType of the node in which the failure occurred.
// FailureNodeLocation will contain the CodeLocation of the node in which the failure occurred.
// If populated, FailureNodeContainerIndex will be the index into SpecReport.ContainerHierarchyTexts and SpecReport.ContainerHierarchyLocations that represents the parent container of the node in which the failure occurred.
FailureNodeContext FailureNodeContext
FailureNodeType NodeType
FailureNodeLocation CodeLocation
FailureNodeContainerIndex int
FailureNodeContext FailureNodeContext `json:",omitempty"`
FailureNodeType NodeType `json:",omitempty"`
FailureNodeLocation CodeLocation `json:",omitempty"`
FailureNodeContainerIndex int `json:",omitempty"`
//ProgressReport is populated if the spec was interrupted or timed out
ProgressReport ProgressReport
ProgressReport ProgressReport `json:",omitempty"`
//AdditionalFailure is non-nil if a follow-on failure occurred within the same node after the primary failure. This only happens when a node has timed out or been interrupted. In such cases the AdditionalFailure can include information about where/why the spec was stuck.
AdditionalFailure *AdditionalFailure `json:",omitempty"`
}
func (f Failure) IsZero() bool {
return f.Message == "" && (f.Location == CodeLocation{})
}
func (f Failure) GetTimelineLocation() TimelineLocation {
return f.TimelineLocation
}
// FailureNodeContext captures the location context for the node containing the failing line of code
type FailureNodeContext uint
@ -471,6 +568,10 @@ type AdditionalFailure struct {
Failure Failure
}
func (f AdditionalFailure) GetTimelineLocation() TimelineLocation {
return f.Failure.TimelineLocation
}
// SpecState captures the state of a spec
// To determine if a given `state` represents a failure state, use `state.Is(SpecStateFailureStates)`
type SpecState uint
@ -503,6 +604,9 @@ var ssEnumSupport = NewEnumSupport(map[uint]string{
func (ss SpecState) String() string {
return ssEnumSupport.String(uint(ss))
}
func (ss SpecState) GomegaString() string {
return ssEnumSupport.String(uint(ss))
}
func (ss *SpecState) UnmarshalJSON(b []byte) error {
out, err := ssEnumSupport.UnmarshJSON(b)
*ss = SpecState(out)
@ -520,38 +624,40 @@ func (ss SpecState) Is(states SpecState) bool {
// ProgressReport captures the progress of the current spec. It is, effectively, a structured Ginkgo-aware stack trace
type ProgressReport struct {
Message string
ParallelProcess int
RunningInParallel bool
Message string `json:",omitempty"`
ParallelProcess int `json:",omitempty"`
RunningInParallel bool `json:",omitempty"`
Time time.Time
ContainerHierarchyTexts []string `json:",omitempty"`
LeafNodeText string `json:",omitempty"`
LeafNodeLocation CodeLocation `json:",omitempty"`
SpecStartTime time.Time `json:",omitempty"`
ContainerHierarchyTexts []string
LeafNodeText string
LeafNodeLocation CodeLocation
SpecStartTime time.Time
CurrentNodeType NodeType `json:",omitempty"`
CurrentNodeText string `json:",omitempty"`
CurrentNodeLocation CodeLocation `json:",omitempty"`
CurrentNodeStartTime time.Time `json:",omitempty"`
CurrentNodeType NodeType
CurrentNodeText string
CurrentNodeLocation CodeLocation
CurrentNodeStartTime time.Time
CurrentStepText string `json:",omitempty"`
CurrentStepLocation CodeLocation `json:",omitempty"`
CurrentStepStartTime time.Time `json:",omitempty"`
CurrentStepText string
CurrentStepLocation CodeLocation
CurrentStepStartTime time.Time
AdditionalReports []string `json:",omitempty"`
AdditionalReports []string
CapturedGinkgoWriterOutput string `json:",omitempty"`
TimelineLocation TimelineLocation `json:",omitempty"`
CapturedGinkgoWriterOutput string `json:",omitempty"`
GinkgoWriterOffset int
Goroutines []Goroutine
Goroutines []Goroutine `json:",omitempty"`
}
func (pr ProgressReport) IsZero() bool {
return pr.CurrentNodeType == NodeTypeInvalid
}
func (pr ProgressReport) Time() time.Time {
return pr.TimelineLocation.Time
}
func (pr ProgressReport) SpecGoroutine() Goroutine {
for _, goroutine := range pr.Goroutines {
if goroutine.IsSpecGoroutine {
@ -589,6 +695,22 @@ func (pr ProgressReport) WithoutCapturedGinkgoWriterOutput() ProgressReport {
return out
}
func (pr ProgressReport) WithoutOtherGoroutines() ProgressReport {
out := pr
filteredGoroutines := []Goroutine{}
for _, goroutine := range pr.Goroutines {
if goroutine.IsSpecGoroutine || goroutine.HasHighlights() {
filteredGoroutines = append(filteredGoroutines, goroutine)
}
}
out.Goroutines = filteredGoroutines
return out
}
func (pr ProgressReport) GetTimelineLocation() TimelineLocation {
return pr.TimelineLocation
}
type Goroutine struct {
ID uint64
State string
@ -643,6 +765,7 @@ const (
NodeTypeReportBeforeEach
NodeTypeReportAfterEach
NodeTypeReportBeforeSuite
NodeTypeReportAfterSuite
NodeTypeCleanupInvalid
@ -652,9 +775,9 @@ const (
)
var NodeTypesForContainerAndIt = NodeTypeContainer | NodeTypeIt
var NodeTypesForSuiteLevelNodes = NodeTypeBeforeSuite | NodeTypeSynchronizedBeforeSuite | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeReportAfterSuite | NodeTypeCleanupAfterSuite
var NodeTypesForSuiteLevelNodes = NodeTypeBeforeSuite | NodeTypeSynchronizedBeforeSuite | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite | NodeTypeCleanupAfterSuite
var NodeTypesAllowedDuringCleanupInterrupt = NodeTypeAfterEach | NodeTypeJustAfterEach | NodeTypeAfterAll | NodeTypeAfterSuite | NodeTypeSynchronizedAfterSuite | NodeTypeCleanupAfterEach | NodeTypeCleanupAfterAll | NodeTypeCleanupAfterSuite
var NodeTypesAllowedDuringReportInterrupt = NodeTypeReportBeforeEach | NodeTypeReportAfterEach | NodeTypeReportAfterSuite
var NodeTypesAllowedDuringReportInterrupt = NodeTypeReportBeforeEach | NodeTypeReportAfterEach | NodeTypeReportBeforeSuite | NodeTypeReportAfterSuite
var ntEnumSupport = NewEnumSupport(map[uint]string{
uint(NodeTypeInvalid): "INVALID NODE TYPE",
@ -672,6 +795,7 @@ var ntEnumSupport = NewEnumSupport(map[uint]string{
uint(NodeTypeSynchronizedAfterSuite): "SynchronizedAfterSuite",
uint(NodeTypeReportBeforeEach): "ReportBeforeEach",
uint(NodeTypeReportAfterEach): "ReportAfterEach",
uint(NodeTypeReportBeforeSuite): "ReportBeforeSuite",
uint(NodeTypeReportAfterSuite): "ReportAfterSuite",
uint(NodeTypeCleanupInvalid): "DeferCleanup",
uint(NodeTypeCleanupAfterEach): "DeferCleanup (Each)",
@ -694,3 +818,99 @@ func (nt NodeType) MarshalJSON() ([]byte, error) {
func (nt NodeType) Is(nodeTypes NodeType) bool {
return nt&nodeTypes != 0
}
/*
SpecEvent captures a vareity of events that can occur when specs run. See SpecEventType for the list of available events.
*/
type SpecEvent struct {
SpecEventType SpecEventType
CodeLocation CodeLocation
TimelineLocation TimelineLocation
Message string `json:",omitempty"`
Duration time.Duration `json:",omitempty"`
NodeType NodeType `json:",omitempty"`
Attempt int `json:",omitempty"`
}
func (se SpecEvent) GetTimelineLocation() TimelineLocation {
return se.TimelineLocation
}
func (se SpecEvent) IsOnlyVisibleAtVeryVerbose() bool {
return se.SpecEventType.Is(SpecEventByEnd | SpecEventNodeStart | SpecEventNodeEnd)
}
func (se SpecEvent) GomegaString() string {
out := &strings.Builder{}
out.WriteString("[" + se.SpecEventType.String() + " SpecEvent] ")
if se.Message != "" {
out.WriteString("Message=")
out.WriteString(`"` + se.Message + `",`)
}
if se.Duration != 0 {
out.WriteString("Duration=" + se.Duration.String() + ",")
}
if se.NodeType != NodeTypeInvalid {
out.WriteString("NodeType=" + se.NodeType.String() + ",")
}
if se.Attempt != 0 {
out.WriteString(fmt.Sprintf("Attempt=%d", se.Attempt) + ",")
}
out.WriteString("CL=" + se.CodeLocation.String() + ",")
out.WriteString(fmt.Sprintf("TL.Offset=%d", se.TimelineLocation.Offset))
return out.String()
}
type SpecEvents []SpecEvent
func (se SpecEvents) WithType(seType SpecEventType) SpecEvents {
out := SpecEvents{}
for _, event := range se {
if event.SpecEventType.Is(seType) {
out = append(out, event)
}
}
return out
}
type SpecEventType uint
const (
SpecEventInvalid SpecEventType = 0
SpecEventByStart SpecEventType = 1 << iota
SpecEventByEnd
SpecEventNodeStart
SpecEventNodeEnd
SpecEventSpecRepeat
SpecEventSpecRetry
)
var seEnumSupport = NewEnumSupport(map[uint]string{
uint(SpecEventInvalid): "INVALID SPEC EVENT",
uint(SpecEventByStart): "By",
uint(SpecEventByEnd): "By (End)",
uint(SpecEventNodeStart): "Node",
uint(SpecEventNodeEnd): "Node (End)",
uint(SpecEventSpecRepeat): "Repeat",
uint(SpecEventSpecRetry): "Retry",
})
func (se SpecEventType) String() string {
return seEnumSupport.String(uint(se))
}
func (se *SpecEventType) UnmarshalJSON(b []byte) error {
out, err := seEnumSupport.UnmarshJSON(b)
*se = SpecEventType(out)
return err
}
func (se SpecEventType) MarshalJSON() ([]byte, error) {
return seEnumSupport.MarshJSON(uint(se))
}
func (se SpecEventType) Is(specEventTypes SpecEventType) bool {
return se&specEventTypes != 0
}

View File

@ -1,3 +1,3 @@
package types
const VERSION = "2.4.0"
const VERSION = "2.9.5"

View File

@ -1,27 +0,0 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,6 +0,0 @@
# qtls
[![Go Reference](https://pkg.go.dev/badge/github.com/quic-go/qtls-go1-19.svg)](https://pkg.go.dev/github.com/quic-go/qtls-go1-19)
[![.github/workflows/go-test.yml](https://github.com/quic-go/qtls-go1-19/actions/workflows/go-test.yml/badge.svg)](https://github.com/quic-go/qtls-go1-19/actions/workflows/go-test.yml)
This repository contains a modified version of the standard library's TLS implementation, modified for the QUIC protocol. It is used by [quic-go](https://github.com/lucas-clemente/quic-go).

View File

@ -1,102 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import "strconv"
type alert uint8
// Alert is a TLS alert
type Alert = alert
const (
// alert level
alertLevelWarning = 1
alertLevelError = 2
)
const (
alertCloseNotify alert = 0
alertUnexpectedMessage alert = 10
alertBadRecordMAC alert = 20
alertDecryptionFailed alert = 21
alertRecordOverflow alert = 22
alertDecompressionFailure alert = 30
alertHandshakeFailure alert = 40
alertBadCertificate alert = 42
alertUnsupportedCertificate alert = 43
alertCertificateRevoked alert = 44
alertCertificateExpired alert = 45
alertCertificateUnknown alert = 46
alertIllegalParameter alert = 47
alertUnknownCA alert = 48
alertAccessDenied alert = 49
alertDecodeError alert = 50
alertDecryptError alert = 51
alertExportRestriction alert = 60
alertProtocolVersion alert = 70
alertInsufficientSecurity alert = 71
alertInternalError alert = 80
alertInappropriateFallback alert = 86
alertUserCanceled alert = 90
alertNoRenegotiation alert = 100
alertMissingExtension alert = 109
alertUnsupportedExtension alert = 110
alertCertificateUnobtainable alert = 111
alertUnrecognizedName alert = 112
alertBadCertificateStatusResponse alert = 113
alertBadCertificateHashValue alert = 114
alertUnknownPSKIdentity alert = 115
alertCertificateRequired alert = 116
alertNoApplicationProtocol alert = 120
)
var alertText = map[alert]string{
alertCloseNotify: "close notify",
alertUnexpectedMessage: "unexpected message",
alertBadRecordMAC: "bad record MAC",
alertDecryptionFailed: "decryption failed",
alertRecordOverflow: "record overflow",
alertDecompressionFailure: "decompression failure",
alertHandshakeFailure: "handshake failure",
alertBadCertificate: "bad certificate",
alertUnsupportedCertificate: "unsupported certificate",
alertCertificateRevoked: "revoked certificate",
alertCertificateExpired: "expired certificate",
alertCertificateUnknown: "unknown certificate",
alertIllegalParameter: "illegal parameter",
alertUnknownCA: "unknown certificate authority",
alertAccessDenied: "access denied",
alertDecodeError: "error decoding message",
alertDecryptError: "error decrypting message",
alertExportRestriction: "export restriction",
alertProtocolVersion: "protocol version not supported",
alertInsufficientSecurity: "insufficient security level",
alertInternalError: "internal error",
alertInappropriateFallback: "inappropriate fallback",
alertUserCanceled: "user canceled",
alertNoRenegotiation: "no renegotiation",
alertMissingExtension: "missing extension",
alertUnsupportedExtension: "unsupported extension",
alertCertificateUnobtainable: "certificate unobtainable",
alertUnrecognizedName: "unrecognized name",
alertBadCertificateStatusResponse: "bad certificate status response",
alertBadCertificateHashValue: "bad certificate hash value",
alertUnknownPSKIdentity: "unknown PSK identity",
alertCertificateRequired: "certificate required",
alertNoApplicationProtocol: "no application protocol",
}
func (e alert) String() string {
s, ok := alertText[e]
if ok {
return "tls: " + s
}
return "tls: alert(" + strconv.Itoa(int(e)) + ")"
}
func (e alert) Error() string {
return e.String()
}

View File

@ -1,293 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"errors"
"fmt"
"hash"
"io"
)
// verifyHandshakeSignature verifies a signature against pre-hashed
// (if required) handshake contents.
func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error {
switch sigType {
case signatureECDSA:
pubKey, ok := pubkey.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("expected an ECDSA public key, got %T", pubkey)
}
if !ecdsa.VerifyASN1(pubKey, signed, sig) {
return errors.New("ECDSA verification failure")
}
case signatureEd25519:
pubKey, ok := pubkey.(ed25519.PublicKey)
if !ok {
return fmt.Errorf("expected an Ed25519 public key, got %T", pubkey)
}
if !ed25519.Verify(pubKey, signed, sig) {
return errors.New("Ed25519 verification failure")
}
case signaturePKCS1v15:
pubKey, ok := pubkey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
}
if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, signed, sig); err != nil {
return err
}
case signatureRSAPSS:
pubKey, ok := pubkey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("expected an RSA public key, got %T", pubkey)
}
signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}
if err := rsa.VerifyPSS(pubKey, hashFunc, signed, sig, signOpts); err != nil {
return err
}
default:
return errors.New("internal error: unknown signature type")
}
return nil
}
const (
serverSignatureContext = "TLS 1.3, server CertificateVerify\x00"
clientSignatureContext = "TLS 1.3, client CertificateVerify\x00"
)
var signaturePadding = []byte{
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
}
// signedMessage returns the pre-hashed (if necessary) message to be signed by
// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3.
func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte {
if sigHash == directSigning {
b := &bytes.Buffer{}
b.Write(signaturePadding)
io.WriteString(b, context)
b.Write(transcript.Sum(nil))
return b.Bytes()
}
h := sigHash.New()
h.Write(signaturePadding)
io.WriteString(h, context)
h.Write(transcript.Sum(nil))
return h.Sum(nil)
}
// typeAndHashFromSignatureScheme returns the corresponding signature type and
// crypto.Hash for a given TLS SignatureScheme.
func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType uint8, hash crypto.Hash, err error) {
switch signatureAlgorithm {
case PKCS1WithSHA1, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512:
sigType = signaturePKCS1v15
case PSSWithSHA256, PSSWithSHA384, PSSWithSHA512:
sigType = signatureRSAPSS
case ECDSAWithSHA1, ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512:
sigType = signatureECDSA
case Ed25519:
sigType = signatureEd25519
default:
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
}
switch signatureAlgorithm {
case PKCS1WithSHA1, ECDSAWithSHA1:
hash = crypto.SHA1
case PKCS1WithSHA256, PSSWithSHA256, ECDSAWithP256AndSHA256:
hash = crypto.SHA256
case PKCS1WithSHA384, PSSWithSHA384, ECDSAWithP384AndSHA384:
hash = crypto.SHA384
case PKCS1WithSHA512, PSSWithSHA512, ECDSAWithP521AndSHA512:
hash = crypto.SHA512
case Ed25519:
hash = directSigning
default:
return 0, 0, fmt.Errorf("unsupported signature algorithm: %v", signatureAlgorithm)
}
return sigType, hash, nil
}
// legacyTypeAndHashFromPublicKey returns the fixed signature type and crypto.Hash for
// a given public key used with TLS 1.0 and 1.1, before the introduction of
// signature algorithm negotiation.
func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash crypto.Hash, err error) {
switch pub.(type) {
case *rsa.PublicKey:
return signaturePKCS1v15, crypto.MD5SHA1, nil
case *ecdsa.PublicKey:
return signatureECDSA, crypto.SHA1, nil
case ed25519.PublicKey:
// RFC 8422 specifies support for Ed25519 in TLS 1.0 and 1.1,
// but it requires holding on to a handshake transcript to do a
// full signature, and not even OpenSSL bothers with the
// complexity, so we can't even test it properly.
return 0, 0, fmt.Errorf("tls: Ed25519 public keys are not supported before TLS 1.2")
default:
return 0, 0, fmt.Errorf("tls: unsupported public key: %T", pub)
}
}
var rsaSignatureSchemes = []struct {
scheme SignatureScheme
minModulusBytes int
maxVersion uint16
}{
// RSA-PSS is used with PSSSaltLengthEqualsHash, and requires
// emLen >= hLen + sLen + 2
{PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13},
{PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13},
{PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13},
// PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires
// emLen >= len(prefix) + hLen + 11
// TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS.
{PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12},
{PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12},
{PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12},
{PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12},
}
// signatureSchemesForCertificate returns the list of supported SignatureSchemes
// for a given certificate, based on the public key and the protocol version,
// and optionally filtered by its explicit SupportedSignatureAlgorithms.
//
// This function must be kept in sync with supportedSignatureAlgorithms.
// FIPS filtering is applied in the caller, selectSignatureScheme.
func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme {
priv, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil
}
var sigAlgs []SignatureScheme
switch pub := priv.Public().(type) {
case *ecdsa.PublicKey:
if version != VersionTLS13 {
// In TLS 1.2 and earlier, ECDSA algorithms are not
// constrained to a single curve.
sigAlgs = []SignatureScheme{
ECDSAWithP256AndSHA256,
ECDSAWithP384AndSHA384,
ECDSAWithP521AndSHA512,
ECDSAWithSHA1,
}
break
}
switch pub.Curve {
case elliptic.P256():
sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256}
case elliptic.P384():
sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384}
case elliptic.P521():
sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512}
default:
return nil
}
case *rsa.PublicKey:
size := pub.Size()
sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes))
for _, candidate := range rsaSignatureSchemes {
if size >= candidate.minModulusBytes && version <= candidate.maxVersion {
sigAlgs = append(sigAlgs, candidate.scheme)
}
}
case ed25519.PublicKey:
sigAlgs = []SignatureScheme{Ed25519}
default:
return nil
}
if cert.SupportedSignatureAlgorithms != nil {
var filteredSigAlgs []SignatureScheme
for _, sigAlg := range sigAlgs {
if isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) {
filteredSigAlgs = append(filteredSigAlgs, sigAlg)
}
}
return filteredSigAlgs
}
return sigAlgs
}
// selectSignatureScheme picks a SignatureScheme from the peer's preference list
// that works with the selected certificate. It's only called for protocol
// versions that support signature algorithms, so TLS 1.2 and 1.3.
func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) {
supportedAlgs := signatureSchemesForCertificate(vers, c)
if len(supportedAlgs) == 0 {
return 0, unsupportedCertificateError(c)
}
if len(peerAlgs) == 0 && vers == VersionTLS12 {
// For TLS 1.2, if the client didn't send signature_algorithms then we
// can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1.
peerAlgs = []SignatureScheme{PKCS1WithSHA1, ECDSAWithSHA1}
}
// Pick signature scheme in the peer's preference order, as our
// preference order is not configurable.
for _, preferredAlg := range peerAlgs {
if needFIPS() && !isSupportedSignatureAlgorithm(preferredAlg, fipsSupportedSignatureAlgorithms) {
continue
}
if isSupportedSignatureAlgorithm(preferredAlg, supportedAlgs) {
return preferredAlg, nil
}
}
return 0, errors.New("tls: peer doesn't support any of the certificate's signature algorithms")
}
// unsupportedCertificateError returns a helpful error for certificates with
// an unsupported private key.
func unsupportedCertificateError(cert *Certificate) error {
switch cert.PrivateKey.(type) {
case rsa.PrivateKey, ecdsa.PrivateKey:
return fmt.Errorf("tls: unsupported certificate: private key is %T, expected *%T",
cert.PrivateKey, cert.PrivateKey)
case *ed25519.PrivateKey:
return fmt.Errorf("tls: unsupported certificate: private key is *ed25519.PrivateKey, expected ed25519.PrivateKey")
}
signer, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return fmt.Errorf("tls: certificate private key (%T) does not implement crypto.Signer",
cert.PrivateKey)
}
switch pub := signer.Public().(type) {
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P256():
case elliptic.P384():
case elliptic.P521():
default:
return fmt.Errorf("tls: unsupported certificate curve (%s)", pub.Curve.Params().Name)
}
case *rsa.PublicKey:
return fmt.Errorf("tls: certificate RSA key size too small for supported signature algorithms")
case ed25519.PublicKey:
default:
return fmt.Errorf("tls: unsupported certificate key (%T)", pub)
}
if cert.SupportedSignatureAlgorithms != nil {
return fmt.Errorf("tls: peer doesn't support the certificate custom signature algorithms")
}
return fmt.Errorf("tls: internal error: unsupported key (%T)", cert.PrivateKey)
}

View File

@ -1,693 +0,0 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/hmac"
"crypto/rc4"
"crypto/sha1"
"crypto/sha256"
"fmt"
"hash"
"golang.org/x/crypto/chacha20poly1305"
)
// CipherSuite is a TLS cipher suite. Note that most functions in this package
// accept and expose cipher suite IDs instead of this type.
type CipherSuite struct {
ID uint16
Name string
// Supported versions is the list of TLS protocol versions that can
// negotiate this cipher suite.
SupportedVersions []uint16
// Insecure is true if the cipher suite has known security issues
// due to its primitives, design, or implementation.
Insecure bool
}
var (
supportedUpToTLS12 = []uint16{VersionTLS10, VersionTLS11, VersionTLS12}
supportedOnlyTLS12 = []uint16{VersionTLS12}
supportedOnlyTLS13 = []uint16{VersionTLS13}
)
// CipherSuites returns a list of cipher suites currently implemented by this
// package, excluding those with security issues, which are returned by
// InsecureCipherSuites.
//
// The list is sorted by ID. Note that the default cipher suites selected by
// this package might depend on logic that can't be captured by a static list,
// and might not match those returned by this function.
func CipherSuites() []*CipherSuite {
return []*CipherSuite{
{TLS_RSA_WITH_AES_128_CBC_SHA, "TLS_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_RSA_WITH_AES_256_CBC_SHA, "TLS_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_RSA_WITH_AES_128_GCM_SHA256, "TLS_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_RSA_WITH_AES_256_GCM_SHA384, "TLS_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_AES_128_GCM_SHA256, "TLS_AES_128_GCM_SHA256", supportedOnlyTLS13, false},
{TLS_AES_256_GCM_SHA384, "TLS_AES_256_GCM_SHA384", supportedOnlyTLS13, false},
{TLS_CHACHA20_POLY1305_SHA256, "TLS_CHACHA20_POLY1305_SHA256", supportedOnlyTLS13, false},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", supportedUpToTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", supportedOnlyTLS12, false},
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", supportedOnlyTLS12, false},
}
}
// InsecureCipherSuites returns a list of cipher suites currently implemented by
// this package and which have security issues.
//
// Most applications should not use the cipher suites in this list, and should
// only use those returned by CipherSuites.
func InsecureCipherSuites() []*CipherSuite {
// This list includes RC4, CBC_SHA256, and 3DES cipher suites. See
// cipherSuitesPreferenceOrder for details.
return []*CipherSuite{
{TLS_RSA_WITH_RC4_128_SHA, "TLS_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
{TLS_RSA_WITH_AES_128_CBC_SHA256, "TLS_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", supportedUpToTLS12, true},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", supportedOnlyTLS12, true},
}
}
// CipherSuiteName returns the standard name for the passed cipher suite ID
// (e.g. "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"), or a fallback representation
// of the ID value if the cipher suite is not implemented by this package.
func CipherSuiteName(id uint16) string {
for _, c := range CipherSuites() {
if c.ID == id {
return c.Name
}
}
for _, c := range InsecureCipherSuites() {
if c.ID == id {
return c.Name
}
}
return fmt.Sprintf("0x%04X", id)
}
const (
// suiteECDHE indicates that the cipher suite involves elliptic curve
// Diffie-Hellman. This means that it should only be selected when the
// client indicates that it supports ECC with a curve and point format
// that we're happy with.
suiteECDHE = 1 << iota
// suiteECSign indicates that the cipher suite involves an ECDSA or
// EdDSA signature and therefore may only be selected when the server's
// certificate is ECDSA or EdDSA. If this is not set then the cipher suite
// is RSA based.
suiteECSign
// suiteTLS12 indicates that the cipher suite should only be advertised
// and accepted when using TLS 1.2.
suiteTLS12
// suiteSHA384 indicates that the cipher suite uses SHA384 as the
// handshake hash.
suiteSHA384
)
// A cipherSuite is a TLS 1.01.2 cipher suite, and defines the key exchange
// mechanism, as well as the cipher+MAC pair or the AEAD.
type cipherSuite struct {
id uint16
// the lengths, in bytes, of the key material needed for each component.
keyLen int
macLen int
ivLen int
ka func(version uint16) keyAgreement
// flags is a bitmask of the suite* values, above.
flags int
cipher func(key, iv []byte, isRead bool) any
mac func(key []byte) hash.Hash
aead func(key, fixedNonce []byte) aead
}
var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter.
{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305},
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM},
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM},
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, cipherAES, macSHA256, nil},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherAES, macSHA1, nil},
{TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM},
{TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
{TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil},
{TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
{TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil},
{TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil},
{TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECSign, cipherRC4, macSHA1, nil},
}
// selectCipherSuite returns the first TLS 1.01.2 cipher suite from ids which
// is also in supportedIDs and passes the ok filter.
func selectCipherSuite(ids, supportedIDs []uint16, ok func(*cipherSuite) bool) *cipherSuite {
for _, id := range ids {
candidate := cipherSuiteByID(id)
if candidate == nil || !ok(candidate) {
continue
}
for _, suppID := range supportedIDs {
if id == suppID {
return candidate
}
}
}
return nil
}
// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash
// algorithm to be used with HKDF. See RFC 8446, Appendix B.4.
type cipherSuiteTLS13 struct {
id uint16
keyLen int
aead func(key, fixedNonce []byte) aead
hash crypto.Hash
}
type CipherSuiteTLS13 struct {
ID uint16
KeyLen int
Hash crypto.Hash
AEAD func(key, fixedNonce []byte) cipher.AEAD
}
func (c *CipherSuiteTLS13) IVLen() int {
return aeadNonceLength
}
var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map.
{TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256},
{TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256},
{TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384},
}
// cipherSuitesPreferenceOrder is the order in which we'll select (on the
// server) or advertise (on the client) TLS 1.01.2 cipher suites.
//
// Cipher suites are filtered but not reordered based on the application and
// peer's preferences, meaning we'll never select a suite lower in this list if
// any higher one is available. This makes it more defensible to keep weaker
// cipher suites enabled, especially on the server side where we get the last
// word, since there are no known downgrade attacks on cipher suites selection.
//
// The list is sorted by applying the following priority rules, stopping at the
// first (most important) applicable one:
//
// - Anything else comes before RC4
//
// RC4 has practically exploitable biases. See https://www.rc4nomore.com.
//
// - Anything else comes before CBC_SHA256
//
// SHA-256 variants of the CBC ciphersuites don't implement any Lucky13
// countermeasures. See http://www.isg.rhul.ac.uk/tls/Lucky13.html and
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
//
// - Anything else comes before 3DES
//
// 3DES has 64-bit blocks, which makes it fundamentally susceptible to
// birthday attacks. See https://sweet32.info.
//
// - ECDHE comes before anything else
//
// Once we got the broken stuff out of the way, the most important
// property a cipher suite can have is forward secrecy. We don't
// implement FFDHE, so that means ECDHE.
//
// - AEADs come before CBC ciphers
//
// Even with Lucky13 countermeasures, MAC-then-Encrypt CBC cipher suites
// are fundamentally fragile, and suffered from an endless sequence of
// padding oracle attacks. See https://eprint.iacr.org/2015/1129,
// https://www.imperialviolet.org/2014/12/08/poodleagain.html, and
// https://blog.cloudflare.com/yet-another-padding-oracle-in-openssl-cbc-ciphersuites/.
//
// - AES comes before ChaCha20
//
// When AES hardware is available, AES-128-GCM and AES-256-GCM are faster
// than ChaCha20Poly1305.
//
// When AES hardware is not available, AES-128-GCM is one or more of: much
// slower, way more complex, and less safe (because not constant time)
// than ChaCha20Poly1305.
//
// We use this list if we think both peers have AES hardware, and
// cipherSuitesPreferenceOrderNoAES otherwise.
//
// - AES-128 comes before AES-256
//
// The only potential advantages of AES-256 are better multi-target
// margins, and hypothetical post-quantum properties. Neither apply to
// TLS, and AES-256 is slower due to its four extra rounds (which don't
// contribute to the advantages above).
//
// - ECDSA comes before RSA
//
// The relative order of ECDSA and RSA cipher suites doesn't matter,
// as they depend on the certificate. Pick one to get a stable order.
var cipherSuitesPreferenceOrder = []uint16{
// AEADs w/ ECDHE
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
// CBC w/ ECDHE
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
// AEADs w/o ECDHE
TLS_RSA_WITH_AES_128_GCM_SHA256,
TLS_RSA_WITH_AES_256_GCM_SHA384,
// CBC w/o ECDHE
TLS_RSA_WITH_AES_128_CBC_SHA,
TLS_RSA_WITH_AES_256_CBC_SHA,
// 3DES
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_RSA_WITH_3DES_EDE_CBC_SHA,
// CBC_SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_RSA_WITH_AES_128_CBC_SHA256,
// RC4
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,
}
var cipherSuitesPreferenceOrderNoAES = []uint16{
// ChaCha20Poly1305
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
// AES-GCM w/ ECDHE
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
// The rest of cipherSuitesPreferenceOrder.
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
TLS_RSA_WITH_AES_128_GCM_SHA256,
TLS_RSA_WITH_AES_256_GCM_SHA384,
TLS_RSA_WITH_AES_128_CBC_SHA,
TLS_RSA_WITH_AES_256_CBC_SHA,
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_RSA_WITH_3DES_EDE_CBC_SHA,
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_RSA_WITH_AES_128_CBC_SHA256,
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,
}
// disabledCipherSuites are not used unless explicitly listed in
// Config.CipherSuites. They MUST be at the end of cipherSuitesPreferenceOrder.
var disabledCipherSuites = []uint16{
// CBC_SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
TLS_RSA_WITH_AES_128_CBC_SHA256,
// RC4
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA,
TLS_RSA_WITH_RC4_128_SHA,
}
var (
defaultCipherSuitesLen = len(cipherSuitesPreferenceOrder) - len(disabledCipherSuites)
defaultCipherSuites = cipherSuitesPreferenceOrder[:defaultCipherSuitesLen]
)
// defaultCipherSuitesTLS13 is also the preference order, since there are no
// disabled by default TLS 1.3 cipher suites. The same AES vs ChaCha20 logic as
// cipherSuitesPreferenceOrder applies.
var defaultCipherSuitesTLS13 = []uint16{
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
TLS_CHACHA20_POLY1305_SHA256,
}
var defaultCipherSuitesTLS13NoAES = []uint16{
TLS_CHACHA20_POLY1305_SHA256,
TLS_AES_128_GCM_SHA256,
TLS_AES_256_GCM_SHA384,
}
var aesgcmCiphers = map[uint16]bool{
// TLS 1.2
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true,
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: true,
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: true,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: true,
// TLS 1.3
TLS_AES_128_GCM_SHA256: true,
TLS_AES_256_GCM_SHA384: true,
}
var nonAESGCMAEADCiphers = map[uint16]bool{
// TLS 1.2
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: true,
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: true,
// TLS 1.3
TLS_CHACHA20_POLY1305_SHA256: true,
}
// aesgcmPreferred returns whether the first known cipher in the preference list
// is an AES-GCM cipher, implying the peer has hardware support for it.
func aesgcmPreferred(ciphers []uint16) bool {
for _, cID := range ciphers {
if c := cipherSuiteByID(cID); c != nil {
return aesgcmCiphers[cID]
}
if c := cipherSuiteTLS13ByID(cID); c != nil {
return aesgcmCiphers[cID]
}
}
return false
}
func cipherRC4(key, iv []byte, isRead bool) any {
cipher, _ := rc4.NewCipher(key)
return cipher
}
func cipher3DES(key, iv []byte, isRead bool) any {
block, _ := des.NewTripleDESCipher(key)
if isRead {
return cipher.NewCBCDecrypter(block, iv)
}
return cipher.NewCBCEncrypter(block, iv)
}
func cipherAES(key, iv []byte, isRead bool) any {
block, _ := aes.NewCipher(key)
if isRead {
return cipher.NewCBCDecrypter(block, iv)
}
return cipher.NewCBCEncrypter(block, iv)
}
// macSHA1 returns a SHA-1 based constant time MAC.
func macSHA1(key []byte) hash.Hash {
h := sha1.New
h = newConstantTimeHash(h)
return hmac.New(h, key)
}
// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and
// is currently only used in disabled-by-default cipher suites.
func macSHA256(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}
type aead interface {
cipher.AEAD
// explicitNonceLen returns the number of bytes of explicit nonce
// included in each record. This is eight for older AEADs and
// zero for modern ones.
explicitNonceLen() int
}
const (
aeadNonceLength = 12
noncePrefixLength = 4
)
// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to
// each call.
type prefixNonceAEAD struct {
// nonce contains the fixed part of the nonce in the first four bytes.
nonce [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength }
func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() }
func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
copy(f.nonce[4:], nonce)
return f.aead.Seal(out, f.nonce[:], plaintext, additionalData)
}
func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
copy(f.nonce[4:], nonce)
return f.aead.Open(out, f.nonce[:], ciphertext, additionalData)
}
// xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce
// before each call.
type xorNonceAEAD struct {
nonceMask [aeadNonceLength]byte
aead cipher.AEAD
}
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result
}
func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData)
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}
return result, err
}
func aeadAESGCM(key, noncePrefix []byte) aead {
if len(noncePrefix) != noncePrefixLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
var aead cipher.AEAD
aead, err = cipher.NewGCM(aes)
if err != nil {
panic(err)
}
ret := &prefixNonceAEAD{aead: aead}
copy(ret.nonce[:], noncePrefix)
return ret
}
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
return aeadAESGCMTLS13(key, fixedNonce)
}
func aeadAESGCMTLS13(key, nonceMask []byte) aead {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aes, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
func aeadChaCha20Poly1305(key, nonceMask []byte) aead {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")
}
aead, err := chacha20poly1305.New(key)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
copy(ret.nonceMask[:], nonceMask)
return ret
}
type constantTimeHash interface {
hash.Hash
ConstantTimeSum(b []byte) []byte
}
// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces
// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC.
type cthWrapper struct {
h constantTimeHash
}
func (c *cthWrapper) Size() int { return c.h.Size() }
func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() }
func (c *cthWrapper) Reset() { c.h.Reset() }
func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) }
func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) }
func newConstantTimeHash(h func() hash.Hash) func() hash.Hash {
return func() hash.Hash {
return &cthWrapper{h().(constantTimeHash)}
}
}
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3.
func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte {
h.Reset()
h.Write(seq)
h.Write(header)
h.Write(data)
res := h.Sum(out)
if extra != nil {
h.Write(extra)
}
return res
}
func rsaKA(version uint16) keyAgreement {
return rsaKeyAgreement{}
}
func ecdheECDSAKA(version uint16) keyAgreement {
return &ecdheKeyAgreement{
isRSA: false,
version: version,
}
}
func ecdheRSAKA(version uint16) keyAgreement {
return &ecdheKeyAgreement{
isRSA: true,
version: version,
}
}
// mutualCipherSuite returns a cipherSuite given a list of supported
// ciphersuites and the id requested by the peer.
func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
for _, id := range have {
if id == want {
return cipherSuiteByID(id)
}
}
return nil
}
func cipherSuiteByID(id uint16) *cipherSuite {
for _, cipherSuite := range cipherSuites {
if cipherSuite.id == id {
return cipherSuite
}
}
return nil
}
func mutualCipherSuiteTLS13(have []uint16, want uint16) *cipherSuiteTLS13 {
for _, id := range have {
if id == want {
return cipherSuiteTLS13ByID(id)
}
}
return nil
}
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 {
for _, cipherSuite := range cipherSuitesTLS13 {
if cipherSuite.id == id {
return cipherSuite
}
}
return nil
}
// A list of cipher suite IDs that are, or have been, implemented by this
// package.
//
// See https://www.iana.org/assignments/tls-parameters/tls-parameters.xml
const (
// TLS 1.0 - 1.2 cipher suites.
TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005
TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a
TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f
TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035
TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c
TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c
TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d
TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009
TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a
TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011
TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014
TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca8
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xcca9
// TLS 1.3 cipher suites.
TLS_AES_128_GCM_SHA256 uint16 = 0x1301
TLS_AES_256_GCM_SHA384 uint16 = 0x1302
TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303
// TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator
// that the client is doing version fallback. See RFC 7507.
TLS_FALLBACK_SCSV uint16 = 0x5600
// Legacy names for the corresponding cipher suites with the correct _SHA256
// suffix, retained for backward compatibility.
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 = TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,22 +0,0 @@
//go:build !js
// +build !js
package qtls
import (
"runtime"
"golang.org/x/sys/cpu"
)
var (
hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
// Keep in sync with crypto/aes/cipher_s390x.go.
hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR &&
(cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 ||
runtime.GOARCH == "arm64" && hasGCMAsmARM64 ||
runtime.GOARCH == "s390x" && hasGCMAsmS390X
)

View File

@ -1,12 +0,0 @@
//go:build js
// +build js
package qtls
var (
hasGCMAsmAMD64 = false
hasGCMAsmARM64 = false
hasGCMAsmS390X = false
hasAESGCMHardwareSupport = false
)

File diff suppressed because it is too large Load Diff

View File

@ -1,755 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"bytes"
"context"
"crypto"
"crypto/hmac"
"crypto/rsa"
"encoding/binary"
"errors"
"hash"
"sync/atomic"
"time"
"golang.org/x/crypto/cryptobyte"
)
type clientHandshakeStateTLS13 struct {
c *Conn
ctx context.Context
serverHello *serverHelloMsg
hello *clientHelloMsg
ecdheParams ecdheParameters
session *clientSessionState
earlySecret []byte
binderKey []byte
certReq *certificateRequestMsgTLS13
usingPSK bool
sentDummyCCS bool
suite *cipherSuiteTLS13
transcript hash.Hash
masterSecret []byte
trafficSecret []byte // client_application_traffic_secret_0
}
// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and,
// optionally, hs.session, hs.earlySecret and hs.binderKey to be set.
func (hs *clientHandshakeStateTLS13) handshake() error {
c := hs.c
if needFIPS() {
return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode")
}
// The server must not select TLS 1.3 in a renegotiation. See RFC 8446,
// sections 4.1.2 and 4.1.3.
if c.handshakes > 0 {
c.sendAlert(alertProtocolVersion)
return errors.New("tls: server selected TLS 1.3 in a renegotiation")
}
// Consistency check on the presence of a keyShare and its parameters.
if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 {
return c.sendAlert(alertInternalError)
}
if err := hs.checkServerHelloOrHRR(); err != nil {
return err
}
hs.transcript = hs.suite.hash.New()
if err := transcriptMsg(hs.hello, hs.transcript); err != nil {
return err
}
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
if err := hs.processHelloRetryRequest(); err != nil {
return err
}
}
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
return err
}
c.buffering = true
if err := hs.processServerHello(); err != nil {
return err
}
c.updateConnectionState()
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
if err := hs.establishHandshakeKeys(); err != nil {
return err
}
if err := hs.readServerParameters(); err != nil {
return err
}
if err := hs.readServerCertificate(); err != nil {
return err
}
c.updateConnectionState()
if err := hs.readServerFinished(); err != nil {
return err
}
if err := hs.sendClientCertificate(); err != nil {
return err
}
if err := hs.sendClientFinished(); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
atomic.StoreUint32(&c.handshakeStatus, 1)
c.updateConnectionState()
return nil
}
// checkServerHelloOrHRR does validity checks that apply to both ServerHello and
// HelloRetryRequest messages. It sets hs.suite.
func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
c := hs.c
if hs.serverHello.supportedVersion == 0 {
c.sendAlert(alertMissingExtension)
return errors.New("tls: server selected TLS 1.3 using the legacy version field")
}
if hs.serverHello.supportedVersion != VersionTLS13 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid version after a HelloRetryRequest")
}
if hs.serverHello.vers != VersionTLS12 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an incorrect legacy version")
}
if hs.serverHello.ocspStapling ||
hs.serverHello.ticketSupported ||
hs.serverHello.secureRenegotiationSupported ||
len(hs.serverHello.secureRenegotiation) != 0 ||
len(hs.serverHello.alpnProtocol) != 0 ||
len(hs.serverHello.scts) != 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3")
}
if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server did not echo the legacy session ID")
}
if hs.serverHello.compressionMethod != compressionNone {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported compression format")
}
selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite)
if hs.suite != nil && selectedSuite != hs.suite {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server changed cipher suite after a HelloRetryRequest")
}
if selectedSuite == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server chose an unconfigured cipher suite")
}
hs.suite = selectedSuite
c.cipherSuite = hs.suite.id
return nil
}
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.sentDummyCCS {
return nil
}
hs.sentDummyCCS = true
return hs.c.writeChangeCipherRecord()
}
// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and
// resends hs.hello, and reads the new ServerHello into hs.serverHello.
func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
c := hs.c
// The first ClientHello gets double-hashed into the transcript upon a
// HelloRetryRequest. (The idea is that the server might offload transcript
// storage to the client in the cookie.) See RFC 8446, Section 4.4.1.
chHash := hs.transcript.Sum(nil)
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.transcript.Write(chHash)
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
return err
}
// The only HelloRetryRequest extensions we support are key_share and
// cookie, and clients must abort the handshake if the HRR would not result
// in any change in the ClientHello.
if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an unnecessary HelloRetryRequest message")
}
if hs.serverHello.cookie != nil {
hs.hello.cookie = hs.serverHello.cookie
}
if hs.serverHello.serverShare.group != 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: received malformed key_share extension")
}
// If the server sent a key_share extension selecting a group, ensure it's
// a group we advertised but did not send a key share for, and send a key
// share for it this time.
if curveID := hs.serverHello.selectedGroup; curveID != 0 {
curveOK := false
for _, id := range hs.hello.supportedCurves {
if id == curveID {
curveOK = true
break
}
}
if !curveOK {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
if hs.ecdheParams.CurveID() == curveID {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
}
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
c.sendAlert(alertInternalError)
return errors.New("tls: CurvePreferences includes unsupported curve")
}
params, err := generateECDHEParameters(c.config.rand(), curveID)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.ecdheParams = params
hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}}
}
hs.hello.raw = nil
if len(hs.hello.pskIdentities) > 0 {
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil {
return c.sendAlert(alertInternalError)
}
if pskSuite.hash == hs.suite.hash {
// Update binders and obfuscated_ticket_age.
ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond)
hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd
transcript := hs.suite.hash.New()
transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
transcript.Write(chHash)
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
return err
}
helloBytes, err := hs.hello.marshalWithoutBinders()
if err != nil {
return err
}
transcript.Write(helloBytes)
pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)}
if err := hs.hello.updateBinders(pskBinders); err != nil {
return err
}
} else {
// Server selected a cipher suite incompatible with the PSK.
hs.hello.pskIdentities = nil
hs.hello.pskBinders = nil
}
}
if hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil {
c.extraConfig.Rejected0RTT()
}
hs.hello.earlyData = false // disable 0-RTT
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
// serverHelloMsg is not included in the transcript
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
serverHello, ok := msg.(*serverHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(serverHello, msg)
}
hs.serverHello = serverHello
if err := hs.checkServerHelloOrHRR(); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) processServerHello() error {
c := hs.c
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: server sent two HelloRetryRequest messages")
}
if len(hs.serverHello.cookie) != 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent a cookie in a normal ServerHello")
}
if hs.serverHello.selectedGroup != 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: malformed key_share extension")
}
if hs.serverHello.serverShare.group == 0 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server did not send a key share")
}
if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
if !hs.serverHello.selectedIdentityPresent {
return nil
}
if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid PSK")
}
if len(hs.hello.pskIdentities) != 1 || hs.session == nil {
return c.sendAlert(alertInternalError)
}
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil {
return c.sendAlert(alertInternalError)
}
if pskSuite.hash != hs.suite.hash {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected an invalid PSK and cipher suite pair")
}
hs.usingPSK = true
c.didResume = true
c.peerCertificates = hs.session.serverCertificates
c.verifiedChains = hs.session.verifiedChains
c.ocspResponse = hs.session.ocspResponse
c.scts = hs.session.scts
return nil
}
func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
c := hs.c
sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data)
if sharedKey == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid server key share")
}
earlySecret := hs.earlySecret
if !hs.usingPSK {
earlySecret = hs.suite.extract(nil, nil)
}
handshakeSecret := hs.suite.extract(sharedKey,
hs.suite.deriveSecret(earlySecret, "derived", nil))
clientSecret := hs.suite.deriveSecret(handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript)
c.out.exportKey(EncryptionHandshake, hs.suite, clientSecret)
c.out.setTrafficSecret(hs.suite, clientSecret)
serverSecret := hs.suite.deriveSecret(handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript)
c.in.exportKey(EncryptionHandshake, hs.suite, serverSecret)
c.in.setTrafficSecret(hs.suite, serverSecret)
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.masterSecret = hs.suite.extract(nil,
hs.suite.deriveSecret(handshakeSecret, "derived", nil))
return nil
}
func (hs *clientHandshakeStateTLS13) readServerParameters() error {
c := hs.c
msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(encryptedExtensions, msg)
}
// Notify the caller if 0-RTT was rejected.
if !encryptedExtensions.earlyData && hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil {
c.extraConfig.Rejected0RTT()
}
c.used0RTT = encryptedExtensions.earlyData
if hs.c.extraConfig != nil && hs.c.extraConfig.ReceivedExtensions != nil {
hs.c.extraConfig.ReceivedExtensions(typeEncryptedExtensions, encryptedExtensions.additionalExtensions)
}
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil {
c.sendAlert(alertUnsupportedExtension)
return err
}
c.clientProtocol = encryptedExtensions.alpnProtocol
if c.extraConfig != nil && c.extraConfig.EnforceNextProtoSelection {
if len(encryptedExtensions.alpnProtocol) == 0 {
// the server didn't select an ALPN
c.sendAlert(alertNoApplicationProtocol)
return errors.New("ALPN negotiation failed. Server didn't offer any protocols")
}
}
return nil
}
func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
c := hs.c
// Either a PSK or a certificate is always used, but not both.
// See RFC 8446, Section 4.1.1.
if hs.usingPSK {
// Make sure the connection is still being verified whether or not this
// is a resumption. Resumptions currently don't reverify certificates so
// they don't call verifyServerCertificate. See Issue 31641.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
certReq, ok := msg.(*certificateRequestMsgTLS13)
if ok {
hs.certReq = certReq
msg, err = c.readHandshake(hs.transcript)
if err != nil {
return err
}
}
certMsg, ok := msg.(*certificateMsgTLS13)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if len(certMsg.certificate.Certificate) == 0 {
c.sendAlert(alertDecodeError)
return errors.New("tls: received empty certificates message")
}
c.scts = certMsg.certificate.SignedCertificateTimestamps
c.ocspResponse = certMsg.certificate.OCSPStaple
if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil {
return err
}
// certificateVerifyMsg is included in the transcript, but not until
// after we verify the handshake signature, since the state before
// this message was sent is used.
msg, err = c.readHandshake(nil)
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
// See RFC 8446, Section 4.4.3.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: certificate used with invalid signature algorithm")
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
if err := transcriptMsg(certVerify, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) readServerFinished() error {
c := hs.c
// finishedMsg is included in the transcript, but not until after we
// check the client version, since the state before this message was
// sent is used during verification.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
finished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(finished, msg)
}
expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
if !hmac.Equal(expectedMAC, finished.verifyData) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid server finished hash")
}
if err := transcriptMsg(finished, hs.transcript); err != nil {
return err
}
// Derive secrets that take context through the server Finished.
hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript)
c.in.exportKey(EncryptionApplication, hs.suite, serverSecret)
c.in.setTrafficSecret(hs.suite, serverSecret)
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
return nil
}
func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
c := hs.c
if hs.certReq == nil {
return nil
}
cert, err := c.getClientCertificate(toCertificateRequestInfo(&certificateRequestInfo{
AcceptableCAs: hs.certReq.certificateAuthorities,
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
Version: c.vers,
ctx: hs.ctx,
}))
if err != nil {
return err
}
certMsg := new(certificateMsgTLS13)
certMsg.certificate = *cert
certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0
if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
return err
}
// If we sent an empty certificate message, skip the CertificateVerify.
if len(cert.Certificate) == 0 {
return nil
}
certVerifyMsg := new(certificateVerifyMsg)
certVerifyMsg.hasSignatureAlgorithm = true
certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms)
if err != nil {
// getClientCertificate returned a certificate incompatible with the
// CertificateRequestInfo supported signature algorithms.
c.sendAlert(alertHandshakeFailure)
return err
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
if err != nil {
c.sendAlert(alertInternalError)
return errors.New("tls: failed to sign handshake: " + err.Error())
}
certVerifyMsg.signature = sig
if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
c := hs.c
finished := &finishedMsg{
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
return err
}
c.out.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret)
c.out.setTrafficSecret(hs.suite, hs.trafficSecret)
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript)
}
return nil
}
func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
if !c.isClient {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: received new session ticket from a client")
}
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
return nil
}
// See RFC 8446, Section 4.6.1.
if msg.lifetime == 0 {
return nil
}
lifetime := time.Duration(msg.lifetime) * time.Second
if lifetime > maxSessionTicketLifetime {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: received a session ticket with invalid lifetime")
}
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
if cipherSuite == nil || c.resumptionSecret == nil {
return c.sendAlert(alertInternalError)
}
// We need to save the max_early_data_size that the server sent us, in order
// to decide if we're going to try 0-RTT with this ticket.
// However, at the same time, the qtls.ClientSessionTicket needs to be equal to
// the tls.ClientSessionTicket, so we can't just add a new field to the struct.
// We therefore abuse the nonce field (which is a byte slice)
nonceWithEarlyData := make([]byte, len(msg.nonce)+4)
binary.BigEndian.PutUint32(nonceWithEarlyData, msg.maxEarlyData)
copy(nonceWithEarlyData[4:], msg.nonce)
var appData []byte
if c.extraConfig != nil && c.extraConfig.GetAppDataForSessionState != nil {
appData = c.extraConfig.GetAppDataForSessionState()
}
var b cryptobyte.Builder
b.AddUint16(clientSessionStateVersion) // revision
b.AddUint32(msg.maxEarlyData)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(appData)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(msg.nonce)
})
// Save the resumption_master_secret and nonce instead of deriving the PSK
// to do the least amount of work on NewSessionTicket messages before we
// know if the ticket will be used. Forward secrecy of resumed connections
// is guaranteed by the requirement for pskModeDHE.
session := &clientSessionState{
sessionTicket: msg.label,
vers: c.vers,
cipherSuite: c.cipherSuite,
masterSecret: c.resumptionSecret,
serverCertificates: c.peerCertificates,
verifiedChains: c.verifiedChains,
receivedAt: c.config.time(),
nonce: b.BytesOrPanic(),
useBy: c.config.time().Add(lifetime),
ageAdd: msg.ageAdd,
ocspResponse: c.ocspResponse,
scts: c.scts,
}
cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session))
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -1,922 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/subtle"
"crypto/x509"
"errors"
"fmt"
"hash"
"io"
"sync/atomic"
"time"
)
// serverHandshakeState contains details of a server handshake in progress.
// It's discarded once the handshake has completed.
type serverHandshakeState struct {
c *Conn
ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
suite *cipherSuite
ecdheOk bool
ecSignOk bool
rsaDecryptOk bool
rsaSignOk bool
sessionState *sessionState
finishedHash finishedHash
masterSecret []byte
cert *Certificate
}
// serverHandshake performs a TLS handshake as a server.
func (c *Conn) serverHandshake(ctx context.Context) error {
c.setAlternativeRecordLayer()
clientHello, err := c.readClientHello(ctx)
if err != nil {
return err
}
if c.vers == VersionTLS13 {
hs := serverHandshakeStateTLS13{
c: c,
ctx: ctx,
clientHello: clientHello,
}
return hs.handshake()
} else if c.extraConfig.usesAlternativeRecordLayer() {
// This should already have been caught by the check that the ClientHello doesn't
// offer any (supported) versions older than TLS 1.3.
// Check again to make sure we can't be tricked into using an older version.
c.sendAlert(alertProtocolVersion)
return errors.New("tls: negotiated TLS < 1.3 when using QUIC")
}
hs := serverHandshakeState{
c: c,
ctx: ctx,
clientHello: clientHello,
}
return hs.handshake()
}
func (hs *serverHandshakeState) handshake() error {
c := hs.c
if err := hs.processClientHello(); err != nil {
return err
}
// For an overview of TLS handshaking, see RFC 5246, Section 7.3.
c.buffering = true
if hs.checkForResumption() {
// The client has included a session ticket and so we do an abbreviated handshake.
c.didResume = true
if err := hs.doResumeHandshake(); err != nil {
return err
}
if err := hs.establishKeys(); err != nil {
return err
}
if err := hs.sendSessionTicket(); err != nil {
return err
}
if err := hs.sendFinished(c.serverFinished[:]); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
c.clientFinishedIsFirst = false
if err := hs.readFinished(nil); err != nil {
return err
}
} else {
// The client didn't include a session ticket, or it wasn't
// valid so we do a full handshake.
if err := hs.pickCipherSuite(); err != nil {
return err
}
if err := hs.doFullHandshake(); err != nil {
return err
}
if err := hs.establishKeys(); err != nil {
return err
}
if err := hs.readFinished(c.clientFinished[:]); err != nil {
return err
}
c.clientFinishedIsFirst = true
c.buffering = true
if err := hs.sendSessionTicket(); err != nil {
return err
}
if err := hs.sendFinished(nil); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
}
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
atomic.StoreUint32(&c.handshakeStatus, 1)
c.updateConnectionState()
return nil
}
// readClientHello reads a ClientHello message and selects the protocol version.
func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
// clientHelloMsg is included in the transcript, but we haven't initialized
// it yet. The respective handshake functions will record it themselves.
msg, err := c.readHandshake(nil)
if err != nil {
return nil, err
}
clientHello, ok := msg.(*clientHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return nil, unexpectedMessageError(clientHello, msg)
}
var configForClient *config
originalConfig := c.config
if c.config.GetConfigForClient != nil {
chi := newClientHelloInfo(ctx, c, clientHello)
if cfc, err := c.config.GetConfigForClient(chi); err != nil {
c.sendAlert(alertInternalError)
return nil, err
} else if cfc != nil {
configForClient = fromConfig(cfc)
c.config = configForClient
}
}
c.ticketKeys = originalConfig.ticketKeys(configForClient)
clientVersions := clientHello.supportedVersions
if len(clientHello.supportedVersions) == 0 {
clientVersions = supportedVersionsFromMax(clientHello.vers)
}
if c.extraConfig.usesAlternativeRecordLayer() {
// In QUIC, the client MUST NOT offer any old TLS versions.
// Here, we can only check that none of the other supported versions of this library
// (TLS 1.0 - TLS 1.2) is offered. We don't check for any SSL versions here.
for _, ver := range clientVersions {
if ver == VersionTLS13 {
continue
}
for _, v := range supportedVersions {
if ver == v {
c.sendAlert(alertProtocolVersion)
return nil, fmt.Errorf("tls: client offered old TLS version %#x", ver)
}
}
}
// Make the config we're using allows us to use TLS 1.3.
if c.config.maxSupportedVersion(roleServer) < VersionTLS13 {
c.sendAlert(alertInternalError)
return nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3")
}
}
c.vers, ok = c.config.mutualVersion(roleServer, clientVersions)
if !ok {
c.sendAlert(alertProtocolVersion)
return nil, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions)
}
c.haveVers = true
c.in.version = c.vers
c.out.version = c.vers
return clientHello, nil
}
func (hs *serverHandshakeState) processClientHello() error {
c := hs.c
hs.hello = new(serverHelloMsg)
hs.hello.vers = c.vers
foundCompression := false
// We only support null compression, so check that the client offered it.
for _, compression := range hs.clientHello.compressionMethods {
if compression == compressionNone {
foundCompression = true
break
}
}
if !foundCompression {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client does not support uncompressed connections")
}
hs.hello.random = make([]byte, 32)
serverRandom := hs.hello.random
// Downgrade protection canaries. See RFC 8446, Section 4.1.3.
maxVers := c.config.maxSupportedVersion(roleServer)
if maxVers >= VersionTLS12 && c.vers < maxVers || testingOnlyForceDowngradeCanary {
if c.vers == VersionTLS12 {
copy(serverRandom[24:], downgradeCanaryTLS12)
} else {
copy(serverRandom[24:], downgradeCanaryTLS11)
}
serverRandom = serverRandom[:24]
}
_, err := io.ReadFull(c.config.rand(), serverRandom)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
if len(hs.clientHello.secureRenegotiation) != 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: initial handshake had non-empty renegotiation extension")
}
hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
hs.hello.compressionMethod = compressionNone
if len(hs.clientHello.serverName) > 0 {
c.serverName = hs.clientHello.serverName
}
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
if err != nil {
c.sendAlert(alertNoApplicationProtocol)
return err
}
hs.hello.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
hs.cert, err = c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
} else {
c.sendAlert(alertInternalError)
}
return err
}
if hs.clientHello.scts {
hs.hello.scts = hs.cert.SignedCertificateTimestamps
}
hs.ecdheOk = supportsECDHE(c.config, hs.clientHello.supportedCurves, hs.clientHello.supportedPoints)
if hs.ecdheOk && len(hs.clientHello.supportedPoints) > 0 {
// Although omitting the ec_point_formats extension is permitted, some
// old OpenSSL version will refuse to handshake if not present.
//
// Per RFC 4492, section 5.1.2, implementations MUST support the
// uncompressed point format. See golang.org/issue/31943.
hs.hello.supportedPoints = []uint8{pointFormatUncompressed}
}
if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok {
switch priv.Public().(type) {
case *ecdsa.PublicKey:
hs.ecSignOk = true
case ed25519.PublicKey:
hs.ecSignOk = true
case *rsa.PublicKey:
hs.rsaSignOk = true
default:
c.sendAlert(alertInternalError)
return fmt.Errorf("tls: unsupported signing key type (%T)", priv.Public())
}
}
if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok {
switch priv.Public().(type) {
case *rsa.PublicKey:
hs.rsaDecryptOk = true
default:
c.sendAlert(alertInternalError)
return fmt.Errorf("tls: unsupported decryption key type (%T)", priv.Public())
}
}
return nil
}
// negotiateALPN picks a shared ALPN protocol that both sides support in server
// preference order. If ALPN is not configured or the peer doesn't support it,
// it returns "" and no error.
func negotiateALPN(serverProtos, clientProtos []string) (string, error) {
if len(serverProtos) == 0 || len(clientProtos) == 0 {
return "", nil
}
var http11fallback bool
for _, s := range serverProtos {
for _, c := range clientProtos {
if s == c {
return s, nil
}
if s == "h2" && c == "http/1.1" {
http11fallback = true
}
}
}
// As a special case, let http/1.1 clients connect to h2 servers as if they
// didn't support ALPN. We used not to enforce protocol overlap, so over
// time a number of HTTP servers were configured with only "h2", but
// expected to accept connections from "http/1.1" clients. See Issue 46310.
if http11fallback {
return "", nil
}
return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos)
}
// supportsECDHE returns whether ECDHE key exchanges can be used with this
// pre-TLS 1.3 client.
func supportsECDHE(c *config, supportedCurves []CurveID, supportedPoints []uint8) bool {
supportsCurve := false
for _, curve := range supportedCurves {
if c.supportsCurve(curve) {
supportsCurve = true
break
}
}
supportsPointFormat := false
for _, pointFormat := range supportedPoints {
if pointFormat == pointFormatUncompressed {
supportsPointFormat = true
break
}
}
// Per RFC 8422, Section 5.1.2, if the Supported Point Formats extension is
// missing, uncompressed points are supported. If supportedPoints is empty,
// the extension must be missing, as an empty extension body is rejected by
// the parser. See https://go.dev/issue/49126.
if len(supportedPoints) == 0 {
supportsPointFormat = true
}
return supportsCurve && supportsPointFormat
}
func (hs *serverHandshakeState) pickCipherSuite() error {
c := hs.c
preferenceOrder := cipherSuitesPreferenceOrder
if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) {
preferenceOrder = cipherSuitesPreferenceOrderNoAES
}
configCipherSuites := c.config.cipherSuites()
preferenceList := make([]uint16, 0, len(configCipherSuites))
for _, suiteID := range preferenceOrder {
for _, id := range configCipherSuites {
if id == suiteID {
preferenceList = append(preferenceList, id)
break
}
}
}
hs.suite = selectCipherSuite(preferenceList, hs.clientHello.cipherSuites, hs.cipherSuiteOk)
if hs.suite == nil {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no cipher suite supported by both client and server")
}
c.cipherSuite = hs.suite.id
for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV {
// The client is doing a fallback connection. See RFC 7507.
if hs.clientHello.vers < c.config.maxSupportedVersion(roleServer) {
c.sendAlert(alertInappropriateFallback)
return errors.New("tls: client using inappropriate protocol fallback")
}
break
}
}
return nil
}
func (hs *serverHandshakeState) cipherSuiteOk(c *cipherSuite) bool {
if c.flags&suiteECDHE != 0 {
if !hs.ecdheOk {
return false
}
if c.flags&suiteECSign != 0 {
if !hs.ecSignOk {
return false
}
} else if !hs.rsaSignOk {
return false
}
} else if !hs.rsaDecryptOk {
return false
}
if hs.c.vers < VersionTLS12 && c.flags&suiteTLS12 != 0 {
return false
}
return true
}
// checkForResumption reports whether we should perform resumption on this connection.
func (hs *serverHandshakeState) checkForResumption() bool {
c := hs.c
if c.config.SessionTicketsDisabled {
return false
}
plaintext, usedOldKey := c.decryptTicket(hs.clientHello.sessionTicket)
if plaintext == nil {
return false
}
hs.sessionState = &sessionState{usedOldKey: usedOldKey}
ok := hs.sessionState.unmarshal(plaintext)
if !ok {
return false
}
createdAt := time.Unix(int64(hs.sessionState.createdAt), 0)
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
return false
}
// Never resume a session for a different TLS version.
if c.vers != hs.sessionState.vers {
return false
}
cipherSuiteOk := false
// Check that the client is still offering the ciphersuite in the session.
for _, id := range hs.clientHello.cipherSuites {
if id == hs.sessionState.cipherSuite {
cipherSuiteOk = true
break
}
}
if !cipherSuiteOk {
return false
}
// Check that we also support the ciphersuite from the session.
hs.suite = selectCipherSuite([]uint16{hs.sessionState.cipherSuite},
c.config.cipherSuites(), hs.cipherSuiteOk)
if hs.suite == nil {
return false
}
sessionHasClientCerts := len(hs.sessionState.certificates) != 0
needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts {
return false
}
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
return false
}
return true
}
func (hs *serverHandshakeState) doResumeHandshake() error {
c := hs.c
hs.hello.cipherSuite = hs.suite.id
c.cipherSuite = hs.suite.id
// We echo the client's session ID in the ServerHello to let it know
// that we're doing a resumption.
hs.hello.sessionId = hs.clientHello.sessionId
hs.hello.ticketSupported = hs.sessionState.usedOldKey
hs.finishedHash = newFinishedHash(c.vers, hs.suite)
hs.finishedHash.discardHandshakeBuffer()
if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
return err
}
if err := c.processCertsFromClient(Certificate{
Certificate: hs.sessionState.certificates,
}); err != nil {
return err
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
hs.masterSecret = hs.sessionState.masterSecret
return nil
}
func (hs *serverHandshakeState) doFullHandshake() error {
c := hs.c
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
hs.hello.ocspStapling = true
}
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
hs.hello.cipherSuite = hs.suite.id
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
if c.config.ClientAuth == NoClientCert {
// No need to keep a full record of the handshake if client
// certificates won't be used.
hs.finishedHash.discardHandshakeBuffer()
}
if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
return err
}
certMsg := new(certificateMsg)
certMsg.certificates = hs.cert.Certificate
if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil {
return err
}
if hs.hello.ocspStapling {
certStatus := new(certificateStatusMsg)
certStatus.response = hs.cert.OCSPStaple
if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil {
return err
}
}
keyAgreement := hs.suite.ka(c.vers)
skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello)
if err != nil {
c.sendAlert(alertHandshakeFailure)
return err
}
if skx != nil {
if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil {
return err
}
}
var certReq *certificateRequestMsg
if c.config.ClientAuth >= RequestClientCert {
// Request a client certificate
certReq = new(certificateRequestMsg)
certReq.certificateTypes = []byte{
byte(certTypeRSASign),
byte(certTypeECDSASign),
}
if c.vers >= VersionTLS12 {
certReq.hasSignatureAlgorithm = true
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
}
// An empty list of certificateAuthorities signals to
// the client that it may send any certificate in response
// to our request. When we know the CAs we trust, then
// we can send them down, so that the client can choose
// an appropriate certificate to give to us.
if c.config.ClientCAs != nil {
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil {
return err
}
}
helloDone := new(serverHelloDoneMsg)
if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
var pub crypto.PublicKey // public key for client auth, if any
msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
// If we requested a client certificate, then the client must send a
// certificate message, even if it's empty.
if c.config.ClientAuth >= RequestClientCert {
certMsg, ok := msg.(*certificateMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if err := c.processCertsFromClient(Certificate{
Certificate: certMsg.certificates,
}); err != nil {
return err
}
if len(certMsg.certificates) != 0 {
pub = c.peerCertificates[0].PublicKey
}
msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
// Get client key exchange
ckx, ok := msg.(*clientKeyExchangeMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(ckx, msg)
}
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
if err != nil {
c.sendAlert(alertHandshakeFailure)
return err
}
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil {
c.sendAlert(alertInternalError)
return err
}
// If we received a client cert in response to our certificate request message,
// the client will send us a certificateVerifyMsg immediately after the
// clientKeyExchangeMsg. This message is a digest of all preceding
// handshake-layer messages that is signed using the private key corresponding
// to the client's certificate. This allows us to verify that the client is in
// possession of the private key of the certificate.
if len(c.peerCertificates) > 0 {
// certificateVerifyMsg is included in the transcript, but not until
// after we verify the handshake signature, since the state before
// this message was sent is used.
msg, err = c.readHandshake(nil)
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
var sigType uint8
var sigHash crypto.Hash
if c.vers >= VersionTLS12 {
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, certReq.supportedSignatureAlgorithms) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client certificate used with invalid signature algorithm")
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub)
if err != nil {
c.sendAlert(alertIllegalParameter)
return err
}
}
signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash, hs.masterSecret)
if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil {
return err
}
}
hs.finishedHash.discardHandshakeBuffer()
return nil
}
func (hs *serverHandshakeState) establishKeys() error {
c := hs.c
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
var clientCipher, serverCipher any
var clientHash, serverHash hash.Hash
if hs.suite.aead == nil {
clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */)
clientHash = hs.suite.mac(clientMAC)
serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */)
serverHash = hs.suite.mac(serverMAC)
} else {
clientCipher = hs.suite.aead(clientKey, clientIV)
serverCipher = hs.suite.aead(serverKey, serverIV)
}
c.in.prepareCipherSpec(c.vers, clientCipher, clientHash)
c.out.prepareCipherSpec(c.vers, serverCipher, serverHash)
return nil
}
func (hs *serverHandshakeState) readFinished(out []byte) error {
c := hs.c
if err := c.readChangeCipherSpec(); err != nil {
return err
}
// finishedMsg is included in the transcript, but not until after we
// check the client version, since the state before this message was
// sent is used during verification.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
clientFinished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(clientFinished, msg)
}
verify := hs.finishedHash.clientSum(hs.masterSecret)
if len(verify) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client's Finished message is incorrect")
}
if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil {
return err
}
copy(out, verify)
return nil
}
func (hs *serverHandshakeState) sendSessionTicket() error {
// ticketSupported is set in a resumption handshake if the
// ticket from the client was encrypted with an old session
// ticket key and thus a refreshed ticket should be sent.
if !hs.hello.ticketSupported {
return nil
}
c := hs.c
m := new(newSessionTicketMsg)
createdAt := uint64(c.config.time().Unix())
if hs.sessionState != nil {
// If this is re-wrapping an old key, then keep
// the original time it was created.
createdAt = hs.sessionState.createdAt
}
var certsFromClient [][]byte
for _, cert := range c.peerCertificates {
certsFromClient = append(certsFromClient, cert.Raw)
}
state := sessionState{
vers: c.vers,
cipherSuite: hs.suite.id,
createdAt: createdAt,
masterSecret: hs.masterSecret,
certificates: certsFromClient,
}
stateBytes, err := state.marshal()
if err != nil {
return err
}
m.ticket, err = c.encryptTicket(stateBytes)
if err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeState) sendFinished(out []byte) error {
c := hs.c
if err := c.writeChangeCipherRecord(); err != nil {
return err
}
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil {
return err
}
copy(out, finished.verifyData)
return nil
}
// processCertsFromClient takes a chain of client certificates either from a
// Certificates message or from a sessionState and verifies them. It returns
// the public key of the leaf certificate.
func (c *Conn) processCertsFromClient(certificate Certificate) error {
certificates := certificate.Certificate
certs := make([]*x509.Certificate, len(certificates))
var err error
for i, asn1Data := range certificates {
if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
c.sendAlert(alertBadCertificate)
return errors.New("tls: failed to parse client certificate: " + err.Error())
}
}
if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) {
c.sendAlert(alertBadCertificate)
return errors.New("tls: client didn't provide a certificate")
}
if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
opts := x509.VerifyOptions{
Roots: c.config.ClientCAs,
CurrentTime: c.config.time(),
Intermediates: x509.NewCertPool(),
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
chains, err := certs[0].Verify(opts)
if err != nil {
c.sendAlert(alertBadCertificate)
return errors.New("tls: failed to verify client certificate: " + err.Error())
}
c.verifiedChains = chains
}
c.peerCertificates = certs
c.ocspResponse = certificate.OCSPStaple
c.scts = certificate.SignedCertificateTimestamps
if len(certs) > 0 {
switch certs[0].PublicKey.(type) {
case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey:
default:
c.sendAlert(alertUnsupportedCertificate)
return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey)
}
}
if c.config.VerifyPeerCertificate != nil {
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
func newClientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
supportedVersions := clientHello.supportedVersions
if len(clientHello.supportedVersions) == 0 {
supportedVersions = supportedVersionsFromMax(clientHello.vers)
}
return toClientHelloInfo(&clientHelloInfo{
CipherSuites: clientHello.cipherSuites,
ServerName: clientHello.serverName,
SupportedCurves: clientHello.supportedCurves,
SupportedPoints: clientHello.supportedPoints,
SignatureSchemes: clientHello.supportedSignatureAlgorithms,
SupportedProtos: clientHello.alpnProtocols,
SupportedVersions: supportedVersions,
Conn: c.conn,
config: toConfig(c.config),
ctx: ctx,
})
}

View File

@ -1,903 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"bytes"
"context"
"crypto"
"crypto/hmac"
"crypto/rsa"
"errors"
"hash"
"io"
"sync/atomic"
"time"
)
// maxClientPSKIdentities is the number of client PSK identities the server will
// attempt to validate. It will ignore the rest not to let cheap ClientHello
// messages cause too much work in session ticket decryption attempts.
const maxClientPSKIdentities = 5
type serverHandshakeStateTLS13 struct {
c *Conn
ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
alpnNegotiationErr error
encryptedExtensions *encryptedExtensionsMsg
sentDummyCCS bool
usingPSK bool
suite *cipherSuiteTLS13
cert *Certificate
sigAlg SignatureScheme
earlySecret []byte
sharedKey []byte
handshakeSecret []byte
masterSecret []byte
trafficSecret []byte // client_application_traffic_secret_0
transcript hash.Hash
clientFinished []byte
}
func (hs *serverHandshakeStateTLS13) handshake() error {
c := hs.c
if needFIPS() {
return errors.New("tls: internal error: TLS 1.3 reached in FIPS mode")
}
// For an overview of the TLS 1.3 handshake, see RFC 8446, Section 2.
if err := hs.processClientHello(); err != nil {
return err
}
if err := hs.checkForResumption(); err != nil {
return err
}
c.updateConnectionState()
if err := hs.pickCertificate(); err != nil {
return err
}
c.buffering = true
if err := hs.sendServerParameters(); err != nil {
return err
}
if err := hs.sendServerCertificate(); err != nil {
return err
}
if err := hs.sendServerFinished(); err != nil {
return err
}
// Note that at this point we could start sending application data without
// waiting for the client's second flight, but the application might not
// expect the lack of replay protection of the ClientHello parameters.
if _, err := c.flush(); err != nil {
return err
}
if err := hs.readClientCertificate(); err != nil {
return err
}
c.updateConnectionState()
if err := hs.readClientFinished(); err != nil {
return err
}
atomic.StoreUint32(&c.handshakeStatus, 1)
c.updateConnectionState()
return nil
}
func (hs *serverHandshakeStateTLS13) processClientHello() error {
c := hs.c
hs.hello = new(serverHelloMsg)
hs.encryptedExtensions = new(encryptedExtensionsMsg)
// TLS 1.3 froze the ServerHello.legacy_version field, and uses
// supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1.
hs.hello.vers = VersionTLS12
hs.hello.supportedVersion = c.vers
if len(hs.clientHello.supportedVersions) == 0 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client used the legacy version field to negotiate TLS 1.3")
}
// Abort if the client is doing a fallback and landing lower than what we
// support. See RFC 7507, which however does not specify the interaction
// with supported_versions. The only difference is that with
// supported_versions a client has a chance to attempt a [TLS 1.2, TLS 1.4]
// handshake in case TLS 1.3 is broken but 1.2 is not. Alas, in that case,
// it will have to drop the TLS_FALLBACK_SCSV protection if it falls back to
// TLS 1.2, because a TLS 1.3 server would abort here. The situation before
// supported_versions was not better because there was just no way to do a
// TLS 1.4 handshake without risking the server selecting TLS 1.3.
for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV {
// Use c.vers instead of max(supported_versions) because an attacker
// could defeat this by adding an arbitrary high version otherwise.
if c.vers < c.config.maxSupportedVersion(roleServer) {
c.sendAlert(alertInappropriateFallback)
return errors.New("tls: client using inappropriate protocol fallback")
}
break
}
}
if len(hs.clientHello.compressionMethods) != 1 ||
hs.clientHello.compressionMethods[0] != compressionNone {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: TLS 1.3 client supports illegal compression methods")
}
hs.hello.random = make([]byte, 32)
if _, err := io.ReadFull(c.config.rand(), hs.hello.random); err != nil {
c.sendAlert(alertInternalError)
return err
}
if len(hs.clientHello.secureRenegotiation) != 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: initial handshake had non-empty renegotiation extension")
}
hs.hello.sessionId = hs.clientHello.sessionId
hs.hello.compressionMethod = compressionNone
preferenceList := defaultCipherSuitesTLS13
if !hasAESGCMHardwareSupport || !aesgcmPreferred(hs.clientHello.cipherSuites) {
preferenceList = defaultCipherSuitesTLS13NoAES
}
for _, suiteID := range preferenceList {
hs.suite = mutualCipherSuiteTLS13(hs.clientHello.cipherSuites, suiteID)
if hs.suite != nil {
break
}
}
if hs.suite == nil {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no cipher suite supported by both client and server")
}
c.cipherSuite = hs.suite.id
hs.hello.cipherSuite = hs.suite.id
hs.transcript = hs.suite.hash.New()
// Pick the ECDHE group in server preference order, but give priority to
// groups with a key share, to avoid a HelloRetryRequest round-trip.
var selectedGroup CurveID
var clientKeyShare *keyShare
GroupSelection:
for _, preferredGroup := range c.config.curvePreferences() {
for _, ks := range hs.clientHello.keyShares {
if ks.group == preferredGroup {
selectedGroup = ks.group
clientKeyShare = &ks
break GroupSelection
}
}
if selectedGroup != 0 {
continue
}
for _, group := range hs.clientHello.supportedCurves {
if group == preferredGroup {
selectedGroup = group
break
}
}
}
if selectedGroup == 0 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: no ECDHE curve supported by both client and server")
}
if clientKeyShare == nil {
if err := hs.doHelloRetryRequest(selectedGroup); err != nil {
return err
}
clientKeyShare = &hs.clientHello.keyShares[0]
}
if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok {
c.sendAlert(alertInternalError)
return errors.New("tls: CurvePreferences includes unsupported curve")
}
params, err := generateECDHEParameters(c.config.rand(), selectedGroup)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()}
hs.sharedKey = params.SharedKey(clientKeyShare.data)
if hs.sharedKey == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid client key share")
}
c.serverName = hs.clientHello.serverName
if c.extraConfig != nil && c.extraConfig.ReceivedExtensions != nil {
c.extraConfig.ReceivedExtensions(typeClientHello, hs.clientHello.additionalExtensions)
}
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
if err != nil {
hs.alpnNegotiationErr = err
}
hs.encryptedExtensions.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
return nil
}
func (hs *serverHandshakeStateTLS13) checkForResumption() error {
c := hs.c
if c.config.SessionTicketsDisabled {
return nil
}
modeOK := false
for _, mode := range hs.clientHello.pskModes {
if mode == pskModeDHE {
modeOK = true
break
}
}
if !modeOK {
return nil
}
if len(hs.clientHello.pskIdentities) != len(hs.clientHello.pskBinders) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: invalid or missing PSK binders")
}
if len(hs.clientHello.pskIdentities) == 0 {
return nil
}
for i, identity := range hs.clientHello.pskIdentities {
if i >= maxClientPSKIdentities {
break
}
plaintext, _ := c.decryptTicket(identity.label)
if plaintext == nil {
continue
}
sessionState := new(sessionStateTLS13)
if ok := sessionState.unmarshal(plaintext); !ok {
continue
}
if hs.clientHello.earlyData {
if sessionState.maxEarlyData == 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: client sent unexpected early data")
}
if hs.alpnNegotiationErr == nil && sessionState.alpn == c.clientProtocol &&
c.extraConfig != nil && c.extraConfig.MaxEarlyData > 0 &&
c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) {
hs.encryptedExtensions.earlyData = true
c.used0RTT = true
}
}
createdAt := time.Unix(int64(sessionState.createdAt), 0)
if c.config.time().Sub(createdAt) > maxSessionTicketLifetime {
continue
}
// We don't check the obfuscated ticket age because it's affected by
// clock skew and it's only a freshness signal useful for shrinking the
// window for replay attacks, which don't affect us as we don't do 0-RTT.
pskSuite := cipherSuiteTLS13ByID(sessionState.cipherSuite)
if pskSuite == nil || pskSuite.hash != hs.suite.hash {
continue
}
// PSK connections don't re-establish client certificates, but carry
// them over in the session ticket. Ensure the presence of client certs
// in the ticket is consistent with the configured requirements.
sessionHasClientCerts := len(sessionState.certificate.Certificate) != 0
needClientCerts := requiresClientCert(c.config.ClientAuth)
if needClientCerts && !sessionHasClientCerts {
continue
}
if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
continue
}
psk := hs.suite.expandLabel(sessionState.resumptionSecret, "resumption",
nil, hs.suite.hash.Size())
hs.earlySecret = hs.suite.extract(psk, nil)
binderKey := hs.suite.deriveSecret(hs.earlySecret, resumptionBinderLabel, nil)
// Clone the transcript in case a HelloRetryRequest was recorded.
transcript := cloneHash(hs.transcript, hs.suite.hash)
if transcript == nil {
c.sendAlert(alertInternalError)
return errors.New("tls: internal error: failed to clone hash")
}
clientHelloBytes, err := hs.clientHello.marshalWithoutBinders()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
transcript.Write(clientHelloBytes)
pskBinder := hs.suite.finishedHash(binderKey, transcript)
if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid PSK binder")
}
c.didResume = true
if err := c.processCertsFromClient(sessionState.certificate); err != nil {
return err
}
h := cloneHash(hs.transcript, hs.suite.hash)
clientHelloWithBindersBytes, err := hs.clientHello.marshal()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
h.Write(clientHelloWithBindersBytes)
if hs.encryptedExtensions.earlyData {
clientEarlySecret := hs.suite.deriveSecret(hs.earlySecret, "c e traffic", h)
c.in.exportKey(Encryption0RTT, hs.suite, clientEarlySecret)
if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hs.clientHello.random, clientEarlySecret); err != nil {
c.sendAlert(alertInternalError)
return err
}
}
hs.hello.selectedIdentityPresent = true
hs.hello.selectedIdentity = uint16(i)
hs.usingPSK = true
return nil
}
return nil
}
// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
// interfaces implemented by standard library hashes to clone the state of in
// to a new instance of h. It returns nil if the operation fails.
func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash {
// Recreate the interface to avoid importing encoding.
type binaryMarshaler interface {
MarshalBinary() (data []byte, err error)
UnmarshalBinary(data []byte) error
}
marshaler, ok := in.(binaryMarshaler)
if !ok {
return nil
}
state, err := marshaler.MarshalBinary()
if err != nil {
return nil
}
out := h.New()
unmarshaler, ok := out.(binaryMarshaler)
if !ok {
return nil
}
if err := unmarshaler.UnmarshalBinary(state); err != nil {
return nil
}
return out
}
func (hs *serverHandshakeStateTLS13) pickCertificate() error {
c := hs.c
// Only one of PSK and certificates are used at a time.
if hs.usingPSK {
return nil
}
// signature_algorithms is required in TLS 1.3. See RFC 8446, Section 4.2.3.
if len(hs.clientHello.supportedSignatureAlgorithms) == 0 {
return c.sendAlert(alertMissingExtension)
}
certificate, err := c.config.getCertificate(newClientHelloInfo(hs.ctx, c, hs.clientHello))
if err != nil {
if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName)
} else {
c.sendAlert(alertInternalError)
}
return err
}
hs.sigAlg, err = selectSignatureScheme(c.vers, certificate, hs.clientHello.supportedSignatureAlgorithms)
if err != nil {
// getCertificate returned a certificate that is unsupported or
// incompatible with the client's signature algorithms.
c.sendAlert(alertHandshakeFailure)
return err
}
hs.cert = certificate
return nil
}
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.sentDummyCCS {
return nil
}
hs.sentDummyCCS = true
return hs.c.writeChangeCipherRecord()
}
func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error {
c := hs.c
// The first ClientHello gets double-hashed into the transcript upon a
// HelloRetryRequest. See RFC 8446, Section 4.4.1.
if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
return err
}
chHash := hs.transcript.Sum(nil)
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.transcript.Write(chHash)
helloRetryRequest := &serverHelloMsg{
vers: hs.hello.vers,
random: helloRetryRequestRandom,
sessionId: hs.hello.sessionId,
cipherSuite: hs.hello.cipherSuite,
compressionMethod: hs.hello.compressionMethod,
supportedVersion: hs.hello.supportedVersion,
selectedGroup: selectedGroup,
}
if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
// clientHelloMsg is not included in the transcript.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
clientHello, ok := msg.(*clientHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(clientHello, msg)
}
if len(clientHello.keyShares) != 1 || clientHello.keyShares[0].group != selectedGroup {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client sent invalid key share in second ClientHello")
}
if clientHello.earlyData {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client indicated early data in second ClientHello")
}
if illegalClientHelloChange(clientHello, hs.clientHello) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client illegally modified second ClientHello")
}
if clientHello.earlyData {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client offered 0-RTT data in second ClientHello")
}
hs.clientHello = clientHello
return nil
}
// illegalClientHelloChange reports whether the two ClientHello messages are
// different, with the exception of the changes allowed before and after a
// HelloRetryRequest. See RFC 8446, Section 4.1.2.
func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool {
if len(ch.supportedVersions) != len(ch1.supportedVersions) ||
len(ch.cipherSuites) != len(ch1.cipherSuites) ||
len(ch.supportedCurves) != len(ch1.supportedCurves) ||
len(ch.supportedSignatureAlgorithms) != len(ch1.supportedSignatureAlgorithms) ||
len(ch.supportedSignatureAlgorithmsCert) != len(ch1.supportedSignatureAlgorithmsCert) ||
len(ch.alpnProtocols) != len(ch1.alpnProtocols) {
return true
}
for i := range ch.supportedVersions {
if ch.supportedVersions[i] != ch1.supportedVersions[i] {
return true
}
}
for i := range ch.cipherSuites {
if ch.cipherSuites[i] != ch1.cipherSuites[i] {
return true
}
}
for i := range ch.supportedCurves {
if ch.supportedCurves[i] != ch1.supportedCurves[i] {
return true
}
}
for i := range ch.supportedSignatureAlgorithms {
if ch.supportedSignatureAlgorithms[i] != ch1.supportedSignatureAlgorithms[i] {
return true
}
}
for i := range ch.supportedSignatureAlgorithmsCert {
if ch.supportedSignatureAlgorithmsCert[i] != ch1.supportedSignatureAlgorithmsCert[i] {
return true
}
}
for i := range ch.alpnProtocols {
if ch.alpnProtocols[i] != ch1.alpnProtocols[i] {
return true
}
}
return ch.vers != ch1.vers ||
!bytes.Equal(ch.random, ch1.random) ||
!bytes.Equal(ch.sessionId, ch1.sessionId) ||
!bytes.Equal(ch.compressionMethods, ch1.compressionMethods) ||
ch.serverName != ch1.serverName ||
ch.ocspStapling != ch1.ocspStapling ||
!bytes.Equal(ch.supportedPoints, ch1.supportedPoints) ||
ch.ticketSupported != ch1.ticketSupported ||
!bytes.Equal(ch.sessionTicket, ch1.sessionTicket) ||
ch.secureRenegotiationSupported != ch1.secureRenegotiationSupported ||
!bytes.Equal(ch.secureRenegotiation, ch1.secureRenegotiation) ||
ch.scts != ch1.scts ||
!bytes.Equal(ch.cookie, ch1.cookie) ||
!bytes.Equal(ch.pskModes, ch1.pskModes)
}
func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
c := hs.c
if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
return err
}
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
earlySecret := hs.earlySecret
if earlySecret == nil {
earlySecret = hs.suite.extract(nil, nil)
}
hs.handshakeSecret = hs.suite.extract(hs.sharedKey,
hs.suite.deriveSecret(earlySecret, "derived", nil))
clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript)
c.in.exportKey(EncryptionHandshake, hs.suite, clientSecret)
c.in.setTrafficSecret(hs.suite, clientSecret)
serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript)
c.out.exportKey(EncryptionHandshake, hs.suite, serverSecret)
c.out.setTrafficSecret(hs.suite, serverSecret)
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.clientHello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
if hs.alpnNegotiationErr != nil {
c.sendAlert(alertNoApplicationProtocol)
return hs.alpnNegotiationErr
}
if hs.c.extraConfig != nil && hs.c.extraConfig.GetExtensions != nil {
hs.encryptedExtensions.additionalExtensions = hs.c.extraConfig.GetExtensions(typeEncryptedExtensions)
}
if _, err := hs.c.writeHandshakeRecord(hs.encryptedExtensions, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) requestClientCert() bool {
return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK
}
func (hs *serverHandshakeStateTLS13) sendServerCertificate() error {
c := hs.c
// Only one of PSK and certificates are used at a time.
if hs.usingPSK {
return nil
}
if hs.requestClientCert() {
// Request a client certificate
certReq := new(certificateRequestMsgTLS13)
certReq.ocspStapling = true
certReq.scts = true
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
if c.config.ClientCAs != nil {
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil {
return err
}
}
certMsg := new(certificateMsgTLS13)
certMsg.certificate = *hs.cert
certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0
if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
return err
}
certVerifyMsg := new(certificateVerifyMsg)
certVerifyMsg.hasSignatureAlgorithm = true
certVerifyMsg.signatureAlgorithm = hs.sigAlg
sigType, sigHash, err := typeAndHashFromSignatureScheme(hs.sigAlg)
if err != nil {
return c.sendAlert(alertInternalError)
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
if err != nil {
public := hs.cert.PrivateKey.(crypto.Signer).Public()
if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS &&
rsaKey.N.BitLen()/8 < sigHash.Size()*2+2 { // key too small for RSA-PSS
c.sendAlert(alertHandshakeFailure)
} else {
c.sendAlert(alertInternalError)
}
return errors.New("tls: failed to sign handshake: " + err.Error())
}
certVerifyMsg.signature = sig
if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
c := hs.c
finished := &finishedMsg{
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
return err
}
// Derive secrets that take context through the server Finished.
hs.masterSecret = hs.suite.extract(nil,
hs.suite.deriveSecret(hs.handshakeSecret, "derived", nil))
hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript)
c.out.exportKey(EncryptionApplication, hs.suite, serverSecret)
c.out.setTrafficSecret(hs.suite, serverSecret)
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.clientHello.random, serverSecret)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
// If we did not request client certificates, at this point we can
// precompute the client finished and roll the transcript forward to send
// session tickets in our first flight.
if !hs.requestClientCert() {
if err := hs.sendSessionTickets(); err != nil {
return err
}
}
return nil
}
func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool {
if hs.c.config.SessionTicketsDisabled {
return false
}
// Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9.
for _, pskMode := range hs.clientHello.pskModes {
if pskMode == pskModeDHE {
return true
}
}
return false
}
func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
c := hs.c
hs.clientFinished = hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
finishedMsg := &finishedMsg{
verifyData: hs.clientFinished,
}
if err := transcriptMsg(finishedMsg, hs.transcript); err != nil {
return err
}
if !hs.shouldSendSessionTickets() {
return nil
}
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript)
// Don't send session tickets when the alternative record layer is set.
// Instead, save the resumption secret on the Conn.
// Session tickets can then be generated by calling Conn.GetSessionTicket().
if hs.c.extraConfig != nil && hs.c.extraConfig.AlternativeRecordLayer != nil {
return nil
}
m, err := hs.c.getSessionTicketMsg(nil)
if err != nil {
return err
}
if _, err := c.writeHandshakeRecord(m, nil); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
c := hs.c
if !hs.requestClientCert() {
// Make sure the connection is still being verified whether or not
// the server requested a client certificate.
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
return nil
}
// If we requested a client certificate, then the client must send a
// certificate message. If it's empty, no CertificateVerify is sent.
msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
certMsg, ok := msg.(*certificateMsgTLS13)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
if err := c.processCertsFromClient(certMsg.certificate); err != nil {
return err
}
if c.config.VerifyConnection != nil {
if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil {
c.sendAlert(alertBadCertificate)
return err
}
}
if len(certMsg.certificate.Certificate) != 0 {
// certificateVerifyMsg is included in the transcript, but not until
// after we verify the handshake signature, since the state before
// this message was sent is used.
msg, err = c.readHandshake(nil)
if err != nil {
return err
}
certVerify, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerify, msg)
}
// See RFC 8446, Section 4.4.3.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client certificate used with invalid signature algorithm")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
return c.sendAlert(alertInternalError)
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client certificate used with invalid signature algorithm")
}
signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
sigHash, signed, certVerify.signature); err != nil {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
if err := transcriptMsg(certVerify, hs.transcript); err != nil {
return err
}
}
// If we waited until the client certificates to send session tickets, we
// are ready to do it now.
if err := hs.sendSessionTickets(); err != nil {
return err
}
return nil
}
func (hs *serverHandshakeStateTLS13) readClientFinished() error {
c := hs.c
// finishedMsg is not included in the transcript.
msg, err := c.readHandshake(nil)
if err != nil {
return err
}
finished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(finished, msg)
}
if !hmac.Equal(hs.clientFinished, finished.verifyData) {
c.sendAlert(alertDecryptError)
return errors.New("tls: invalid client finished hash")
}
c.in.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret)
c.in.setTrafficSecret(hs.suite, hs.trafficSecret)
return nil
}

View File

@ -1,357 +0,0 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"crypto"
"crypto/md5"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"errors"
"fmt"
"io"
)
// a keyAgreement implements the client and server side of a TLS key agreement
// protocol by generating and processing key exchange messages.
type keyAgreement interface {
// On the server side, the first two methods are called in order.
// In the case that the key agreement protocol doesn't use a
// ServerKeyExchange message, generateServerKeyExchange can return nil,
// nil.
generateServerKeyExchange(*config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error)
processClientKeyExchange(*config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error)
// On the client side, the next two methods are called in order.
// This method may not be called if the server doesn't send a
// ServerKeyExchange message.
processServerKeyExchange(*config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error
generateClientKeyExchange(*config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error)
}
var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
// rsaKeyAgreement implements the standard TLS key agreement where the client
// encrypts the pre-master secret to the server's public key.
type rsaKeyAgreement struct{}
func (ka rsaKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
return nil, nil
}
func (ka rsaKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) < 2 {
return nil, errClientKeyExchange
}
ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
if ciphertextLen != len(ckx.ciphertext)-2 {
return nil, errClientKeyExchange
}
ciphertext := ckx.ciphertext[2:]
priv, ok := cert.PrivateKey.(crypto.Decrypter)
if !ok {
return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
}
// Perform constant time RSA PKCS #1 v1.5 decryption
preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48})
if err != nil {
return nil, err
}
// We don't check the version number in the premaster secret. For one,
// by checking it, we would leak information about the validity of the
// encrypted pre-master secret. Secondly, it provides only a small
// benefit against a downgrade attack and some implementations send the
// wrong version anyway. See the discussion at the end of section
// 7.4.7.1 of RFC 4346.
return preMasterSecret, nil
}
func (ka rsaKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
return errors.New("tls: unexpected ServerKeyExchange")
}
func (ka rsaKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
preMasterSecret := make([]byte, 48)
preMasterSecret[0] = byte(clientHello.vers >> 8)
preMasterSecret[1] = byte(clientHello.vers)
_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
if err != nil {
return nil, nil, err
}
rsaKey, ok := cert.PublicKey.(*rsa.PublicKey)
if !ok {
return nil, nil, errors.New("tls: server certificate contains incorrect key type for selected ciphersuite")
}
encrypted, err := rsa.EncryptPKCS1v15(config.rand(), rsaKey, preMasterSecret)
if err != nil {
return nil, nil, err
}
ckx := new(clientKeyExchangeMsg)
ckx.ciphertext = make([]byte, len(encrypted)+2)
ckx.ciphertext[0] = byte(len(encrypted) >> 8)
ckx.ciphertext[1] = byte(len(encrypted))
copy(ckx.ciphertext[2:], encrypted)
return preMasterSecret, ckx, nil
}
// sha1Hash calculates a SHA1 hash over the given byte slices.
func sha1Hash(slices [][]byte) []byte {
hsha1 := sha1.New()
for _, slice := range slices {
hsha1.Write(slice)
}
return hsha1.Sum(nil)
}
// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
// concatenation of an MD5 and SHA1 hash.
func md5SHA1Hash(slices [][]byte) []byte {
md5sha1 := make([]byte, md5.Size+sha1.Size)
hmd5 := md5.New()
for _, slice := range slices {
hmd5.Write(slice)
}
copy(md5sha1, hmd5.Sum(nil))
copy(md5sha1[md5.Size:], sha1Hash(slices))
return md5sha1
}
// hashForServerKeyExchange hashes the given slices and returns their digest
// using the given hash function (for >= TLS 1.2) or using a default based on
// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't
// do pre-hashing, it returns the concatenation of the slices.
func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte {
if sigType == signatureEd25519 {
var signed []byte
for _, slice := range slices {
signed = append(signed, slice...)
}
return signed
}
if version >= VersionTLS12 {
h := hashFunc.New()
for _, slice := range slices {
h.Write(slice)
}
digest := h.Sum(nil)
return digest
}
if sigType == signatureECDSA {
return sha1Hash(slices)
}
return md5SHA1Hash(slices)
}
// ecdheKeyAgreement implements a TLS key agreement where the server
// generates an ephemeral EC public/private key pair and signs it. The
// pre-master secret is then calculated using ECDH. The signature may
// be ECDSA, Ed25519 or RSA.
type ecdheKeyAgreement struct {
version uint16
isRSA bool
params ecdheParameters
// ckx and preMasterSecret are generated in processServerKeyExchange
// and returned in generateClientKeyExchange.
ckx *clientKeyExchangeMsg
preMasterSecret []byte
}
func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
var curveID CurveID
for _, c := range clientHello.supportedCurves {
if config.supportsCurve(c) {
curveID = c
break
}
}
if curveID == 0 {
return nil, errors.New("tls: no supported elliptic curves offered")
}
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
return nil, errors.New("tls: CurvePreferences includes unsupported curve")
}
params, err := generateECDHEParameters(config.rand(), curveID)
if err != nil {
return nil, err
}
ka.params = params
// See RFC 4492, Section 5.4.
ecdhePublic := params.PublicKey()
serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic))
serverECDHEParams[0] = 3 // named curve
serverECDHEParams[1] = byte(curveID >> 8)
serverECDHEParams[2] = byte(curveID)
serverECDHEParams[3] = byte(len(ecdhePublic))
copy(serverECDHEParams[4:], ecdhePublic)
priv, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey)
}
var signatureAlgorithm SignatureScheme
var sigType uint8
var sigHash crypto.Hash
if ka.version >= VersionTLS12 {
signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms)
if err != nil {
return nil, err
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
return nil, err
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public())
if err != nil {
return nil, err
}
}
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
}
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams)
signOpts := crypto.SignerOpts(sigHash)
if sigType == signatureRSAPSS {
signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
sig, err := priv.Sign(config.rand(), signed, signOpts)
if err != nil {
return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
}
skx := new(serverKeyExchangeMsg)
sigAndHashLen := 0
if ka.version >= VersionTLS12 {
sigAndHashLen = 2
}
skx.key = make([]byte, len(serverECDHEParams)+sigAndHashLen+2+len(sig))
copy(skx.key, serverECDHEParams)
k := skx.key[len(serverECDHEParams):]
if ka.version >= VersionTLS12 {
k[0] = byte(signatureAlgorithm >> 8)
k[1] = byte(signatureAlgorithm)
k = k[2:]
}
k[0] = byte(len(sig) >> 8)
k[1] = byte(len(sig))
copy(k[2:], sig)
return skx, nil
}
func (ka *ecdheKeyAgreement) processClientKeyExchange(config *config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
return nil, errClientKeyExchange
}
preMasterSecret := ka.params.SharedKey(ckx.ciphertext[1:])
if preMasterSecret == nil {
return nil, errClientKeyExchange
}
return preMasterSecret, nil
}
func (ka *ecdheKeyAgreement) processServerKeyExchange(config *config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
if len(skx.key) < 4 {
return errServerKeyExchange
}
if skx.key[0] != 3 { // named curve
return errors.New("tls: server selected unsupported curve")
}
curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
publicLen := int(skx.key[3])
if publicLen+4 > len(skx.key) {
return errServerKeyExchange
}
serverECDHEParams := skx.key[:4+publicLen]
publicKey := serverECDHEParams[4:]
sig := skx.key[4+publicLen:]
if len(sig) < 2 {
return errServerKeyExchange
}
if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
return errors.New("tls: server selected unsupported curve")
}
params, err := generateECDHEParameters(config.rand(), curveID)
if err != nil {
return err
}
ka.params = params
ka.preMasterSecret = params.SharedKey(publicKey)
if ka.preMasterSecret == nil {
return errServerKeyExchange
}
ourPublicKey := params.PublicKey()
ka.ckx = new(clientKeyExchangeMsg)
ka.ckx.ciphertext = make([]byte, 1+len(ourPublicKey))
ka.ckx.ciphertext[0] = byte(len(ourPublicKey))
copy(ka.ckx.ciphertext[1:], ourPublicKey)
var sigType uint8
var sigHash crypto.Hash
if ka.version >= VersionTLS12 {
signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
sig = sig[2:]
if len(sig) < 2 {
return errServerKeyExchange
}
if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) {
return errors.New("tls: certificate used with invalid signature algorithm")
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
return err
}
} else {
sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey)
if err != nil {
return err
}
}
if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
return errServerKeyExchange
}
sigLen := int(sig[0])<<8 | int(sig[1])
if sigLen+2 != len(sig) {
return errServerKeyExchange
}
sig = sig[2:]
signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams)
if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil {
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
return nil
}
func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
if ka.ckx == nil {
return nil, nil, errors.New("tls: missing ServerKeyExchange message")
}
return ka.preMasterSecret, ka.ckx, nil
}

View File

@ -1,216 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"crypto/elliptic"
"crypto/hmac"
"errors"
"fmt"
"hash"
"io"
"math/big"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/hkdf"
)
// This file contains the functions necessary to compute the TLS 1.3 key
// schedule. See RFC 8446, Section 7.
const (
resumptionBinderLabel = "res binder"
clientHandshakeTrafficLabel = "c hs traffic"
serverHandshakeTrafficLabel = "s hs traffic"
clientApplicationTrafficLabel = "c ap traffic"
serverApplicationTrafficLabel = "s ap traffic"
exporterLabel = "exp master"
resumptionLabel = "res master"
trafficUpdateLabel = "traffic upd"
)
// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte {
var hkdfLabel cryptobyte.Builder
hkdfLabel.AddUint16(uint16(length))
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte("tls13 "))
b.AddBytes([]byte(label))
})
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(context)
})
hkdfLabelBytes, err := hkdfLabel.Bytes()
if err != nil {
// Rather than calling BytesOrPanic, we explicitly handle this error, in
// order to provide a reasonable error message. It should be basically
// impossible for this to panic, and routing errors back through the
// tree rooted in this function is quite painful. The labels are fixed
// size, and the context is either a fixed-length computed hash, or
// parsed from a field which has the same length limitation. As such, an
// error here is likely to only be caused during development.
//
// NOTE: another reasonable approach here might be to return a
// randomized slice if we encounter an error, which would break the
// connection, but avoid panicking. This would perhaps be safer but
// significantly more confusing to users.
panic(fmt.Errorf("failed to construct HKDF label: %s", err))
}
out := make([]byte, length)
n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out)
if err != nil || n != length {
panic("tls: HKDF-Expand-Label invocation failed unexpectedly")
}
return out
}
// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1.
func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte {
if transcript == nil {
transcript = c.hash.New()
}
return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size())
}
// extract implements HKDF-Extract with the cipher suite hash.
func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte {
if newSecret == nil {
newSecret = make([]byte, c.hash.Size())
}
return hkdf.Extract(c.hash.New, newSecret, currentSecret)
}
// nextTrafficSecret generates the next traffic secret, given the current one,
// according to RFC 8446, Section 7.2.
func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte {
return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size())
}
// trafficKey generates traffic keys according to RFC 8446, Section 7.3.
func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) {
key = c.expandLabel(trafficSecret, "key", nil, c.keyLen)
iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength)
return
}
// finishedHash generates the Finished verify_data or PskBinderEntry according
// to RFC 8446, Section 4.4.4. See sections 4.4 and 4.2.11.2 for the baseKey
// selection.
func (c *cipherSuiteTLS13) finishedHash(baseKey []byte, transcript hash.Hash) []byte {
finishedKey := c.expandLabel(baseKey, "finished", nil, c.hash.Size())
verifyData := hmac.New(c.hash.New, finishedKey)
verifyData.Write(transcript.Sum(nil))
return verifyData.Sum(nil)
}
// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to
// RFC 8446, Section 7.5.
func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) {
expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript)
return func(label string, context []byte, length int) ([]byte, error) {
secret := c.deriveSecret(expMasterSecret, label, nil)
h := c.hash.New()
h.Write(context)
return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil
}
}
// ecdheParameters implements Diffie-Hellman with either NIST curves or X25519,
// according to RFC 8446, Section 4.2.8.2.
type ecdheParameters interface {
CurveID() CurveID
PublicKey() []byte
SharedKey(peerPublicKey []byte) []byte
}
func generateECDHEParameters(rand io.Reader, curveID CurveID) (ecdheParameters, error) {
if curveID == X25519 {
privateKey := make([]byte, curve25519.ScalarSize)
if _, err := io.ReadFull(rand, privateKey); err != nil {
return nil, err
}
publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint)
if err != nil {
return nil, err
}
return &x25519Parameters{privateKey: privateKey, publicKey: publicKey}, nil
}
curve, ok := curveForCurveID(curveID)
if !ok {
return nil, errors.New("tls: internal error: unsupported curve")
}
p := &nistParameters{curveID: curveID}
var err error
p.privateKey, p.x, p.y, err = elliptic.GenerateKey(curve, rand)
if err != nil {
return nil, err
}
return p, nil
}
func curveForCurveID(id CurveID) (elliptic.Curve, bool) {
switch id {
case CurveP256:
return elliptic.P256(), true
case CurveP384:
return elliptic.P384(), true
case CurveP521:
return elliptic.P521(), true
default:
return nil, false
}
}
type nistParameters struct {
privateKey []byte
x, y *big.Int // public key
curveID CurveID
}
func (p *nistParameters) CurveID() CurveID {
return p.curveID
}
func (p *nistParameters) PublicKey() []byte {
curve, _ := curveForCurveID(p.curveID)
return elliptic.Marshal(curve, p.x, p.y)
}
func (p *nistParameters) SharedKey(peerPublicKey []byte) []byte {
curve, _ := curveForCurveID(p.curveID)
// Unmarshal also checks whether the given point is on the curve.
x, y := elliptic.Unmarshal(curve, peerPublicKey)
if x == nil {
return nil
}
xShared, _ := curve.ScalarMult(x, y, p.privateKey)
sharedKey := make([]byte, (curve.Params().BitSize+7)/8)
return xShared.FillBytes(sharedKey)
}
type x25519Parameters struct {
privateKey []byte
publicKey []byte
}
func (p *x25519Parameters) CurveID() CurveID {
return X25519
}
func (p *x25519Parameters) PublicKey() []byte {
return p.publicKey[:]
}
func (p *x25519Parameters) SharedKey(peerPublicKey []byte) []byte {
sharedKey, err := curve25519.X25519(p.privateKey, peerPublicKey)
if err != nil {
return nil
}
return sharedKey
}

View File

@ -1,18 +0,0 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
func needFIPS() bool { return false }
func supportedSignatureAlgorithms() []SignatureScheme {
return defaultSupportedSignatureAlgorithms
}
func fipsMinVersion(c *config) uint16 { panic("fipsMinVersion") }
func fipsMaxVersion(c *config) uint16 { panic("fipsMaxVersion") }
func fipsCurvePreferences(c *config) []CurveID { panic("fipsCurvePreferences") }
func fipsCipherSuites(c *config) []uint16 { panic("fipsCipherSuites") }
var fipsSupportedSignatureAlgorithms []SignatureScheme

View File

@ -1,283 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"crypto"
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"errors"
"fmt"
"hash"
)
// Split a premaster secret in two as specified in RFC 4346, Section 5.
func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
s1 = secret[0 : (len(secret)+1)/2]
s2 = secret[len(secret)/2:]
return
}
// pHash implements the P_hash function, as defined in RFC 4346, Section 5.
func pHash(result, secret, seed []byte, hash func() hash.Hash) {
h := hmac.New(hash, secret)
h.Write(seed)
a := h.Sum(nil)
j := 0
for j < len(result) {
h.Reset()
h.Write(a)
h.Write(seed)
b := h.Sum(nil)
copy(result[j:], b)
j += len(b)
h.Reset()
h.Write(a)
a = h.Sum(nil)
}
}
// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5.
func prf10(result, secret, label, seed []byte) {
hashSHA1 := sha1.New
hashMD5 := md5.New
labelAndSeed := make([]byte, len(label)+len(seed))
copy(labelAndSeed, label)
copy(labelAndSeed[len(label):], seed)
s1, s2 := splitPreMasterSecret(secret)
pHash(result, s1, labelAndSeed, hashMD5)
result2 := make([]byte, len(result))
pHash(result2, s2, labelAndSeed, hashSHA1)
for i, b := range result2 {
result[i] ^= b
}
}
// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5.
func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) {
return func(result, secret, label, seed []byte) {
labelAndSeed := make([]byte, len(label)+len(seed))
copy(labelAndSeed, label)
copy(labelAndSeed[len(label):], seed)
pHash(result, secret, labelAndSeed, hashFunc)
}
}
const (
masterSecretLength = 48 // Length of a master secret in TLS 1.1.
finishedVerifyLength = 12 // Length of verify_data in a Finished message.
)
var masterSecretLabel = []byte("master secret")
var keyExpansionLabel = []byte("key expansion")
var clientFinishedLabel = []byte("client finished")
var serverFinishedLabel = []byte("server finished")
func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) {
switch version {
case VersionTLS10, VersionTLS11:
return prf10, crypto.Hash(0)
case VersionTLS12:
if suite.flags&suiteSHA384 != 0 {
return prf12(sha512.New384), crypto.SHA384
}
return prf12(sha256.New), crypto.SHA256
default:
panic("unknown version")
}
}
func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) {
prf, _ := prfAndHashForVersion(version, suite)
return prf
}
// masterFromPreMasterSecret generates the master secret from the pre-master
// secret. See RFC 5246, Section 8.1.
func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte {
seed := make([]byte, 0, len(clientRandom)+len(serverRandom))
seed = append(seed, clientRandom...)
seed = append(seed, serverRandom...)
masterSecret := make([]byte, masterSecretLength)
prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed)
return masterSecret
}
// keysFromMasterSecret generates the connection keys from the master
// secret, given the lengths of the MAC key, cipher key and IV, as defined in
// RFC 2246, Section 6.3.
func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
seed := make([]byte, 0, len(serverRandom)+len(clientRandom))
seed = append(seed, serverRandom...)
seed = append(seed, clientRandom...)
n := 2*macLen + 2*keyLen + 2*ivLen
keyMaterial := make([]byte, n)
prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed)
clientMAC = keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
serverMAC = keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
clientKey = keyMaterial[:keyLen]
keyMaterial = keyMaterial[keyLen:]
serverKey = keyMaterial[:keyLen]
keyMaterial = keyMaterial[keyLen:]
clientIV = keyMaterial[:ivLen]
keyMaterial = keyMaterial[ivLen:]
serverIV = keyMaterial[:ivLen]
return
}
func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash {
var buffer []byte
if version >= VersionTLS12 {
buffer = []byte{}
}
prf, hash := prfAndHashForVersion(version, cipherSuite)
if hash != 0 {
return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf}
}
return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf}
}
// A finishedHash calculates the hash of a set of handshake messages suitable
// for including in a Finished message.
type finishedHash struct {
client hash.Hash
server hash.Hash
// Prior to TLS 1.2, an additional MD5 hash is required.
clientMD5 hash.Hash
serverMD5 hash.Hash
// In TLS 1.2, a full buffer is sadly required.
buffer []byte
version uint16
prf func(result, secret, label, seed []byte)
}
func (h *finishedHash) Write(msg []byte) (n int, err error) {
h.client.Write(msg)
h.server.Write(msg)
if h.version < VersionTLS12 {
h.clientMD5.Write(msg)
h.serverMD5.Write(msg)
}
if h.buffer != nil {
h.buffer = append(h.buffer, msg...)
}
return len(msg), nil
}
func (h finishedHash) Sum() []byte {
if h.version >= VersionTLS12 {
return h.client.Sum(nil)
}
out := make([]byte, 0, md5.Size+sha1.Size)
out = h.clientMD5.Sum(out)
return h.client.Sum(out)
}
// clientSum returns the contents of the verify_data member of a client's
// Finished message.
func (h finishedHash) clientSum(masterSecret []byte) []byte {
out := make([]byte, finishedVerifyLength)
h.prf(out, masterSecret, clientFinishedLabel, h.Sum())
return out
}
// serverSum returns the contents of the verify_data member of a server's
// Finished message.
func (h finishedHash) serverSum(masterSecret []byte) []byte {
out := make([]byte, finishedVerifyLength)
h.prf(out, masterSecret, serverFinishedLabel, h.Sum())
return out
}
// hashForClientCertificate returns the handshake messages so far, pre-hashed if
// necessary, suitable for signing by a TLS client certificate.
func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) []byte {
if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil {
panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer")
}
if sigType == signatureEd25519 {
return h.buffer
}
if h.version >= VersionTLS12 {
hash := hashAlg.New()
hash.Write(h.buffer)
return hash.Sum(nil)
}
if sigType == signatureECDSA {
return h.server.Sum(nil)
}
return h.Sum()
}
// discardHandshakeBuffer is called when there is no more need to
// buffer the entirety of the handshake messages.
func (h *finishedHash) discardHandshakeBuffer() {
h.buffer = nil
}
// noExportedKeyingMaterial is used as a value of
// ConnectionState.ekm when renegotiation is enabled and thus
// we wish to fail all key-material export requests.
func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled")
}
// ekmFromMasterSecret generates exported keying material as defined in RFC 5705.
func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) {
return func(label string, context []byte, length int) ([]byte, error) {
switch label {
case "client finished", "server finished", "master secret", "key expansion":
// These values are reserved and may not be used.
return nil, fmt.Errorf("crypto/tls: reserved ExportKeyingMaterial label: %s", label)
}
seedLen := len(serverRandom) + len(clientRandom)
if context != nil {
seedLen += 2 + len(context)
}
seed := make([]byte, 0, seedLen)
seed = append(seed, clientRandom...)
seed = append(seed, serverRandom...)
if context != nil {
if len(context) >= 1<<16 {
return nil, fmt.Errorf("crypto/tls: ExportKeyingMaterial context too long")
}
seed = append(seed, byte(len(context)>>8), byte(len(context)))
seed = append(seed, context...)
}
keyMaterial := make([]byte, length)
prfForVersion(version, suite)(keyMaterial, masterSecret, []byte(label), seed)
return keyMaterial, nil
}
}

View File

@ -1,277 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/binary"
"errors"
"io"
"time"
"golang.org/x/crypto/cryptobyte"
)
// sessionState contains the information that is serialized into a session
// ticket in order to later resume a connection.
type sessionState struct {
vers uint16
cipherSuite uint16
createdAt uint64
masterSecret []byte // opaque master_secret<1..2^16-1>;
// struct { opaque certificate<1..2^24-1> } Certificate;
certificates [][]byte // Certificate certificate_list<0..2^24-1>;
// usedOldKey is true if the ticket from which this session came from
// was encrypted with an older key and thus should be refreshed.
usedOldKey bool
}
func (m *sessionState) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(m.vers)
b.AddUint16(m.cipherSuite)
addUint64(&b, m.createdAt)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.masterSecret)
})
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
for _, cert := range m.certificates {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(cert)
})
}
})
return b.Bytes()
}
func (m *sessionState) unmarshal(data []byte) bool {
*m = sessionState{usedOldKey: m.usedOldKey}
s := cryptobyte.String(data)
if ok := s.ReadUint16(&m.vers) &&
s.ReadUint16(&m.cipherSuite) &&
readUint64(&s, &m.createdAt) &&
readUint16LengthPrefixed(&s, &m.masterSecret) &&
len(m.masterSecret) != 0; !ok {
return false
}
var certList cryptobyte.String
if !s.ReadUint24LengthPrefixed(&certList) {
return false
}
for !certList.Empty() {
var cert []byte
if !readUint24LengthPrefixed(&certList, &cert) {
return false
}
m.certificates = append(m.certificates, cert)
}
return s.Empty()
}
// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first
// version (revision = 0) doesn't carry any of the information needed for 0-RTT
// validation and the nonce is always empty.
// version (revision = 1) carries the max_early_data_size sent in the ticket.
// version (revision = 2) carries the ALPN sent in the ticket.
type sessionStateTLS13 struct {
// uint8 version = 0x0304;
// uint8 revision = 2;
cipherSuite uint16
createdAt uint64
resumptionSecret []byte // opaque resumption_master_secret<1..2^8-1>;
certificate Certificate // CertificateEntry certificate_list<0..2^24-1>;
maxEarlyData uint32
alpn string
appData []byte
}
func (m *sessionStateTLS13) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(VersionTLS13)
b.AddUint8(2) // revision
b.AddUint16(m.cipherSuite)
addUint64(&b, m.createdAt)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.resumptionSecret)
})
marshalCertificate(&b, m.certificate)
b.AddUint32(m.maxEarlyData)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte(m.alpn))
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.appData)
})
return b.Bytes()
}
func (m *sessionStateTLS13) unmarshal(data []byte) bool {
*m = sessionStateTLS13{}
s := cryptobyte.String(data)
var version uint16
var revision uint8
var alpn []byte
ret := s.ReadUint16(&version) &&
version == VersionTLS13 &&
s.ReadUint8(&revision) &&
revision == 2 &&
s.ReadUint16(&m.cipherSuite) &&
readUint64(&s, &m.createdAt) &&
readUint8LengthPrefixed(&s, &m.resumptionSecret) &&
len(m.resumptionSecret) != 0 &&
unmarshalCertificate(&s, &m.certificate) &&
s.ReadUint32(&m.maxEarlyData) &&
readUint8LengthPrefixed(&s, &alpn) &&
readUint16LengthPrefixed(&s, &m.appData) &&
s.Empty()
m.alpn = string(alpn)
return ret
}
func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
if len(c.ticketKeys) == 0 {
return nil, errors.New("tls: internal error: session ticket keys unavailable")
}
encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size)
keyName := encrypted[:ticketKeyNameLen]
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
macBytes := encrypted[len(encrypted)-sha256.Size:]
if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
return nil, err
}
key := c.ticketKeys[0]
copy(keyName, key.keyName[:])
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
}
cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state)
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(encrypted[:len(encrypted)-sha256.Size])
mac.Sum(macBytes[:0])
return encrypted, nil
}
func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) {
if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
return nil, false
}
keyName := encrypted[:ticketKeyNameLen]
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
macBytes := encrypted[len(encrypted)-sha256.Size:]
ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
keyIndex := -1
for i, candidateKey := range c.ticketKeys {
if bytes.Equal(keyName, candidateKey.keyName[:]) {
keyIndex = i
break
}
}
if keyIndex == -1 {
return nil, false
}
key := &c.ticketKeys[keyIndex]
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(encrypted[:len(encrypted)-sha256.Size])
expected := mac.Sum(nil)
if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
return nil, false
}
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil, false
}
plaintext = make([]byte, len(ciphertext))
cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
return plaintext, keyIndex > 0
}
func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, error) {
m := new(newSessionTicketMsgTLS13)
var certsFromClient [][]byte
for _, cert := range c.peerCertificates {
certsFromClient = append(certsFromClient, cert.Raw)
}
state := sessionStateTLS13{
cipherSuite: c.cipherSuite,
createdAt: uint64(c.config.time().Unix()),
resumptionSecret: c.resumptionSecret,
certificate: Certificate{
Certificate: certsFromClient,
OCSPStaple: c.ocspResponse,
SignedCertificateTimestamps: c.scts,
},
appData: appData,
alpn: c.clientProtocol,
}
if c.extraConfig != nil {
state.maxEarlyData = c.extraConfig.MaxEarlyData
}
stateBytes, err := state.marshal()
if err != nil {
return nil, err
}
m.label, err = c.encryptTicket(stateBytes)
if err != nil {
return nil, err
}
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
// ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1
// The value is not stored anywhere; we never need to check the ticket age
// because 0-RTT is not supported.
ageAdd := make([]byte, 4)
_, err = c.config.rand().Read(ageAdd)
if err != nil {
return nil, err
}
m.ageAdd = binary.LittleEndian.Uint32(ageAdd)
// ticket_nonce, which must be unique per connection, is always left at
// zero because we only ever send one ticket per connection.
if c.extraConfig != nil {
m.maxEarlyData = c.extraConfig.MaxEarlyData
}
return m, nil
}
// GetSessionTicket generates a new session ticket.
// It should only be called after the handshake completes.
// It can only be used for servers, and only if the alternative record layer is set.
// The ticket may be nil if config.SessionTicketsDisabled is set,
// or if the client isn't able to receive session tickets.
func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) {
if c.isClient || !c.handshakeComplete() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil {
return nil, errors.New("GetSessionTicket is only valid for servers after completion of the handshake, and if an alternative record layer is set.")
}
if c.config.SessionTicketsDisabled {
return nil, nil
}
m, err := c.getSessionTicketMsg(appData)
if err != nil {
return nil, err
}
return m.marshal()
}

View File

@ -1,362 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// package qtls partially implements TLS 1.2, as specified in RFC 5246,
// and TLS 1.3, as specified in RFC 8446.
package qtls
// BUG(agl): The crypto/tls package only implements some countermeasures
// against Lucky13 attacks on CBC-mode encryption, and only on SHA1
// variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net"
"os"
"strings"
)
// Server returns a new TLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
c := &Conn{
conn: conn,
config: fromConfig(config),
extraConfig: extraConfig,
}
c.handshakeFn = c.serverHandshake
return c
}
// Client returns a new TLS client side connection
// using conn as the underlying transport.
// The config cannot be nil: users must set either ServerName or
// InsecureSkipVerify in the config.
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
c := &Conn{
conn: conn,
config: fromConfig(config),
extraConfig: extraConfig,
isClient: true,
}
c.handshakeFn = c.clientHandshake
return c
}
// A listener implements a network listener (net.Listener) for TLS connections.
type listener struct {
net.Listener
config *Config
extraConfig *ExtraConfig
}
// Accept waits for and returns the next incoming TLS connection.
// The returned connection is of type *Conn.
func (l *listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return Server(c, l.config, l.extraConfig), nil
}
// NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with Server.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func NewListener(inner net.Listener, config *Config, extraConfig *ExtraConfig) net.Listener {
l := new(listener)
l.Listener = inner
l.config = config
l.extraConfig = extraConfig
return l
}
// Listen creates a TLS listener accepting connections on the
// given network address using net.Listen.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Listen(network, laddr string, config *Config, extraConfig *ExtraConfig) (net.Listener, error) {
if config == nil || len(config.Certificates) == 0 &&
config.GetCertificate == nil && config.GetConfigForClient == nil {
return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
}
l, err := net.Listen(network, laddr)
if err != nil {
return nil, err
}
return NewListener(l, config, extraConfig), nil
}
type timeoutError struct{}
func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }
// DialWithDialer connects to the given network address using dialer.Dial and
// then initiates a TLS handshake, returning the resulting TLS connection. Any
// timeout or deadline given in the dialer apply to connection and TLS
// handshake as a whole.
//
// DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of Config for the defaults.
//
// DialWithDialer uses context.Background internally; to specify the context,
// use Dialer.DialContext with NetDialer set to the desired dialer.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) {
return dial(context.Background(), dialer, network, addr, config, extraConfig)
}
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) {
if netDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
defer cancel()
}
if !netDialer.Deadline.IsZero() {
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
defer cancel()
}
rawConn, err := netDialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
if config == nil {
config = defaultConfig()
}
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if config.ServerName == "" {
// Make a copy to avoid polluting argument or default.
c := config.Clone()
c.ServerName = hostname
config = c
}
conn := Client(rawConn, config, extraConfig)
if err := conn.HandshakeContext(ctx); err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
// Dial connects to the given network address using net.Dial
// and then initiates a TLS handshake, returning the resulting
// TLS connection.
// Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Dial(network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config, extraConfig)
}
// Dialer dials TLS connections given a configuration and a Dialer for the
// underlying connection.
type Dialer struct {
// NetDialer is the optional dialer to use for the TLS connections'
// underlying TCP connections.
// A nil NetDialer is equivalent to the net.Dialer zero value.
NetDialer *net.Dialer
// Config is the TLS configuration to use for new connections.
// A nil configuration is equivalent to the zero
// configuration; see the documentation of Config for the
// defaults.
Config *Config
ExtraConfig *ExtraConfig
}
// Dial connects to the given network address and initiates a TLS
// handshake, returning the resulting TLS connection.
//
// The returned Conn, if any, will always be of type *Conn.
//
// Dial uses context.Background internally; to specify the context,
// use DialContext.
func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
return d.DialContext(context.Background(), network, addr)
}
func (d *Dialer) netDialer() *net.Dialer {
if d.NetDialer != nil {
return d.NetDialer
}
return new(net.Dialer)
}
// DialContext connects to the given network address and initiates a TLS
// handshake, returning the resulting TLS connection.
//
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
//
// The returned Conn, if any, will always be of type *Conn.
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := dial(ctx, d.netDialer(), network, addr, d.Config, d.ExtraConfig)
if err != nil {
// Don't return c (a typed nil) in an interface.
return nil, err
}
return c, nil
}
// LoadX509KeyPair reads and parses a public/private key pair from a pair
// of files. The files must contain PEM encoded data. The certificate file
// may contain intermediate certificates following the leaf certificate to
// form a certificate chain. On successful return, Certificate.Leaf will
// be nil because the parsed form of the certificate is not retained.
func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
certPEMBlock, err := os.ReadFile(certFile)
if err != nil {
return Certificate{}, err
}
keyPEMBlock, err := os.ReadFile(keyFile)
if err != nil {
return Certificate{}, err
}
return X509KeyPair(certPEMBlock, keyPEMBlock)
}
// X509KeyPair parses a public/private key pair from a pair of
// PEM encoded data. On successful return, Certificate.Leaf will be nil because
// the parsed form of the certificate is not retained.
func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
fail := func(err error) (Certificate, error) { return Certificate{}, err }
var cert Certificate
var skippedBlockTypes []string
for {
var certDERBlock *pem.Block
certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
if certDERBlock == nil {
break
}
if certDERBlock.Type == "CERTIFICATE" {
cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
} else {
skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
}
}
if len(cert.Certificate) == 0 {
if len(skippedBlockTypes) == 0 {
return fail(errors.New("tls: failed to find any PEM data in certificate input"))
}
if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
}
return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
}
skippedBlockTypes = skippedBlockTypes[:0]
var keyDERBlock *pem.Block
for {
keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
if keyDERBlock == nil {
if len(skippedBlockTypes) == 0 {
return fail(errors.New("tls: failed to find any PEM data in key input"))
}
if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
}
return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
}
if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
break
}
skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
}
// We don't need to parse the public key for TLS, but we so do anyway
// to check that it looks sane and matches the private key.
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return fail(err)
}
cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
if err != nil {
return fail(err)
}
switch pub := x509Cert.PublicKey.(type) {
case *rsa.PublicKey:
priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if pub.N.Cmp(priv.N) != 0 {
return fail(errors.New("tls: private key does not match public key"))
}
case *ecdsa.PublicKey:
priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
return fail(errors.New("tls: private key does not match public key"))
}
case ed25519.PublicKey:
priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
if !ok {
return fail(errors.New("tls: private key type does not match public key type"))
}
if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
return fail(errors.New("tls: private key does not match public key"))
}
default:
return fail(errors.New("tls: unknown public key algorithm"))
}
return cert, nil
}
// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
// PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys.
// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
return key, nil
}
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
switch key := key.(type) {
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
return key, nil
default:
return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
}
}
if key, err := x509.ParseECPrivateKey(der); err == nil {
return key, nil
}
return nil, errors.New("tls: failed to parse private key")
}

View File

@ -1,96 +0,0 @@
package qtls
import (
"crypto/tls"
"reflect"
"unsafe"
)
func init() {
if !structsEqual(&tls.ConnectionState{}, &connectionState{}) {
panic("qtls.ConnectionState doesn't match")
}
if !structsEqual(&tls.ClientSessionState{}, &clientSessionState{}) {
panic("qtls.ClientSessionState doesn't match")
}
if !structsEqual(&tls.CertificateRequestInfo{}, &certificateRequestInfo{}) {
panic("qtls.CertificateRequestInfo doesn't match")
}
if !structsEqual(&tls.Config{}, &config{}) {
panic("qtls.Config doesn't match")
}
if !structsEqual(&tls.ClientHelloInfo{}, &clientHelloInfo{}) {
panic("qtls.ClientHelloInfo doesn't match")
}
}
func toConnectionState(c connectionState) ConnectionState {
return *(*ConnectionState)(unsafe.Pointer(&c))
}
func toClientSessionState(s *clientSessionState) *ClientSessionState {
return (*ClientSessionState)(unsafe.Pointer(s))
}
func fromClientSessionState(s *ClientSessionState) *clientSessionState {
return (*clientSessionState)(unsafe.Pointer(s))
}
func toCertificateRequestInfo(i *certificateRequestInfo) *CertificateRequestInfo {
return (*CertificateRequestInfo)(unsafe.Pointer(i))
}
func toConfig(c *config) *Config {
return (*Config)(unsafe.Pointer(c))
}
func fromConfig(c *Config) *config {
return (*config)(unsafe.Pointer(c))
}
func toClientHelloInfo(chi *clientHelloInfo) *ClientHelloInfo {
return (*ClientHelloInfo)(unsafe.Pointer(chi))
}
func structsEqual(a, b interface{}) bool {
return compare(reflect.ValueOf(a), reflect.ValueOf(b))
}
func compare(a, b reflect.Value) bool {
sa := a.Elem()
sb := b.Elem()
if sa.NumField() != sb.NumField() {
return false
}
for i := 0; i < sa.NumField(); i++ {
fa := sa.Type().Field(i)
fb := sb.Type().Field(i)
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
if fa.Type.Kind() != fb.Type.Kind() {
return false
}
if fa.Type.Kind() == reflect.Slice {
if !compareStruct(fa.Type.Elem(), fb.Type.Elem()) {
return false
}
continue
}
return false
}
}
return true
}
func compareStruct(a, b reflect.Type) bool {
if a.NumField() != b.NumField() {
return false
}
for i := 0; i < a.NumField(); i++ {
fa := a.Field(i)
fb := b.Field(i)
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
return false
}
}
return true
}

View File

@ -6,10 +6,17 @@ package qtls
import "strconv"
type alert uint8
// An AlertError is a TLS alert.
//
// When using a QUIC transport, QUICConn methods will return an error
// which wraps AlertError rather than sending a TLS alert.
type AlertError uint8
// Alert is a TLS alert
type Alert = alert
func (e AlertError) Error() string {
return alert(e).String()
}
type alert uint8
const (
// alert level

View File

@ -15,8 +15,10 @@ import (
"crypto/sha256"
"fmt"
"hash"
"runtime"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/sys/cpu"
)
// CipherSuite is a TLS cipher suite. Note that most functions in this package
@ -195,17 +197,6 @@ type cipherSuiteTLS13 struct {
hash crypto.Hash
}
type CipherSuiteTLS13 struct {
ID uint16
KeyLen int
Hash crypto.Hash
AEAD func(key, fixedNonce []byte) cipher.AEAD
}
func (c *CipherSuiteTLS13) IVLen() int {
return aeadNonceLength
}
var cipherSuitesTLS13 = []*cipherSuiteTLS13{ // TODO: replace with a map.
{TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256},
{TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256},
@ -362,6 +353,18 @@ var defaultCipherSuitesTLS13NoAES = []uint16{
TLS_AES_256_GCM_SHA384,
}
var (
hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
// Keep in sync with crypto/aes/cipher_s390x.go.
hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR &&
(cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 ||
runtime.GOARCH == "arm64" && hasGCMAsmARM64 ||
runtime.GOARCH == "s390x" && hasGCMAsmS390X
)
var aesgcmCiphers = map[uint16]bool{
// TLS 1.2
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: true,
@ -519,11 +522,6 @@ func aeadAESGCM(key, noncePrefix []byte) aead {
return ret
}
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
return aeadAESGCMTLS13(key, fixedNonce)
}
func aeadAESGCMTLS13(key, nonceMask []byte) aead {
if len(nonceMask) != aeadNonceLength {
panic("tls: internal error: wrong nonce length")

View File

@ -82,11 +82,6 @@ const (
compressionNone uint8 = 0
)
type Extension struct {
Type uint16
Data []byte
}
// TLS extension numbers
const (
extensionServerName uint16 = 0
@ -105,6 +100,7 @@ const (
extensionCertificateAuthorities uint16 = 47
extensionSignatureAlgorithmsCert uint16 = 50
extensionKeyShare uint16 = 51
extensionQUICTransportParameters uint16 = 57
extensionRenegotiationInfo uint16 = 0xff01
)
@ -113,14 +109,6 @@ const (
scsvRenegotiation uint16 = 0x00ff
)
type EncryptionLevel uint8
const (
EncryptionHandshake EncryptionLevel = iota
Encryption0RTT
EncryptionApplication
)
// CurveID is a tls.CurveID
type CurveID = tls.CurveID
@ -294,12 +282,6 @@ type connectionState struct {
ekm func(label string, context []byte, length int) ([]byte, error)
}
type ConnectionStateWith0RTT struct {
ConnectionState
Used0RTT bool // true if 0-RTT was both offered and accepted
}
// ClientAuthType is tls.ClientAuthType
type ClientAuthType = tls.ClientAuthType
@ -349,8 +331,6 @@ type clientSessionState struct {
// goroutines. Up to TLS 1.2, only ticket-based resumption is supported, not
// SessionID-based resumption. In TLS 1.3 they were merged into PSK modes, which
// are supported via this interface.
//
//go:generate sh -c "mockgen -package qtls -destination mock_client_session_cache_test.go github.com/quic-go/qtls-go1-20 ClientSessionCache"
type ClientSessionCache = tls.ClientSessionCache
// SignatureScheme is a tls.SignatureScheme
@ -736,64 +716,22 @@ type config struct {
autoSessionTicketKeys []ticketKey
}
// A RecordLayer handles encrypting and decrypting of TLS messages.
type RecordLayer interface {
SetReadKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte)
SetWriteKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte)
ReadHandshakeMessage() ([]byte, error)
WriteRecord([]byte) (int, error)
SendAlert(uint8)
}
type ExtraConfig struct {
// GetExtensions, if not nil, is called before a message that allows
// sending of extensions is sent.
// Currently only implemented for the ClientHello message (for the client)
// and for the EncryptedExtensions message (for the server).
// Only valid for TLS 1.3.
GetExtensions func(handshakeMessageType uint8) []Extension
// ReceivedExtensions, if not nil, is called when a message that allows the
// inclusion of extensions is received.
// It is called with an empty slice of extensions, if the message didn't
// contain any extensions.
// Currently only implemented for the ClientHello message (sent by the
// client) and for the EncryptedExtensions message (sent by the server).
// Only valid for TLS 1.3.
ReceivedExtensions func(handshakeMessageType uint8, exts []Extension)
// AlternativeRecordLayer is used by QUIC
AlternativeRecordLayer RecordLayer
// Enforce the selection of a supported application protocol.
// Only works for TLS 1.3.
// If enabled, client and server have to agree on an application protocol.
// Otherwise, connection establishment fails.
EnforceNextProtoSelection bool
// If MaxEarlyData is greater than 0, the client will be allowed to send early
// data when resuming a session.
// Requires the AlternativeRecordLayer to be set.
// If Enable0RTT is enabled, the client will be allowed to send early data when resuming a session.
//
// It has no meaning on the client.
MaxEarlyData uint32
Enable0RTT bool
// GetAppDataForSessionTicket requests application data to be sent with a session ticket.
//
// It has no meaning on the client.
GetAppDataForSessionTicket func() []byte
// The Accept0RTT callback is called when the client offers 0-RTT.
// The server then has to decide if it wants to accept or reject 0-RTT.
// It is only used for servers.
Accept0RTT func(appData []byte) bool
// 0RTTRejected is called when the server rejectes 0-RTT.
// It is only used for clients.
Rejected0RTT func()
// If set, the client will export the 0-RTT key when resuming a session that
// allows sending of early data.
// Requires the AlternativeRecordLayer to be set.
//
// It has no meaning to the server.
Enable0RTT bool
// Is called when the client saves a session ticket to the session ticket.
// This gives the application the opportunity to save some data along with the ticket,
// which can be restored when the session ticket is used.
@ -801,29 +739,20 @@ type ExtraConfig struct {
// Is called when the client uses a session ticket.
// Restores the application data that was saved earlier on GetAppDataForSessionTicket.
SetAppDataFromSessionState func([]byte)
SetAppDataFromSessionState func([]byte) (allowEarlyData bool)
}
// Clone clones.
func (c *ExtraConfig) Clone() *ExtraConfig {
return &ExtraConfig{
GetExtensions: c.GetExtensions,
ReceivedExtensions: c.ReceivedExtensions,
AlternativeRecordLayer: c.AlternativeRecordLayer,
EnforceNextProtoSelection: c.EnforceNextProtoSelection,
MaxEarlyData: c.MaxEarlyData,
Enable0RTT: c.Enable0RTT,
GetAppDataForSessionTicket: c.GetAppDataForSessionTicket,
Accept0RTT: c.Accept0RTT,
Rejected0RTT: c.Rejected0RTT,
GetAppDataForSessionState: c.GetAppDataForSessionState,
SetAppDataFromSessionState: c.SetAppDataFromSessionState,
}
}
func (c *ExtraConfig) usesAlternativeRecordLayer() bool {
return c != nil && c.AlternativeRecordLayer != nil
}
const (
// ticketKeyNameLen is the number of bytes of identifier that is prepended to
// an encrypted session ticket in order to identify the key used to encrypt it.
@ -1384,7 +1313,6 @@ func (c *config) BuildNameToCertificate() {
const (
keyLogLabelTLS12 = "CLIENT_RANDOM"
keyLogLabelEarlyTraffic = "CLIENT_EARLY_TRAFFIC_SECRET"
keyLogLabelClientHandshake = "CLIENT_HANDSHAKE_TRAFFIC_SECRET"
keyLogLabelServerHandshake = "SERVER_HANDSHAKE_TRAFFIC_SECRET"
keyLogLabelClientTraffic = "CLIENT_TRAFFIC_SECRET_0"
@ -1523,16 +1451,4 @@ func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlg
}
// CertificateVerificationError is returned when certificate verification fails during the handshake.
type CertificateVerificationError struct {
// UnverifiedCertificates and its contents should not be modified.
UnverifiedCertificates []*x509.Certificate
Err error
}
func (e *CertificateVerificationError) Error() string {
return fmt.Sprintf("tls: failed to verify certificate: %s", e.Err)
}
func (e *CertificateVerificationError) Unwrap() error {
return e.Err
}
type CertificateVerificationError = tls.CertificateVerificationError

View File

@ -29,6 +29,7 @@ type Conn struct {
conn net.Conn
isClient bool
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
quic *quicState // nil for non-QUIC connections
// isHandshakeComplete is true if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
@ -40,11 +41,10 @@ type Conn struct {
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *config // configuration passed to constructor
extraConfig *ExtraConfig
// handshakes counts the number of handshakes performed on the
// connection so far. If renegotiation is disabled then this is either
// zero or one.
extraConfig *ExtraConfig
handshakes int
didResume bool // whether this connection was a session resumption
cipherSuite uint16
@ -65,13 +65,8 @@ type Conn struct {
secureRenegotiation bool
// ekm is a closure for exporting keying material.
ekm func(label string, context []byte, length int) ([]byte, error)
// For the client:
// resumptionSecret is the resumption_master_secret for handling
// NewSessionTicket messages. nil if config.SessionTicketsDisabled.
// For the server:
// resumptionSecret is the resumption_master_secret for generating
// NewSessionTicket messages. Only used when the alternative record
// layer is set. nil if config.SessionTicketsDisabled.
// or sending NewSessionTicket messages.
resumptionSecret []byte
// ticketKeys is the set of active session ticket keys for this
@ -123,12 +118,7 @@ type Conn struct {
// the rest of the bits are the number of goroutines in Conn.Write.
activeCall atomic.Int32
used0RTT bool
tmp [16]byte
connStateMutex sync.Mutex
connState ConnectionStateWith0RTT
}
// Access to net.Conn methods.
@ -188,9 +178,8 @@ type halfConn struct {
nextCipher any // next encryption state
nextMac hash.Hash // next MAC algorithm
trafficSecret []byte // current TLS 1.3 traffic secret
setKeyCallback func(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte)
level QUICEncryptionLevel // current QUIC encryption level
trafficSecret []byte // current TLS 1.3 traffic secret
}
type permanentError struct {
@ -235,20 +224,9 @@ func (hc *halfConn) changeCipherSpec() error {
return nil
}
func (hc *halfConn) exportKey(encLevel EncryptionLevel, suite *cipherSuiteTLS13, trafficSecret []byte) {
if hc.setKeyCallback != nil {
s := &CipherSuiteTLS13{
ID: suite.id,
KeyLen: suite.keyLen,
Hash: suite.hash,
AEAD: func(key, fixedNonce []byte) cipher.AEAD { return suite.aead(key, fixedNonce) },
}
hc.setKeyCallback(encLevel, s, trafficSecret)
}
}
func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
hc.trafficSecret = secret
hc.level = level
key, iv := suite.trafficKey(secret)
hc.cipher = suite.aead(key, iv)
for i := range hc.seq {
@ -481,13 +459,6 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
return plaintext, typ, nil
}
func (c *Conn) setAlternativeRecordLayer() {
if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil {
c.in.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetReadKey
c.out.setKeyCallback = c.extraConfig.AlternativeRecordLayer.SetWriteKey
}
}
// sliceForAppend extends the input slice by n bytes. head is the full extended
// slice, while tail is the appended part. If the original slice has sufficient
// capacity no allocation is performed.
@ -646,6 +617,10 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
}
c.input.Reset(nil)
if c.quic != nil {
return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
}
// Read header, payload.
if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
@ -729,6 +704,9 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
case recordTypeAlert:
if c.quic != nil {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
if len(data) != 2 {
return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
@ -846,6 +824,9 @@ func (c *Conn) readFromUntil(r io.Reader, n int) error {
// sendAlert sends a TLS alert message.
func (c *Conn) sendAlertLocked(err alert) error {
if c.quic != nil {
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
switch err {
case alertNoRenegotiation, alertCloseNotify:
c.tmp[0] = alertLevelWarning
@ -865,11 +846,6 @@ func (c *Conn) sendAlertLocked(err alert) error {
// sendAlert sends a TLS alert message.
func (c *Conn) sendAlert(err alert) error {
if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil {
c.extraConfig.AlternativeRecordLayer.SendAlert(uint8(err))
return &net.OpError{Op: "local error", Err: err}
}
c.out.Lock()
defer c.out.Unlock()
return c.sendAlertLocked(err)
@ -985,6 +961,19 @@ var outBufPool = sync.Pool{
// writeRecordLocked writes a TLS record with the given type and payload to the
// connection and updates the record layer state.
func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
if c.quic != nil {
if typ != recordTypeHandshake {
return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
}
c.quicWriteCryptoData(c.out.level, data)
if !c.buffering {
if _, err := c.flush(); err != nil {
return 0, err
}
}
return len(data), nil
}
outBufPtr := outBufPool.Get().(*[]byte)
outBuf := *outBufPtr
defer func() {
@ -1046,69 +1035,63 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
// the record layer state. If transcript is non-nil the marshalled message is
// written to it.
func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
c.out.Lock()
defer c.out.Unlock()
data, err := msg.marshal()
if err != nil {
return 0, err
}
c.out.Lock()
defer c.out.Unlock()
if transcript != nil {
transcript.Write(data)
}
if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil {
return c.extraConfig.AlternativeRecordLayer.WriteRecord(data)
}
return c.writeRecordLocked(recordTypeHandshake, data)
}
// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and
// updates the record layer state.
func (c *Conn) writeChangeCipherRecord() error {
if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil {
return nil
}
c.out.Lock()
defer c.out.Unlock()
_, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
return err
}
// readHandshakeBytes reads handshake data until c.hand contains at least n bytes.
func (c *Conn) readHandshakeBytes(n int) error {
if c.quic != nil {
return c.quicReadHandshakeBytes(n)
}
for c.hand.Len() < n {
if err := c.readRecord(); err != nil {
return err
}
}
return nil
}
// readHandshake reads the next handshake message from
// the record layer. If transcript is non-nil, the message
// is written to the passed transcriptHash.
func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
var data []byte
if c.extraConfig != nil && c.extraConfig.AlternativeRecordLayer != nil {
var err error
data, err = c.extraConfig.AlternativeRecordLayer.ReadHandshakeMessage()
if err != nil {
return nil, err
}
} else {
for c.hand.Len() < 4 {
if err := c.readRecord(); err != nil {
return nil, err
}
}
data = c.hand.Bytes()
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
c.sendAlertLocked(alertInternalError)
return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
}
for c.hand.Len() < 4+n {
if err := c.readRecord(); err != nil {
return nil, err
}
}
data = c.hand.Next(4 + n)
if err := c.readHandshakeBytes(4); err != nil {
return nil, err
}
data := c.hand.Bytes()
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
c.sendAlertLocked(alertInternalError)
return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
}
if err := c.readHandshakeBytes(4 + n); err != nil {
return nil, err
}
data = c.hand.Next(4 + n)
return c.unmarshalHandshakeMessage(data, transcript)
}
func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
var m handshakeMessage
switch data[0] {
case typeHelloRequest:
@ -1288,10 +1271,6 @@ func (c *Conn) handleRenegotiation() error {
return c.handshakeErr
}
func (c *Conn) HandlePostHandshakeMessage() error {
return c.handlePostHandshakeMessage()
}
// handlePostHandshakeMessage processes a handshake message arrived after the
// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation.
func (c *Conn) handlePostHandshakeMessage() error {
@ -1303,7 +1282,6 @@ func (c *Conn) handlePostHandshakeMessage() error {
if err != nil {
return err
}
c.retryCount++
if c.retryCount > maxUselessRecords {
c.sendAlert(alertUnexpectedMessage)
@ -1315,20 +1293,28 @@ func (c *Conn) handlePostHandshakeMessage() error {
return c.handleNewSessionTicket(msg)
case *keyUpdateMsg:
return c.handleKeyUpdate(msg)
default:
c.sendAlert(alertUnexpectedMessage)
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
}
// The QUIC layer is supposed to treat an unexpected post-handshake CertificateRequest
// as a QUIC-level PROTOCOL_VIOLATION error (RFC 9001, Section 4.4). Returning an
// unexpected_message alert here doesn't provide it with enough information to distinguish
// this condition from other unexpected messages. This is probably fine.
c.sendAlert(alertUnexpectedMessage)
return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
}
func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
if c.quic != nil {
c.sendAlert(alertUnexpectedMessage)
return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
}
cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
if cipherSuite == nil {
return c.in.setErrorLocked(c.sendAlert(alertInternalError))
}
newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
c.in.setTrafficSecret(cipherSuite, newSecret)
c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
if keyUpdate.updateRequested {
c.out.Lock()
@ -1347,7 +1333,7 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
}
newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
c.out.setTrafficSecret(cipherSuite, newSecret)
c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
}
return nil
@ -1508,12 +1494,15 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
// this cancellation. In the former case, we need to close the connection.
defer cancel()
// Start the "interrupter" goroutine, if this context might be canceled.
// (The background context cannot).
//
// The interrupter goroutine waits for the input context to be done and
// closes the connection if this happens before the function returns.
if ctx.Done() != nil {
if c.quic != nil {
c.quic.cancelc = handshakeCtx.Done()
c.quic.cancel = cancel
} else if ctx.Done() != nil {
// Start the "interrupter" goroutine, if this context might be canceled.
// (The background context cannot).
//
// The interrupter goroutine waits for the input context to be done and
// closes the connection if this happens before the function returns.
done := make(chan struct{})
interruptRes := make(chan error, 1)
defer func() {
@ -1564,21 +1553,38 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
panic("tls: internal error: handshake returned an error but is marked successful")
}
if c.quic != nil {
if c.handshakeErr == nil {
c.quicHandshakeComplete()
// Provide the 1-RTT read secret now that the handshake is complete.
// The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing
// the handshake (RFC 9001, Section 5.7).
c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret)
} else {
var a alert
c.out.Lock()
if !errors.As(c.out.err, &a) {
a = alertInternalError
}
c.out.Unlock()
// Return an error which wraps both the handshake error and
// any alert error we may have sent, or alertInternalError
// if we didn't send an alert.
// Truncate the text of the alert to 0 characters.
c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
}
close(c.quic.blockedc)
close(c.quic.signalc)
}
return c.handshakeErr
}
// ConnectionState returns basic TLS details about the connection.
func (c *Conn) ConnectionState() ConnectionState {
c.connStateMutex.Lock()
defer c.connStateMutex.Unlock()
return c.connState.ConnectionState
}
// ConnectionStateWith0RTT returns basic TLS details (incl. 0-RTT status) about the connection.
func (c *Conn) ConnectionStateWith0RTT() ConnectionStateWith0RTT {
c.connStateMutex.Lock()
defer c.connStateMutex.Unlock()
return c.connState
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
return c.connectionStateLocked()
}
func (c *Conn) connectionStateLocked() ConnectionState {
@ -1609,15 +1615,6 @@ func (c *Conn) connectionStateLocked() ConnectionState {
return toConnectionState(state)
}
func (c *Conn) updateConnectionState() {
c.connStateMutex.Lock()
defer c.connStateMutex.Unlock()
c.connState = ConnectionStateWith0RTT{
Used0RTT: c.used0RTT,
ConnectionState: c.connectionStateLocked(),
}
}
// OCSPResponse returns the stapled OCSP response from the TLS server, if
// any. (Only valid for client connections.)
func (c *Conn) OCSPResponse() []byte {

View File

@ -1,22 +0,0 @@
//go:build !js
// +build !js
package qtls
import (
"runtime"
"golang.org/x/sys/cpu"
)
var (
hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ
hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
// Keep in sync with crypto/aes/cipher_s390x.go.
hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR &&
(cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM)
hasAESGCMHardwareSupport = runtime.GOARCH == "amd64" && hasGCMAsmAMD64 ||
runtime.GOARCH == "arm64" && hasGCMAsmARM64 ||
runtime.GOARCH == "s390x" && hasGCMAsmS390X
)

View File

@ -1,12 +0,0 @@
//go:build js
// +build js
package qtls
var (
hasGCMAsmAMD64 = false
hasGCMAsmARM64 = false
hasGCMAsmS390X = false
hasAESGCMHardwareSupport = false
)

View File

@ -57,23 +57,12 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
return nil, nil, errors.New("tls: NextProtos values too large")
}
var supportedVersions []uint16
var clientHelloVersion uint16
if c.extraConfig.usesAlternativeRecordLayer() {
if config.maxSupportedVersion(roleClient) < VersionTLS13 {
return nil, nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3")
}
// Only offer TLS 1.3 when QUIC is used.
supportedVersions = []uint16{VersionTLS13}
clientHelloVersion = VersionTLS13
} else {
supportedVersions = config.supportedVersions(roleClient)
if len(supportedVersions) == 0 {
return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion")
}
clientHelloVersion = config.maxSupportedVersion(roleClient)
supportedVersions := config.supportedVersions(roleClient)
if len(supportedVersions) == 0 {
return nil, nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion")
}
clientHelloVersion := config.maxSupportedVersion(roleClient)
// The version at the beginning of the ClientHello was capped at TLS 1.2
// for compatibility reasons. The supported_versions extension is used
// to negotiate versions now. See RFC 8446, Section 4.2.1.
@ -127,7 +116,9 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
// A random session ID is used to detect when the server accepted a ticket
// and is resuming a session (see RFC 5077). In TLS 1.3, it's always set as
// a compatibility measure (see RFC 8446, Section 4.1.2).
if c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil {
//
// The session ID is not set for QUIC connections (see RFC 9001, Section 8.4).
if c.quic == nil {
hello.sessionId = make([]byte, 32)
if _, err := io.ReadFull(config.rand(), hello.sessionId); err != nil {
return nil, nil, errors.New("tls: short read from Rand: " + err.Error())
@ -143,6 +134,9 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
var secret clientKeySharePrivate
if hello.supportedVersions[0] == VersionTLS13 {
if len(hello.supportedVersions) == 1 {
hello.cipherSuites = hello.cipherSuites[:0]
}
if hasAESGCMHardwareSupport {
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...)
} else {
@ -176,8 +170,15 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, clientKeySharePrivate, error)
}
}
if hello.supportedVersions[0] == VersionTLS13 && c.extraConfig != nil && c.extraConfig.GetExtensions != nil {
hello.additionalExtensions = c.extraConfig.GetExtensions(typeClientHello)
if c.quic != nil {
p, err := c.quicGetTransportParameters()
if err != nil {
return nil, nil, err
}
if p == nil {
p = []byte{}
}
hello.quicTransportParameters = p
}
return hello, secret, nil
@ -187,7 +188,6 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
if c.config == nil {
c.config = fromConfig(defaultConfig())
}
c.setAlternativeRecordLayer()
// This may be a renegotiation handshake, in which case some fields
// need to be reset.
@ -204,45 +204,33 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
return err
}
if cacheKey != "" && session != nil {
var deletedTicket bool
if session.vers == VersionTLS13 && hello.earlyData && c.extraConfig != nil && c.extraConfig.Enable0RTT {
// don't reuse a session ticket that enabled 0-RTT
c.config.ClientSessionCache.Put(cacheKey, nil)
deletedTicket = true
if suite := cipherSuiteTLS13ByID(session.cipherSuite); suite != nil {
h := suite.hash.New()
helloBytes, err := hello.marshal()
if err != nil {
return err
}
h.Write(helloBytes)
clientEarlySecret := suite.deriveSecret(earlySecret, "c e traffic", h)
c.out.exportKey(Encryption0RTT, suite, clientEarlySecret)
if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hello.random, clientEarlySecret); err != nil {
return err
}
defer func() {
// If we got a handshake failure when resuming a session, throw away
// the session ticket. See RFC 5077, Section 3.2.
//
// RFC 8446 makes no mention of dropping tickets on failure, but it
// does require servers to abort on invalid binders, so we need to
// delete tickets to recover from a corrupted PSK.
if err != nil {
c.config.ClientSessionCache.Put(cacheKey, nil)
}
}
if !deletedTicket {
defer func() {
// If we got a handshake failure when resuming a session, throw away
// the session ticket. See RFC 5077, Section 3.2.
//
// RFC 8446 makes no mention of dropping tickets on failure, but it
// does require servers to abort on invalid binders, so we need to
// delete tickets to recover from a corrupted PSK.
if err != nil {
c.config.ClientSessionCache.Put(cacheKey, nil)
}
}()
}
}()
}
if _, err := c.writeHandshakeRecord(hello, nil); err != nil {
return err
}
if hello.earlyData {
suite := cipherSuiteTLS13ByID(session.cipherSuite)
transcript := suite.hash.New()
if err := transcriptMsg(hello, transcript); err != nil {
return err
}
earlyTrafficSecret := suite.deriveSecret(earlySecret, clientEarlyTrafficLabel, transcript)
c.quicSetWriteSecret(QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret)
}
// serverHelloMsg is not included in the transcript
msg, err := c.readHandshake(nil)
if err != nil {
@ -305,7 +293,6 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(hs.session))
}
c.updateConnectionState()
return nil
}
@ -358,7 +345,10 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
}
// Try to resume a previously negotiated TLS session, if available.
cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
cacheKey = c.clientSessionCacheKey()
if cacheKey == "" {
return "", nil, nil, nil, nil
}
sess, ok := c.config.ClientSessionCache.Get(cacheKey)
if !ok || sess == nil {
return cacheKey, nil, nil, nil, nil
@ -442,6 +432,17 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
return cacheKey, nil, nil, nil, nil
}
if c.quic != nil && maxEarlyData > 0 {
var earlyData bool
if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil {
earlyData = c.extraConfig.SetAppDataFromSessionState(appData)
}
// For 0-RTT, the cipher suite has to match exactly.
if earlyData && mutualCipherSuiteTLS13(hello.cipherSuites, session.cipherSuite) != nil {
hello.earlyData = true
}
}
// Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1.
ticketAge := uint32(c.config.time().Sub(session.receivedAt) / time.Millisecond)
identity := pskIdentity{
@ -456,9 +457,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
session.nonce, cipherSuite.hash.Size())
earlySecret = cipherSuite.extract(psk, nil)
binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil)
if c.extraConfig != nil {
hello.earlyData = c.extraConfig.Enable0RTT && maxEarlyData > 0
}
transcript := cipherSuite.hash.New()
helloBytes, err := hello.marshalWithoutBinders()
if err != nil {
@ -470,9 +468,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
return "", nil, nil, nil, err
}
if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil {
c.extraConfig.SetAppDataFromSessionState(appData)
}
return
}
@ -827,7 +822,7 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
}
}
if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol); err != nil {
if err := checkALPN(hs.hello.alpnProtocols, hs.serverHello.alpnProtocol, false); err != nil {
c.sendAlert(alertUnsupportedExtension)
return false, err
}
@ -865,8 +860,12 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
// checkALPN ensure that the server's choice of ALPN protocol is compatible with
// the protocols that we advertised in the Client Hello.
func checkALPN(clientProtos []string, serverProto string) error {
func checkALPN(clientProtos []string, serverProto string, quic bool) error {
if serverProto == "" {
if quic && len(clientProtos) > 0 {
// RFC 9001, Section 8.1
return errors.New("tls: server did not select an ALPN protocol")
}
return nil
}
if len(clientProtos) == 0 {
@ -962,6 +961,10 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error {
return nil
}
// maxRSAKeySize is the maximum RSA key size in bits that we are willing
// to verify the signatures of during a TLS handshake.
const maxRSAKeySize = 8192
// verifyServerCertificate parses and verifies the provided chain, setting
// c.verifiedChains and c.peerCertificates or sending the appropriate alert.
func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
@ -973,6 +976,10 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
c.sendAlert(alertBadCertificate)
return errors.New("tls: failed to parse certificate from server: " + err.Error())
}
if cert.cert.PublicKeyAlgorithm == x509.RSA && cert.cert.PublicKey.(*rsa.PublicKey).N.BitLen() > maxRSAKeySize {
c.sendAlert(alertBadCertificate)
return fmt.Errorf("tls: server sent certificate containing RSA key larger than %d bits", maxRSAKeySize)
}
activeHandles[i] = cert
certs[i] = cert.cert
}
@ -1106,15 +1113,16 @@ func (c *Conn) getClientCertificate(cri *CertificateRequestInfo) (*Certificate,
return new(Certificate), nil
}
const clientSessionCacheKeyPrefix = "qtls-"
// clientSessionCacheKey returns a key used to cache sessionTickets that could
// be used to resume previously negotiated TLS sessions with a server.
func clientSessionCacheKey(serverAddr net.Addr, config *config) string {
if len(config.ServerName) > 0 {
return clientSessionCacheKeyPrefix + config.ServerName
func (c *Conn) clientSessionCacheKey() string {
if len(c.config.ServerName) > 0 {
return c.config.ServerName
}
return clientSessionCacheKeyPrefix + serverAddr.String()
if c.conn != nil {
return c.conn.RemoteAddr().String()
}
return ""
}
// hostnameInSNI converts name into an appropriate hostname for SNI.

View File

@ -91,7 +91,6 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
if err := hs.processServerHello(); err != nil {
return err
}
c.updateConnectionState()
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
@ -104,7 +103,6 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
if err := hs.readServerCertificate(); err != nil {
return err
}
c.updateConnectionState()
if err := hs.readServerFinished(); err != nil {
return err
}
@ -125,7 +123,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
})
c.isHandshakeComplete.Store(true)
c.updateConnectionState()
return nil
}
@ -187,6 +185,9 @@ func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.c.quic != nil {
return nil
}
if hs.sentDummyCCS {
return nil
}
@ -293,7 +294,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
transcript := hs.suite.hash.New()
transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
transcript.Write(chHash)
if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
if err := transcriptMsg(hs.serverHello, transcript); err != nil {
return err
}
helloBytes, err := hs.hello.marshalWithoutBinders()
@ -312,10 +313,11 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
}
}
if hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil {
c.extraConfig.Rejected0RTT()
if hs.hello.earlyData {
hs.hello.earlyData = false
c.quicRejectedEarlyData()
}
hs.hello.earlyData = false // disable 0-RTT
if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
@ -430,12 +432,18 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
clientSecret := hs.suite.deriveSecret(handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript)
c.out.exportKey(EncryptionHandshake, hs.suite, clientSecret)
c.out.setTrafficSecret(hs.suite, clientSecret)
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
serverSecret := hs.suite.deriveSecret(handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript)
c.in.exportKey(EncryptionHandshake, hs.suite, serverSecret)
c.in.setTrafficSecret(hs.suite, serverSecret)
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
if c.quic != nil {
if c.hand.Len() != 0 {
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
}
err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
if err != nil {
@ -467,28 +475,35 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(encryptedExtensions, msg)
}
// Notify the caller if 0-RTT was rejected.
if !encryptedExtensions.earlyData && hs.hello.earlyData && c.extraConfig != nil && c.extraConfig.Rejected0RTT != nil {
c.extraConfig.Rejected0RTT()
}
c.used0RTT = encryptedExtensions.earlyData
if hs.c.extraConfig != nil && hs.c.extraConfig.ReceivedExtensions != nil {
hs.c.extraConfig.ReceivedExtensions(typeEncryptedExtensions, encryptedExtensions.additionalExtensions)
}
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil {
c.sendAlert(alertUnsupportedExtension)
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol, c.quic != nil); err != nil {
// RFC 8446 specifies that no_application_protocol is sent by servers, but
// does not specify how clients handle the selection of an incompatible protocol.
// RFC 9001 Section 8.1 specifies that QUIC clients send no_application_protocol
// in this case. Always sending no_application_protocol seems reasonable.
c.sendAlert(alertNoApplicationProtocol)
return err
}
c.clientProtocol = encryptedExtensions.alpnProtocol
if c.extraConfig != nil && c.extraConfig.EnforceNextProtoSelection {
if len(encryptedExtensions.alpnProtocol) == 0 {
// the server didn't select an ALPN
c.sendAlert(alertNoApplicationProtocol)
return errors.New("ALPN negotiation failed. Server didn't offer any protocols")
if c.quic != nil {
if encryptedExtensions.quicTransportParameters == nil {
// RFC 9001 Section 8.2.
c.sendAlert(alertMissingExtension)
return errors.New("tls: server did not send a quic_transport_parameters extension")
}
c.quicSetTransportParameters(encryptedExtensions.quicTransportParameters)
} else {
if encryptedExtensions.quicTransportParameters != nil {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent an unexpected quic_transport_parameters extension")
}
}
if hs.hello.earlyData && !encryptedExtensions.earlyData {
c.quicRejectedEarlyData()
}
return nil
}
@ -616,8 +631,7 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error {
clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript)
c.in.exportKey(EncryptionApplication, hs.suite, serverSecret)
c.in.setTrafficSecret(hs.suite, serverSecret)
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
if err != nil {
@ -713,14 +727,20 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
return err
}
c.out.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret)
c.out.setTrafficSecret(hs.suite, hs.trafficSecret)
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript)
}
if c.quic != nil {
if c.hand.Len() != 0 {
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, hs.trafficSecret)
}
return nil
}
@ -791,8 +811,10 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
scts: c.scts,
}
cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session))
cacheKey := c.clientSessionCacheKey()
if cacheKey != "" {
c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(session))
}
return nil
}

View File

@ -93,7 +93,7 @@ type clientHelloMsg struct {
pskModes []uint8
pskIdentities []pskIdentity
pskBinders [][]byte
additionalExtensions []Extension
quicTransportParameters []byte
}
func (m *clientHelloMsg) marshal() ([]byte, error) {
@ -247,10 +247,11 @@ func (m *clientHelloMsg) marshal() ([]byte, error) {
})
})
}
for _, ext := range m.additionalExtensions {
exts.AddUint16(ext.Type)
if m.quicTransportParameters != nil { // marshal zero-length parameters when present
// RFC 9001, Section 8.2
exts.AddUint16(extensionQUICTransportParameters)
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
exts.AddBytes(ext.Data)
exts.AddBytes(m.quicTransportParameters)
})
}
if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
@ -567,6 +568,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
if !readUint8LengthPrefixed(&extData, &m.pskModes) {
return false
}
case extensionQUICTransportParameters:
m.quicTransportParameters = make([]byte, len(extData))
if !extData.CopyBytes(m.quicTransportParameters) {
return false
}
case extensionPreSharedKey:
// RFC 8446, Section 4.2.11
if !extensions.Empty() {
@ -598,7 +604,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.pskBinders = append(m.pskBinders, binder)
}
default:
m.additionalExtensions = append(m.additionalExtensions, Extension{Type: extension, Data: extData})
// Ignore unknown extensions.
continue
}
@ -867,11 +873,10 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
}
type encryptedExtensionsMsg struct {
raw []byte
alpnProtocol string
earlyData bool
additionalExtensions []Extension
raw []byte
alpnProtocol string
quicTransportParameters []byte
earlyData bool
}
func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
@ -893,17 +898,18 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
})
})
}
if m.quicTransportParameters != nil { // marshal zero-length parameters when present
// draft-ietf-quic-tls-32, Section 8.2
b.AddUint16(extensionQUICTransportParameters)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.quicTransportParameters)
})
}
if m.earlyData {
// RFC 8446, Section 4.2.10
b.AddUint16(extensionEarlyData)
b.AddUint16(0) // empty extension_data
}
for _, ext := range m.additionalExtensions {
b.AddUint16(ext.Type)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(ext.Data)
})
}
})
})
@ -923,14 +929,14 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
}
for !extensions.Empty() {
var ext uint16
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&ext) ||
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
switch ext {
switch extension {
case extensionALPN:
var protoList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
@ -942,10 +948,15 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
return false
}
m.alpnProtocol = string(proto)
case extensionQUICTransportParameters:
m.quicTransportParameters = make([]byte, len(extData))
if !extData.CopyBytes(m.quicTransportParameters) {
return false
}
case extensionEarlyData:
m.earlyData = true
default:
m.additionalExtensions = append(m.additionalExtensions, Extension{Type: ext, Data: extData})
// Ignore unknown extensions.
continue
}

View File

@ -39,8 +39,6 @@ type serverHandshakeState struct {
// serverHandshake performs a TLS handshake as a server.
func (c *Conn) serverHandshake(ctx context.Context) error {
c.setAlternativeRecordLayer()
clientHello, err := c.readClientHello(ctx)
if err != nil {
return err
@ -53,12 +51,6 @@ func (c *Conn) serverHandshake(ctx context.Context) error {
clientHello: clientHello,
}
return hs.handshake()
} else if c.extraConfig.usesAlternativeRecordLayer() {
// This should already have been caught by the check that the ClientHello doesn't
// offer any (supported) versions older than TLS 1.3.
// Check again to make sure we can't be tricked into using an older version.
c.sendAlert(alertProtocolVersion)
return errors.New("tls: negotiated TLS < 1.3 when using QUIC")
}
hs := serverHandshakeState{
@ -131,7 +123,6 @@ func (hs *serverHandshakeState) handshake() error {
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
c.isHandshakeComplete.Store(true)
c.updateConnectionState()
return nil
}
@ -167,27 +158,6 @@ func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
if len(clientHello.supportedVersions) == 0 {
clientVersions = supportedVersionsFromMax(clientHello.vers)
}
if c.extraConfig.usesAlternativeRecordLayer() {
// In QUIC, the client MUST NOT offer any old TLS versions.
// Here, we can only check that none of the other supported versions of this library
// (TLS 1.0 - TLS 1.2) is offered. We don't check for any SSL versions here.
for _, ver := range clientVersions {
if ver == VersionTLS13 {
continue
}
for _, v := range supportedVersions {
if ver == v {
c.sendAlert(alertProtocolVersion)
return nil, fmt.Errorf("tls: client offered old TLS version %#x", ver)
}
}
}
// Make the config we're using allows us to use TLS 1.3.
if c.config.maxSupportedVersion(roleServer) < VersionTLS13 {
c.sendAlert(alertInternalError)
return nil, errors.New("tls: MaxVersion prevents QUIC from using TLS 1.3")
}
}
c.vers, ok = c.config.mutualVersion(roleServer, clientVersions)
if !ok {
c.sendAlert(alertProtocolVersion)
@ -249,7 +219,7 @@ func (hs *serverHandshakeState) processClientHello() error {
c.serverName = hs.clientHello.serverName
}
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, false)
if err != nil {
c.sendAlert(alertNoApplicationProtocol)
return err
@ -310,8 +280,12 @@ func (hs *serverHandshakeState) processClientHello() error {
// negotiateALPN picks a shared ALPN protocol that both sides support in server
// preference order. If ALPN is not configured or the peer doesn't support it,
// it returns "" and no error.
func negotiateALPN(serverProtos, clientProtos []string) (string, error) {
func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, error) {
if len(serverProtos) == 0 || len(clientProtos) == 0 {
if quic && len(serverProtos) != 0 {
// RFC 9001, Section 8.1
return "", fmt.Errorf("tls: client did not request an application protocol")
}
return "", nil
}
var http11fallback bool
@ -849,6 +823,10 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
c.sendAlert(alertBadCertificate)
return errors.New("tls: failed to parse client certificate: " + err.Error())
}
if certs[i].PublicKeyAlgorithm == x509.RSA && certs[i].PublicKey.(*rsa.PublicKey).N.BitLen() > maxRSAKeySize {
c.sendAlert(alertBadCertificate)
return fmt.Errorf("tls: client sent certificate containing RSA key larger than %d bits", maxRSAKeySize)
}
}
if len(certs) == 0 && requiresClientCert(c.config.ClientAuth) {

View File

@ -41,6 +41,7 @@ type serverHandshakeStateTLS13 struct {
trafficSecret []byte // client_application_traffic_secret_0
transcript hash.Hash
clientFinished []byte
earlyData bool
}
func (hs *serverHandshakeStateTLS13) handshake() error {
@ -59,7 +60,6 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
if err := hs.checkForResumption(); err != nil {
return err
}
c.updateConnectionState()
if err := hs.pickCertificate(); err != nil {
return err
}
@ -82,7 +82,6 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
if err := hs.readClientCertificate(); err != nil {
return err
}
c.updateConnectionState()
if err := hs.readClientFinished(); err != nil {
return err
}
@ -94,7 +93,7 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
})
c.isHandshakeComplete.Store(true)
c.updateConnectionState()
return nil
}
@ -236,13 +235,23 @@ GroupSelection:
return errors.New("tls: invalid client key share")
}
c.serverName = hs.clientHello.serverName
if c.extraConfig != nil && c.extraConfig.ReceivedExtensions != nil {
c.extraConfig.ReceivedExtensions(typeClientHello, hs.clientHello.additionalExtensions)
if c.quic != nil {
if hs.clientHello.quicTransportParameters == nil {
// RFC 9001 Section 8.2.
c.sendAlert(alertMissingExtension)
return errors.New("tls: client did not send a quic_transport_parameters extension")
}
c.quicSetTransportParameters(hs.clientHello.quicTransportParameters)
} else {
if hs.clientHello.quicTransportParameters != nil {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: client sent an unexpected quic_transport_parameters extension")
}
}
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols)
c.serverName = hs.clientHello.serverName
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
if err != nil {
hs.alpnNegotiationErr = err
}
@ -299,10 +308,9 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
}
if hs.alpnNegotiationErr == nil && sessionState.alpn == c.clientProtocol &&
c.extraConfig != nil && c.extraConfig.MaxEarlyData > 0 &&
c.extraConfig != nil && c.extraConfig.Enable0RTT &&
c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) {
hs.encryptedExtensions.earlyData = true
c.used0RTT = true
}
}
@ -354,27 +362,23 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
return errors.New("tls: invalid PSK binder")
}
if c.quic != nil && hs.clientHello.earlyData && hs.encryptedExtensions.earlyData && i == 0 &&
sessionState.maxEarlyData > 0 && sessionState.cipherSuite == hs.suite.id {
hs.earlyData = true
transcript := hs.suite.hash.New()
if err := transcriptMsg(hs.clientHello, transcript); err != nil {
return err
}
earlyTrafficSecret := hs.suite.deriveSecret(hs.earlySecret, clientEarlyTrafficLabel, transcript)
c.quicSetReadSecret(QUICEncryptionLevelEarly, hs.suite.id, earlyTrafficSecret)
}
c.didResume = true
if err := c.processCertsFromClient(sessionState.certificate); err != nil {
return err
}
h := cloneHash(hs.transcript, hs.suite.hash)
clientHelloWithBindersBytes, err := hs.clientHello.marshal()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
h.Write(clientHelloWithBindersBytes)
if hs.encryptedExtensions.earlyData {
clientEarlySecret := hs.suite.deriveSecret(hs.earlySecret, "c e traffic", h)
c.in.exportKey(Encryption0RTT, hs.suite, clientEarlySecret)
if err := c.config.writeKeyLog(keyLogLabelEarlyTraffic, hs.clientHello.random, clientEarlySecret); err != nil {
c.sendAlert(alertInternalError)
return err
}
}
hs.hello.selectedIdentityPresent = true
hs.hello.selectedIdentity = uint16(i)
hs.usingPSK = true
@ -449,6 +453,9 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error {
// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
if hs.c.quic != nil {
return nil
}
if hs.sentDummyCCS {
return nil
}
@ -517,9 +524,9 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID)
return errors.New("tls: client illegally modified second ClientHello")
}
if clientHello.earlyData {
if illegalClientHelloChange(clientHello, hs.clientHello) {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: client offered 0-RTT data in second ClientHello")
return errors.New("tls: client illegally modified second ClientHello")
}
hs.clientHello = clientHello
@ -607,12 +614,18 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
clientSecret := hs.suite.deriveSecret(hs.handshakeSecret,
clientHandshakeTrafficLabel, hs.transcript)
c.in.exportKey(EncryptionHandshake, hs.suite, clientSecret)
c.in.setTrafficSecret(hs.suite, clientSecret)
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret)
serverSecret := hs.suite.deriveSecret(hs.handshakeSecret,
serverHandshakeTrafficLabel, hs.transcript)
c.out.exportKey(EncryptionHandshake, hs.suite, serverSecret)
c.out.setTrafficSecret(hs.suite, serverSecret)
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret)
if c.quic != nil {
if c.hand.Len() != 0 {
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret)
c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret)
}
err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret)
if err != nil {
@ -625,12 +638,20 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
return err
}
if hs.alpnNegotiationErr != nil {
selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
if err != nil {
c.sendAlert(alertNoApplicationProtocol)
return hs.alpnNegotiationErr
return err
}
if hs.c.extraConfig != nil && hs.c.extraConfig.GetExtensions != nil {
hs.encryptedExtensions.additionalExtensions = hs.c.extraConfig.GetExtensions(typeEncryptedExtensions)
hs.encryptedExtensions.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
if c.quic != nil {
p, err := c.quicGetTransportParameters()
if err != nil {
return err
}
hs.encryptedExtensions.quicTransportParameters = p
}
if _, err := hs.c.writeHandshakeRecord(hs.encryptedExtensions, hs.transcript); err != nil {
@ -731,8 +752,15 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
clientApplicationTrafficLabel, hs.transcript)
serverSecret := hs.suite.deriveSecret(hs.masterSecret,
serverApplicationTrafficLabel, hs.transcript)
c.out.exportKey(EncryptionApplication, hs.suite, serverSecret)
c.out.setTrafficSecret(hs.suite, serverSecret)
c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret)
if c.quic != nil {
if c.hand.Len() != 0 {
// TODO: Handle this in setTrafficSecret?
c.sendAlert(alertUnexpectedMessage)
}
c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, serverSecret)
}
err := c.config.writeKeyLog(keyLogLabelClientTraffic, hs.clientHello.random, hs.trafficSecret)
if err != nil {
@ -764,6 +792,10 @@ func (hs *serverHandshakeStateTLS13) shouldSendSessionTickets() bool {
return false
}
// QUIC tickets are sent by QUICConn.SendSessionTicket, not automatically.
if hs.c.quic != nil {
return false
}
// Don't send tickets the client wouldn't use. See RFC 8446, Section 4.2.9.
for _, pskMode := range hs.clientHello.pskModes {
if pskMode == pskModeDHE {
@ -783,25 +815,66 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
if err := transcriptMsg(finishedMsg, hs.transcript); err != nil {
return err
}
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript)
if !hs.shouldSendSessionTickets() {
return nil
}
return c.sendSessionTicket(false)
}
c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
resumptionLabel, hs.transcript)
// Don't send session tickets when the alternative record layer is set.
// Instead, save the resumption secret on the Conn.
// Session tickets can then be generated by calling Conn.GetSessionTicket().
if hs.c.extraConfig != nil && hs.c.extraConfig.AlternativeRecordLayer != nil {
return nil
func (c *Conn) sendSessionTicket(earlyData bool) error {
suite := cipherSuiteTLS13ByID(c.cipherSuite)
if suite == nil {
return errors.New("tls: internal error: unknown cipher suite")
}
m, err := hs.c.getSessionTicketMsg(nil)
m := new(newSessionTicketMsgTLS13)
var certsFromClient [][]byte
for _, cert := range c.peerCertificates {
certsFromClient = append(certsFromClient, cert.Raw)
}
state := sessionStateTLS13{
cipherSuite: suite.id,
createdAt: uint64(c.config.time().Unix()),
resumptionSecret: c.resumptionSecret,
certificate: Certificate{
Certificate: certsFromClient,
OCSPStaple: c.ocspResponse,
SignedCertificateTimestamps: c.scts,
},
alpn: c.clientProtocol,
}
if earlyData {
state.maxEarlyData = 0xffffffff
state.appData = c.extraConfig.GetAppDataForSessionTicket()
}
stateBytes, err := state.marshal()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
m.label, err = c.encryptTicket(stateBytes)
if err != nil {
return err
}
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
// ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1
// The value is not stored anywhere; we never need to check the ticket age
// because 0-RTT is not supported.
ageAdd := make([]byte, 4)
_, err = c.config.rand().Read(ageAdd)
if err != nil {
return err
}
if earlyData {
// RFC 9001, Section 4.6.1
m.maxEarlyData = 0xffffffff
}
if _, err := c.writeHandshakeRecord(m, nil); err != nil {
return err
@ -919,8 +992,7 @@ func (hs *serverHandshakeStateTLS13) readClientFinished() error {
return errors.New("tls: invalid client finished hash")
}
c.in.exportKey(EncryptionApplication, hs.suite, hs.trafficSecret)
c.in.setTrafficSecret(hs.suite, hs.trafficSecret)
c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret)
return nil
}

View File

@ -21,6 +21,7 @@ import (
const (
resumptionBinderLabel = "res binder"
clientEarlyTrafficLabel = "c e traffic"
clientHandshakeTrafficLabel = "c hs traffic"
serverHandshakeTrafficLabel = "s hs traffic"
clientApplicationTrafficLabel = "c ap traffic"

418
vendor/github.com/quic-go/qtls-go1-20/quic.go generated vendored Normal file
View File

@ -0,0 +1,418 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package qtls
import (
"context"
"errors"
"fmt"
)
// QUICEncryptionLevel represents a QUIC encryption level used to transmit
// handshake messages.
type QUICEncryptionLevel int
const (
QUICEncryptionLevelInitial = QUICEncryptionLevel(iota)
QUICEncryptionLevelEarly
QUICEncryptionLevelHandshake
QUICEncryptionLevelApplication
)
func (l QUICEncryptionLevel) String() string {
switch l {
case QUICEncryptionLevelInitial:
return "Initial"
case QUICEncryptionLevelEarly:
return "Early"
case QUICEncryptionLevelHandshake:
return "Handshake"
case QUICEncryptionLevelApplication:
return "Application"
default:
return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l))
}
}
// A QUICConn represents a connection which uses a QUIC implementation as the underlying
// transport as described in RFC 9001.
//
// Methods of QUICConn are not safe for concurrent use.
type QUICConn struct {
conn *Conn
sessionTicketSent bool
}
// A QUICConfig configures a QUICConn.
type QUICConfig struct {
TLSConfig *Config
ExtraConfig *ExtraConfig
}
// A QUICEventKind is a type of operation on a QUIC connection.
type QUICEventKind int
const (
// QUICNoEvent indicates that there are no events available.
QUICNoEvent QUICEventKind = iota
// QUICSetReadSecret and QUICSetWriteSecret provide the read and write
// secrets for a given encryption level.
// QUICEvent.Level, QUICEvent.Data, and QUICEvent.Suite are set.
//
// Secrets for the Initial encryption level are derived from the initial
// destination connection ID, and are not provided by the QUICConn.
QUICSetReadSecret
QUICSetWriteSecret
// QUICWriteData provides data to send to the peer in CRYPTO frames.
// QUICEvent.Data is set.
QUICWriteData
// QUICTransportParameters provides the peer's QUIC transport parameters.
// QUICEvent.Data is set.
QUICTransportParameters
// QUICTransportParametersRequired indicates that the caller must provide
// QUIC transport parameters to send to the peer. The caller should set
// the transport parameters with QUICConn.SetTransportParameters and call
// QUICConn.NextEvent again.
//
// If transport parameters are set before calling QUICConn.Start, the
// connection will never generate a QUICTransportParametersRequired event.
QUICTransportParametersRequired
// QUICRejectedEarlyData indicates that the server rejected 0-RTT data even
// if we offered it. It's returned before QUICEncryptionLevelApplication
// keys are returned.
QUICRejectedEarlyData
// QUICHandshakeDone indicates that the TLS handshake has completed.
QUICHandshakeDone
)
// A QUICEvent is an event occurring on a QUIC connection.
//
// The type of event is specified by the Kind field.
// The contents of the other fields are kind-specific.
type QUICEvent struct {
Kind QUICEventKind
// Set for QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
Level QUICEncryptionLevel
// Set for QUICTransportParameters, QUICSetReadSecret, QUICSetWriteSecret, and QUICWriteData.
// The contents are owned by crypto/tls, and are valid until the next NextEvent call.
Data []byte
// Set for QUICSetReadSecret and QUICSetWriteSecret.
Suite uint16
}
type quicState struct {
events []QUICEvent
nextEvent int
// eventArr is a statically allocated event array, large enough to handle
// the usual maximum number of events resulting from a single call: transport
// parameters, Initial data, Early read secret, Handshake write and read
// secrets, Handshake data, Application write secret, Application data.
eventArr [8]QUICEvent
started bool
signalc chan struct{} // handshake data is available to be read
blockedc chan struct{} // handshake is waiting for data, closed when done
cancelc <-chan struct{} // handshake has been canceled
cancel context.CancelFunc
// readbuf is shared between HandleData and the handshake goroutine.
// HandshakeCryptoData passes ownership to the handshake goroutine by
// reading from signalc, and reclaims ownership by reading from blockedc.
readbuf []byte
transportParams []byte // to send to the peer
}
// QUICClient returns a new TLS client side connection using QUICTransport as the
// underlying transport. The config cannot be nil.
//
// The config's MinVersion must be at least TLS 1.3.
func QUICClient(config *QUICConfig) *QUICConn {
return newQUICConn(Client(nil, config.TLSConfig), config.ExtraConfig)
}
// QUICServer returns a new TLS server side connection using QUICTransport as the
// underlying transport. The config cannot be nil.
//
// The config's MinVersion must be at least TLS 1.3.
func QUICServer(config *QUICConfig) *QUICConn {
return newQUICConn(Server(nil, config.TLSConfig), config.ExtraConfig)
}
func newQUICConn(conn *Conn, extraConfig *ExtraConfig) *QUICConn {
conn.quic = &quicState{
signalc: make(chan struct{}),
blockedc: make(chan struct{}),
}
conn.quic.events = conn.quic.eventArr[:0]
conn.extraConfig = extraConfig
return &QUICConn{
conn: conn,
}
}
// Start starts the client or server handshake protocol.
// It may produce connection events, which may be read with NextEvent.
//
// Start must be called at most once.
func (q *QUICConn) Start(ctx context.Context) error {
if q.conn.quic.started {
return quicError(errors.New("tls: Start called more than once"))
}
q.conn.quic.started = true
if q.conn.config.MinVersion < VersionTLS13 {
return quicError(errors.New("tls: Config MinVersion must be at least TLS 1.13"))
}
go q.conn.HandshakeContext(ctx)
if _, ok := <-q.conn.quic.blockedc; !ok {
return q.conn.handshakeErr
}
return nil
}
// NextEvent returns the next event occurring on the connection.
// It returns an event with a Kind of QUICNoEvent when no events are available.
func (q *QUICConn) NextEvent() QUICEvent {
qs := q.conn.quic
if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
// Write over some of the previous event's data,
// to catch callers erroniously retaining it.
qs.events[last].Data[0] = 0
}
if qs.nextEvent >= len(qs.events) {
qs.events = qs.events[:0]
qs.nextEvent = 0
return QUICEvent{Kind: QUICNoEvent}
}
e := qs.events[qs.nextEvent]
qs.events[qs.nextEvent] = QUICEvent{} // zero out references to data
qs.nextEvent++
return e
}
// Close closes the connection and stops any in-progress handshake.
func (q *QUICConn) Close() error {
if q.conn.quic.cancel == nil {
return nil // never started
}
q.conn.quic.cancel()
for range q.conn.quic.blockedc {
// Wait for the handshake goroutine to return.
}
return q.conn.handshakeErr
}
// HandleData handles handshake bytes received from the peer.
// It may produce connection events, which may be read with NextEvent.
func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
c := q.conn
if c.in.level != level {
return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level")))
}
c.quic.readbuf = data
<-c.quic.signalc
_, ok := <-c.quic.blockedc
if ok {
// The handshake goroutine is waiting for more data.
return nil
}
// The handshake goroutine has exited.
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
c.hand.Write(c.quic.readbuf)
c.quic.readbuf = nil
for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
b := q.conn.hand.Bytes()
n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if n > maxHandshake {
q.conn.handshakeErr = fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)
break
}
if len(b) < 4+n {
return nil
}
if err := q.conn.handlePostHandshakeMessage(); err != nil {
q.conn.handshakeErr = err
}
}
if q.conn.handshakeErr != nil {
return quicError(q.conn.handshakeErr)
}
return nil
}
// SendSessionTicket sends a session ticket to the client.
// It produces connection events, which may be read with NextEvent.
// Currently, it can only be called once.
func (q *QUICConn) SendSessionTicket(earlyData bool) error {
c := q.conn
if !c.isHandshakeComplete.Load() {
return quicError(errors.New("tls: SendSessionTicket called before handshake completed"))
}
if c.isClient {
return quicError(errors.New("tls: SendSessionTicket called on the client"))
}
if q.sessionTicketSent {
return quicError(errors.New("tls: SendSessionTicket called multiple times"))
}
q.sessionTicketSent = true
return quicError(c.sendSessionTicket(earlyData))
}
// ConnectionState returns basic TLS details about the connection.
func (q *QUICConn) ConnectionState() ConnectionState {
return q.conn.ConnectionState()
}
// SetTransportParameters sets the transport parameters to send to the peer.
//
// Server connections may delay setting the transport parameters until after
// receiving the client's transport parameters. See QUICTransportParametersRequired.
func (q *QUICConn) SetTransportParameters(params []byte) {
if params == nil {
params = []byte{}
}
q.conn.quic.transportParams = params
if q.conn.quic.started {
<-q.conn.quic.signalc
<-q.conn.quic.blockedc
}
}
// quicError ensures err is an AlertError.
// If err is not already, quicError wraps it with alertInternalError.
func quicError(err error) error {
if err == nil {
return nil
}
var ae AlertError
if errors.As(err, &ae) {
return err
}
var a alert
if !errors.As(err, &a) {
a = alertInternalError
}
// Return an error wrapping the original error and an AlertError.
// Truncate the text of the alert to 0 characters.
return fmt.Errorf("%w%.0w", err, AlertError(a))
}
func (c *Conn) quicReadHandshakeBytes(n int) error {
for c.hand.Len() < n {
if err := c.quicWaitForSignal(); err != nil {
return err
}
}
return nil
}
func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICSetReadSecret,
Level: level,
Suite: suite,
Data: secret,
})
}
func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICSetWriteSecret,
Level: level,
Suite: suite,
Data: secret,
})
}
func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
var last *QUICEvent
if len(c.quic.events) > 0 {
last = &c.quic.events[len(c.quic.events)-1]
}
if last == nil || last.Kind != QUICWriteData || last.Level != level {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICWriteData,
Level: level,
})
last = &c.quic.events[len(c.quic.events)-1]
}
last.Data = append(last.Data, data...)
}
func (c *Conn) quicSetTransportParameters(params []byte) {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICTransportParameters,
Data: params,
})
}
func (c *Conn) quicGetTransportParameters() ([]byte, error) {
if c.quic.transportParams == nil {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICTransportParametersRequired,
})
}
for c.quic.transportParams == nil {
if err := c.quicWaitForSignal(); err != nil {
return nil, err
}
}
return c.quic.transportParams, nil
}
func (c *Conn) quicHandshakeComplete() {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICHandshakeDone,
})
}
func (c *Conn) quicRejectedEarlyData() {
c.quic.events = append(c.quic.events, QUICEvent{
Kind: QUICRejectedEarlyData,
})
}
// quicWaitForSignal notifies the QUICConn that handshake progress is blocked,
// and waits for a signal that the handshake should proceed.
//
// The handshake may become blocked waiting for handshake bytes
// or for the user to provide transport parameters.
func (c *Conn) quicWaitForSignal() error {
// Drop the handshake mutex while blocked to allow the user
// to call ConnectionState before the handshake completes.
c.handshakeMutex.Unlock()
defer c.handshakeMutex.Lock()
// Send on blockedc to notify the QUICConn that the handshake is blocked.
// Exported methods of QUICConn wait for the handshake to become blocked
// before returning to the user.
select {
case c.quic.blockedc <- struct{}{}:
case <-c.quic.cancelc:
return c.sendAlertLocked(alertCloseNotify)
}
// The QUICConn reads from signalc to notify us that the handshake may
// be able to proceed. (The QUICConn reads, because we close signalc to
// indicate that the handshake has completed.)
select {
case c.quic.signalc <- struct{}{}:
c.hand.Write(c.quic.readbuf)
c.quic.readbuf = nil
case <-c.quic.cancelc:
return c.sendAlertLocked(alertCloseNotify)
}
return nil
}

View File

@ -11,12 +11,9 @@ import (
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/binary"
"errors"
"io"
"time"
"golang.org/x/crypto/cryptobyte"
"io"
)
// sessionState contains the information that is serialized into a session
@ -204,74 +201,3 @@ func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey boo
return plaintext, keyIndex > 0
}
func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, error) {
m := new(newSessionTicketMsgTLS13)
var certsFromClient [][]byte
for _, cert := range c.peerCertificates {
certsFromClient = append(certsFromClient, cert.Raw)
}
state := sessionStateTLS13{
cipherSuite: c.cipherSuite,
createdAt: uint64(c.config.time().Unix()),
resumptionSecret: c.resumptionSecret,
certificate: Certificate{
Certificate: certsFromClient,
OCSPStaple: c.ocspResponse,
SignedCertificateTimestamps: c.scts,
},
appData: appData,
alpn: c.clientProtocol,
}
if c.extraConfig != nil {
state.maxEarlyData = c.extraConfig.MaxEarlyData
}
stateBytes, err := state.marshal()
if err != nil {
return nil, err
}
m.label, err = c.encryptTicket(stateBytes)
if err != nil {
return nil, err
}
m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
// ticket_age_add is a random 32-bit value. See RFC 8446, section 4.6.1
// The value is not stored anywhere; we never need to check the ticket age
// because 0-RTT is not supported.
ageAdd := make([]byte, 4)
_, err = c.config.rand().Read(ageAdd)
if err != nil {
return nil, err
}
m.ageAdd = binary.LittleEndian.Uint32(ageAdd)
// ticket_nonce, which must be unique per connection, is always left at
// zero because we only ever send one ticket per connection.
if c.extraConfig != nil {
m.maxEarlyData = c.extraConfig.MaxEarlyData
}
return m, nil
}
// GetSessionTicket generates a new session ticket.
// It should only be called after the handshake completes.
// It can only be used for servers, and only if the alternative record layer is set.
// The ticket may be nil if config.SessionTicketsDisabled is set,
// or if the client isn't able to receive session tickets.
func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) {
if c.isClient || !c.isHandshakeComplete.Load() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil {
return nil, errors.New("GetSessionTicket is only valid for servers after completion of the handshake, and if an alternative record layer is set.")
}
if c.config.SessionTicketsDisabled {
return nil, nil
}
m, err := c.getSessionTicketMsg(appData)
if err != nil {
return nil, err
}
return m.marshal()
}

View File

@ -31,11 +31,10 @@ import (
// using conn as the underlying transport.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
func Server(conn net.Conn, config *Config) *Conn {
c := &Conn{
conn: conn,
config: fromConfig(config),
extraConfig: extraConfig,
conn: conn,
config: fromConfig(config),
}
c.handshakeFn = c.serverHandshake
return c
@ -45,12 +44,11 @@ func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
// using conn as the underlying transport.
// The config cannot be nil: users must set either ServerName or
// InsecureSkipVerify in the config.
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
func Client(conn net.Conn, config *Config) *Conn {
c := &Conn{
conn: conn,
config: fromConfig(config),
extraConfig: extraConfig,
isClient: true,
conn: conn,
config: fromConfig(config),
isClient: true,
}
c.handshakeFn = c.clientHandshake
return c
@ -59,8 +57,7 @@ func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
// A listener implements a network listener (net.Listener) for TLS connections.
type listener struct {
net.Listener
config *Config
extraConfig *ExtraConfig
config *Config
}
// Accept waits for and returns the next incoming TLS connection.
@ -70,18 +67,17 @@ func (l *listener) Accept() (net.Conn, error) {
if err != nil {
return nil, err
}
return Server(c, l.config, l.extraConfig), nil
return Server(c, l.config), nil
}
// NewListener creates a Listener which accepts connections from an inner
// Listener and wraps each connection with Server.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func NewListener(inner net.Listener, config *Config, extraConfig *ExtraConfig) net.Listener {
func NewListener(inner net.Listener, config *Config) net.Listener {
l := new(listener)
l.Listener = inner
l.config = config
l.extraConfig = extraConfig
return l
}
@ -89,7 +85,7 @@ func NewListener(inner net.Listener, config *Config, extraConfig *ExtraConfig) n
// given network address using net.Listen.
// The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate.
func Listen(network, laddr string, config *Config, extraConfig *ExtraConfig) (net.Listener, error) {
func Listen(network, laddr string, config *Config) (net.Listener, error) {
if config == nil || len(config.Certificates) == 0 &&
config.GetCertificate == nil && config.GetConfigForClient == nil {
return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
@ -98,7 +94,7 @@ func Listen(network, laddr string, config *Config, extraConfig *ExtraConfig) (ne
if err != nil {
return nil, err
}
return NewListener(l, config, extraConfig), nil
return NewListener(l, config), nil
}
type timeoutError struct{}
@ -117,11 +113,11 @@ func (timeoutError) Temporary() bool { return true }
//
// DialWithDialer uses context.Background internally; to specify the context,
// use Dialer.DialContext with NetDialer set to the desired dialer.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) {
return dial(context.Background(), dialer, network, addr, config, extraConfig)
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
return dial(context.Background(), dialer, network, addr, config)
}
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) {
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
if netDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
@ -157,7 +153,7 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf
config = c
}
conn := Client(rawConn, config, extraConfig)
conn := Client(rawConn, config)
if err := conn.HandshakeContext(ctx); err != nil {
rawConn.Close()
return nil, err
@ -171,8 +167,8 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf
// Dial interprets a nil configuration as equivalent to
// the zero configuration; see the documentation of Config
// for the defaults.
func Dial(network, addr string, config *Config, extraConfig *ExtraConfig) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config, extraConfig)
func Dial(network, addr string, config *Config) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config)
}
// Dialer dials TLS connections given a configuration and a Dialer for the
@ -188,8 +184,6 @@ type Dialer struct {
// configuration; see the documentation of Config for the
// defaults.
Config *Config
ExtraConfig *ExtraConfig
}
// Dial connects to the given network address and initiates a TLS
@ -220,7 +214,7 @@ func (d *Dialer) netDialer() *net.Dialer {
//
// The returned Conn, if any, will always be of type *Conn.
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := dial(ctx, d.netDialer(), network, addr, d.Config, d.ExtraConfig)
c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
if err != nil {
// Don't return c (a typed nil) in an interface.
return nil, err

View File

@ -94,3 +94,8 @@ func compareStruct(a, b reflect.Type) bool {
}
return true
}
// InitSessionTicketKeys triggers the initialization of session ticket keys.
func InitSessionTicketKeys(conf *Config) {
fromConfig(conf).ticketKeys(nil)
}

View File

@ -1,4 +1,6 @@
run:
skip-files:
- internal/handshake/cipher_suite.go
linters-settings:
depguard:
type: blacklist

View File

@ -4,29 +4,210 @@
[![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go)](https://pkg.go.dev/github.com/quic-go/quic-go)
[![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/)
[![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/quic-go.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:quic-go)
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go, including the Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) and Datagram Packetization Layer Path MTU
Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899)). It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
In addition to the RFCs listed above, it currently implements the [IETF QUIC draft-29](https://tools.ietf.org/html/draft-ietf-quic-transport-29). Support for draft-29 will eventually be dropped, as it is phased out of the ecosystem.
In addition to these base RFCs, it also implements the following RFCs:
* Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221))
* Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899))
* QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369))
* QUIC Event Logging using qlog ([draft-ietf-quic-qlog-main-schema](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-main-schema/) and [draft-ietf-quic-qlog-quic-events](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/))
## Guides
Support for WebTransport over HTTP/3 ([draft-ietf-webtrans-http3](https://datatracker.ietf.org/doc/draft-ietf-webtrans-http3/)) is implemented in [webtransport-go](https://github.com/quic-go/webtransport-go).
*We currently support Go 1.19.x and Go 1.20.x*
## Using QUIC
Running tests:
### Running a Server
go test ./...
The central entry point is the `quic.Transport`. A transport manages QUIC connections running on a single UDP socket. Since QUIC uses Connection IDs, it can demultiplex a listener (accepting incoming connections) and an arbitrary number of outgoing QUIC connections on the same UDP socket.
### QUIC without HTTP/3
```go
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 1234})
// ... error handling
tr := quic.Transport{
Conn: udpConn,
}
ln, err := tr.Listen(tlsConf, quicConf)
// ... error handling
go func() {
for {
conn, err := ln.Accept()
// ... error handling
// handle the connection, usually in a new Go routine
}
}()
```
Take a look at [this echo example](example/echo/echo.go).
The listener `ln` can now be used to accept incoming QUIC connections by (repeatedly) calling the `Accept` method (see below for more information on the `quic.Connection`).
## Usage
As a shortcut, `quic.Listen` and `quic.ListenAddr` can be used without explicitly initializing a `quic.Transport`:
```
ln, err := quic.Listen(udpConn, tlsConf, quicConf)
```
When using the shortcut, it's not possible to reuse the same UDP socket for outgoing connections.
### Running a Client
As mentioned above, multiple outgoing connections can share a single UDP socket, since QUIC uses Connection IDs to demultiplex connections.
```go
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
defer cancel()
conn, err := tr.Dial(ctx, <server address>, <tls.Config>, <quic.Config>)
// ... error handling
```
As a shortcut, `quic.Dial` and `quic.DialAddr` can be used without explictly initializing a `quic.Transport`:
```go
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
defer cancel()
conn, err := quic.Dial(ctx, conn, <server address>, <tls.Config>, <quic.Config>)
```
Just as we saw before when used a similar shortcut to run a server, it's also not possible to reuse the same UDP socket for other outgoing connections, or to listen for incoming connections.
### Using a QUIC Connection
#### Accepting Streams
QUIC is a stream-multiplexed transport. A `quic.Connection` fundamentally differs from the `net.Conn` and the `net.PacketConn` interface defined in the standard library. Data is sent and received on (unidirectional and bidirectional) streams (and, if supported, in [datagrams](#quic-datagrams)), not on the connection itself. The stream state machine is described in detail in [Section 3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-3).
Note: A unidirectional stream is a stream that the initiator can only write to (`quic.SendStream`), and the receiver can only read from (`quic.ReceiveStream`). A bidirectional stream (`quic.Stream`) allows reading from and writing to for both sides.
On the receiver side, streams are accepted using the `AcceptStream` (for bidirectional) and `AcceptUniStream` functions. For most user cases, it makes sense to call these functions in a loop:
```go
for {
str, err := conn.AcceptStream(context.Background()) // for bidirectional streams
// ... error handling
// handle the stream, usually in a new Go routine
}
```
These functions return an error when the underlying QUIC connection is closed.
#### Opening Streams
There are two slightly different ways to open streams, one synchronous and one (potentially) asynchronous. This API is necessary since the receiver grants us a certain number of streams that we're allowed to open. It may grant us additional streams later on (typically when existing streams are closed), but it means that at the time we want to open a new stream, we might not be able to do so.
Using the synchronous method `OpenStreamSync` for bidirectional streams, and `OpenUniStreamSync` for unidirectional streams, an application can block until the peer allows opening additional streams. In case that we're allowed to open a new stream, these methods return right away:
```go
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
str, err := conn.OpenStreamSync(ctx) // wait up to 5s to open a new bidirectional stream
```
The asynchronous version never blocks. If it's currently not possible to open a new stream, it returns a `net.Error` timeout error:
```go
str, err := conn.OpenStream()
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
// It's currently not possible to open another stream,
// but it might be possible later, once the peer allowed us to do so.
}
```
These functions return an error when the underlying QUIC connection is closed.
#### Using Streams
Using QUIC streams is pretty straightforward. The `quic.ReceiveStream` implements the `io.Reader` interface, and the `quic.SendStream` implements the `io.Writer` interface. A bidirectional stream (`quic.Stream`) implements both these interfaces. Conceptually, a bidirectional stream can be thought of as the composition of two unidirectional streams in opposite directions.
Calling `Close` on a `quic.SendStream` or a `quic.Stream` closes the send side of the stream. On the receiver side, this will be surfaced as an `io.EOF` returned from the `io.Reader` once all data has been consumed. Note that for bidirectional streams, `Close` _only_ closes the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
In case the application wishes to abort sending on a `quic.SendStream` or a `quic.Stream` , it can reset the send side by calling `CancelWrite` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Reader`. Note that for bidirectional streams, `CancelWrite` _only_ resets the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
Conversely, in case the application wishes to abort receiving from a `quic.ReceiveStream` or a `quic.Stream`, it can ask the sender to abort data transmission by calling `CancelRead` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Writer`. Note that for bidirectional streams, `CancelWrite` _only_ resets the receive side of the stream. It is still possible to write to the stream.
A bidirectional stream is only closed once both the read and the write side of the stream have been either closed or reset. Only then the peer is granted a new stream according to the maximum number of concurrent streams configured via `quic.Config.MaxIncomingStreams`.
### Configuring QUIC
The `quic.Config` struct passed to both the listen and dial calls (see above) contains a wide range of configuration options for QUIC connections, incl. the ability to fine-tune flow control limits, the number of streams that the peer is allowed to open concurrently, keep-alives, idle timeouts, and many more. Please refer to the documentation for the `quic.Config` for details.
The `quic.Transport` contains a few configuration options that don't apply to any single QUIC connection, but to all connections handled by that transport. It is highly recommend to set the `StatelessResetToken`, which allows endpoints to quickly recover from crashes / reboots of our node (see [Section 10.3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-10.3)).
### Closing a Connection
#### When the remote Peer closes the Connection
In case the peer closes the QUIC connection, all calls to open streams, accept streams, as well as all methods on streams immediately return an error. Additionally, it is set as cancellation cause of the connection context. Users can use errors assertions to find out what exactly went wrong:
* `quic.VersionNegotiationError`: Happens during the handshake, if there is no overlap between our and the remote's supported QUIC versions.
* `quic.HandshakeTimeoutError`: Happens if the QUIC handshake doesn't complete within the time specified in `quic.Config.HandshakeTimeout`.
* `quic.IdleTimeoutError`: Happens after completion of the handshake if the connection is idle for longer than the minimum of both peers idle timeouts (as configured by `quic.Config.IdleTimeout`). The connection is considered idle when no stream data (and datagrams, if applicable) are exchanged for that period. The QUIC connection can be instructed to regularly send a packet to prevent a connection from going idle by setting `quic.Config.KeepAlive`. However, this is no guarantee that the peer doesn't suddenly go away (e.g. by abruptly shutting down the node or by crashing), or by a NAT binding expiring, in which case this error might still occur.
* `quic.StatelessResetError`: Happens when the remote peer lost the state required to decrypt the packet. This requires the `quic.Transport.StatelessResetToken` to be configured by the peer.
* `quic.TransportError`: Happens if when the QUIC protocol is violated. Unless the error code is `APPLICATION_ERROR`, this will not happen unless one of the QUIC stacks involved is misbehaving. Please open an issue if you encounter this error.
* `quic.ApplicationError`: Happens when the remote decides to close the connection, see below.
#### Initiated by the Application
A `quic.Connection` can be closed using `CloseWithError`:
```go
conn.CloseWithError(0x42, "error 0x42 occurred")
```
Applications can transmit both an error code (an unsigned 62-bit number) as well as a UTF-8 encoded human-readable reason. The error code allows the receiver to learn why the connection was closed, and the reason can be useful for debugging purposes.
On the receiver side, this is surfaced as a `quic.ApplicationError`.
### QUIC Datagrams
Unreliable datagrams are a QUIC extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) that is negotiated during the handshake. Support can be enabled by setting the `quic.Config.EnableDatagram` flag. Note that this doesn't guarantee that the peer also supports datagrams. Whether or not the feature negotiation succeeded can be learned from the `quic.ConnectionState.SupportsDatagrams` obtained from `quic.Connection.ConnectionState()`.
QUIC DATAGRAMs are a new QUIC frame type sent in QUIC 1-RTT packets (i.e. after completion of the handshake). Therefore, they're end-to-end encrypted and congestion-controlled. However, if a DATAGRAM frame is deemed lost by QUIC's loss detection mechanism, they are not retransmitted.
Datagrams are sent using the `SendDatagram` method on the `quic.Connection`:
```go
conn.SendDatagram([]byte("foobar"))
```
And received using `ReceiveDatagram`:
```go
msg, err := conn.ReceiveDatagram()
```
Note that this code path is currently not optimized. It works for datagrams that are sent occasionally, but it doesn't achieve the same throughput as writing data on a stream. Please get in touch on issue #3766 if your use case relies on high datagram throughput, or if you'd like to help fix this issue. There are also some restrictions regarding the maximum message size (see #3599).
### QUIC Event Logging using qlog
quic-go logs a wide range of events defined in [draft-ietf-quic-qlog-quic-events](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/), providing comprehensive insights in the internals of a QUIC connection.
qlog files can be processed by a number of 3rd-party tools. [qviz](https://qvis.quictools.info/) has proven very useful for debugging all kinds of QUIC connection failures.
qlog is activated by setting a `Tracer` callback on the `Config`. It is called as soon as quic-go decides to starts the QUIC handshake on a new connection.
A useful implementation of this callback could look like this:
```go
quic.Config{
Tracer: func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
role := "server"
if p == logging.PerspectiveClient {
role = "client"
}
filename := fmt.Sprintf("./log_%s_%s.qlog", connID, role)
f, err := os.Create(filename)
// handle the error
return qlog.NewConnectionTracer(f, p, connID)
}
}
```
This implementation of the callback creates a new qlog file in the current directory named `log_<client / server>_<QUIC connection ID>.qlog`.
## Using HTTP/3
### As a server
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard lib http in go:
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard library http package in Go:
```go
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
@ -46,12 +227,13 @@ http.Client{
## Projects using quic-go
| Project | Description | Stars |
|-----------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
| --------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. | ![GitHub Repo stars](https://img.shields.io/github/stars/AdguardTeam/AdGuardHome?style=flat-square) |
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support | ![GitHub Repo stars](https://img.shields.io/github/stars/xyproto/algernon?style=flat-square) |
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS | ![GitHub Repo stars](https://img.shields.io/github/stars/caddyserver/caddy?style=flat-square) |
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins | ![GitHub Repo stars](https://img.shields.io/github/stars/cloudflare/cloudflared?style=flat-square) |
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others | ![GitHub Repo stars](https://img.shields.io/github/stars/libp2p/go-libp2p?style=flat-square) |
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy | ![GitHub Repo stars](https://img.shields.io/github/stars/apernet/hysteria?style=flat-square) |
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications | ![GitHub Repo stars](https://img.shields.io/github/stars/dunglas/mercure?style=flat-square) |
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. | ![GitHub Repo stars](https://img.shields.io/github/stars/ooni/probe-cli?style=flat-square) |
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization | ![GitHub Repo stars](https://img.shields.io/github/stars/syncthing/syncthing?style=flat-square) |
@ -59,6 +241,17 @@ http.Client{
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions | ![GitHub Repo stars](https://img.shields.io/github/stars/v2fly/v2ray-core?style=flat-square) |
| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System | ![GitHub Repo stars](https://img.shields.io/github/stars/yomorun/yomo?style=flat-square) |
If you'd like to see your project added to this list, please send us a PR.
## Release Policy
quic-go always aims to support the latest two Go releases.
### Dependency on forked crypto/tls
Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20).
This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward.
## Contributing
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.

View File

@ -51,18 +51,22 @@ func (b *packetBuffer) Release() {
}
// Len returns the length of Data
func (b *packetBuffer) Len() protocol.ByteCount {
return protocol.ByteCount(len(b.Data))
}
func (b *packetBuffer) Len() protocol.ByteCount { return protocol.ByteCount(len(b.Data)) }
func (b *packetBuffer) Cap() protocol.ByteCount { return protocol.ByteCount(cap(b.Data)) }
func (b *packetBuffer) putBack() {
if cap(b.Data) != int(protocol.MaxPacketBufferSize) {
panic("putPacketBuffer called with packet of wrong size!")
if cap(b.Data) == protocol.MaxPacketBufferSize {
bufferPool.Put(b)
return
}
bufferPool.Put(b)
if cap(b.Data) == protocol.MaxLargePacketBufferSize {
largeBufferPool.Put(b)
return
}
panic("putPacketBuffer called with packet of wrong size!")
}
var bufferPool sync.Pool
var bufferPool, largeBufferPool sync.Pool
func getPacketBuffer() *packetBuffer {
buf := bufferPool.Get().(*packetBuffer)
@ -71,10 +75,18 @@ func getPacketBuffer() *packetBuffer {
return buf
}
func getLargePacketBuffer() *packetBuffer {
buf := largeBufferPool.Get().(*packetBuffer)
buf.refCount = 1
buf.Data = buf.Data[:0]
return buf
}
func init() {
bufferPool.New = func() interface{} {
return &packetBuffer{
Data: make([]byte, 0, protocol.MaxPacketBufferSize),
}
bufferPool.New = func() any {
return &packetBuffer{Data: make([]byte, 0, protocol.MaxPacketBufferSize)}
}
largeBufferPool.New = func() any {
return &packetBuffer{Data: make([]byte, 0, protocol.MaxLargePacketBufferSize)}
}
}

View File

@ -34,7 +34,7 @@ type client struct {
conn quicConn
tracer logging.ConnectionTracer
tracer *logging.ConnectionTracer
tracingID uint64
logger utils.Logger
}
@ -43,7 +43,9 @@ type client struct {
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
// DialAddr establishes a new QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
// It resolves the address, and then creates a new UDP connection to dial the QUIC server.
// When the QUIC connection is closed, this UDP connection is closed.
// See Dial for more details.
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
@ -53,15 +55,15 @@ func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Confi
if err != nil {
return nil, err
}
dl, err := setupTransport(udpConn, tlsConf, true)
tr, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
return dl.Dial(ctx, udpAddr, tlsConf, conf)
return tr.dial(ctx, udpAddr, addr, tlsConf, conf, false)
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// It uses a new UDP connection and closes this connection when the QUIC connection is closed.
// See DialAddr for more details.
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
@ -71,20 +73,20 @@ func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *
if err != nil {
return nil, err
}
dl, err := setupTransport(udpConn, tlsConf, true)
tr, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
conn, err := dl.DialEarly(ctx, udpAddr, tlsConf, conf)
conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true)
if err != nil {
dl.Close()
tr.Close()
return nil, err
}
return conn, nil
}
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn using the provided context.
// See DialEarly for details.
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
// See Dial for more details.
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
@ -98,12 +100,15 @@ func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tl
return conn, nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn. If
// the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn
// does), ECN and packet info support will be enabled. In this case, ReadMsgUDP
// and WriteMsgUDP will be used instead of ReadFrom and WriteTo to read/write
// packets.
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
// will be used instead of ReadFrom and WriteTo to read/write packets.
// The tls.Config must define an application protocol (using NextProtos).
//
// This is a convenience function. More advanced use cases should instantiate a Transport,
// which offers configuration options for a more fine-grained control of the connection establishment,
// including reusing the underlying UDP socket for multiple QUIC connections.
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
@ -148,7 +153,7 @@ func dial(
if c.config.Tracer != nil {
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
}
if c.tracer != nil {
if c.tracer != nil && c.tracer.StartedConnection != nil {
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
}
if err := c.dial(ctx); err != nil {
@ -158,12 +163,6 @@ func dial(
}
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
if tlsConf == nil {
tlsConf = &tls.Config{}
} else {
tlsConf = tlsConf.Clone()
}
srcConnID, err := connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
@ -234,7 +233,7 @@ func (c *client) dial(ctx context.Context) error {
select {
case <-ctx.Done():
c.conn.shutdown()
return ctx.Err()
return context.Cause(ctx)
case err := <-errorChan:
return err
case recreateErr := <-recreateChan:

View File

@ -16,13 +16,13 @@ type closedLocalConn struct {
perspective protocol.Perspective
logger utils.Logger
sendPacket func(net.Addr, *packetInfo)
sendPacket func(net.Addr, packetInfo)
}
var _ packetHandler = &closedLocalConn{}
// newClosedLocalConn creates a new closedLocalConn and runs it.
func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
return &closedLocalConn{
sendPacket: sendPacket,
perspective: pers,
@ -30,7 +30,7 @@ func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Pe
}
}
func (c *closedLocalConn) handlePacket(p *receivedPacket) {
func (c *closedLocalConn) handlePacket(p receivedPacket) {
c.counter++
// exponential backoff
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
@ -58,7 +58,7 @@ func newClosedRemoteConn(pers protocol.Perspective) packetHandler {
return &closedRemoteConn{perspective: pers}
}
func (s *closedRemoteConn) handlePacket(*receivedPacket) {}
func (s *closedRemoteConn) handlePacket(receivedPacket) {}
func (s *closedRemoteConn) shutdown() {}
func (s *closedRemoteConn) destroy(error) {}
func (s *closedRemoteConn) getPerspective() protocol.Perspective { return s.perspective }

View File

@ -1,18 +1,10 @@
coverage:
round: nearest
ignore:
- streams_map_incoming_bidi.go
- streams_map_incoming_uni.go
- streams_map_outgoing_bidi.go
- streams_map_outgoing_uni.go
- http3/gzip_reader.go
- interop/
- internal/ackhandler/packet_linkedlist.go
- internal/utils/byteinterval_linkedlist.go
- internal/utils/newconnectionid_linkedlist.go
- internal/utils/packetinterval_linkedlist.go
- internal/handshake/cipher_suite.go
- internal/utils/linkedlist/linkedlist.go
- logging/null_tracer.go
- fuzzing/
- metrics/
status:

View File

@ -1,13 +1,12 @@
package quic
import (
"errors"
"fmt"
"net"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)
// Clone clones a Config
@ -17,18 +16,29 @@ func (c *Config) Clone() *Config {
}
func (c *Config) handshakeTimeout() time.Duration {
return utils.Max(protocol.DefaultHandshakeTimeout, 2*c.HandshakeIdleTimeout)
return 2 * c.HandshakeIdleTimeout
}
func (c *Config) maxRetryTokenAge() time.Duration {
return c.handshakeTimeout()
}
func validateConfig(config *Config) error {
if config == nil {
return nil
}
if config.MaxIncomingStreams > 1<<60 {
return errors.New("invalid value for Config.MaxIncomingStreams")
const maxStreams = 1 << 60
if config.MaxIncomingStreams > maxStreams {
config.MaxIncomingStreams = maxStreams
}
if config.MaxIncomingUniStreams > 1<<60 {
return errors.New("invalid value for Config.MaxIncomingUniStreams")
if config.MaxIncomingUniStreams > maxStreams {
config.MaxIncomingUniStreams = maxStreams
}
if config.MaxStreamReceiveWindow > quicvarint.Max {
config.MaxStreamReceiveWindow = quicvarint.Max
}
if config.MaxConnectionReceiveWindow > quicvarint.Max {
config.MaxConnectionReceiveWindow = quicvarint.Max
}
// check that all QUIC versions are actually supported
for _, v := range config.Versions {
@ -43,12 +53,6 @@ func validateConfig(config *Config) error {
// it may be called with nil
func populateServerConfig(config *Config) *Config {
config = populateConfig(config)
if config.MaxTokenAge == 0 {
config.MaxTokenAge = protocol.TokenValidity
}
if config.MaxRetryTokenAge == 0 {
config.MaxRetryTokenAge = protocol.RetryTokenValidity
}
if config.RequireAddressValidation == nil {
config.RequireAddressValidation = func(net.Addr) bool { return false }
}
@ -101,33 +105,25 @@ func populateConfig(config *Config) *Config {
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
maxDatagrameFrameSize := config.MaxDatagramFrameSize
if maxDatagrameFrameSize == 0 {
maxDatagrameFrameSize = int64(protocol.DefaultMaxDatagramFrameSize)
}
return &Config{
GetConfigForClient: config.GetConfigForClient,
Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout,
MaxTokenAge: config.MaxTokenAge,
MaxRetryTokenAge: config.MaxRetryTokenAge,
RequireAddressValidation: config.RequireAddressValidation,
KeepAlivePeriod: config.KeepAlivePeriod,
InitialStreamReceiveWindow: initialStreamReceiveWindow,
MaxStreamReceiveWindow: maxStreamReceiveWindow,
InitialConnectionReceiveWindow: initialConnectionReceiveWindow,
MaxConnectionReceiveWindow: maxConnectionReceiveWindow,
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,
MaxDatagramFrameSize: maxDatagrameFrameSize,
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets,
Allow0RTT: config.Allow0RTT,
Tracer: config.Tracer,
GetConfigForClient: config.GetConfigForClient,
Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout,
RequireAddressValidation: config.RequireAddressValidation,
KeepAlivePeriod: config.KeepAlivePeriod,
InitialStreamReceiveWindow: initialStreamReceiveWindow,
MaxStreamReceiveWindow: maxStreamReceiveWindow,
InitialConnectionReceiveWindow: initialConnectionReceiveWindow,
MaxConnectionReceiveWindow: maxConnectionReceiveWindow,
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
Allow0RTT: config.Allow0RTT,
Tracer: config.Tracer,
}
}

File diff suppressed because it is too large Load Diff

View File

@ -71,17 +71,9 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error {
// GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStreamImpl) GetCryptoData() []byte {
if len(s.msgBuf) < 4 {
return nil
}
msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3])
if len(s.msgBuf) < msgLen {
return nil
}
msg := make([]byte, msgLen)
copy(msg, s.msgBuf[:msgLen])
s.msgBuf = s.msgBuf[msgLen:]
return msg
b := s.msgBuf
s.msgBuf = nil
return b
}
func (s *cryptoStreamImpl) Finish() error {

View File

@ -3,12 +3,14 @@ package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
type cryptoDataHandler interface {
HandleMessage([]byte, protocol.EncryptionLevel) bool
HandleMessage([]byte, protocol.EncryptionLevel) error
NextEvent() handshake.Event
}
type cryptoStreamManager struct {
@ -33,7 +35,7 @@ func newCryptoStreamManager(
}
}
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) {
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
var str cryptoStream
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
switch encLevel {
@ -44,18 +46,37 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve
case protocol.Encryption1RTT:
str = m.oneRTTStream
default:
return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
}
if err := str.HandleCryptoFrame(frame); err != nil {
return false, err
return err
}
for {
data := str.GetCryptoData()
if data == nil {
return false, nil
return nil
}
if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished {
return true, str.Finish()
if err := m.cryptoHandler.HandleMessage(data, encLevel); err != nil {
return err
}
}
}
func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {
if !m.oneRTTStream.HasData() {
return nil
}
return m.oneRTTStream.PopCryptoFrame(maxSize)
}
func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error {
//nolint:exhaustive // 1-RTT keys should never get dropped.
switch encLevel {
case protocol.EncryptionInitial:
return m.initialStream.Finish()
case protocol.EncryptionHandshake:
return m.handshakeStream.Finish()
default:
panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel))
}
}

View File

@ -1,6 +1,7 @@
package quic
import (
"context"
"sync"
"github.com/quic-go/quic-go/internal/protocol"
@ -98,7 +99,7 @@ func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
}
// Receive gets a received DATAGRAM frame.
func (h *datagramQueue) Receive() ([]byte, error) {
func (h *datagramQueue) Receive(ctx context.Context) ([]byte, error) {
for {
h.rcvMx.Lock()
if len(h.rcvQueue) > 0 {
@ -113,6 +114,8 @@ func (h *datagramQueue) Receive() ([]byte, error) {
continue
case <-h.closed:
return nil, h.closeErr
case <-ctx.Done():
return nil, ctx.Err()
}
}
}

View File

@ -61,3 +61,15 @@ func (e *StreamError) Error() string {
}
return fmt.Sprintf("stream %d canceled by %s with error code %d", e.StreamID, pers, e.ErrorCode)
}
// DatagramTooLargeError is returned from Connection.SendDatagram if the payload is too large to be sent.
type DatagramTooLargeError struct {
PeerMaxDatagramFrameSize int64
}
func (e *DatagramTooLargeError) Is(target error) bool {
_, ok := target.(*DatagramTooLargeError)
return ok
}
func (e *DatagramTooLargeError) Error() string { return "DATAGRAM frame too large" }

Some files were not shown because too many files have changed in this diff Show More