// +build !js package websocket import ( "bufio" "context" "errors" "fmt" "io" "io/ioutil" "strings" "time" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/xsync" ) // Reader reads from the connection until until there is a WebSocket // data message to be read. It will handle ping, pong and close frames as appropriate. // // It returns the type of the message and an io.Reader to read it. // The passed context will also bound the reader. // Ensure you read to EOF otherwise the connection will hang. // // Call CloseRead if you do not expect any data messages from the peer. // // Only one Reader may be open at a time. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return c.reader(ctx) } // Read is a convenience method around Reader to read a single message // from the connection. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { typ, r, err := c.Reader(ctx) if err != nil { return 0, nil, err } b, err := ioutil.ReadAll(r) return typ, b, err } // CloseRead starts a goroutine to read from the connection until it is closed // or a data message is received. // // Once CloseRead is called you cannot read any messages from the connection. // The returned context will be cancelled when the connection is closed. // // If a data message is received, the connection will be closed with StatusPolicyViolation. // // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close // frames are responded to. This means c.Ping and c.Close will still work as expected. func (c *Conn) CloseRead(ctx context.Context) context.Context { ctx, cancel := context.WithCancel(ctx) go func() { defer cancel() c.Reader(ctx) c.Close(StatusPolicyViolation, "unexpected data message") }() return ctx } // SetReadLimit sets the max number of bytes to read for a single message. // It applies to the Reader and Read methods. // // By default, the connection has a message read limit of 32768 bytes. // // When the limit is hit, the connection will be closed with StatusMessageTooBig. func (c *Conn) SetReadLimit(n int64) { // We add read one more byte than the limit in case // there is a fin frame that needs to be read. c.msgReader.limitReader.limit.Store(n + 1) } const defaultReadLimit = 32768 func newMsgReader(c *Conn) *msgReader { mr := &msgReader{ c: c, fin: true, } mr.readFunc = mr.read mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1) return mr } func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() { mr.dict.init(32768) } if mr.flateBufio == nil { mr.flateBufio = getBufioReader(mr.readFunc) } mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } func (mr *msgReader) putFlateReader() { if mr.flateReader != nil { putFlateReader(mr.flateReader) mr.flateReader = nil } } func (mr *msgReader) close() { mr.c.readMu.forceLock() mr.putFlateReader() mr.dict.close() if mr.flateBufio != nil { putBufioReader(mr.flateBufio) } if mr.c.client { putBufioReader(mr.c.br) mr.c.br = nil } } func (mr *msgReader) flateContextTakeover() bool { if mr.c.client { return !mr.c.copts.serverNoContextTakeover } return !mr.c.copts.clientNoContextTakeover } func (c *Conn) readRSV1Illegal(h header) bool { // If compression is disabled, rsv1 is illegal. if !c.flate() { return true } // rsv1 is only allowed on data frames beginning messages. if h.opcode != opText && h.opcode != opBinary { return true } return false } func (c *Conn) readLoop(ctx context.Context) (header, error) { for { h, err := c.readFrameHeader(ctx) if err != nil { return header{}, err } if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) c.writeError(StatusProtocolError, err) return header{}, err } if !c.client && !h.masked { return header{}, errors.New("received unmasked frame from client") } switch h.opcode { case opClose, opPing, opPong: err = c.handleControl(ctx, h) if err != nil { // Pass through CloseErrors when receiving a close frame. if h.opcode == opClose && CloseStatus(err) != -1 { return header{}, err } return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) } case opContinuation, opText, opBinary: return h, nil default: err := fmt.Errorf("received unknown opcode %v", h.opcode) c.writeError(StatusProtocolError, err) return header{}, err } } } func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { case <-c.closed: return header{}, c.closeErr case c.readTimeout <- ctx: } h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { select { case <-c.closed: return header{}, c.closeErr case <-ctx.Done(): return header{}, ctx.Err() default: c.close(err) return header{}, err } } select { case <-c.closed: return header{}, c.closeErr case c.readTimeout <- context.Background(): } return h, nil } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { case <-c.closed: return 0, c.closeErr case c.readTimeout <- ctx: } n, err := io.ReadFull(c.br, p) if err != nil { select { case <-c.closed: return n, c.closeErr case <-ctx.Done(): return n, ctx.Err() default: err = fmt.Errorf("failed to read frame payload: %w", err) c.close(err) return n, err } } select { case <-c.closed: return n, c.closeErr case c.readTimeout <- context.Background(): } return n, err } func (c *Conn) handleControl(ctx context.Context, h header) (err error) { if h.payloadLength < 0 || h.payloadLength > maxControlPayload { err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) c.writeError(StatusProtocolError, err) return err } if !h.fin { err := errors.New("received fragmented control frame") c.writeError(StatusProtocolError, err) return err } ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() b := c.readControlBuf[:h.payloadLength] _, err = c.readFramePayload(ctx, b) if err != nil { return err } if h.masked { mask(h.maskKey, b) } switch h.opcode { case opPing: return c.writeControl(ctx, opPong, b) case opPong: c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] c.activePingsMu.Unlock() if ok { select { case pong <- struct{}{}: default: } } return nil } defer func() { c.readCloseFrameErr = err }() ce, err := parseClosePayload(b) if err != nil { err = fmt.Errorf("received invalid close payload: %w", err) c.writeError(StatusProtocolError, err) return err } err = fmt.Errorf("received close frame: %w", ce) c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) c.close(err) return err } func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { defer errd.Wrap(&err, "failed to get reader") err = c.readMu.lock(ctx) if err != nil { return 0, nil, err } defer c.readMu.unlock() if !c.msgReader.fin { err = errors.New("previous message not read to completion") c.close(fmt.Errorf("failed to get reader: %w", err)) return 0, nil, err } h, err := c.readLoop(ctx) if err != nil { return 0, nil, err } if h.opcode == opContinuation { err := errors.New("received continuation frame without text or binary frame") c.writeError(StatusProtocolError, err) return 0, nil, err } c.msgReader.reset(ctx, h) return MessageType(h.opcode), c.msgReader, nil } type msgReader struct { c *Conn ctx context.Context flate bool flateReader io.Reader flateBufio *bufio.Reader flateTail strings.Reader limitReader *limitReader dict slidingWindow fin bool payloadLength int64 maskKey uint32 // readerFunc(mr.Read) to avoid continuous allocations. readFunc readerFunc } func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx mr.flate = h.rsv1 mr.limitReader.reset(mr.readFunc) if mr.flate { mr.resetFlate() } mr.setFrame(h) } func (mr *msgReader) setFrame(h header) { mr.fin = h.fin mr.payloadLength = h.payloadLength mr.maskKey = h.maskKey } func (mr *msgReader) Read(p []byte) (n int, err error) { err = mr.c.readMu.lock(mr.ctx) if err != nil { return 0, fmt.Errorf("failed to read: %w", err) } defer mr.c.readMu.unlock() n, err = mr.limitReader.Read(p) if mr.flate && mr.flateContextTakeover() { p = p[:n] mr.dict.write(p) } if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { mr.putFlateReader() return n, io.EOF } if err != nil { err = fmt.Errorf("failed to read: %w", err) mr.c.close(err) } return n, err } func (mr *msgReader) read(p []byte) (int, error) { for { if mr.payloadLength == 0 { if mr.fin { if mr.flate { return mr.flateTail.Read(p) } return 0, io.EOF } h, err := mr.c.readLoop(mr.ctx) if err != nil { return 0, err } if h.opcode != opContinuation { err := errors.New("received new data message without finishing the previous message") mr.c.writeError(StatusProtocolError, err) return 0, err } mr.setFrame(h) continue } if int64(len(p)) > mr.payloadLength { p = p[:mr.payloadLength] } n, err := mr.c.readFramePayload(mr.ctx, p) if err != nil { return n, err } mr.payloadLength -= int64(n) if !mr.c.client { mr.maskKey = mask(mr.maskKey, p) } return n, nil } } type limitReader struct { c *Conn r io.Reader limit xsync.Int64 n int64 } func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { lr := &limitReader{ c: c, } lr.limit.Store(limit) lr.reset(r) return lr } func (lr *limitReader) reset(r io.Reader) { lr.n = lr.limit.Load() lr.r = r } func (lr *limitReader) Read(p []byte) (int, error) { if lr.n <= 0 { err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) lr.c.writeError(StatusMessageTooBig, err) return 0, err } if int64(len(p)) > lr.n { p = p[:lr.n] } n, err := lr.r.Read(p) lr.n -= int64(n) return n, err } type readerFunc func(p []byte) (int, error) func (f readerFunc) Read(p []byte) (int, error) { return f(p) }