package ws import ( "bufio" "bytes" "context" "crypto/tls" "fmt" "io" "net" "net/url" "strconv" "strings" "time" "github.com/gobwas/httphead" "github.com/gobwas/pool/pbufio" ) // Constants used by Dialer. const ( DefaultClientReadBufferSize = 4096 DefaultClientWriteBufferSize = 4096 ) // Handshake represents handshake result. type Handshake struct { // Protocol is the subprotocol selected during handshake. Protocol string // Extensions is the list of negotiated extensions. Extensions []httphead.Option } // Errors used by the websocket client. var ( ErrHandshakeBadStatus = fmt.Errorf("unexpected http status") ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol) ErrHandshakeBadExtensions = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol) ) // DefaultDialer is dialer that holds no options and is used by Dial function. var DefaultDialer Dialer // Dial is like Dialer{}.Dial(). func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) { return DefaultDialer.Dial(ctx, urlstr) } // Dialer contains options for establishing websocket connection to an url. type Dialer struct { // ReadBufferSize and WriteBufferSize is an I/O buffer sizes. // They used to read and write http data while upgrading to WebSocket. // Allocated buffers are pooled with sync.Pool to avoid extra allocations. // // If a size is zero then default value is used. ReadBufferSize, WriteBufferSize int // Timeout is the maximum amount of time a Dial() will wait for a connect // and an handshake to complete. // // The default is no timeout. Timeout time.Duration // Protocols is the list of subprotocols that the client wants to speak, // ordered by preference. // // See https://tools.ietf.org/html/rfc6455#section-4.1 Protocols []string // Extensions is the list of extensions that client wants to speak. // // Note that if server decides to use some of this extensions, Dial() will // return Handshake struct containing a slice of items, which are the // shallow copies of the items from this list. That is, internals of // Extensions items are shared during Dial(). // // See https://tools.ietf.org/html/rfc6455#section-4.1 // See https://tools.ietf.org/html/rfc6455#section-9.1 Extensions []httphead.Option // Header is an optional HandshakeHeader instance that could be used to // write additional headers to the handshake request. // // It used instead of any key-value mappings to avoid allocations in user // land. Header HandshakeHeader // OnStatusError is the callback that will be called after receiving non // "101 Continue" HTTP response status. It receives an io.Reader object // representing server response bytes. That is, it gives ability to parse // HTTP response somehow (probably with http.ReadResponse call) and make a // decision of further logic. // // The arguments are only valid until the callback returns. OnStatusError func(status int, reason []byte, resp io.Reader) // OnHeader is the callback that will be called after successful parsing of // header, that is not used during WebSocket handshake procedure. That is, // it will be called with non-websocket headers, which could be relevant // for application-level logic. // // The arguments are only valid until the callback returns. // // Returned value could be used to prevent processing response. OnHeader func(key, value []byte) (err error) // NetDial is the function that is used to get plain tcp connection. // If it is not nil, then it is used instead of net.Dialer. NetDial func(ctx context.Context, network, addr string) (net.Conn, error) // TLSClient is the callback that will be called after successful dial with // received connection and its remote host name. If it is nil, then the // default tls.Client() will be used. // If it is not nil, then TLSConfig field is ignored. TLSClient func(conn net.Conn, hostname string) net.Conn // TLSConfig is passed to tls.Client() to start TLS over established // connection. If TLSClient is not nil, then it is ignored. If TLSConfig is // non-nil and its ServerName is empty, then for every Dial() it will be // cloned and appropriate ServerName will be set. TLSConfig *tls.Config // WrapConn is the optional callback that will be called when connection is // ready for an i/o. That is, it will be called after successful dial and // TLS initialization (for "wss" schemes). It may be helpful for different // user land purposes such as end to end encryption. // // Note that for debugging purposes of an http handshake (e.g. sent request // and received response), there is an wsutil.DebugDialer struct. WrapConn func(conn net.Conn) net.Conn } // Dial connects to the url host and upgrades connection to WebSocket. // // If server has sent frames right after successful handshake then returned // buffer will be non-nil. In other cases buffer is always nil. For better // memory efficiency received non-nil bufio.Reader should be returned to the // inner pool with PutReader() function after use. // // Note that Dialer does not implement IDNA (RFC5895) logic as net/http does. // If you want to dial non-ascii host name, take care of its name serialization // avoiding bad request issues. For more info see net/http Request.Write() // implementation, especially cleanHost() function. func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) { u, err := url.ParseRequestURI(urlstr) if err != nil { return } // Prepare context to dial with. Initially it is the same as original, but // if d.Timeout is non-zero and points to time that is before ctx.Deadline, // we use more shorter context for dial. dialctx := ctx var deadline time.Time if t := d.Timeout; t != 0 { deadline = time.Now().Add(t) if d, ok := ctx.Deadline(); !ok || deadline.Before(d) { var cancel context.CancelFunc dialctx, cancel = context.WithDeadline(ctx, deadline) defer cancel() } } if conn, err = d.dial(dialctx, u); err != nil { return } defer func() { if err != nil { conn.Close() } }() if ctx == context.Background() { // No need to start I/O interrupter goroutine which is not zero-cost. conn.SetDeadline(deadline) defer conn.SetDeadline(noDeadline) } else { // Context could be canceled or its deadline could be exceeded. // Start the interrupter goroutine to handle context cancelation. done := setupContextDeadliner(ctx, conn) defer func() { // Map Upgrade() error to a possible context expiration error. That // is, even if Upgrade() err is nil, context could be already // expired and connection be "poisoned" by SetDeadline() call. // In that case we must not return ctx.Err() error. done(&err) }() } br, hs, err = d.Upgrade(conn, u) return } var ( // netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if // Dialer.NetDial is not provided. netEmptyDialer net.Dialer // tlsEmptyConfig is an empty tls.Config used as default one. tlsEmptyConfig tls.Config ) func tlsDefaultConfig() *tls.Config { return &tlsEmptyConfig } func hostport(host string, defaultPort string) (hostname, addr string) { var ( colon = strings.LastIndexByte(host, ':') bracket = strings.IndexByte(host, ']') ) if colon > bracket { return host[:colon], host } return host, host + defaultPort } func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) { dial := d.NetDial if dial == nil { dial = netEmptyDialer.DialContext } switch u.Scheme { case "ws": _, addr := hostport(u.Host, ":80") conn, err = dial(ctx, "tcp", addr) case "wss": hostname, addr := hostport(u.Host, ":443") conn, err = dial(ctx, "tcp", addr) if err != nil { return } tlsClient := d.TLSClient if tlsClient == nil { tlsClient = d.tlsClient } conn = tlsClient(conn, hostname) default: return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme) } if wrap := d.WrapConn; wrap != nil { conn = wrap(conn) } return } func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn { config := d.TLSConfig if config == nil { config = tlsDefaultConfig() } if config.ServerName == "" { config = tlsCloneConfig(config) config.ServerName = hostname } // Do not make conn.Handshake() here because downstairs we will prepare // i/o on this conn with proper context's timeout handling. return tls.Client(conn, config) } var ( // This variables are set like in net/net.go. // noDeadline is just zero value for readability. noDeadline = time.Time{} // aLongTimeAgo is a non-zero time, far in the past, used for immediate // cancelation of dials. aLongTimeAgo = time.Unix(42, 0) ) // Upgrade writes an upgrade request to the given io.ReadWriter conn at given // url u and reads a response from it. // // It is a caller responsibility to manage I/O deadlines on conn. // // It returns handshake info and some bytes which could be written by the peer // right after response and be caught by us during buffered read. func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) { // headerSeen constants helps to report whether or not some header was seen // during reading request bytes. const ( headerSeenUpgrade = 1 << iota headerSeenConnection headerSeenSecAccept // headerSeenAll is the value that we expect to receive at the end of // headers read/parse loop. headerSeenAll = 0 | headerSeenUpgrade | headerSeenConnection | headerSeenSecAccept ) br = pbufio.GetReader(conn, nonZero(d.ReadBufferSize, DefaultClientReadBufferSize), ) bw := pbufio.GetWriter(conn, nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize), ) defer func() { pbufio.PutWriter(bw) if br.Buffered() == 0 || err != nil { // Server does not wrote additional bytes to the connection or // error occurred. That is, no reason to return buffer. pbufio.PutReader(br) br = nil } }() nonce := make([]byte, nonceSize) initNonce(nonce) httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header) if err = bw.Flush(); err != nil { return } // Read HTTP status line like "HTTP/1.1 101 Switching Protocols". sl, err := readLine(br) if err != nil { return } // Begin validation of the response. // See https://tools.ietf.org/html/rfc6455#section-4.2.2 // Parse request line data like HTTP version, uri and method. resp, err := httpParseResponseLine(sl) if err != nil { return } // Even if RFC says "1.1 or higher" without mentioning the part of the // version, we apply it only to minor part. if resp.major != 1 || resp.minor < 1 { err = ErrHandshakeBadProtocol return } if resp.status != 101 { err = StatusError(resp.status) if onStatusError := d.OnStatusError; onStatusError != nil { // Invoke callback with multireader of status-line bytes br. onStatusError(resp.status, resp.reason, io.MultiReader( bytes.NewReader(sl), strings.NewReader(crlf), br, ), ) } return } // If response status is 101 then we expect all technical headers to be // valid. If not, then we stop processing response without giving user // ability to read non-technical headers. That is, we do not distinguish // technical errors (such as parsing error) and protocol errors. var headerSeen byte for { line, e := readLine(br) if e != nil { err = e return } if len(line) == 0 { // Blank line, no more lines to read. break } k, v, ok := httpParseHeaderLine(line) if !ok { err = ErrMalformedResponse return } switch btsToString(k) { case headerUpgradeCanonical: headerSeen |= headerSeenUpgrade if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) { err = ErrHandshakeBadUpgrade return } case headerConnectionCanonical: headerSeen |= headerSeenConnection // Note that as RFC6455 says: // > A |Connection| header field with value "Upgrade". // That is, in server side, "Connection" header could contain // multiple token. But in response it must contains exactly one. if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) { err = ErrHandshakeBadConnection return } case headerSecAcceptCanonical: headerSeen |= headerSeenSecAccept if !checkAcceptFromNonce(v, nonce) { err = ErrHandshakeBadSecAccept return } case headerSecProtocolCanonical: // RFC6455 1.3: // "The server selects one or none of the acceptable protocols // and echoes that value in its handshake to indicate that it has // selected that protocol." for _, want := range d.Protocols { if string(v) == want { hs.Protocol = want break } } if hs.Protocol == "" { // Server echoed subprotocol that is not present in client // requested protocols. err = ErrHandshakeBadSubProtocol return } case headerSecExtensionsCanonical: hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions) if err != nil { return } default: if onHeader := d.OnHeader; onHeader != nil { if e := onHeader(k, v); e != nil { err = e return } } } } if err == nil && headerSeen != headerSeenAll { switch { case headerSeen&headerSeenUpgrade == 0: err = ErrHandshakeBadUpgrade case headerSeen&headerSeenConnection == 0: err = ErrHandshakeBadConnection case headerSeen&headerSeenSecAccept == 0: err = ErrHandshakeBadSecAccept default: panic("unknown headers state") } } return } // PutReader returns bufio.Reader instance to the inner reuse pool. // It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which // contains unprocessed buffered data, that was sent by the server quickly // right after handshake. func PutReader(br *bufio.Reader) { pbufio.PutReader(br) } // StatusError contains an unexpected status-line code from the server. type StatusError int func (s StatusError) Error() string { return "unexpected HTTP response status: " + strconv.Itoa(int(s)) } func isTimeoutError(err error) bool { t, ok := err.(net.Error) return ok && t.Timeout() } func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) { if len(selected) == 0 { return received, nil } var ( index int option httphead.Option err error ) index = -1 match := func() (ok bool) { for _, want := range wanted { if option.Equal(want) { // Check parsed extension to be present in client // requested extensions. We move matched extension // from client list to avoid allocation. received = append(received, want) return true } } return false } ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control { if i != index { // Met next option. index = i if i != 0 && !match() { // Server returned non-requested extension. err = ErrHandshakeBadExtensions return httphead.ControlBreak } option = httphead.Option{Name: name} } if attr != nil { option.Parameters.Set(attr, val) } return httphead.ControlContinue }) if !ok { err = ErrMalformedResponse return received, err } if !match() { return received, ErrHandshakeBadExtensions } return received, err } // setupContextDeadliner is a helper function that starts connection I/O // interrupter goroutine. // // Started goroutine calls SetDeadline() with long time ago value when context // become expired to make any I/O operations failed. It returns done function // that stops started goroutine and maps error received from conn I/O methods // to possible context expiration error. // // In concern with possible SetDeadline() call inside interrupter goroutine, // caller passes pointer to its I/O error (even if it is nil) to done(&err). // That is, even if I/O error is nil, context could be already expired and // connection "poisoned" by SetDeadline() call. In that case done(&err) will // store at *err ctx.Err() result. If err is caused not by timeout, it will // leaved untouched. func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) { var ( quit = make(chan struct{}) interrupt = make(chan error, 1) ) go func() { select { case <-quit: interrupt <- nil case <-ctx.Done(): // Cancel i/o immediately. conn.SetDeadline(aLongTimeAgo) interrupt <- ctx.Err() } }() return func(err *error) { close(quit) // If ctx.Err() is non-nil and the original err is net.Error with // Timeout() == true, then it means that I/O was canceled by us by // SetDeadline(aLongTimeAgo) call, or by somebody else previously // by conn.SetDeadline(x). // // Even on race condition when both deadlines are expired // (SetDeadline() made not by us and context's), we prefer ctx.Err() to // be returned. if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) { *err = ctxErr } } }