diff --git a/connection/connection.go b/connection/connection.go index e089483b..25d796dd 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "math" + "net" "net/http" "strconv" "strings" @@ -197,10 +198,55 @@ func (h *HTTPResponseReadWriteAcker) AckConnection(tracePropagation string) erro return h.w.WriteRespHeaders(resp.StatusCode, resp.Header) } +// localProxyConnection emulates an incoming connection to cloudflared as a net.Conn. +// Used when handling a "hijacked" connection from connection.ResponseWriter +type localProxyConnection struct { + io.ReadWriteCloser +} + +func (c *localProxyConnection) Read(b []byte) (int, error) { + return c.ReadWriteCloser.Read(b) +} + +func (c *localProxyConnection) Write(b []byte) (int, error) { + return c.ReadWriteCloser.Write(b) +} + +func (c *localProxyConnection) Close() error { + return c.ReadWriteCloser.Close() +} + +func (c *localProxyConnection) LocalAddr() net.Addr { + // Unused LocalAddr + return &net.TCPAddr{IP: net.IPv6loopback, Port: 0, Zone: ""} +} + +func (c *localProxyConnection) RemoteAddr() net.Addr { + // Unused RemoteAddr + return &net.TCPAddr{IP: net.IPv6loopback, Port: 0, Zone: ""} +} + +func (c *localProxyConnection) SetDeadline(t time.Time) error { + // ignored since we can't set the read/write Deadlines for the tunnel back to origintunneld + return nil +} + +func (c *localProxyConnection) SetReadDeadline(t time.Time) error { + // ignored since we can't set the read/write Deadlines for the tunnel back to origintunneld + return nil +} + +func (c *localProxyConnection) SetWriteDeadline(t time.Time) error { + // ignored since we can't set the read/write Deadlines for the tunnel back to origintunneld + return nil +} + +// ResponseWriter is the response path for a request back through cloudflared's tunnel. type ResponseWriter interface { WriteRespHeaders(status int, header http.Header) error AddTrailer(trailerName, trailerValue string) http.ResponseWriter + http.Hijacker io.Writer } diff --git a/connection/http2.go b/connection/http2.go index 1b80e5f9..4945855b 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -1,6 +1,7 @@ package connection import ( + "bufio" "context" gojson "encoding/json" "fmt" @@ -198,6 +199,8 @@ type http2RespWriter struct { shouldFlush bool statusWritten bool respHeaders http.Header + hijackedMutex sync.Mutex + hijackedv bool log *zerolog.Logger } @@ -233,6 +236,10 @@ func (rp *http2RespWriter) AddTrailer(trailerName, trailerValue string) { } func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error { + if rp.hijacked() { + rp.log.Warn().Msg("WriteRespHeaders after hijack") + return nil + } dest := rp.w.Header() userHeaders := make(http.Header, len(header)) for name, values := range header { @@ -283,9 +290,43 @@ func (rp *http2RespWriter) Header() http.Header { } func (rp *http2RespWriter) WriteHeader(status int) { + if rp.hijacked() { + rp.log.Warn().Msg("WriteHeader after hijack") + return + } rp.WriteRespHeaders(status, rp.respHeaders) } +func (rp *http2RespWriter) hijacked() bool { + rp.hijackedMutex.Lock() + defer rp.hijackedMutex.Unlock() + return rp.hijackedv +} + +func (rp *http2RespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if !rp.statusWritten { + return nil, nil, fmt.Errorf("status not yet written before attempting to hijack connection") + } + // Make sure to flush anything left in the buffer before hijacking + if rp.shouldFlush { + rp.flusher.Flush() + } + rp.hijackedMutex.Lock() + defer rp.hijackedMutex.Unlock() + if rp.hijackedv { + return nil, nil, http.ErrHijacked + } + rp.hijackedv = true + conn := &localProxyConnection{rp} + // We return the http2RespWriter here because we want to make sure that we flush after every write + // otherwise the HTTP2 write buffer waits a few seconds before sending. + readWriter := bufio.NewReadWriter( + bufio.NewReader(rp), + bufio.NewWriter(rp), + ) + return conn, readWriter, nil +} + func (rp *http2RespWriter) WriteErrorResponse() bool { if rp.statusWritten { return false diff --git a/connection/quic.go b/connection/quic.go index 510cd4b0..e1e5869e 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -1,6 +1,7 @@ package connection import ( + "bufio" "context" "crypto/tls" "fmt" @@ -435,6 +436,15 @@ func (hrw *httpResponseAdapter) WriteHeader(status int) { hrw.WriteRespHeaders(status, hrw.headers) } +func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + conn := &localProxyConnection{hrw.ReadWriteCloser} + readWriter := bufio.NewReadWriter( + bufio.NewReader(hrw.ReadWriteCloser), + bufio.NewWriter(hrw.ReadWriteCloser), + ) + return conn, readWriter, nil +} + func (hrw *httpResponseAdapter) WriteErrorResponse(err error) { hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index b3b1e30b..942bbe30 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -1,6 +1,7 @@ package proxy import ( + "bufio" "bytes" "context" "flag" @@ -76,6 +77,10 @@ func (w *mockHTTPRespWriter) headers() http.Header { return w.Header() } +func (m *mockHTTPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + panic("Hijack not implemented") +} + type mockWSRespWriter struct { *mockHTTPRespWriter writeNotification chan []byte @@ -109,6 +114,10 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) { return w.reader.Read(data) } +func (m *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + panic("Hijack not implemented") +} + type mockSSERespWriter struct { *mockHTTPRespWriter writeNotification chan []byte @@ -840,6 +849,10 @@ func (w *wsRespWriter) WriteHeader(status int) { // unused } +func (m *wsRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + panic("Hijack not implemented") +} + type mockTCPRespWriter struct { w io.Writer responseHeaders http.Header @@ -879,6 +892,10 @@ func (m *mockTCPRespWriter) WriteHeader(status int) { // do nothing } +func (m *mockTCPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + panic("Hijack not implemented") +} + func createSingleIngressConfig(t *testing.T, service string) ingress.Ingress { ingressConfig := &config.Configuration{ Ingress: []config.UnvalidatedIngressRule{