TUN-5405: Update net package to v0.0.0-20211109214657-ef0fda0de508

This version contains fix to https://github.com/golang/go/issues/43989
This commit is contained in:
cthuang 2021-11-10 17:20:10 +00:00
parent 794635fb54
commit 7024d193c9
28 changed files with 1679 additions and 982 deletions

2
go.mod
View File

@ -46,7 +46,7 @@ require (
github.com/urfave/cli/v2 v2.2.0 github.com/urfave/cli/v2 v2.2.0
go.uber.org/automaxprocs v1.4.0 go.uber.org/automaxprocs v1.4.0
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 golang.org/x/net v0.0.0-20211109214657-ef0fda0de508
golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43 // indirect golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1

2
go.sum
View File

@ -818,6 +818,8 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.0.0-20211109214657-ef0fda0de508 h1:v3NKo+t/Kc3EASxaKZ82lwK6mCf4ZeObQBduYFZHo7c=
golang.org/x/net v0.0.0-20211109214657-ef0fda0de508/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=

20
vendor/golang.org/x/net/http2/README generated vendored
View File

@ -1,20 +0,0 @@
This is a work-in-progress HTTP/2 implementation for Go.
It will eventually live in the Go standard library and won't require
any changes to your code to use. It will just be automatic.
Status:
* The server support is pretty good. A few things are missing
but are being worked on.
* The client work has just started but shares a lot of code
is coming along much quicker.
Docs are at https://godoc.org/golang.org/x/net/http2
Demo test server at https://http2.golang.org/
Help & bug reports welcome!
Contributing: https://golang.org/doc/contribute.html
Bugs: https://golang.org/issue/new?title=x/net/http2:+

53
vendor/golang.org/x/net/http2/ascii.go generated vendored Normal file
View File

@ -0,0 +1,53 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import "strings"
// The HTTP protocols are defined in terms of ASCII, not Unicode. This file
// contains helper functions which may use Unicode-aware functions which would
// otherwise be unsafe and could introduce vulnerabilities if used improperly.
// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
// are equal, ASCII-case-insensitively.
func asciiEqualFold(s, t string) bool {
if len(s) != len(t) {
return false
}
for i := 0; i < len(s); i++ {
if lower(s[i]) != lower(t[i]) {
return false
}
}
return true
}
// lower returns the ASCII lowercase version of b.
func lower(b byte) byte {
if 'A' <= b && b <= 'Z' {
return b + ('a' - 'A')
}
return b
}
// isASCIIPrint returns whether s is ASCII and printable according to
// https://tools.ietf.org/html/rfc20#section-4.2.
func isASCIIPrint(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < ' ' || s[i] > '~' {
return false
}
}
return true
}
// asciiToLower returns the lowercase version of s if s is ASCII and printable,
// and whether or not it was.
func asciiToLower(s string) (lower string, ok bool) {
if !isASCIIPrint(s) {
return "", false
}
return strings.ToLower(s), true
}

View File

@ -7,13 +7,21 @@
package http2 package http2
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors"
"net/http" "net/http"
"sync" "sync"
) )
// ClientConnPool manages a pool of HTTP/2 client connections. // ClientConnPool manages a pool of HTTP/2 client connections.
type ClientConnPool interface { type ClientConnPool interface {
// GetClientConn returns a specific HTTP/2 connection (usually
// a TLS-TCP connection) to an HTTP/2 server. On success, the
// returned ClientConn accounts for the upcoming RoundTrip
// call, so the caller should not omit it. If the caller needs
// to, ClientConn.RoundTrip can be called with a bogus
// new(http.Request) to release the stream reservation.
GetClientConn(req *http.Request, addr string) (*ClientConn, error) GetClientConn(req *http.Request, addr string) (*ClientConn, error)
MarkDead(*ClientConn) MarkDead(*ClientConn)
} }
@ -40,7 +48,7 @@ type clientConnPool struct {
conns map[string][]*ClientConn // key is host:port conns map[string][]*ClientConn // key is host:port
dialing map[string]*dialCall // currently in-flight dials dialing map[string]*dialCall // currently in-flight dials
keys map[*ClientConn][]string keys map[*ClientConn][]string
addConnCalls map[string]*addConnCall // in-flight addConnIfNeede calls addConnCalls map[string]*addConnCall // in-flight addConnIfNeeded calls
} }
func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
@ -52,87 +60,85 @@ const (
noDialOnMiss = false noDialOnMiss = false
) )
// shouldTraceGetConn reports whether getClientConn should call any
// ClientTrace.GetConn hook associated with the http.Request.
//
// This complexity is needed to avoid double calls of the GetConn hook
// during the back-and-forth between net/http and x/net/http2 (when the
// net/http.Transport is upgraded to also speak http2), as well as support
// the case where x/net/http2 is being used directly.
func (p *clientConnPool) shouldTraceGetConn(st clientConnIdleState) bool {
// If our Transport wasn't made via ConfigureTransport, always
// trace the GetConn hook if provided, because that means the
// http2 package is being used directly and it's the one
// dialing, as opposed to net/http.
if _, ok := p.t.ConnPool.(noDialClientConnPool); !ok {
return true
}
// Otherwise, only use the GetConn hook if this connection has
// been used previously for other requests. For fresh
// connections, the net/http package does the dialing.
return !st.freshConn
}
func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) { func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
// TODO(dneil): Dial a new connection when t.DisableKeepAlives is set?
if isConnectionCloseRequest(req) && dialOnMiss { if isConnectionCloseRequest(req) && dialOnMiss {
// It gets its own connection. // It gets its own connection.
traceGetConn(req, addr) traceGetConn(req, addr)
const singleUse = true const singleUse = true
cc, err := p.t.dialClientConn(addr, singleUse) cc, err := p.t.dialClientConn(req.Context(), addr, singleUse)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return cc, nil return cc, nil
} }
p.mu.Lock() for {
for _, cc := range p.conns[addr] { p.mu.Lock()
if st := cc.idleState(); st.canTakeNewRequest { for _, cc := range p.conns[addr] {
if p.shouldTraceGetConn(st) { if cc.ReserveNewRequest() {
traceGetConn(req, addr) // When a connection is presented to us by the net/http package,
// the GetConn hook has already been called.
// Don't call it a second time here.
if !cc.getConnCalled {
traceGetConn(req, addr)
}
cc.getConnCalled = false
p.mu.Unlock()
return cc, nil
} }
}
if !dialOnMiss {
p.mu.Unlock() p.mu.Unlock()
return nil, ErrNoCachedConn
}
traceGetConn(req, addr)
call := p.getStartDialLocked(req.Context(), addr)
p.mu.Unlock()
<-call.done
if shouldRetryDial(call, req) {
continue
}
cc, err := call.res, call.err
if err != nil {
return nil, err
}
if cc.ReserveNewRequest() {
return cc, nil return cc, nil
} }
} }
if !dialOnMiss {
p.mu.Unlock()
return nil, ErrNoCachedConn
}
traceGetConn(req, addr)
call := p.getStartDialLocked(addr)
p.mu.Unlock()
<-call.done
return call.res, call.err
} }
// dialCall is an in-flight Transport dial call to a host. // dialCall is an in-flight Transport dial call to a host.
type dialCall struct { type dialCall struct {
_ incomparable _ incomparable
p *clientConnPool p *clientConnPool
// the context associated with the request
// that created this dialCall
ctx context.Context
done chan struct{} // closed when done done chan struct{} // closed when done
res *ClientConn // valid after done is closed res *ClientConn // valid after done is closed
err error // valid after done is closed err error // valid after done is closed
} }
// requires p.mu is held. // requires p.mu is held.
func (p *clientConnPool) getStartDialLocked(addr string) *dialCall { func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
if call, ok := p.dialing[addr]; ok { if call, ok := p.dialing[addr]; ok {
// A dial is already in-flight. Don't start another. // A dial is already in-flight. Don't start another.
return call return call
} }
call := &dialCall{p: p, done: make(chan struct{})} call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
if p.dialing == nil { if p.dialing == nil {
p.dialing = make(map[string]*dialCall) p.dialing = make(map[string]*dialCall)
} }
p.dialing[addr] = call p.dialing[addr] = call
go call.dial(addr) go call.dial(call.ctx, addr)
return call return call
} }
// run in its own goroutine. // run in its own goroutine.
func (c *dialCall) dial(addr string) { func (c *dialCall) dial(ctx context.Context, addr string) {
const singleUse = false // shared conn const singleUse = false // shared conn
c.res, c.err = c.p.t.dialClientConn(addr, singleUse) c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
close(c.done) close(c.done)
c.p.mu.Lock() c.p.mu.Lock()
@ -195,6 +201,7 @@ func (c *addConnCall) run(t *Transport, key string, tc *tls.Conn) {
if err != nil { if err != nil {
c.err = err c.err = err
} else { } else {
cc.getConnCalled = true // already called by the net/http package
p.addConnLocked(key, cc) p.addConnLocked(key, cc)
} }
delete(p.addConnCalls, key) delete(p.addConnCalls, key)
@ -276,3 +283,28 @@ type noDialClientConnPool struct{ *clientConnPool }
func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) { func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
return p.getClientConn(req, addr, noDialOnMiss) return p.getClientConn(req, addr, noDialOnMiss)
} }
// shouldRetryDial reports whether the current request should
// retry dialing after the call finished unsuccessfully, for example
// if the dial was canceled because of a context cancellation or
// deadline expiry.
func shouldRetryDial(call *dialCall, req *http.Request) bool {
if call.err == nil {
// No error, no need to retry
return false
}
if call.ctx == req.Context() {
// If the call has the same context as the request, the dial
// should not be retried, since any cancellation will have come
// from this request.
return false
}
if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
// If the call error is not because of a context cancellation or a deadline expiry,
// the dial should not be retried.
return false
}
// Only retry if the error is a context cancellation error or deadline expiry
// and the context associated with the call was canceled or expired.
return call.ctx.Err() != nil
}

View File

@ -53,6 +53,13 @@ func (e ErrCode) String() string {
return fmt.Sprintf("unknown error code 0x%x", uint32(e)) return fmt.Sprintf("unknown error code 0x%x", uint32(e))
} }
func (e ErrCode) stringToken() string {
if s, ok := errCodeName[e]; ok {
return s
}
return fmt.Sprintf("ERR_UNKNOWN_%d", uint32(e))
}
// ConnectionError is an error that results in the termination of the // ConnectionError is an error that results in the termination of the
// entire connection. // entire connection.
type ConnectionError ErrCode type ConnectionError ErrCode
@ -67,6 +74,11 @@ type StreamError struct {
Cause error // optional additional detail Cause error // optional additional detail
} }
// errFromPeer is a sentinel error value for StreamError.Cause to
// indicate that the StreamError was sent from the peer over the wire
// and wasn't locally generated in the Transport.
var errFromPeer = errors.New("received from peer")
func streamError(id uint32, code ErrCode) StreamError { func streamError(id uint32, code ErrCode) StreamError {
return StreamError{StreamID: id, Code: code} return StreamError{StreamID: id, Code: code}
} }

View File

