TUN-6679: Allow client side of quic request to close body

In a previous commit, we fixed a bug where the client roundtrip code
could close the request body, which in fact would be the quic.Stream,
thus closing the write-side.
The way that was fixed, prevented the client roundtrip code from closing
also read-side (the body).

This fixes that, by allowing close to only close the read side, which
will guarantee that any subsquent will fail with an error or EOF it
occurred before the close.
This commit is contained in:
João Oliveirinha 2022-08-22 23:48:45 +01:00
parent 8e9e1d973e
commit 20ed7557f9
2 changed files with 86 additions and 4 deletions

View File

@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@ -156,9 +157,10 @@ func (q *QUICConnection) runStream(quicStream quic.Stream) {
defer stream.Close() defer stream.Close()
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
// code executed in the code path of handleStream don't trigger an earlier close to the downstream stream. // code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream.
// So, we wrap the stream with a no-op closer and only this method can actually close the stream. // So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream.
noCloseStream := &nopCloserReadWriter{stream} // A call to close will simulate a close to the read-side, which will fail subsequent reads.
noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream}
if err := q.handleStream(ctx, noCloseStream); err != nil { if err := q.handleStream(ctx, noCloseStream); err != nil {
q.logger.Err(err).Msg("Failed to handle QUIC stream") q.logger.Err(err).Msg("Failed to handle QUIC stream")
} }
@ -408,10 +410,39 @@ func isTransferEncodingChunked(req *http.Request) bool {
return strings.Contains(strings.ToLower(transferEncodingVal), "chunked") return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
} }
// A helper struct that guarantees a call to close only affects read side, but not write side.
type nopCloserReadWriter struct { type nopCloserReadWriter struct {
io.ReadWriteCloser io.ReadWriteCloser
// for use by Read only
// we don't need a memory barrier here because there is an implicit assumption that
// Read calls can't happen concurrently by different go-routines.
sawEOF bool
// should be updated and read using atomic primitives.
// value is read in Read method and written in Close method, which could be done by different
// go-routines.
closed uint32
} }
func (n *nopCloserReadWriter) Close() error { func (np *nopCloserReadWriter) Read(p []byte) (n int, err error) {
if np.sawEOF {
return 0, io.EOF
}
if atomic.LoadUint32(&np.closed) > 0 {
return 0, fmt.Errorf("closed by handler")
}
n, err = np.ReadWriteCloser.Read(p)
if err == io.EOF {
np.sawEOF = true
}
return
}
func (np *nopCloserReadWriter) Close() error {
atomic.StoreUint32(&np.closed, 1)
return nil return nil
} }

View File

@ -10,6 +10,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -527,6 +528,44 @@ func TestServeUDPSession(t *testing.T) {
cancel() cancel()
} }
func TestNopCloserReadWriterCloseBeforeEOF(t *testing.T) {
readerWriter := nopCloserReadWriter{ReadWriteCloser: &mockReaderNoopWriter{Reader: strings.NewReader("123456789")}}
buffer := make([]byte, 5)
n, err := readerWriter.Read(buffer)
require.NoError(t, err)
require.Equal(t, n, 5)
// close
require.NoError(t, readerWriter.Close())
// read should get error
n, err = readerWriter.Read(buffer)
require.Equal(t, n, 0)
require.Equal(t, err, fmt.Errorf("closed by handler"))
}
func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
readerWriter := nopCloserReadWriter{ReadWriteCloser: &mockReaderNoopWriter{Reader: strings.NewReader("123456789")}}
buffer := make([]byte, 20)
n, err := readerWriter.Read(buffer)
require.NoError(t, err)
require.Equal(t, n, 9)
// force another read to read eof
n, err = readerWriter.Read(buffer)
require.Equal(t, err, io.EOF)
// close
require.NoError(t, readerWriter.Close())
// read should get EOF still
n, err = readerWriter.Read(buffer)
require.Equal(t, n, 0)
require.Equal(t, err, io.EOF)
}
func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) { func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) {
var ( var (
payload = []byte(t.Name()) payload = []byte(t.Name())
@ -647,3 +686,15 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection
require.NoError(t, err) require.NoError(t, err)
return qc return qc
} }
type mockReaderNoopWriter struct {
io.Reader
}
func (m *mockReaderNoopWriter) Write(p []byte) (n int, err error) {
return len(p), nil
}
func (m *mockReaderNoopWriter) Close() error {
return nil
}