diff --git a/Makefile b/Makefile index ee1b3101..4923aa83 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,10 @@ endif .PHONY: all all: cloudflared test +.PHONY: clean +clean: + go clean + .PHONY: cloudflared cloudflared: go build -v $(VERSION_FLAGS) $(IMPORT_PATH)/cmd/cloudflared diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index a80415d1..055348f0 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -15,11 +15,12 @@ import ( ) const ( - defaultFrameSize uint32 = 1 << 14 // Minimum frame size in http2 spec - defaultWindowSize uint32 = 65535 - maxWindowSize uint32 = (1 << 31) - 1 // 2^31-1 = 2147483647, max window size specified in http2 spec - defaultTimeout time.Duration = 5 * time.Second - defaultRetries uint64 = 5 + defaultFrameSize uint32 = 1 << 14 // Minimum frame size in http2 spec + defaultWindowSize uint32 = (1 << 16) - 1 // Minimum window size in http2 spec + maxWindowSize uint32 = (1 << 31) - 1 // 2^31-1 = 2147483647, max window size in http2 spec + defaultTimeout time.Duration = 5 * time.Second + defaultRetries uint64 = 5 + defaultWriteBufferMaxLen int = 1024 * 1024 * 512 // 500mb SettingMuxerMagic http2.SettingID = 0x42db MuxerMagicOrigin uint32 = 0xa2e43c8b @@ -49,6 +50,12 @@ type MuxerConfig struct { // Logger to use Logger *log.Entry CompressionQuality CompressionSetting + // Initial size for HTTP2 flow control windows + DefaultWindowSize uint32 + // Largest allowable size for HTTP2 flow control windows + MaxWindowSize uint32 + // Largest allowable capacity for the buffer of data to be sent + StreamWriteBufferMaxLen int } type Muxer struct { @@ -98,6 +105,15 @@ func Handshake( if config.Timeout == 0 { config.Timeout = defaultTimeout } + if config.DefaultWindowSize == 0 { + config.DefaultWindowSize = defaultWindowSize + } + if config.MaxWindowSize == 0 { + config.MaxWindowSize = maxWindowSize + } + if config.StreamWriteBufferMaxLen == 0 { + config.StreamWriteBufferMaxLen = defaultWriteBufferMaxLen + } // Initialise connection state fields m := &Muxer{ f: http2.NewFramer(w, r), // A framer that writes to w and reads from r @@ -179,8 +195,9 @@ func Handshake( abortChan: m.abortChan, pingTimestamp: pingTimestamp, connActive: connActive, - initialStreamWindow: defaultWindowSize, - streamWindowMax: maxWindowSize, + initialStreamWindow: m.config.DefaultWindowSize, + streamWindowMax: m.config.MaxWindowSize, + streamWriteBufferMaxLen: m.config.StreamWriteBufferMaxLen, r: m.r, updateRTTChan: updateRTTChan, updateReceiveWindowChan: updateReceiveWindowChan, @@ -375,10 +392,12 @@ func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, erro responseHeadersReceived: make(chan struct{}), readBuffer: NewSharedBuffer(), writeBuffer: &bytes.Buffer{}, - receiveWindow: defaultWindowSize, - receiveWindowCurrentMax: defaultWindowSize, // Initial window size limit. exponentially increase it when receiveWindow is exhausted - receiveWindowMax: maxWindowSize, - sendWindow: defaultWindowSize, + writeBufferMaxLen: m.config.StreamWriteBufferMaxLen, + writeBufferHasSpace: make(chan struct{}, 1), + receiveWindow: m.config.DefaultWindowSize, + receiveWindowCurrentMax: m.config.DefaultWindowSize, + receiveWindowMax: m.config.MaxWindowSize, + sendWindow: m.config.DefaultWindowSize, readyList: m.readyList, writeHeaders: headers, dictionaries: m.muxReader.dictionaries, diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go index 0415d225..6c5ac73a 100644 --- a/h2mux/h2mux_test.go +++ b/h2mux/h2mux_test.go @@ -39,20 +39,26 @@ func NewDefaultMuxerPair() *DefaultMuxerPair { origin, edge := net.Pipe() return &DefaultMuxerPair{ OriginMuxConfig: MuxerConfig{ - Timeout: time.Second, - IsClient: true, - Name: "origin", - Logger: log.NewEntry(log.New()), + Timeout: time.Second, + IsClient: true, + Name: "origin", + Logger: log.NewEntry(log.New()), + DefaultWindowSize: (1 << 8) - 1, + MaxWindowSize: (1 << 15) - 1, + StreamWriteBufferMaxLen: 1024, }, - OriginConn: origin, - EdgeMuxConfig: MuxerConfig{ - Timeout: time.Second, - IsClient: false, - Name: "edge", - Logger: log.NewEntry(log.New()), + OriginConn: origin, + EdgeMuxConfig: MuxerConfig{ + Timeout: time.Second, + IsClient: false, + Name: "edge", + Logger: log.NewEntry(log.New()), + DefaultWindowSize: (1 << 8) - 1, + MaxWindowSize: (1 << 15) - 1, + StreamWriteBufferMaxLen: 1024, }, - EdgeConn: edge, - doneC: make(chan struct{}), + EdgeConn: edge, + doneC: make(chan struct{}), } } @@ -60,22 +66,22 @@ func NewCompressedMuxerPair(quality CompressionSetting) *DefaultMuxerPair { origin, edge := net.Pipe() return &DefaultMuxerPair{ OriginMuxConfig: MuxerConfig{ - Timeout: time.Second, - IsClient: true, - Name: "origin", + Timeout: time.Second, + IsClient: true, + Name: "origin", CompressionQuality: quality, - Logger: log.NewEntry(log.New()), + Logger: log.NewEntry(log.New()), }, - OriginConn: origin, - EdgeMuxConfig: MuxerConfig{ - Timeout: time.Second, - IsClient: false, - Name: "edge", + OriginConn: origin, + EdgeMuxConfig: MuxerConfig{ + Timeout: time.Second, + IsClient: false, + Name: "edge", CompressionQuality: quality, - Logger: log.NewEntry(log.New()), + Logger: log.NewEntry(log.New()), }, - EdgeConn: edge, - doneC: make(chan struct{}), + EdgeConn: edge, + doneC: make(chan struct{}), } } @@ -230,7 +236,6 @@ func TestSingleStream(t *testing.T) { func TestSingleStreamLargeResponseBody(t *testing.T) { muxPair := NewDefaultMuxerPair() bodySize := 1 << 24 - streamReady := make(chan struct{}) muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { if len(stream.Headers) != 1 { t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) @@ -257,8 +262,6 @@ func TestSingleStreamLargeResponseBody(t *testing.T) { if n != len(payload) { t.Fatalf("origin short write: %d/%d bytes", n, len(payload)) } - t.Log("Payload written; signaling that the stream is ready") - streamReady <- struct{}{} return nil }) @@ -282,9 +285,6 @@ func TestSingleStreamLargeResponseBody(t *testing.T) { } responseBody := make([]byte, bodySize) - <-streamReady - t.Log("Received stream ready signal; resuming the test") - n, err := io.ReadFull(stream, responseBody) if err != nil { t.Fatalf("error from (*MuxedStream).Read: %s", err) @@ -367,14 +367,13 @@ func TestMultipleStreams(t *testing.T) { log.Error(err) } if testFail { - t.Fatalf("TestMultipleStreamsFlowControl failed") + t.Fatalf("TestMultipleStreams failed") } } func TestMultipleStreamsFlowControl(t *testing.T) { maxStreams := 32 errorsC := make(chan error, maxStreams) - streamReady := make(chan struct{}) responseSizes := make([]int32, maxStreams) for i := 0; i < maxStreams; i++ { responseSizes[i] = rand.Int31n(int32(defaultWindowSize << 4)) @@ -398,7 +397,6 @@ func TestMultipleStreamsFlowControl(t *testing.T) { payload[i] = byte(i % 256) } n, err := stream.Write(payload) - streamReady <- struct{}{} if err != nil { t.Fatalf("origin write error: %s", err) } @@ -435,7 +433,6 @@ func TestMultipleStreamsFlowControl(t *testing.T) { return } - <-streamReady responseBody := make([]byte, responseSizes[(stream.streamID-2)/2]) n, err := io.ReadFull(stream, responseBody) if err != nil { @@ -782,9 +779,11 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) { } wg.Add(len(paths)) + errorsC := make(chan error, len(paths)) for i, s := range paths { go func(i int, path string) { + defer wg.Done() stream, err := muxPair.EdgeMux.OpenStream( []Header{ {Name: ":method", Value: "GET"}, @@ -805,22 +804,30 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) { responseBody := make([]byte, len(expectBody)*2) n, err := stream.Read(responseBody) if err != nil { - log.Printf("error from (*MuxedStream).Read: %s", err) - t.Fatalf("error from (*MuxedStream).Read: %s", err) + errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err) + return } if n != len(expectBody) { - log.Printf("expected response body to have %d bytes, got %d", len(expectBody), n) - t.Fatalf("expected response body to have %d bytes, got %d", len(expectBody), n) + errorsC <- fmt.Errorf("stream %d expected response body to have %d bytes, got %d", stream.streamID, len(expectBody), n) + return } if string(responseBody[:n]) != expectBody { - log.Printf("expected response body %s, got %s", expectBody, responseBody[:n]) - t.Fatalf("expected response body %s, got %s", expectBody, responseBody[:n]) + errorsC <- fmt.Errorf("stream %d expected response body %s, got %s", stream.streamID, expectBody, responseBody[:n]) + return } - wg.Done() }(i, s) - time.Sleep(1 * time.Millisecond) } + wg.Wait() + close(errorsC) + testFail := false + for err := range errorsC { + testFail = true + log.Error(err) + } + if testFail { + t.Fatalf("TestMultipleStreams failed") + } if q > CompressionNone && muxPair.OriginMux.muxMetricsUpdater.compBytesBefore.Value() <= 10*muxPair.OriginMux.muxMetricsUpdater.compBytesAfter.Value() { t.Fatalf("Cross-stream compression is expected to give a better compression ratio") diff --git a/h2mux/muxedstream.go b/h2mux/muxedstream.go index cb31f01c..2bb59db1 100644 --- a/h2mux/muxedstream.go +++ b/h2mux/muxedstream.go @@ -17,32 +17,51 @@ type ReadWriteClosedCloser interface { Closed() bool } +// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data. type MuxedStream struct { - Headers []Header - streamID uint32 + // The "Receive" end of the stream + readBufferLock sync.RWMutex + readBuffer ReadWriteClosedCloser + // This is the amount of bytes that are in our receive window + // (how much data we can receive into this stream). + receiveWindow uint32 + // current receive window size limit. Exponentially increase it when it's exhausted + receiveWindowCurrentMax uint32 + // hard limit set in http2 spec. 2^31-1 + receiveWindowMax uint32 + // The desired size increment for receiveWindow. + // If this is nonzero, a WINDOW_UPDATE frame needs to be sent. + windowUpdate uint32 + // The headers that were most recently received. + // Particularly: + // * for an eyeball-initiated stream (as passed to TunnelHandler::ServeStream), + // these are the request headers + // * for a cloudflared-initiated stream (as created by Register/UnregisterTunnel), + // these are the response headers. + // They are useful in both of these contexts; hence `Headers` is public. + Headers []Header + // For use in the context of a cloudflared-initiated stream. responseHeadersReceived chan struct{} - readBuffer ReadWriteClosedCloser - receiveWindow uint32 - // current window size limit. Exponentially increase it when it's exhausted - receiveWindowCurrentMax uint32 - // limit set in http2 spec. 2^31-1 - receiveWindowMax uint32 - - // nonzero if a WINDOW_UPDATE frame for a stream needs to be sent - windowUpdate uint32 - - writeLock sync.Mutex - // The zero value for Buffer is an empty buffer ready to use. + // The "Send" end of the stream + writeLock sync.Mutex writeBuffer ReadWriteLengther - + // The maximum capacity that the send buffer should grow to. + writeBufferMaxLen int + // A channel to be notified when the send buffer is not full. + writeBufferHasSpace chan struct{} + // This is the amount of bytes that are in the peer's receive window + // (how much data we can send from this stream). sendWindow uint32 - - readyList *ReadyList + // Reference to the muxer's readyList; signal this for stream data to be sent. + readyList *ReadyList + // The headers that should be sent, and a flag so we only send them once. headersSent bool writeHeaders []Header + + // EOF-related fields // true if the write end of this stream has been closed writeEOF bool // true if we have sent EOF to the peer @@ -50,40 +69,63 @@ type MuxedStream struct { // true if the peer sent us an EOF receivedEOF bool - // dictionary that was used to compress the stream + // Compression-related fields receivedUseDict bool method string contentType string path string dictionaries h2Dictionaries - readBufferLock sync.RWMutex } func (s *MuxedStream) Read(p []byte) (n int, err error) { + var readBuffer ReadWriteClosedCloser if s.dictionaries.read != nil { s.readBufferLock.RLock() - b := s.readBuffer + readBuffer = s.readBuffer s.readBufferLock.RUnlock() - return b.Read(p) + } else { + readBuffer = s.readBuffer } - return s.readBuffer.Read(p) + n, err = readBuffer.Read(p) + s.replenishReceiveWindow(uint32(n)) + return } -func (s *MuxedStream) Write(p []byte) (n int, err error) { +// Blocks until len(p) bytes have been written to the buffer +func (s *MuxedStream) Write(p []byte) (int, error) { + // If assignDictToStream returns success, then it will have acquired the + // writeLock. Otherwise we must acquire it ourselves. ok := assignDictToStream(s, p) if !ok { s.writeLock.Lock() } defer s.writeLock.Unlock() + if s.writeEOF { return 0, io.EOF } - n, err = s.writeBuffer.Write(p) - if n != len(p) || err != nil { - return n, err + totalWritten := 0 + for totalWritten < len(p) { + // If the buffer is full, block till there is more room. + // Use a loop to recheck the buffer size after the lock is reacquired. + for s.writeBufferMaxLen <= s.writeBuffer.Len() { + s.writeLock.Unlock() + <-s.writeBufferHasSpace + s.writeLock.Lock() + } + amountToWrite := len(p) - totalWritten + spaceAvailable := s.writeBufferMaxLen - s.writeBuffer.Len() + if spaceAvailable < amountToWrite { + amountToWrite = spaceAvailable + } + amountWritten, err := s.writeBuffer.Write(p[totalWritten : totalWritten+amountToWrite]) + totalWritten += amountWritten + if err != nil { + return totalWritten, err + } + s.writeNotify() } - s.writeNotify() - return n, nil + return totalWritten, nil } func (s *MuxedStream) Close() error { @@ -164,9 +206,9 @@ func (s *MuxedStream) writeNotify() { // receive window (how much data we can send). func (s *MuxedStream) replenishSendWindow(bytes uint32) { s.writeLock.Lock() + defer s.writeLock.Unlock() s.sendWindow += bytes s.writeNotify() - s.writeLock.Unlock() } // Call by muxreader when it receives a data frame @@ -178,17 +220,30 @@ func (s *MuxedStream) consumeReceiveWindow(bytes uint32) bool { return false } s.receiveWindow -= bytes - if s.receiveWindow < s.receiveWindowCurrentMax/2 { + if s.receiveWindow < s.receiveWindowCurrentMax/2 && s.receiveWindowCurrentMax < s.receiveWindowMax { // exhausting client send window (how much data client can send) - if s.receiveWindowCurrentMax < s.receiveWindowMax { - s.receiveWindowCurrentMax <<= 1 + // and there is room to grow the receive window + newMax := s.receiveWindowCurrentMax << 1 + if newMax > s.receiveWindowMax { + newMax = s.receiveWindowMax } - s.windowUpdate += s.receiveWindowCurrentMax - s.receiveWindow + s.windowUpdate += newMax - s.receiveWindowCurrentMax + s.receiveWindowCurrentMax = newMax + // notify MuxWriter to write WINDOW_UPDATE frame s.writeNotify() } return true } +// Arranges for the MuxWriter to send a WINDOW_UPDATE +// Called by MuxedStream::Read when data has left the read buffer. +func (s *MuxedStream) replenishReceiveWindow(bytes uint32) { + s.writeLock.Lock() + defer s.writeLock.Unlock() + s.windowUpdate += bytes + s.writeNotify() +} + // receiveEOF should be called when the peer indicates no more data will be sent. // Returns true if the socket is now closed (i.e. the write side is already closed). func (s *MuxedStream) receiveEOF() (closed bool) { @@ -226,7 +281,8 @@ type streamChunk struct { // true if a HEADERS frame should be sent sendHeaders bool headers []Header - // nonzero if a WINDOW_UPDATE frame should be sent + // nonzero if a WINDOW_UPDATE frame should be sent; + // in that case, it is the increment value to use windowUpdate uint32 // true if data frames should be sent sendData bool @@ -249,11 +305,23 @@ func (s *MuxedStream) getChunk() *streamChunk { eof: s.writeEOF && uint32(s.writeBuffer.Len()) <= s.sendWindow, } - // Copies at most s.sendWindow bytes + // Copy at most s.sendWindow bytes, adjust the sendWindow accordingly writeLen, _ := io.CopyN(&chunk.buffer, s.writeBuffer, int64(s.sendWindow)) s.sendWindow -= uint32(writeLen) + + // Non-blocking channel send. This will allow MuxedStream::Write() to continue, if needed + if s.writeBuffer.Len() < s.writeBufferMaxLen { + select { + case s.writeBufferHasSpace <- struct{}{}: + default: + } + } + + // When we write the chunk, we'll write the WINDOW_UPDATE frame if needed s.receiveWindow += s.windowUpdate s.windowUpdate = 0 + + // When we write the chunk, we'll write the headers if needed s.headersSent = true // if this chunk contains the end of the stream, close the stream now diff --git a/h2mux/muxedstream_test.go b/h2mux/muxedstream_test.go index 4c1091f0..3672b531 100644 --- a/h2mux/muxedstream_test.go +++ b/h2mux/muxedstream_test.go @@ -23,47 +23,55 @@ func TestFlowControlSingleStream(t *testing.T) { sendWindow: testWindowSize, readyList: NewReadyList(), } + var tempWindowUpdate uint32 + var tempStreamChunk *streamChunk + assert.True(t, stream.consumeReceiveWindow(testWindowSize/2)) dataSent := testWindowSize / 2 assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow) assert.Equal(t, testWindowSize, stream.receiveWindowCurrentMax) - assert.Equal(t, uint32(0), stream.windowUpdate) - tempWindowUpdate := stream.windowUpdate - - streamChunk := stream.getChunk() - assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate) - assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow) - assert.Equal(t, uint32(0), stream.windowUpdate) assert.Equal(t, testWindowSize, stream.sendWindow) + assert.Equal(t, uint32(0), stream.windowUpdate) + + tempStreamChunk = stream.getChunk() + assert.Equal(t, uint32(0), tempStreamChunk.windowUpdate) + assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow) + assert.Equal(t, testWindowSize, stream.receiveWindowCurrentMax) + assert.Equal(t, testWindowSize, stream.sendWindow) + assert.Equal(t, uint32(0), stream.windowUpdate) assert.True(t, stream.consumeReceiveWindow(2)) dataSent += 2 assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow) assert.Equal(t, testWindowSize<<1, stream.receiveWindowCurrentMax) - assert.Equal(t, (testWindowSize<<1)-stream.receiveWindow, stream.windowUpdate) + assert.Equal(t, testWindowSize, stream.sendWindow) + assert.Equal(t, testWindowSize, stream.windowUpdate) tempWindowUpdate = stream.windowUpdate - streamChunk = stream.getChunk() - assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate) - assert.Equal(t, testWindowSize<<1, stream.receiveWindow) - assert.Equal(t, uint32(0), stream.windowUpdate) + tempStreamChunk = stream.getChunk() + assert.Equal(t, tempWindowUpdate, tempStreamChunk.windowUpdate) + assert.Equal(t, (testWindowSize<<1)-dataSent, stream.receiveWindow) + assert.Equal(t, testWindowSize<<1, stream.receiveWindowCurrentMax) assert.Equal(t, testWindowSize, stream.sendWindow) + assert.Equal(t, uint32(0), stream.windowUpdate) assert.True(t, stream.consumeReceiveWindow(testWindowSize+10)) - dataSent = testWindowSize + 10 + dataSent += testWindowSize + 10 assert.Equal(t, (testWindowSize<<1)-dataSent, stream.receiveWindow) assert.Equal(t, testWindowSize<<2, stream.receiveWindowCurrentMax) - assert.Equal(t, (testWindowSize<<2)-stream.receiveWindow, stream.windowUpdate) + assert.Equal(t, testWindowSize, stream.sendWindow) + assert.Equal(t, testWindowSize<<1, stream.windowUpdate) tempWindowUpdate = stream.windowUpdate - streamChunk = stream.getChunk() - assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate) - assert.Equal(t, testWindowSize<<2, stream.receiveWindow) - assert.Equal(t, uint32(0), stream.windowUpdate) + tempStreamChunk = stream.getChunk() + assert.Equal(t, tempWindowUpdate, tempStreamChunk.windowUpdate) + assert.Equal(t, (testWindowSize<<2)-dataSent, stream.receiveWindow) + assert.Equal(t, testWindowSize<<2, stream.receiveWindowCurrentMax) assert.Equal(t, testWindowSize, stream.sendWindow) + assert.Equal(t, uint32(0), stream.windowUpdate) assert.False(t, stream.consumeReceiveWindow(testMaxWindowSize+1)) - assert.Equal(t, testWindowSize<<2, stream.receiveWindow) + assert.Equal(t, (testWindowSize<<2)-dataSent, stream.receiveWindow) assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax) } diff --git a/h2mux/muxreader.go b/h2mux/muxreader.go index b258bd69..f80c1ab7 100644 --- a/h2mux/muxreader.go +++ b/h2mux/muxreader.go @@ -35,6 +35,8 @@ type MuxReader struct { initialStreamWindow uint32 // The max value for the send window of a stream. streamWindowMax uint32 + // The max size for the write buffer of a stream + streamWriteBufferMaxLen int // r is a reference to the underlying connection used when shutting down. r io.Closer // updateRTTChan is the channel to send new RTT measurement to muxerMetricsUpdater @@ -153,6 +155,8 @@ func (r *MuxReader) newMuxedStream(streamID uint32) *MuxedStream { streamID: streamID, readBuffer: NewSharedBuffer(), writeBuffer: &bytes.Buffer{}, + writeBufferMaxLen: r.streamWriteBufferMaxLen, + writeBufferHasSpace: make(chan struct{}, 1), receiveWindow: r.initialStreamWindow, receiveWindowCurrentMax: r.initialStreamWindow, receiveWindowMax: r.streamWindowMax,