@ -122,7 +122,7 @@ var flagName = map[FrameType]map[Flags]string{
// a frameParser parses a frame given its FrameHeader and payload // a frameParser parses a frame given its FrameHeader and payload
// bytes. The length of payload will always equal fh.Length (which // bytes. The length of payload will always equal fh.Length (which
// might be 0). // might be 0).
type frameParser func(fc *frameCache, fh FrameHeader, payload []byte) (Frame, error) type frameParser func(fc *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error)
var frameParsers = map[FrameType]frameParser{ var frameParsers = map[FrameType]frameParser{
FrameData: parseDataFrame, FrameData: parseDataFrame,
@ -267,6 +267,11 @@ type Framer struct {
lastFrame Frame lastFrame Frame
errDetail error errDetail error
// countError is a non-nil func that's called on a frame parse
// error with some unique error path token. It's initialized
// from Transport.CountError or Server.CountError.
countError func(errToken string)
// lastHeaderStream is non-zero if the last frame was an // lastHeaderStream is non-zero if the last frame was an
// unfinished HEADERS/CONTINUATION. // unfinished HEADERS/CONTINUATION.
lastHeaderStream uint32 lastHeaderStream uint32
@ -426,6 +431,7 @@ func NewFramer(w io.Writer, r io.Reader) *Framer {
fr := &Framer{ fr := &Framer{
w: w, w: w,
r: r, r: r,
countError: func(string) {},
logReads: logFrameReads, logReads: logFrameReads,
logWrites: logFrameWrites, logWrites: logFrameWrites,
debugReadLoggerf: log.Printf, debugReadLoggerf: log.Printf,
@ -500,7 +506,7 @@ func (fr *Framer) ReadFrame() (Frame, error) {
if _, err := io.ReadFull(fr.r, payload); err != nil { if _, err := io.ReadFull(fr.r, payload); err != nil {
return nil, err return nil, err
} }
f, err := typeFrameParser(fh.Type)(fr.frameCache, fh, payload) f, err := typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload)
if err != nil { if err != nil {
if ce, ok := err.(connError); ok { if ce, ok := err.(connError); ok {
return nil, fr.connError(ce.Code, ce.Reason) return nil, fr.connError(ce.Code, ce.Reason)
@ -588,13 +594,14 @@ func (f *DataFrame) Data() []byte {
return f.data return f.data
} }
func parseDataFrame(fc *frameCache, fh FrameHeader, payload []byte) (Frame, error) { func parseDataFrame(fc *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error) {
if fh.StreamID == 0 { if fh.StreamID == 0 {
// DATA frames MUST be associated with a stream. If a // DATA frames MUST be associated with a stream. If a
// DATA frame is received whose stream identifier // DATA frame is received whose stream identifier
// field is 0x0, the recipient MUST respond with a // field is 0x0, the recipient MUST respond with a
// connection error (Section 5.4.1) of type // connection error (Section 5.4.1) of type
// PROTOCOL_ERROR. // PROTOCOL_ERROR.
countError("frame_data_stream_0")
return nil, connError{ErrCodeProtocol, "DATA frame with stream ID 0"} return nil, connError{ErrCodeProtocol, "DATA frame with stream ID 0"}
} }
f := fc.getDataFrame() f := fc.getDataFrame()
@ -605,6 +612,7 @@ func parseDataFrame(fc *frameCache, fh FrameHeader, payload []byte) (Frame, erro
var err error var err error
payload, padSize, err = readByte(payload) payload, padSize, err = readByte(payload)
if err != nil { if err != nil {
countError("frame_data_pad_byte_short")
return nil, err return nil, err
} }
} }
@ -613,6 +621,7 @@ func parseDataFrame(fc *frameCache, fh FrameHeader, payload []byte) (Frame, erro
// length of the frame payload, the recipient MUST // length of the frame payload, the recipient MUST
// treat this as a connection error. // treat this as a connection error.
// Filed: https://github.com/http2/http2-spec/issues/610 // Filed: https://github.com/http2/http2-spec/issues/610
countError("frame_data_pad_too_big")
return nil, connError{ErrCodeProtocol, "pad size larger than data payload"} return nil, connError{ErrCodeProtocol, "pad size larger than data payload"}
} }
f.data = payload[:len(payload)-int(padSize)] f.data = payload[:len(payload)-int(padSize)]
@ -695,7 +704,7 @@ type SettingsFrame struct {
p []byte p []byte
} }
func parseSettingsFrame(_ *frameCache, fh FrameHeader, p []byte) (Frame, error) { func parseSettingsFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (Frame, error) {
if fh.Flags.Has(FlagSettingsAck) && fh.Length > 0 { if fh.Flags.Has(FlagSettingsAck) && fh.Length > 0 {
// When this (ACK 0x1) bit is set, the payload of the // When this (ACK 0x1) bit is set, the payload of the
// SETTINGS frame MUST be empty. Receipt of a // SETTINGS frame MUST be empty. Receipt of a
@ -703,6 +712,7 @@ func parseSettingsFrame(_ *frameCache, fh FrameHeader, p []byte) (Frame, error)
// field value other than 0 MUST be treated as a // field value other than 0 MUST be treated as a
// connection error (Section 5.4.1) of type // connection error (Section 5.4.1) of type
// FRAME_SIZE_ERROR. // FRAME_SIZE_ERROR.
countError("frame_settings_ack_with_length")
return nil, ConnectionError(ErrCodeFrameSize) return nil, ConnectionError(ErrCodeFrameSize)
} }
if fh.StreamID != 0 { if fh.StreamID != 0 {
@ -713,14 +723,17 @@ func parseSettingsFrame(_ *frameCache, fh FrameHeader, p []byte) (Frame, error)
// field is anything other than 0x0, the endpoint MUST // field is anything other than 0x0, the endpoint MUST
// respond with a connection error (Section 5.4.1) of // respond with a connection error (Section 5.4.1) of
// type PROTOCOL_ERROR. // type PROTOCOL_ERROR.
countError("frame_settings_has_stream")
return nil, ConnectionError(ErrCodeProtocol) return nil, ConnectionError(ErrCodeProtocol)
} }
if len(p)%6 != 0 { if len(p)%6 != 0 {
countError("frame_settings_mod_6")
// Expecting even number of 6 byte settings. // Expecting even number of 6 byte settings.
return nil, ConnectionError(ErrCodeFrameSize) return nil, ConnectionError(ErrCodeFrameSize)
} }
f := &SettingsFrame{FrameHeader: fh, p: p} f := &SettingsFrame{FrameHeader: fh, p: p}
if v, ok := f.Value(SettingInitialWindowSize); ok && v > (1<<31)-1 { if v, ok := f.Value(SettingInitialWindowSize); ok && v > (1<<31)-1 {
countError("frame_settings_window_size_too_big")
// Values above the maximum flow control window size of 2^31 - 1 MUST // Values above the maximum flow control window size of 2^31 - 1 MUST
// be treated as a connection error (Section 5.4.1) of type // be treated as a connection error (Section 5.4.1) of type
// FLOW_CONTROL_ERROR. // FLOW_CONTROL_ERROR.
@ -832,11 +845,13 @@ type PingFrame struct {
func (f *PingFrame) IsAck() bool { return f.Flags.Has(FlagPingAck) } func (f *PingFrame) IsAck() bool { return f.Flags.Has(FlagPingAck) }
func parsePingFrame(_ *frameCache, fh FrameHeader, payload []byte) (Frame, error) { func parsePingFrame(_ *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error) {
if len(payload) != 8 { if len(payload) != 8 {
countError("frame_ping_length")
return nil, ConnectionError(ErrCodeFrameSize) return nil, ConnectionError(ErrCodeFrameSize)
} }
if fh.StreamID != 0 { if fh.StreamID != 0 {
countError("frame_ping_has_stream")
return nil, ConnectionError(ErrCodeProtocol) return nil, ConnectionError(ErrCodeProtocol)
} }
f := &PingFrame{FrameHeader: fh} f := &PingFrame{FrameHeader: fh}
@ -872,11 +887,13 @@ func (f *GoAwayFrame) DebugData() []byte {
return f.debugData return f.debugData
} }
func parseGoAwayFrame(_ *frameCache, fh FrameHeader, p []byte) (Frame, error) { func parseGoAwayFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (Frame, error) {
if fh.StreamID != 0 { if fh.StreamID != 0 {
countError("frame_goaway_has_stream")
return nil, ConnectionError(ErrCodeProtocol) return nil, ConnectionError(ErrCodeProtocol)
} }
if len(p) < 8 { if len(p) < 8 {
countError("frame_goaway_short")
return nil, ConnectionError(ErrCodeFrameSize) return nil, ConnectionError(ErrCodeFrameSize)
} }
return &GoAwayFrame{ return &GoAwayFrame{
@ -912,7 +929,7 @@ func (f *UnknownFrame) Payload() []byte {
return f.p return f.p
} }
func parseUnknownFrame(_ *frameCache, fh FrameHeader, p []byte) (Frame, error) { func parseUnknownFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (Frame, error) {
return &UnknownFrame{fh, p}, nil return &UnknownFrame{fh, p}, nil
} }
@ -923,8 +940,9 @@ type WindowUpdateFrame struct {
Increment uint32 // never read with high bit set Increment uint32 // never read with high bit set
} }
func parseWindowUpdateFrame(_ *frameCache, fh FrameHeader, p []byte) (Frame, error) { func parseWindowUpdateFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (Frame, error) {
if len(p) != 4 { if len(p) != 4 {
countError("frame_windowupdate_bad_len")
return nil, ConnectionError(ErrCodeFrameSize) return nil, ConnectionError(ErrCodeFrameSize)
} }
inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit
@ -936,8 +954,10 @@ func parseWindowUpdateFrame(_ *frameCache, fh FrameHeader, p []byte) (Frame, err
// control window MUST be treated as a connection // control window MUST be treated as a connection
// error (Section 5.4.1). // error (Section 5.4.1).
if fh.StreamID == 0 { if fh.StreamID == 0 {
countError("frame_windowupdate_zero_inc_conn")
return nil, ConnectionError(ErrCodeProtocol) return nil, ConnectionError(ErrCodeProtocol)
} }
countError("frame_windowupdate_zero_inc_stream")
return nil, streamError(fh.StreamID, ErrCodeProtocol) return nil, streamError(fh.StreamID, ErrCodeProtocol)
} }
return &WindowUpdateFrame{ return &WindowUpdateFrame{
@ -988,7 +1008,7 @@ func (f *HeadersFrame) HasPriority() bool {
return f.FrameHeader.Flags.Has(FlagHeadersPriority) return f.FrameHeader.Flags.Has(FlagHeadersPriority)
} }
func parseHeadersFrame(_ *frameCache, fh FrameHeader, p []byte) (_ Frame, err error) { func parseHeadersFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (_ Frame, err error) {
hf := &HeadersFrame{ hf := &HeadersFrame{
FrameHeader: fh, FrameHeader: fh,
} }
@ -997,11 +1017,13 @@ func parseHeadersFrame(_ *frameCache, fh FrameHeader, p []byte) (_ Frame, err er
// is received whose stream identifier field is 0x0, the recipient MUST // is received whose stream identifier field is 0x0, the recipient MUST
// respond with a connection error (Section 5.4.1) of type // respond with a connection error (Section 5.4.1) of type
// PROTOCOL_ERROR. // PROTOCOL_ERROR.
countError("frame_headers_zero_stream")
return nil, connError{ErrCodeProtocol, "HEADERS frame with stream ID 0"} return nil, connError{ErrCodeProtocol, "HEADERS frame with stream ID 0"}
} }
var padLength uint8 var padLength uint8
if fh.Flags.Has(FlagHeadersPadded) { if fh.Flags.Has(FlagHeadersPadded) {
if p, padLength, err = readByte(p); err != nil { if p, padLength, err = readByte(p); err != nil {
countError("frame_headers_pad_short")
return return
} }
} }
@ -1009,16 +1031,19 @@ func parseHeadersFrame(_ *frameCache, fh FrameHeader, p []byte) (_ Frame, err er
var v uint32 var v uint32
p, v, err = readUint32(p) p, v, err = readUint32(p)
if err != nil { if err != nil {
countError("frame_headers_prio_short")
return nil, err return nil, err
} }
hf.Priority.StreamDep = v & 0x7fffffff hf.Priority.StreamDep = v & 0x7fffffff
hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set
p, hf.Priority.Weight, err = readByte(p) p, hf.Priority.Weight, err = readByte(p)
if err != nil { if err != nil {
countError("frame_headers_prio_weight_short")
return nil, err return nil, err
} }
} }
if len(p)-int(padLength) <= 0 { if len(p)-int(padLength) < 0 {
countError("frame_headers_pad_too_big")
return nil, streamError(fh.StreamID, ErrCodeProtocol) return nil, streamError(fh.StreamID, ErrCodeProtocol)
} }
hf.headerFragBuf = p[:len(p)-int(padLength)] hf.headerFragBuf = p[:len(p)-int(padLength)]
@ -1125,11 +1150,13 @@ func (p PriorityParam) IsZero() bool {
return p == PriorityParam{} return p == PriorityParam{}
} }
func parsePriorityFrame(_ *frameCache, fh FrameHeader, payload []byte) (Frame, error) { func parsePriorityFrame(_ *frameCache, fh FrameHeader, countError func(string), payload []byte) (Frame, error) {
if fh.StreamID == 0 { if fh.StreamID == 0 {
countError("frame_priority_zero_stream")
return nil, connError{ErrCodeProtocol, "PRIORITY frame with stream ID 0"} return nil, connError{ErrCodeProtocol, "PRIORITY frame with stream ID 0"}
} }
if len(payload) != 5 { if len(payload) != 5 {
countError("frame_priority_bad_length")
return nil, connError{ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))} return nil, connError{ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))}
} }
v := binary.BigEndian.Uint32(payload[:4]) v := binary.BigEndian.Uint32(payload[:4])
@ -1172,11 +1199,13 @@ type RSTStreamFrame struct {
ErrCode ErrCode ErrCode ErrCode
} }
func parseRSTStreamFrame(_ *frameCache, fh FrameHeader, p []byte) (Frame, error) { func parseRSTStreamFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (Frame, error) {
if len(p) != 4 { if len(p) != 4 {
countError("frame_rststream_bad_len")
return nil, ConnectionError(ErrCodeFrameSize) return nil, ConnectionError(ErrCodeFrameSize)
} }
if fh.StreamID == 0 { if fh.StreamID == 0 {
countError("frame_rststream_zero_stream")
return nil, ConnectionError(ErrCodeProtocol) return nil, ConnectionError(ErrCodeProtocol)
} }
return &RSTStreamFrame{fh, ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil return &RSTStreamFrame{fh, ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil
@ -1202,8 +1231,9 @@ type ContinuationFrame struct {
headerFragBuf []byte headerFragBuf []byte
} }
func parseContinuationFrame(_ *frameCache, fh FrameHeader, p []byte) (Frame, error) { func parseContinuationFrame(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (Frame, error) {
if fh.StreamID == 0 { if fh.StreamID == 0 {
countError("frame_continuation_zero_stream")
return nil, connError{ErrCodeProtocol, "CONTINUATION frame with stream ID 0"} return nil, connError{ErrCodeProtocol, "CONTINUATION frame with stream ID 0"}
} }
return &ContinuationFrame{fh, p}, nil return &ContinuationFrame{fh, p}, nil
@ -1252,7 +1282,7 @@ func (f *PushPromiseFrame) HeadersEnded() bool {
return f.FrameHeader.Flags.Has(FlagPushPromiseEndHeaders) return f.FrameHeader.Flags.Has(FlagPushPromiseEndHeaders)
} }
func parsePushPromise(_ *frameCache, fh FrameHeader, p []byte) (_ Frame, err error) { func parsePushPromise(_ *frameCache, fh FrameHeader, countError func(string), p []byte) (_ Frame, err error) {
pp := &PushPromiseFrame{ pp := &PushPromiseFrame{
FrameHeader: fh, FrameHeader: fh,
} }
@ -1263,6 +1293,7 @@ func parsePushPromise(_ *frameCache, fh FrameHeader, p []byte) (_ Frame, err err
// with. If the stream identifier field specifies the value // with. If the stream identifier field specifies the value
// 0x0, a recipient MUST respond with a connection error // 0x0, a recipient MUST respond with a connection error
// (Section 5.4.1) of type PROTOCOL_ERROR. // (Section 5.4.1) of type PROTOCOL_ERROR.
countError("frame_pushpromise_zero_stream")
return nil, ConnectionError(ErrCodeProtocol) return nil, ConnectionError(ErrCodeProtocol)
} }
// The PUSH_PROMISE frame includes optional padding. // The PUSH_PROMISE frame includes optional padding.
@ -1270,18 +1301,21 @@ func parsePushPromise(_ *frameCache, fh FrameHeader, p []byte) (_ Frame, err err
var padLength uint8 var padLength uint8
if fh.Flags.Has(FlagPushPromisePadded) { if fh.Flags.Has(FlagPushPromisePadded) {
if p, padLength, err = readByte(p); err != nil { if p, padLength, err = readByte(p); err != nil {
countError("frame_pushpromise_pad_short")
return return
} }
} }
p, pp.PromiseID, err = readUint32(p) p, pp.PromiseID, err = readUint32(p)
if err != nil { if err != nil {
countError("frame_pushpromise_promiseid_short")
return return
} }
pp.PromiseID = pp.PromiseID & (1<<31 - 1) pp.PromiseID = pp.PromiseID & (1<<31 - 1)
if int(padLength) > len(p) { if int(padLength) > len(p) {
// like the DATA frame, error out if padding is longer than the body. // like the DATA frame, error out if padding is longer than the body.
countError("frame_pushpromise_pad_too_big")
return nil, ConnectionError(ErrCodeProtocol) return nil, ConnectionError(ErrCodeProtocol)
} }
pp.headerFragBuf = p[:len(p)-int(padLength)] pp.headerFragBuf = p[:len(p)-int(padLength)]

27
vendor/golang.org/x/net/http2/go115.go generated vendored Normal file
View File

@ -0,0 +1,27 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.15
// +build go1.15
package http2
import (
"context"
"crypto/tls"
)
// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS
// connection.
func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
dialer := &tls.Dialer{
Config: cfg,
}
cn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
return tlsCn, nil
}

View File

@ -6,7 +6,6 @@ package http2
import ( import (
"net/http" "net/http"
"strings"
"sync" "sync"
) )
@ -79,10 +78,10 @@ func buildCommonHeaderMaps() {
} }
} }
func lowerHeader(v string) string { func lowerHeader(v string) (lower string, ascii bool) {
buildCommonHeaderMapsOnce() buildCommonHeaderMapsOnce()
if s, ok := commonLowerHeader[v]; ok { if s, ok := commonLowerHeader[v]; ok {
return s return s, true
} }
return strings.ToLower(v) return asciiToLower(v)
} }

View File

@ -140,25 +140,29 @@ func buildRootHuffmanNode() {
panic("unexpected size") panic("unexpected size")
} }
lazyRootHuffmanNode = newInternalNode() lazyRootHuffmanNode = newInternalNode()
for i, code := range huffmanCodes { // allocate a leaf node for each of the 256 symbols
addDecoderNode(byte(i), code, huffmanCodeLen[i]) leaves := new([256]node)
}
}
func addDecoderNode(sym byte, code uint32, codeLen uint8) { for sym, code := range huffmanCodes {
cur := lazyRootHuffmanNode codeLen := huffmanCodeLen[sym]
for codeLen > 8 {
codeLen -= 8 cur := lazyRootHuffmanNode
i := uint8(code >> codeLen) for codeLen > 8 {
if cur.children[i] == nil { codeLen -= 8
cur.children[i] = newInternalNode() i := uint8(code >> codeLen)
if cur.children[i] == nil {
cur.children[i] = newInternalNode()
}
cur = cur.children[i]
}
shift := 8 - codeLen
start, end := int(uint8(code<<shift)), int(1<<shift)
leaves[sym].sym = byte(sym)
leaves[sym].codeLen = codeLen
for i := start; i < start+end; i++ {
cur.children[i] = &leaves[sym]
} }
cur = cur.children[i]
}
shift := 8 - codeLen
start, end := int(uint8(code<<shift)), int(1<<shift)
for i := start; i < start+end; i++ {
cur.children[i] = &node{sym: sym, codeLen: codeLen}
} }
} }

31
vendor/golang.org/x/net/http2/not_go115.go generated vendored Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.15
// +build !go1.15
package http2
import (
"context"
"crypto/tls"
)
// dialTLSWithContext opens a TLS connection.
func (t *Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
cn, err := tls.Dial(network, addr, cfg)
if err != nil {
return nil, err
}
if err := cn.Handshake(); err != nil {
return nil, err
}
if cfg.InsecureSkipVerify {
return cn, nil
}
if err := cn.VerifyHostname(cfg.ServerName); err != nil {
return nil, err
}
return cn, nil
}

View File

@ -30,6 +30,17 @@ type pipeBuffer interface {
io.Reader io.Reader
} }
// setBuffer initializes the pipe buffer.
// It has no effect if the pipe is already closed.
func (p *pipe) setBuffer(b pipeBuffer) {
p.mu.Lock()
defer p.mu.Unlock()
if p.err != nil || p.breakErr != nil {
return
}
p.b = b
}
func (p *pipe) Len() int { func (p *pipe) Len() int {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()

View File

@ -130,6 +130,12 @@ type Server struct {
// If nil, a default scheduler is chosen. // If nil, a default scheduler is chosen.
NewWriteScheduler func() WriteScheduler NewWriteScheduler func() WriteScheduler
// CountError, if non-nil, is called on HTTP/2 server errors.
// It's intended to increment a metric for monitoring, such
// as an expvar or Prometheus metric.
// The errType consists of only ASCII word characters.
CountError func(errType string)
// Internal state. This is a pointer (rather than embedded directly) // Internal state. This is a pointer (rather than embedded directly)
// so that we don't embed a Mutex in this struct, which will make the // so that we don't embed a Mutex in this struct, which will make the
// struct non-copyable, which might break some callers. // struct non-copyable, which might break some callers.
@ -231,13 +237,12 @@ func ConfigureServer(s *http.Server, conf *Server) error {
if s.TLSConfig == nil { if s.TLSConfig == nil {
s.TLSConfig = new(tls.Config) s.TLSConfig = new(tls.Config)
} else if s.TLSConfig.CipherSuites != nil { } else if s.TLSConfig.CipherSuites != nil && s.TLSConfig.MinVersion < tls.VersionTLS13 {
// If they already provided a CipherSuite list, return // If they already provided a TLS 1.01.2 CipherSuite list, return an
// an error if it has a bad order or is missing // error if it is missing ECDHE_RSA_WITH_AES_128_GCM_SHA256 or
// ECDHE_RSA_WITH_AES_128_GCM_SHA256 or ECDHE_ECDSA_WITH_AES_128_GCM_SHA256. // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.
haveRequired := false haveRequired := false
sawBad := false for _, cs := range s.TLSConfig.CipherSuites {
for i, cs := range s.TLSConfig.CipherSuites {
switch cs { switch cs {
case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
// Alternative MTI cipher to not discourage ECDSA-only servers. // Alternative MTI cipher to not discourage ECDSA-only servers.
@ -245,14 +250,9 @@ func ConfigureServer(s *http.Server, conf *Server) error {
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
haveRequired = true haveRequired = true
} }
if isBadCipher(cs) {
sawBad = true
} else if sawBad {
return fmt.Errorf("http2: TLSConfig.CipherSuites index %d contains an HTTP/2-approved cipher suite (%#04x), but it comes after unapproved cipher suites. With this configuration, clients that don't support previous, approved cipher suites may be given an unapproved one and reject the connection.", i, cs)
}
} }
if !haveRequired { if !haveRequired {
return fmt.Errorf("http2: TLSConfig.CipherSuites is missing an HTTP/2-required AES_128_GCM_SHA256 cipher (need at least one of TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 or TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256).") return fmt.Errorf("http2: TLSConfig.CipherSuites is missing an HTTP/2-required AES_128_GCM_SHA256 cipher (need at least one of TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 or TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)")
} }
} }
@ -265,16 +265,12 @@ func ConfigureServer(s *http.Server, conf *Server) error {
s.TLSConfig.PreferServerCipherSuites = true s.TLSConfig.PreferServerCipherSuites = true
haveNPN := false if !strSliceContains(s.TLSConfig.NextProtos, NextProtoTLS) {
for _, p := range s.TLSConfig.NextProtos {
if p == NextProtoTLS {
haveNPN = true
break
}
}
if !haveNPN {
s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, NextProtoTLS) s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, NextProtoTLS)
} }
if !strSliceContains(s.TLSConfig.NextProtos, "http/1.1") {
s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "http/1.1")
}
if s.TLSNextProto == nil { if s.TLSNextProto == nil {
s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){}
@ -415,6 +411,9 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
fr := NewFramer(sc.bw, c) fr := NewFramer(sc.bw, c)
if s.CountError != nil {
fr.countError = s.CountError
}
fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil)
fr.MaxHeaderListSize = sc.maxHeaderListSize() fr.MaxHeaderListSize = sc.maxHeaderListSize()
fr.SetMaxReadFrameSize(s.maxReadFrameSize()) fr.SetMaxReadFrameSize(s.maxReadFrameSize())
@ -826,7 +825,7 @@ func (sc *serverConn) serve() {
}) })
sc.unackedSettings++ sc.unackedSettings++
// Each connection starts with intialWindowSize inflow tokens. // Each connection starts with initialWindowSize inflow tokens.
// If a higher value is configured, we add more tokens. // If a higher value is configured, we add more tokens.
if diff := sc.srv.initialConnRecvWindowSize() - initialWindowSize; diff > 0 { if diff := sc.srv.initialConnRecvWindowSize() - initialWindowSize; diff > 0 {
sc.sendWindowUpdate(nil, int(diff)) sc.sendWindowUpdate(nil, int(diff))
@ -866,6 +865,15 @@ func (sc *serverConn) serve() {
case res := <-sc.wroteFrameCh: case res := <-sc.wroteFrameCh:
sc.wroteFrame(res) sc.wroteFrame(res)
case res := <-sc.readFrameCh: case res := <-sc.readFrameCh:
// Process any written frames before reading new frames from the client since a
// written frame could have triggered a new stream to be started.
if sc.writingFrameAsync {
select {
case wroteRes := <-sc.wroteFrameCh:
sc.wroteFrame(wroteRes)
default:
}
}
if !sc.processFrameFromReader(res) { if !sc.processFrameFromReader(res) {
return return
} }
@ -1400,7 +1408,7 @@ func (sc *serverConn) processFrame(f Frame) error {
// First frame received must be SETTINGS. // First frame received must be SETTINGS.
if !sc.sawFirstSettings { if !sc.sawFirstSettings {
if _, ok := f.(*SettingsFrame); !ok { if _, ok := f.(*SettingsFrame); !ok {
return ConnectionError(ErrCodeProtocol) return sc.countError("first_settings", ConnectionError(ErrCodeProtocol))
} }
sc.sawFirstSettings = true sc.sawFirstSettings = true
} }
@ -1425,7 +1433,7 @@ func (sc *serverConn) processFrame(f Frame) error {
case *PushPromiseFrame: case *PushPromiseFrame:
// A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE
// frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
return ConnectionError(ErrCodeProtocol) return sc.countError("push_promise", ConnectionError(ErrCodeProtocol))
default: default:
sc.vlogf("http2: server ignoring frame: %v", f.Header()) sc.vlogf("http2: server ignoring frame: %v", f.Header())
return nil return nil
@ -1445,7 +1453,7 @@ func (sc *serverConn) processPing(f *PingFrame) error {
// identifier field value other than 0x0, the recipient MUST // identifier field value other than 0x0, the recipient MUST
// respond with a connection error (Section 5.4.1) of type // respond with a connection error (Section 5.4.1) of type
// PROTOCOL_ERROR." // PROTOCOL_ERROR."
return ConnectionError(ErrCodeProtocol) return sc.countError("ping_on_stream", ConnectionError(ErrCodeProtocol))
} }
if sc.inGoAway && sc.goAwayCode != ErrCodeNo { if sc.inGoAway && sc.goAwayCode != ErrCodeNo {
return nil return nil
@ -1464,7 +1472,7 @@ func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
// or PRIORITY on a stream in this state MUST be // or PRIORITY on a stream in this state MUST be
// treated as a connection error (Section 5.4.1) of // treated as a connection error (Section 5.4.1) of
// type PROTOCOL_ERROR." // type PROTOCOL_ERROR."
return ConnectionError(ErrCodeProtocol) return sc.countError("stream_idle", ConnectionError(ErrCodeProtocol))
} }
if st == nil { if st == nil {
// "WINDOW_UPDATE can be sent by a peer that has sent a // "WINDOW_UPDATE can be sent by a peer that has sent a
@ -1475,7 +1483,7 @@ func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
return nil return nil
} }
if !st.flow.add(int32(f.Increment)) { if !st.flow.add(int32(f.Increment)) {
return streamError(f.StreamID, ErrCodeFlowControl) return sc.countError("bad_flow", streamError(f.StreamID, ErrCodeFlowControl))
} }
default: // connection-level flow control default: // connection-level flow control
if !sc.flow.add(int32(f.Increment)) { if !sc.flow.add(int32(f.Increment)) {
@ -1496,7 +1504,7 @@ func (sc *serverConn) processResetStream(f *RSTStreamFrame) error {
// identifying an idle stream is received, the // identifying an idle stream is received, the
// recipient MUST treat this as a connection error // recipient MUST treat this as a connection error
// (Section 5.4.1) of type PROTOCOL_ERROR. // (Section 5.4.1) of type PROTOCOL_ERROR.
return ConnectionError(ErrCodeProtocol) return sc.countError("reset_idle_stream", ConnectionError(ErrCodeProtocol))
} }
if st != nil { if st != nil {
st.cancelCtx() st.cancelCtx()
@ -1548,7 +1556,7 @@ func (sc *serverConn) processSettings(f *SettingsFrame) error {
// Why is the peer ACKing settings we never sent? // Why is the peer ACKing settings we never sent?
// The spec doesn't mention this case, but // The spec doesn't mention this case, but
// hang up on them anyway. // hang up on them anyway.
return ConnectionError(ErrCodeProtocol) return sc.countError("ack_mystery", ConnectionError(ErrCodeProtocol))
} }
return nil return nil
} }
@ -1556,7 +1564,7 @@ func (sc *serverConn) processSettings(f *SettingsFrame) error {
// This isn't actually in the spec, but hang up on // This isn't actually in the spec, but hang up on
// suspiciously large settings frames or those with // suspiciously large settings frames or those with
// duplicate entries. // duplicate entries.
return ConnectionError(ErrCodeProtocol) return sc.countError("settings_big_or_dups", ConnectionError(ErrCodeProtocol))
} }
if err := f.ForeachSetting(sc.processSetting); err != nil { if err := f.ForeachSetting(sc.processSetting); err != nil {
return err return err
@ -1623,7 +1631,7 @@ func (sc *serverConn) processSettingInitialWindowSize(val uint32) error {
// control window to exceed the maximum size as a // control window to exceed the maximum size as a
// connection error (Section 5.4.1) of type // connection error (Section 5.4.1) of type
// FLOW_CONTROL_ERROR." // FLOW_CONTROL_ERROR."
return ConnectionError(ErrCodeFlowControl) return sc.countError("setting_win_size", ConnectionError(ErrCodeFlowControl))
} }
} }
return nil return nil
@ -1656,7 +1664,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
// or PRIORITY on a stream in this state MUST be // or PRIORITY on a stream in this state MUST be
// treated as a connection error (Section 5.4.1) of // treated as a connection error (Section 5.4.1) of
// type PROTOCOL_ERROR." // type PROTOCOL_ERROR."
return ConnectionError(ErrCodeProtocol) return sc.countError("data_on_idle", ConnectionError(ErrCodeProtocol))
} }
// "If a DATA frame is received whose stream is not in "open" // "If a DATA frame is received whose stream is not in "open"
@ -1673,7 +1681,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
// and return any flow control bytes since we're not going // and return any flow control bytes since we're not going
// to consume them. // to consume them.
if sc.inflow.available() < int32(f.Length) { if sc.inflow.available() < int32(f.Length) {
return streamError(id, ErrCodeFlowControl) return sc.countError("data_flow", streamError(id, ErrCodeFlowControl))
} }
// Deduct the flow control from inflow, since we're // Deduct the flow control from inflow, since we're
// going to immediately add it back in // going to immediately add it back in
@ -1686,7 +1694,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
// Already have a stream error in flight. Don't send another. // Already have a stream error in flight. Don't send another.
return nil return nil
} }
return streamError(id, ErrCodeStreamClosed) return sc.countError("closed", streamError(id, ErrCodeStreamClosed))
} }
if st.body == nil { if st.body == nil {
panic("internal error: should have a body in this state") panic("internal error: should have a body in this state")
@ -1698,12 +1706,12 @@ func (sc *serverConn) processData(f *DataFrame) error {
// RFC 7540, sec 8.1.2.6: A request or response is also malformed if the // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the
// value of a content-length header field does not equal the sum of the // value of a content-length header field does not equal the sum of the
// DATA frame payload lengths that form the body. // DATA frame payload lengths that form the body.
return streamError(id, ErrCodeProtocol) return sc.countError("send_too_much", streamError(id, ErrCodeProtocol))
} }
if f.Length > 0 { if f.Length > 0 {
// Check whether the client has flow control quota. // Check whether the client has flow control quota.
if st.inflow.available() < int32(f.Length) { if st.inflow.available() < int32(f.Length) {
return streamError(id, ErrCodeFlowControl) return sc.countError("flow_on_data_length", streamError(id, ErrCodeFlowControl))
} }
st.inflow.take(int32(f.Length)) st.inflow.take(int32(f.Length))
@ -1711,7 +1719,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
wrote, err := st.body.Write(data) wrote, err := st.body.Write(data)
if err != nil { if err != nil {
sc.sendWindowUpdate(nil, int(f.Length)-wrote) sc.sendWindowUpdate(nil, int(f.Length)-wrote)
return streamError(id, ErrCodeStreamClosed) return sc.countError("body_write_err", streamError(id, ErrCodeStreamClosed))
} }
if wrote != len(data) { if wrote != len(data) {
panic("internal error: bad Writer") panic("internal error: bad Writer")
@ -1797,7 +1805,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// stream identifier MUST respond with a connection error // stream identifier MUST respond with a connection error
// (Section 5.4.1) of type PROTOCOL_ERROR. // (Section 5.4.1) of type PROTOCOL_ERROR.
if id%2 != 1 { if id%2 != 1 {
return ConnectionError(ErrCodeProtocol) return sc.countError("headers_even", ConnectionError(ErrCodeProtocol))
} }
// A HEADERS frame can be used to create a new stream or // A HEADERS frame can be used to create a new stream or
// send a trailer for an open one. If we already have a stream // send a trailer for an open one. If we already have a stream
@ -1814,7 +1822,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// this state, it MUST respond with a stream error (Section 5.4.2) of // this state, it MUST respond with a stream error (Section 5.4.2) of
// type STREAM_CLOSED. // type STREAM_CLOSED.
if st.state == stateHalfClosedRemote { if st.state == stateHalfClosedRemote {
return streamError(id, ErrCodeStreamClosed) return sc.countError("headers_half_closed", streamError(id, ErrCodeStreamClosed))
} }
return st.processTrailerHeaders(f) return st.processTrailerHeaders(f)
} }
@ -1825,7 +1833,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// receives an unexpected stream identifier MUST respond with // receives an unexpected stream identifier MUST respond with
// a connection error (Section 5.4.1) of type PROTOCOL_ERROR. // a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
if id <= sc.maxClientStreamID { if id <= sc.maxClientStreamID {
return ConnectionError(ErrCodeProtocol) return sc.countError("stream_went_down", ConnectionError(ErrCodeProtocol))
} }
sc.maxClientStreamID = id sc.maxClientStreamID = id
@ -1842,14 +1850,14 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
if sc.curClientStreams+1 > sc.advMaxStreams { if sc.curClientStreams+1 > sc.advMaxStreams {
if sc.unackedSettings == 0 { if sc.unackedSettings == 0 {
// They should know better. // They should know better.
return streamError(id, ErrCodeProtocol) return sc.countError("over_max_streams", streamError(id, ErrCodeProtocol))
} }
// Assume it's a network race, where they just haven't // Assume it's a network race, where they just haven't
// received our last SETTINGS update. But actually // received our last SETTINGS update. But actually
// this can't happen yet, because we don't yet provide // this can't happen yet, because we don't yet provide
// a way for users to adjust server parameters at // a way for users to adjust server parameters at
// runtime. // runtime.
return streamError(id, ErrCodeRefusedStream) return sc.countError("over_max_streams_race", streamError(id, ErrCodeRefusedStream))
} }
initialState := stateOpen initialState := stateOpen
@ -1859,7 +1867,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
st := sc.newStream(id, 0, initialState) st := sc.newStream(id, 0, initialState)
if f.HasPriority() { if f.HasPriority() {
if err := checkPriority(f.StreamID, f.Priority); err != nil { if err := sc.checkPriority(f.StreamID, f.Priority); err != nil {
return err return err
} }
sc.writeSched.AdjustStream(st.id, f.Priority) sc.writeSched.AdjustStream(st.id, f.Priority)
@ -1903,15 +1911,15 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
sc := st.sc sc := st.sc
sc.serveG.check() sc.serveG.check()
if st.gotTrailerHeader { if st.gotTrailerHeader {
return ConnectionError(ErrCodeProtocol) return sc.countError("dup_trailers", ConnectionError(ErrCodeProtocol))
} }
st.gotTrailerHeader = true st.gotTrailerHeader = true
if !f.StreamEnded() { if !f.StreamEnded() {
return streamError(st.id, ErrCodeProtocol) return sc.countError("trailers_not_ended", streamError(st.id, ErrCodeProtocol))
} }
if len(f.PseudoFields()) > 0 { if len(f.PseudoFields()) > 0 {
return streamError(st.id, ErrCodeProtocol) return sc.countError("trailers_pseudo", streamError(st.id, ErrCodeProtocol))
} }
if st.trailer != nil { if st.trailer != nil {
for _, hf := range f.RegularFields() { for _, hf := range f.RegularFields() {
@ -1920,7 +1928,7 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
// TODO: send more details to the peer somehow. But http2 has // TODO: send more details to the peer somehow. But http2 has
// no way to send debug data at a stream level. Discuss with // no way to send debug data at a stream level. Discuss with
// HTTP folk. // HTTP folk.
return streamError(st.id, ErrCodeProtocol) return sc.countError("trailers_bogus", streamError(st.id, ErrCodeProtocol))
} }
st.trailer[key] = append(st.trailer[key], hf.Value) st.trailer[key] = append(st.trailer[key], hf.Value)
} }
@ -1929,13 +1937,13 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
return nil return nil
} }
func checkPriority(streamID uint32, p PriorityParam) error { func (sc *serverConn) checkPriority(streamID uint32, p PriorityParam) error {
if streamID == p.StreamDep { if streamID == p.StreamDep {
// Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat // Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat
// this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR." // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR."
// Section 5.3.3 says that a stream can depend on one of its dependencies, // Section 5.3.3 says that a stream can depend on one of its dependencies,
// so it's only self-dependencies that are forbidden. // so it's only self-dependencies that are forbidden.
return streamError(streamID, ErrCodeProtocol) return sc.countError("priority", streamError(streamID, ErrCodeProtocol))
} }
return nil return nil
} }
@ -1944,7 +1952,7 @@ func (sc *serverConn) processPriority(f *PriorityFrame) error {
if sc.inGoAway { if sc.inGoAway {
return nil return nil
} }
if err := checkPriority(f.StreamID, f.PriorityParam); err != nil { if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil {
return err return err
} }
sc.writeSched.AdjustStream(f.StreamID, f.PriorityParam) sc.writeSched.AdjustStream(f.StreamID, f.PriorityParam)
@ -2001,7 +2009,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
isConnect := rp.method == "CONNECT" isConnect := rp.method == "CONNECT"
if isConnect { if isConnect {
if rp.path != "" || rp.scheme != "" || rp.authority == "" { if rp.path != "" || rp.scheme != "" || rp.authority == "" {
return nil, nil, streamError(f.StreamID, ErrCodeProtocol) return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
} }
} else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") {
// See 8.1.2.6 Malformed Requests and Responses: // See 8.1.2.6 Malformed Requests and Responses:
@ -2014,13 +2022,13 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
// "All HTTP/2 requests MUST include exactly one valid // "All HTTP/2 requests MUST include exactly one valid
// value for the :method, :scheme, and :path // value for the :method, :scheme, and :path
// pseudo-header fields" // pseudo-header fields"
return nil, nil, streamError(f.StreamID, ErrCodeProtocol) return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol))
} }
bodyOpen := !f.StreamEnded() bodyOpen := !f.StreamEnded()
if rp.method == "HEAD" && bodyOpen { if rp.method == "HEAD" && bodyOpen {
// HEAD requests can't have bodies // HEAD requests can't have bodies
return nil, nil, streamError(f.StreamID, ErrCodeProtocol) return nil, nil, sc.countError("head_body", streamError(f.StreamID, ErrCodeProtocol))
} }
rp.header = make(http.Header) rp.header = make(http.Header)
@ -2103,7 +2111,7 @@ func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*r
var err error var err error
url_, err = url.ParseRequestURI(rp.path) url_, err = url.ParseRequestURI(rp.path)
if err != nil { if err != nil {
return nil, nil, streamError(st.id, ErrCodeProtocol) return nil, nil, sc.countError("bad_path", streamError(st.id, ErrCodeProtocol))
} }
requestURI = rp.path requestURI = rp.path
} }
@ -2789,8 +2797,12 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error {
// but PUSH_PROMISE requests cannot have a body. // but PUSH_PROMISE requests cannot have a body.
// http://tools.ietf.org/html/rfc7540#section-8.2 // http://tools.ietf.org/html/rfc7540#section-8.2
// Also disallow Host, since the promised URL must be absolute. // Also disallow Host, since the promised URL must be absolute.
switch strings.ToLower(k) { if asciiEqualFold(k, "content-length") ||
case "content-length", "content-encoding", "trailer", "te", "expect", "host": asciiEqualFold(k, "content-encoding") ||
asciiEqualFold(k, "trailer") ||
asciiEqualFold(k, "te") ||
asciiEqualFold(k, "expect") ||
asciiEqualFold(k, "host") {
return fmt.Errorf("promised request headers cannot include %q", k) return fmt.Errorf("promised request headers cannot include %q", k)
} }
} }
@ -2982,3 +2994,31 @@ func h1ServerKeepAlivesDisabled(hs *http.Server) bool {
} }
return false return false
} }
func (sc *serverConn) countError(name string, err error) error {
if sc == nil || sc.srv == nil {
return err
}
f := sc.srv.CountError
if f == nil {
return err
}
var typ string
var code ErrCode
switch e := err.(type) {
case ConnectionError:
typ = "conn"
code = ErrCode(e)
case StreamError:
typ = "stream"
code = ErrCode(e.Code)
default:
return err
}
codeStr := errCodeName[code]
if codeStr == "" {
codeStr = strconv.Itoa(int(code))
}
f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name))
return err
}

File diff suppressed because it is too large Load Diff

View File

@ -341,7 +341,12 @@ func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) {
} }
for _, k := range keys { for _, k := range keys {
vv := h[k] vv := h[k]
k = lowerHeader(k) k, ascii := lowerHeader(k)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x).
continue
}
if !validWireHeaderFieldName(k) { if !validWireHeaderFieldName(k) {
// Skip it as backup paranoia. Per // Skip it as backup paranoia. Per
// golang.org/issue/14048, these should // golang.org/issue/14048, these should

14
vendor/golang.org/x/net/idna/go118.go generated vendored Normal file
View File

@ -0,0 +1,14 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.18
// +build go1.18
package idna
// Transitional processing is disabled by default in Go 1.18.
// https://golang.org/issue/47510
const transitionalLookup = false

View File

@ -59,23 +59,22 @@ type Option func(*options)
// Transitional sets a Profile to use the Transitional mapping as defined in UTS // Transitional sets a Profile to use the Transitional mapping as defined in UTS
// #46. This will cause, for example, "ß" to be mapped to "ss". Using the // #46. This will cause, for example, "ß" to be mapped to "ss". Using the
// transitional mapping provides a compromise between IDNA2003 and IDNA2008 // transitional mapping provides a compromise between IDNA2003 and IDNA2008
// compatibility. It is used by most browsers when resolving domain names. This // compatibility. It is used by some browsers when resolving domain names. This
// option is only meaningful if combined with MapForLookup. // option is only meaningful if combined with MapForLookup.
func Transitional(transitional bool) Option { func Transitional(transitional bool) Option {
return func(o *options) { o.transitional = true } return func(o *options) { o.transitional = transitional }
} }
// VerifyDNSLength sets whether a Profile should fail if any of the IDN parts // VerifyDNSLength sets whether a Profile should fail if any of the IDN parts
// are longer than allowed by the RFC. // are longer than allowed by the RFC.
//
// This option corresponds to the VerifyDnsLength flag in UTS #46.
func VerifyDNSLength(verify bool) Option { func VerifyDNSLength(verify bool) Option {
return func(o *options) { o.verifyDNSLength = verify } return func(o *options) { o.verifyDNSLength = verify }
} }
// RemoveLeadingDots removes leading label separators. Leading runes that map to // RemoveLeadingDots removes leading label separators. Leading runes that map to
// dots, such as U+3002 IDEOGRAPHIC FULL STOP, are removed as well. // dots, such as U+3002 IDEOGRAPHIC FULL STOP, are removed as well.
//
// This is the behavior suggested by the UTS #46 and is adopted by some
// browsers.
func RemoveLeadingDots(remove bool) Option { func RemoveLeadingDots(remove bool) Option {
return func(o *options) { o.removeLeadingDots = remove } return func(o *options) { o.removeLeadingDots = remove }
} }
@ -83,6 +82,8 @@ func RemoveLeadingDots(remove bool) Option {
// ValidateLabels sets whether to check the mandatory label validation criteria // ValidateLabels sets whether to check the mandatory label validation criteria
// as defined in Section 5.4 of RFC 5891. This includes testing for correct use // as defined in Section 5.4 of RFC 5891. This includes testing for correct use
// of hyphens ('-'), normalization, validity of runes, and the context rules. // of hyphens ('-'), normalization, validity of runes, and the context rules.
// In particular, ValidateLabels also sets the CheckHyphens and CheckJoiners flags
// in UTS #46.
func ValidateLabels(enable bool) Option { func ValidateLabels(enable bool) Option {
return func(o *options) { return func(o *options) {
// Don't override existing mappings, but set one that at least checks // Don't override existing mappings, but set one that at least checks
@ -91,25 +92,48 @@ func ValidateLabels(enable bool) Option {
o.mapping = normalize o.mapping = normalize
} }
o.trie = trie o.trie = trie
o.validateLabels = enable o.checkJoiners = enable
o.fromPuny = validateFromPunycode o.checkHyphens = enable
if enable {
o.fromPuny = validateFromPunycode
} else {
o.fromPuny = nil
}
}
}
// CheckHyphens sets whether to check for correct use of hyphens ('-') in
// labels. Most web browsers do not have this option set, since labels such as
// "r3---sn-apo3qvuoxuxbt-j5pe" are in common use.
//
// This option corresponds to the CheckHyphens flag in UTS #46.
func CheckHyphens(enable bool) Option {
return func(o *options) { o.checkHyphens = enable }
}
// CheckJoiners sets whether to check the ContextJ rules as defined in Appendix
// A of RFC 5892, concerning the use of joiner runes.
//
// This option corresponds to the CheckJoiners flag in UTS #46.
func CheckJoiners(enable bool) Option {
return func(o *options) {
o.trie = trie
o.checkJoiners = enable
} }
} }
// StrictDomainName limits the set of permissible ASCII characters to those // StrictDomainName limits the set of permissible ASCII characters to those
// allowed in domain names as defined in RFC 1034 (A-Z, a-z, 0-9 and the // allowed in domain names as defined in RFC 1034 (A-Z, a-z, 0-9 and the
// hyphen). This is set by default for MapForLookup and ValidateForRegistration. // hyphen). This is set by default for MapForLookup and ValidateForRegistration,
// but is only useful if ValidateLabels is set.
// //
// This option is useful, for instance, for browsers that allow characters // This option is useful, for instance, for browsers that allow characters
// outside this range, for example a '_' (U+005F LOW LINE). See // outside this range, for example a '_' (U+005F LOW LINE). See
// http://www.rfc-editor.org/std/std3.txt for more details This option // http://www.rfc-editor.org/std/std3.txt for more details.
// corresponds to the UseSTD3ASCIIRules option in UTS #46. //
// This option corresponds to the UseSTD3ASCIIRules flag in UTS #46.
func StrictDomainName(use bool) Option { func StrictDomainName(use bool) Option {
return func(o *options) { return func(o *options) { o.useSTD3Rules = use }
o.trie = trie
o.useSTD3Rules = use
o.fromPuny = validateFromPunycode
}
} }
// NOTE: the following options pull in tables. The tables should not be linked // NOTE: the following options pull in tables. The tables should not be linked
@ -117,6 +141,8 @@ func StrictDomainName(use bool) Option {
// BidiRule enables the Bidi rule as defined in RFC 5893. Any application // BidiRule enables the Bidi rule as defined in RFC 5893. Any application
// that relies on proper validation of labels should include this rule. // that relies on proper validation of labels should include this rule.
//
// This option corresponds to the CheckBidi flag in UTS #46.
func BidiRule() Option { func BidiRule() Option {
return func(o *options) { o.bidirule = bidirule.ValidString } return func(o *options) { o.bidirule = bidirule.ValidString }
} }
@ -152,7 +178,8 @@ func MapForLookup() Option {
type options struct { type options struct {
transitional bool transitional bool
useSTD3Rules bool useSTD3Rules bool
validateLabels bool checkHyphens bool
checkJoiners bool
verifyDNSLength bool verifyDNSLength bool
removeLeadingDots bool removeLeadingDots bool
@ -225,8 +252,11 @@ func (p *Profile) String() string {
if p.useSTD3Rules { if p.useSTD3Rules {
s += ":UseSTD3Rules" s += ":UseSTD3Rules"
} }
if p.validateLabels { if p.checkHyphens {
s += ":ValidateLabels" s += ":CheckHyphens"
}
if p.checkJoiners {
s += ":CheckJoiners"
} }
if p.verifyDNSLength { if p.verifyDNSLength {
s += ":VerifyDNSLength" s += ":VerifyDNSLength"
@ -254,26 +284,29 @@ var (
punycode = &Profile{} punycode = &Profile{}
lookup = &Profile{options{ lookup = &Profile{options{
transitional: true, transitional: transitionalLookup,
useSTD3Rules: true, useSTD3Rules: true,
validateLabels: true, checkHyphens: true,
trie: trie, checkJoiners: true,
fromPuny: validateFromPunycode, trie: trie,
mapping: validateAndMap, fromPuny: validateFromPunycode,
bidirule: bidirule.ValidString, mapping: validateAndMap,
bidirule: bidirule.ValidString,
}} }}
display = &Profile{options{ display = &Profile{options{
useSTD3Rules: true, useSTD3Rules: true,
validateLabels: true, checkHyphens: true,
trie: trie, checkJoiners: true,
fromPuny: validateFromPunycode, trie: trie,
mapping: validateAndMap, fromPuny: validateFromPunycode,
bidirule: bidirule.ValidString, mapping: validateAndMap,
bidirule: bidirule.ValidString,
}} }}
registration = &Profile{options{ registration = &Profile{options{
useSTD3Rules: true, useSTD3Rules: true,
validateLabels: true,
verifyDNSLength: true, verifyDNSLength: true,
checkHyphens: true,
checkJoiners: true,
trie: trie, trie: trie,
fromPuny: validateFromPunycode, fromPuny: validateFromPunycode,
mapping: validateRegistration, mapping: validateRegistration,
@ -340,7 +373,7 @@ func (p *Profile) process(s string, toASCII bool) (string, error) {
} }
isBidi = isBidi || bidirule.DirectionString(u) != bidi.LeftToRight isBidi = isBidi || bidirule.DirectionString(u) != bidi.LeftToRight
labels.set(u) labels.set(u)
if err == nil && p.validateLabels { if err == nil && p.fromPuny != nil {
err = p.fromPuny(p, u) err = p.fromPuny(p, u)
} }
if err == nil { if err == nil {
@ -681,16 +714,18 @@ func (p *Profile) validateLabel(s string) (err error) {
} }
return nil return nil
} }
if !p.validateLabels { if p.checkHyphens {
if len(s) > 4 && s[2] == '-' && s[3] == '-' {
return &labelError{s, "V2"}
}
if s[0] == '-' || s[len(s)-1] == '-' {
return &labelError{s, "V3"}
}
}
if !p.checkJoiners {
return nil return nil
} }
trie := p.trie // p.validateLabels is only set if trie is set. trie := p.trie // p.checkJoiners is only set if trie is set.
if len(s) > 4 && s[2] == '-' && s[3] == '-' {
return &labelError{s, "V2"}
}
if s[0] == '-' || s[len(s)-1] == '-' {
return &labelError{s, "V3"}
}
// TODO: merge the use of this in the trie. // TODO: merge the use of this in the trie.
v, sz := trie.lookupString(s) v, sz := trie.lookupString(s)
x := info(v) x := info(v)

View File

@ -58,23 +58,22 @@ type Option func(*options)
// Transitional sets a Profile to use the Transitional mapping as defined in UTS // Transitional sets a Profile to use the Transitional mapping as defined in UTS
// #46. This will cause, for example, "ß" to be mapped to "ss". Using the // #46. This will cause, for example, "ß" to be mapped to "ss". Using the
// transitional mapping provides a compromise between IDNA2003 and IDNA2008 // transitional mapping provides a compromise between IDNA2003 and IDNA2008
// compatibility. It is used by most browsers when resolving domain names. This // compatibility. It is used by some browsers when resolving domain names. This
// option is only meaningful if combined with MapForLookup. // option is only meaningful if combined with MapForLookup.
func Transitional(transitional bool) Option { func Transitional(transitional bool) Option {
return func(o *options) { o.transitional = true } return func(o *options) { o.transitional = transitional }
} }
// VerifyDNSLength sets whether a Profile should fail if any of the IDN parts // VerifyDNSLength sets whether a Profile should fail if any of the IDN parts
// are longer than allowed by the RFC. // are longer than allowed by the RFC.
//
// This option corresponds to the VerifyDnsLength flag in UTS #46.
func VerifyDNSLength(verify bool) Option { func VerifyDNSLength(verify bool) Option {
return func(o *options) { o.verifyDNSLength = verify } return func(o *options) { o.verifyDNSLength = verify }
} }
// RemoveLeadingDots removes leading label separators. Leading runes that map to // RemoveLeadingDots removes leading label separators. Leading runes that map to
// dots, such as U+3002 IDEOGRAPHIC FULL STOP, are removed as well. // dots, such as U+3002 IDEOGRAPHIC FULL STOP, are removed as well.
//
// This is the behavior suggested by the UTS #46 and is adopted by some
// browsers.
func RemoveLeadingDots(remove bool) Option { func RemoveLeadingDots(remove bool) Option {
return func(o *options) { o.removeLeadingDots = remove } return func(o *options) { o.removeLeadingDots = remove }
} }
@ -82,6 +81,8 @@ func RemoveLeadingDots(remove bool) Option {
// ValidateLabels sets whether to check the mandatory label validation criteria // ValidateLabels sets whether to check the mandatory label validation criteria
// as defined in Section 5.4 of RFC 5891. This includes testing for correct use // as defined in Section 5.4 of RFC 5891. This includes testing for correct use
// of hyphens ('-'), normalization, validity of runes, and the context rules. // of hyphens ('-'), normalization, validity of runes, and the context rules.
// In particular, ValidateLabels also sets the CheckHyphens and CheckJoiners flags
// in UTS #46.
func ValidateLabels(enable bool) Option { func ValidateLabels(enable bool) Option {
return func(o *options) { return func(o *options) {
// Don't override existing mappings, but set one that at least checks // Don't override existing mappings, but set one that at least checks
@ -90,25 +91,48 @@ func ValidateLabels(enable bool) Option {
o.mapping = normalize o.mapping = normalize
} }
o.trie = trie o.trie = trie
o.validateLabels = enable o.checkJoiners = enable
o.fromPuny = validateFromPunycode o.checkHyphens = enable
if enable {
o.fromPuny = validateFromPunycode
} else {
o.fromPuny = nil
}
}
}
// CheckHyphens sets whether to check for correct use of hyphens ('-') in
// labels. Most web browsers do not have this option set, since labels such as
// "r3---sn-apo3qvuoxuxbt-j5pe" are in common use.
//
// This option corresponds to the CheckHyphens flag in UTS #46.
func CheckHyphens(enable bool) Option {
return func(o *options) { o.checkHyphens = enable }
}
// CheckJoiners sets whether to check the ContextJ rules as defined in Appendix
// A of RFC 5892, concerning the use of joiner runes.
//
// This option corresponds to the CheckJoiners flag in UTS #46.
func CheckJoiners(enable bool) Option {
return func(o *options) {
o.trie = trie
o.checkJoiners = enable
} }
} }
// StrictDomainName limits the set of permissable ASCII characters to those // StrictDomainName limits the set of permissable ASCII characters to those
// allowed in domain names as defined in RFC 1034 (A-Z, a-z, 0-9 and the // allowed in domain names as defined in RFC 1034 (A-Z, a-z, 0-9 and the
// hyphen). This is set by default for MapForLookup and ValidateForRegistration. // hyphen). This is set by default for MapForLookup and ValidateForRegistration,
// but is only useful if ValidateLabels is set.
// //
// This option is useful, for instance, for browsers that allow characters // This option is useful, for instance, for browsers that allow characters
// outside this range, for example a '_' (U+005F LOW LINE). See // outside this range, for example a '_' (U+005F LOW LINE). See
// http://www.rfc-editor.org/std/std3.txt for more details This option // http://www.rfc-editor.org/std/std3.txt for more details.
// corresponds to the UseSTD3ASCIIRules option in UTS #46. //
// This option corresponds to the UseSTD3ASCIIRules flag in UTS #46.
func StrictDomainName(use bool) Option { func StrictDomainName(use bool) Option {
return func(o *options) { return func(o *options) { o.useSTD3Rules = use }
o.trie = trie
o.useSTD3Rules = use
o.fromPuny = validateFromPunycode
}
} }
// NOTE: the following options pull in tables. The tables should not be linked // NOTE: the following options pull in tables. The tables should not be linked
@ -116,6 +140,8 @@ func StrictDomainName(use bool) Option {
// BidiRule enables the Bidi rule as defined in RFC 5893. Any application // BidiRule enables the Bidi rule as defined in RFC 5893. Any application
// that relies on proper validation of labels should include this rule. // that relies on proper validation of labels should include this rule.
//
// This option corresponds to the CheckBidi flag in UTS #46.
func BidiRule() Option { func BidiRule() Option {
return func(o *options) { o.bidirule = bidirule.ValidString } return func(o *options) { o.bidirule = bidirule.ValidString }
} }
@ -152,7 +178,8 @@ func MapForLookup() Option {
type options struct { type options struct {
transitional bool transitional bool
useSTD3Rules bool useSTD3Rules bool
validateLabels bool checkHyphens bool
checkJoiners bool
verifyDNSLength bool verifyDNSLength bool
removeLeadingDots bool removeLeadingDots bool
@ -225,8 +252,11 @@ func (p *Profile) String() string {
if p.useSTD3Rules { if p.useSTD3Rules {
s += ":UseSTD3Rules" s += ":UseSTD3Rules"
} }
if p.validateLabels { if p.checkHyphens {
s += ":ValidateLabels" s += ":CheckHyphens"
}
if p.checkJoiners {
s += ":CheckJoiners"
} }
if p.verifyDNSLength { if p.verifyDNSLength {
s += ":VerifyDNSLength" s += ":VerifyDNSLength"
@ -255,9 +285,10 @@ var (
punycode = &Profile{} punycode = &Profile{}
lookup = &Profile{options{ lookup = &Profile{options{
transitional: true, transitional: true,
useSTD3Rules: true,
validateLabels: true,
removeLeadingDots: true, removeLeadingDots: true,
useSTD3Rules: true,
checkHyphens: true,
checkJoiners: true,
trie: trie, trie: trie,
fromPuny: validateFromPunycode, fromPuny: validateFromPunycode,
mapping: validateAndMap, mapping: validateAndMap,
@ -265,8 +296,9 @@ var (
}} }}
display = &Profile{options{ display = &Profile{options{
useSTD3Rules: true, useSTD3Rules: true,
validateLabels: true,
removeLeadingDots: true, removeLeadingDots: true,
checkHyphens: true,
checkJoiners: true,
trie: trie, trie: trie,
fromPuny: validateFromPunycode, fromPuny: validateFromPunycode,
mapping: validateAndMap, mapping: validateAndMap,
@ -274,8 +306,9 @@ var (
}} }}
registration = &Profile{options{ registration = &Profile{options{
useSTD3Rules: true, useSTD3Rules: true,
validateLabels: true,
verifyDNSLength: true, verifyDNSLength: true,
checkHyphens: true,
checkJoiners: true,
trie: trie, trie: trie,
fromPuny: validateFromPunycode, fromPuny: validateFromPunycode,
mapping: validateRegistration, mapping: validateRegistration,
@ -339,7 +372,7 @@ func (p *Profile) process(s string, toASCII bool) (string, error) {
continue continue
} }
labels.set(u) labels.set(u)
if err == nil && p.validateLabels { if err == nil && p.fromPuny != nil {
err = p.fromPuny(p, u) err = p.fromPuny(p, u)
} }
if err == nil { if err == nil {
@ -629,16 +662,18 @@ func (p *Profile) validateLabel(s string) error {
if p.bidirule != nil && !p.bidirule(s) { if p.bidirule != nil && !p.bidirule(s) {
return &labelError{s, "B"} return &labelError{s, "B"}
} }
if !p.validateLabels { if p.checkHyphens {
if len(s) > 4 && s[2] == '-' && s[3] == '-' {
return &labelError{s, "V2"}
}
if s[0] == '-' || s[len(s)-1] == '-' {
return &labelError{s, "V3"}
}
}
if !p.checkJoiners {
return nil return nil
} }
trie := p.trie // p.validateLabels is only set if trie is set. trie := p.trie // p.checkJoiners is only set if trie is set.
if len(s) > 4 && s[2] == '-' && s[3] == '-' {
return &labelError{s, "V2"}
}
if s[0] == '-' || s[len(s)-1] == '-' {
return &labelError{s, "V3"}
}
// TODO: merge the use of this in the trie. // TODO: merge the use of this in the trie.
v, sz := trie.lookupString(s) v, sz := trie.lookupString(s)
x := info(v) x := info(v)

12
vendor/golang.org/x/net/idna/pre_go118.go generated vendored Normal file
View File

@ -0,0 +1,12 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.18
// +build !go1.18
package idna
const transitionalLookup = true

View File

@ -49,6 +49,7 @@ func decode(encoded string) (string, error) {
} }
} }
i, n, bias := int32(0), initialN, initialBias i, n, bias := int32(0), initialN, initialBias
overflow := false
for pos < len(encoded) { for pos < len(encoded) {
oldI, w := i, int32(1) oldI, w := i, int32(1)
for k := base; ; k += base { for k := base; ; k += base {
@ -60,29 +61,32 @@ func decode(encoded string) (string, error) {
return "", punyError(encoded) return "", punyError(encoded)
} }
pos++ pos++
i += digit * w i, overflow = madd(i, digit, w)
if i < 0 { if overflow {
return "", punyError(encoded) return "", punyError(encoded)
} }
t := k - bias t := k - bias
if t < tmin { if k <= bias {
t = tmin t = tmin
} else if t > tmax { } else if k >= bias+tmax {
t = tmax t = tmax
} }
if digit < t { if digit < t {
break break
} }
w *= base - t w, overflow = madd(0, w, base-t)
if w >= math.MaxInt32/base { if overflow {
return "", punyError(encoded) return "", punyError(encoded)
} }
} }
if len(output) >= 1024 {
return "", punyError(encoded)
}
x := int32(len(output) + 1) x := int32(len(output) + 1)
bias = adapt(i-oldI, x, oldI == 0) bias = adapt(i-oldI, x, oldI == 0)
n += i / x n += i / x
i %= x i %= x
if n > utf8.MaxRune || len(output) >= 1024 { if n < 0 || n > utf8.MaxRune {
return "", punyError(encoded) return "", punyError(encoded)
} }
output = append(output, 0) output = append(output, 0)
@ -115,6 +119,7 @@ func encode(prefix, s string) (string, error) {
if b > 0 { if b > 0 {
output = append(output, '-') output = append(output, '-')
} }
overflow := false
for remaining != 0 { for remaining != 0 {
m := int32(0x7fffffff) m := int32(0x7fffffff)
for _, r := range s { for _, r := range s {
@ -122,8 +127,8 @@ func encode(prefix, s string) (string, error) {
m = r m = r
} }
} }
delta += (m - n) * (h + 1) delta, overflow = madd(delta, m-n, h+1)
if delta < 0 { if overflow {
return "", punyError(s) return "", punyError(s)
} }
n = m n = m
@ -141,9 +146,9 @@ func encode(prefix, s string) (string, error) {
q := delta q := delta
for k := base; ; k += base { for k := base; ; k += base {
t := k - bias t := k - bias
if t < tmin { if k <= bias {
t = tmin t = tmin
} else if t > tmax { } else if k >= bias+tmax {
t = tmax t = tmax
} }
if q < t { if q < t {
@ -164,6 +169,15 @@ func encode(prefix, s string) (string, error) {
return string(output), nil return string(output), nil
} }
// madd computes a + (b * c), detecting overflow.
func madd(a, b, c int32) (next int32, overflow bool) {
p := int64(b) * int64(c)
if p > math.MaxInt32-int64(a) {
return 0, true
}
return a + int32(p), false
}
func decodeDigit(x byte) (digit int32, ok bool) { func decodeDigit(x byte) (digit int32, ok bool) {
switch { switch {
case '0' <= x && x <= '9': case '0' <= x && x <= '9':

View File

@ -0,0 +1,26 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// +build darwin dragonfly freebsd linux netbsd openbsd solaris
package socket
import (
"syscall"
)
// ioComplete checks the flags and result of a syscall, to be used as return
// value in a syscall.RawConn.Read or Write callback.
func ioComplete(flags int, operr error) bool {
if flags&syscall.MSG_DONTWAIT != 0 {
// Caller explicitly said don't wait, so always return immediately.
return true
}
if operr == syscall.EAGAIN || operr == syscall.EWOULDBLOCK {
// No data available, block for I/O and try again.
return false
}
return true
}

View File

@ -0,0 +1,22 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || windows || zos
// +build aix windows zos
package socket
import (
"syscall"
)
// ioComplete checks the flags and result of a syscall, to be used as return
// value in a syscall.RawConn.Read or Write callback.
func ioComplete(flags int, operr error) bool {
if operr == syscall.EAGAIN || operr == syscall.EWOULDBLOCK {
// No data available, block for I/O and try again.
return false
}
return true
}

View File

@ -7,25 +7,13 @@
package socket package socket
import "net" import (
"net"
"sync"
)
type mmsghdrs []mmsghdr type mmsghdrs []mmsghdr
func (hs mmsghdrs) pack(ms []Message, parseFn func([]byte, string) (net.Addr, error), marshalFn func(net.Addr) []byte) error {
for i := range hs {
vs := make([]iovec, len(ms[i].Buffers))
var sa []byte
if parseFn != nil {
sa = make([]byte, sizeofSockaddrInet6)
}
if marshalFn != nil {
sa = marshalFn(ms[i].Addr)
}
hs[i].Hdr.pack(vs, ms[i].Buffers, ms[i].OOB, sa)
}
return nil
}
func (hs mmsghdrs) unpack(ms []Message, parseFn func([]byte, string) (net.Addr, error), hint string) error { func (hs mmsghdrs) unpack(ms []Message, parseFn func([]byte, string) (net.Addr, error), hint string) error {
for i := range hs { for i := range hs {
ms[i].N = int(hs[i].Len) ms[i].N = int(hs[i].Len)
@ -41,3 +29,86 @@ func (hs mmsghdrs) unpack(ms []Message, parseFn func([]byte, string) (net.Addr,
} }
return nil return nil
} }
// mmsghdrsPacker packs Message-slices into mmsghdrs (re-)using pre-allocated buffers.
type mmsghdrsPacker struct {
// hs are the pre-allocated mmsghdrs.
hs mmsghdrs
// sockaddrs is the pre-allocated buffer for the Hdr.Name buffers.
// We use one large buffer for all messages and slice it up.
sockaddrs []byte
// vs are the pre-allocated iovecs.
// We allocate one large buffer for all messages and slice it up. This allows to reuse the buffer
// if the number of buffers per message is distributed differently between calls.
vs []iovec
}
func (p *mmsghdrsPacker) prepare(ms []Message) {
n := len(ms)
if n <= cap(p.hs) {
p.hs = p.hs[:n]
} else {
p.hs = make(mmsghdrs, n)
}
if n*sizeofSockaddrInet6 <= cap(p.sockaddrs) {
p.sockaddrs = p.sockaddrs[:n*sizeofSockaddrInet6]
} else {
p.sockaddrs = make([]byte, n*sizeofSockaddrInet6)
}
nb := 0
for _, m := range ms {
nb += len(m.Buffers)
}
if nb <= cap(p.vs) {
p.vs = p.vs[:nb]
} else {
p.vs = make([]iovec, nb)
}
}
func (p *mmsghdrsPacker) pack(ms []Message, parseFn func([]byte, string) (net.Addr, error), marshalFn func(net.Addr, []byte) int) mmsghdrs {
p.prepare(ms)
hs := p.hs
vsRest := p.vs
saRest := p.sockaddrs
for i := range hs {
nvs := len(ms[i].Buffers)
vs := vsRest[:nvs]
vsRest = vsRest[nvs:]
var sa []byte
if parseFn != nil {
sa = saRest[:sizeofSockaddrInet6]
saRest = saRest[sizeofSockaddrInet6:]
} else if marshalFn != nil {
n := marshalFn(ms[i].Addr, saRest)
if n > 0 {
sa = saRest[:n]
saRest = saRest[n:]
}
}
hs[i].Hdr.pack(vs, ms[i].Buffers, ms[i].OOB, sa)
}
return hs
}
var defaultMmsghdrsPool = mmsghdrsPool{
p: sync.Pool{
New: func() interface{} {
return new(mmsghdrsPacker)
},
},
}
type mmsghdrsPool struct {
p sync.Pool
}
func (p *mmsghdrsPool) Get() *mmsghdrsPacker {
return p.p.Get().(*mmsghdrsPacker)
}
func (p *mmsghdrsPool) Put(packer *mmsghdrsPacker) {
p.p.Put(packer)
}

View File

@ -17,6 +17,9 @@ func (h *msghdr) pack(vs []iovec, bs [][]byte, oob []byte, sa []byte) {
if sa != nil { if sa != nil {
h.Name = (*byte)(unsafe.Pointer(&sa[0])) h.Name = (*byte)(unsafe.Pointer(&sa[0]))
h.Namelen = uint32(len(sa)) h.Namelen = uint32(len(sa))
} else {
h.Name = nil
h.Namelen = 0
} }
} }

View File

@ -10,29 +10,24 @@ package socket
import ( import (
"net" "net"
"os" "os"
"syscall"
) )
func (c *Conn) recvMsgs(ms []Message, flags int) (int, error) { func (c *Conn) recvMsgs(ms []Message, flags int) (int, error) {
for i := range ms { for i := range ms {
ms[i].raceWrite() ms[i].raceWrite()
} }
hs := make(mmsghdrs, len(ms)) packer := defaultMmsghdrsPool.Get()
defer defaultMmsghdrsPool.Put(packer)
var parseFn func([]byte, string) (net.Addr, error) var parseFn func([]byte, string) (net.Addr, error)
if c.network != "tcp" { if c.network != "tcp" {
parseFn = parseInetAddr parseFn = parseInetAddr
} }
if err := hs.pack(ms, parseFn, nil); err != nil { hs := packer.pack(ms, parseFn, nil)
return 0, err
}
var operr error var operr error
var n int var n int
fn := func(s uintptr) bool { fn := func(s uintptr) bool {
n, operr = recvmmsg(s, hs, flags) n, operr = recvmmsg(s, hs, flags)
if operr == syscall.EAGAIN { return ioComplete(flags, operr)
return false
}
return true
} }
if err := c.c.Read(fn); err != nil { if err := c.c.Read(fn); err != nil {
return n, err return n, err
@ -50,22 +45,18 @@ func (c *Conn) sendMsgs(ms []Message, flags int) (int, error) {
for i := range ms { for i := range ms {
ms[i].raceRead() ms[i].raceRead()
} }
hs := make(mmsghdrs, len(ms)) packer := defaultMmsghdrsPool.Get()
var marshalFn func(net.Addr) []byte defer defaultMmsghdrsPool.Put(packer)
var marshalFn func(net.Addr, []byte) int
if c.network != "tcp" { if c.network != "tcp" {
marshalFn = marshalInetAddr marshalFn = marshalInetAddr
} }
if err := hs.pack(ms, nil, marshalFn); err != nil { hs := packer.pack(ms, nil, marshalFn)
return 0, err
}
var operr error var operr error
var n int var n int
fn := func(s uintptr) bool { fn := func(s uintptr) bool {
n, operr = sendmmsg(s, hs, flags) n, operr = sendmmsg(s, hs, flags)
if operr == syscall.EAGAIN { return ioComplete(flags, operr)
return false
}
return true
} }
if err := c.c.Write(fn); err != nil { if err := c.c.Write(fn); err != nil {
return n, err return n, err

View File

@ -9,7 +9,6 @@ package socket
import ( import (
"os" "os"
"syscall"
) )
func (c *Conn) recvMsg(m *Message, flags int) error { func (c *Conn) recvMsg(m *Message, flags int) error {
@ -25,10 +24,7 @@ func (c *Conn) recvMsg(m *Message, flags int) error {
var n int var n int
fn := func(s uintptr) bool { fn := func(s uintptr) bool {
n, operr = recvmsg(s, &h, flags) n, operr = recvmsg(s, &h, flags)
if operr == syscall.EAGAIN || operr == syscall.EWOULDBLOCK { return ioComplete(flags, operr)
return false
}
return true
} }
if err := c.c.Read(fn); err != nil { if err := c.c.Read(fn); err != nil {
return err return err
@ -55,17 +51,16 @@ func (c *Conn) sendMsg(m *Message, flags int) error {
vs := make([]iovec, len(m.Buffers)) vs := make([]iovec, len(m.Buffers))
var sa []byte var sa []byte
if m.Addr != nil { if m.Addr != nil {
sa = marshalInetAddr(m.Addr) var a [sizeofSockaddrInet6]byte
n := marshalInetAddr(m.Addr, a[:])
sa = a[:n]
} }
h.pack(vs, m.Buffers, m.OOB, sa) h.pack(vs, m.Buffers, m.OOB, sa)
var operr error var operr error
var n int var n int
fn := func(s uintptr) bool { fn := func(s uintptr) bool {
n, operr = sendmsg(s, &h, flags) n, operr = sendmsg(s, &h, flags)
if operr == syscall.EAGAIN || operr == syscall.EWOULDBLOCK { return ioComplete(flags, operr)
return false
}
return true
} }
if err := c.c.Write(fn); err != nil { if err := c.c.Write(fn); err != nil {
return err return err

View File

@ -17,22 +17,24 @@ import (
"time" "time"
) )
func marshalInetAddr(a net.Addr) []byte { // marshalInetAddr writes a in sockaddr format into the buffer b.
// The buffer must be sufficiently large (sizeofSockaddrInet4/6).
// Returns the number of bytes written.
func marshalInetAddr(a net.Addr, b []byte) int {
switch a := a.(type) { switch a := a.(type) {
case *net.TCPAddr: case *net.TCPAddr:
return marshalSockaddr(a.IP, a.Port, a.Zone) return marshalSockaddr(a.IP, a.Port, a.Zone, b)
case *net.UDPAddr: case *net.UDPAddr:
return marshalSockaddr(a.IP, a.Port, a.Zone) return marshalSockaddr(a.IP, a.Port, a.Zone, b)
case *net.IPAddr: case *net.IPAddr:
return marshalSockaddr(a.IP, 0, a.Zone) return marshalSockaddr(a.IP, 0, a.Zone, b)
default: default:
return nil return 0
} }
} }
func marshalSockaddr(ip net.IP, port int, zone string) []byte { func marshalSockaddr(ip net.IP, port int, zone string, b []byte) int {
if ip4 := ip.To4(); ip4 != nil { if ip4 := ip.To4(); ip4 != nil {
b := make([]byte, sizeofSockaddrInet4)
switch runtime.GOOS { switch runtime.GOOS {
case "android", "illumos", "linux", "solaris", "windows": case "android", "illumos", "linux", "solaris", "windows":
NativeEndian.PutUint16(b[:2], uint16(sysAF_INET)) NativeEndian.PutUint16(b[:2], uint16(sysAF_INET))
@ -42,10 +44,9 @@ func marshalSockaddr(ip net.IP, port int, zone string) []byte {
} }
binary.BigEndian.PutUint16(b[2:4], uint16(port)) binary.BigEndian.PutUint16(b[2:4], uint16(port))
copy(b[4:8], ip4) copy(b[4:8], ip4)
return b return sizeofSockaddrInet4
} }
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil { if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
b := make([]byte, sizeofSockaddrInet6)
switch runtime.GOOS { switch runtime.GOOS {
case "android", "illumos", "linux", "solaris", "windows": case "android", "illumos", "linux", "solaris", "windows":
NativeEndian.PutUint16(b[:2], uint16(sysAF_INET6)) NativeEndian.PutUint16(b[:2], uint16(sysAF_INET6))
@ -58,9 +59,9 @@ func marshalSockaddr(ip net.IP, port int, zone string) []byte {
if zone != "" { if zone != "" {
NativeEndian.PutUint32(b[24:28], uint32(zoneCache.index(zone))) NativeEndian.PutUint32(b[24:28], uint32(zoneCache.index(zone)))
} }
return b return sizeofSockaddrInet6
} }
return nil return 0
} }
func parseInetAddr(b []byte, network string) (net.Addr, error) { func parseInetAddr(b []byte, network string) (net.Addr, error) {

4
vendor/modules.txt vendored
View File

@ -364,8 +364,8 @@ golang.org/x/crypto/ssh/terminal
## explicit; go 1.12 ## explicit; go 1.12
golang.org/x/mod/module golang.org/x/mod/module
golang.org/x/mod/semver golang.org/x/mod/semver
# golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 # golang.org/x/net v0.0.0-20211109214657-ef0fda0de508
## explicit; go 1.11 ## explicit; go 1.17
golang.org/x/net/bpf golang.org/x/net/bpf
golang.org/x/net/context golang.org/x/net/context
golang.org/x/net/context/ctxhttp golang.org/x/net/context/ctxhttp