TUN-6679: Allow client side of quic request to close body
In a previous commit, we fixed a bug where the client roundtrip code could close the request body, which in fact would be the quic.Stream, thus closing the write-side. The way that was fixed, prevented the client roundtrip code from closing also read-side (the body). This fixes that, by allowing close to only close the read side, which will guarantee that any subsquent will fail with an error or EOF it occurred before the close.
This commit is contained in:
parent
8e9e1d973e
commit
20ed7557f9
|
@ -9,6 +9,7 @@ import (
|
|||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -156,9 +157,10 @@ func (q *QUICConnection) runStream(quicStream quic.Stream) {
|
|||
defer stream.Close()
|
||||
|
||||
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
|
||||
// code executed in the code path of handleStream don't trigger an earlier close to the downstream stream.
|
||||
// So, we wrap the stream with a no-op closer and only this method can actually close the stream.
|
||||
noCloseStream := &nopCloserReadWriter{stream}
|
||||
// code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream.
|
||||
// So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream.
|
||||
// A call to close will simulate a close to the read-side, which will fail subsequent reads.
|
||||
noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream}
|
||||
if err := q.handleStream(ctx, noCloseStream); err != nil {
|
||||
q.logger.Err(err).Msg("Failed to handle QUIC stream")
|
||||
}
|
||||
|
@ -408,10 +410,39 @@ func isTransferEncodingChunked(req *http.Request) bool {
|
|||
return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
|
||||
}
|
||||
|
||||
// A helper struct that guarantees a call to close only affects read side, but not write side.
|
||||
type nopCloserReadWriter struct {
|
||||
io.ReadWriteCloser
|
||||
|
||||
// for use by Read only
|
||||
// we don't need a memory barrier here because there is an implicit assumption that
|
||||
// Read calls can't happen concurrently by different go-routines.
|
||||
sawEOF bool
|
||||
// should be updated and read using atomic primitives.
|
||||
// value is read in Read method and written in Close method, which could be done by different
|
||||
// go-routines.
|
||||
closed uint32
|
||||
}
|
||||
|
||||
func (n *nopCloserReadWriter) Close() error {
|
||||
func (np *nopCloserReadWriter) Read(p []byte) (n int, err error) {
|
||||
if np.sawEOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&np.closed) > 0 {
|
||||
return 0, fmt.Errorf("closed by handler")
|
||||
}
|
||||
|
||||
n, err = np.ReadWriteCloser.Read(p)
|
||||
if err == io.EOF {
|
||||
np.sawEOF = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (np *nopCloserReadWriter) Close() error {
|
||||
atomic.StoreUint32(&np.closed, 1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -527,6 +528,44 @@ func TestServeUDPSession(t *testing.T) {
|
|||
cancel()
|
||||
}
|
||||
|
||||
func TestNopCloserReadWriterCloseBeforeEOF(t *testing.T) {
|
||||
readerWriter := nopCloserReadWriter{ReadWriteCloser: &mockReaderNoopWriter{Reader: strings.NewReader("123456789")}}
|
||||
buffer := make([]byte, 5)
|
||||
|
||||
n, err := readerWriter.Read(buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, 5)
|
||||
|
||||
// close
|
||||
require.NoError(t, readerWriter.Close())
|
||||
|
||||
// read should get error
|
||||
n, err = readerWriter.Read(buffer)
|
||||
require.Equal(t, n, 0)
|
||||
require.Equal(t, err, fmt.Errorf("closed by handler"))
|
||||
}
|
||||
|
||||
func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
|
||||
readerWriter := nopCloserReadWriter{ReadWriteCloser: &mockReaderNoopWriter{Reader: strings.NewReader("123456789")}}
|
||||
buffer := make([]byte, 20)
|
||||
|
||||
n, err := readerWriter.Read(buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, 9)
|
||||
|
||||
// force another read to read eof
|
||||
n, err = readerWriter.Read(buffer)
|
||||
require.Equal(t, err, io.EOF)
|
||||
|
||||
// close
|
||||
require.NoError(t, readerWriter.Close())
|
||||
|
||||
// read should get EOF still
|
||||
n, err = readerWriter.Read(buffer)
|
||||
require.Equal(t, n, 0)
|
||||
require.Equal(t, err, io.EOF)
|
||||
}
|
||||
|
||||
func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) {
|
||||
var (
|
||||
payload = []byte(t.Name())
|
||||
|
@ -647,3 +686,15 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection
|
|||
require.NoError(t, err)
|
||||
return qc
|
||||
}
|
||||
|
||||
type mockReaderNoopWriter struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (m *mockReaderNoopWriter) Write(p []byte) (n int, err error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (m *mockReaderNoopWriter) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue