TUN-7324: Add http.Hijacker to connection.ResponseWriter

Allows connection.ResponseWriter implemenations to be Hijacked to properly
handle WebSocket connection downgrades from proper HTTP requests.
This commit is contained in:
Devin Carr 2023-03-29 09:21:19 -07:00
parent be64362fdb
commit 87f81cc57c
4 changed files with 114 additions and 0 deletions

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"math" "math"
"net"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -197,10 +198,55 @@ func (h *HTTPResponseReadWriteAcker) AckConnection(tracePropagation string) erro
return h.w.WriteRespHeaders(resp.StatusCode, resp.Header) return h.w.WriteRespHeaders(resp.StatusCode, resp.Header)
} }
// localProxyConnection emulates an incoming connection to cloudflared as a net.Conn.
// Used when handling a "hijacked" connection from connection.ResponseWriter
type localProxyConnection struct {
io.ReadWriteCloser
}
func (c *localProxyConnection) Read(b []byte) (int, error) {
return c.ReadWriteCloser.Read(b)
}
func (c *localProxyConnection) Write(b []byte) (int, error) {
return c.ReadWriteCloser.Write(b)
}
func (c *localProxyConnection) Close() error {
return c.ReadWriteCloser.Close()
}
func (c *localProxyConnection) LocalAddr() net.Addr {
// Unused LocalAddr
return &net.TCPAddr{IP: net.IPv6loopback, Port: 0, Zone: ""}
}
func (c *localProxyConnection) RemoteAddr() net.Addr {
// Unused RemoteAddr
return &net.TCPAddr{IP: net.IPv6loopback, Port: 0, Zone: ""}
}
func (c *localProxyConnection) SetDeadline(t time.Time) error {
// ignored since we can't set the read/write Deadlines for the tunnel back to origintunneld
return nil
}
func (c *localProxyConnection) SetReadDeadline(t time.Time) error {
// ignored since we can't set the read/write Deadlines for the tunnel back to origintunneld
return nil
}
func (c *localProxyConnection) SetWriteDeadline(t time.Time) error {
// ignored since we can't set the read/write Deadlines for the tunnel back to origintunneld
return nil
}
// ResponseWriter is the response path for a request back through cloudflared's tunnel.
type ResponseWriter interface { type ResponseWriter interface {
WriteRespHeaders(status int, header http.Header) error WriteRespHeaders(status int, header http.Header) error
AddTrailer(trailerName, trailerValue string) AddTrailer(trailerName, trailerValue string)
http.ResponseWriter http.ResponseWriter
http.Hijacker
io.Writer io.Writer
} }

View File

@ -1,6 +1,7 @@
package connection package connection
import ( import (
"bufio"
"context" "context"
gojson "encoding/json" gojson "encoding/json"
"fmt" "fmt"
@ -198,6 +199,8 @@ type http2RespWriter struct {
shouldFlush bool shouldFlush bool
statusWritten bool statusWritten bool
respHeaders http.Header respHeaders http.Header
hijackedMutex sync.Mutex
hijackedv bool
log *zerolog.Logger log *zerolog.Logger
} }
@ -233,6 +236,10 @@ func (rp *http2RespWriter) AddTrailer(trailerName, trailerValue string) {
} }
func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error { func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error {
if rp.hijacked() {
rp.log.Warn().Msg("WriteRespHeaders after hijack")
return nil
}
dest := rp.w.Header() dest := rp.w.Header()
userHeaders := make(http.Header, len(header)) userHeaders := make(http.Header, len(header))
for name, values := range header { for name, values := range header {
@ -283,9 +290,43 @@ func (rp *http2RespWriter) Header() http.Header {
} }
func (rp *http2RespWriter) WriteHeader(status int) { func (rp *http2RespWriter) WriteHeader(status int) {
if rp.hijacked() {
rp.log.Warn().Msg("WriteHeader after hijack")
return
}
rp.WriteRespHeaders(status, rp.respHeaders) rp.WriteRespHeaders(status, rp.respHeaders)
} }
func (rp *http2RespWriter) hijacked() bool {
rp.hijackedMutex.Lock()
defer rp.hijackedMutex.Unlock()
return rp.hijackedv
}
func (rp *http2RespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if !rp.statusWritten {
return nil, nil, fmt.Errorf("status not yet written before attempting to hijack connection")
}
// Make sure to flush anything left in the buffer before hijacking
if rp.shouldFlush {
rp.flusher.Flush()
}
rp.hijackedMutex.Lock()
defer rp.hijackedMutex.Unlock()
if rp.hijackedv {
return nil, nil, http.ErrHijacked
}
rp.hijackedv = true
conn := &localProxyConnection{rp}
// We return the http2RespWriter here because we want to make sure that we flush after every write
// otherwise the HTTP2 write buffer waits a few seconds before sending.
readWriter := bufio.NewReadWriter(
bufio.NewReader(rp),
bufio.NewWriter(rp),
)
return conn, readWriter, nil
}
func (rp *http2RespWriter) WriteErrorResponse() bool { func (rp *http2RespWriter) WriteErrorResponse() bool {
if rp.statusWritten { if rp.statusWritten {
return false return false

View File

@ -1,6 +1,7 @@
package connection package connection
import ( import (
"bufio"
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
@ -435,6 +436,15 @@ func (hrw *httpResponseAdapter) WriteHeader(status int) {
hrw.WriteRespHeaders(status, hrw.headers) hrw.WriteRespHeaders(status, hrw.headers)
} }
func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
conn := &localProxyConnection{hrw.ReadWriteCloser}
readWriter := bufio.NewReadWriter(
bufio.NewReader(hrw.ReadWriteCloser),
bufio.NewWriter(hrw.ReadWriteCloser),
)
return conn, readWriter, nil
}
func (hrw *httpResponseAdapter) WriteErrorResponse(err error) { func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
} }

View File

@ -1,6 +1,7 @@
package proxy package proxy
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"flag" "flag"
@ -76,6 +77,10 @@ func (w *mockHTTPRespWriter) headers() http.Header {
return w.Header() return w.Header()
} }
func (m *mockHTTPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
panic("Hijack not implemented")
}
type mockWSRespWriter struct { type mockWSRespWriter struct {
*mockHTTPRespWriter *mockHTTPRespWriter
writeNotification chan []byte writeNotification chan []byte
@ -109,6 +114,10 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) {
return w.reader.Read(data) return w.reader.Read(data)
} }
func (m *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
panic("Hijack not implemented")
}
type mockSSERespWriter struct { type mockSSERespWriter struct {
*mockHTTPRespWriter *mockHTTPRespWriter
writeNotification chan []byte writeNotification chan []byte
@ -840,6 +849,10 @@ func (w *wsRespWriter) WriteHeader(status int) {
// unused // unused
} }
func (m *wsRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
panic("Hijack not implemented")
}
type mockTCPRespWriter struct { type mockTCPRespWriter struct {
w io.Writer w io.Writer
responseHeaders http.Header responseHeaders http.Header
@ -879,6 +892,10 @@ func (m *mockTCPRespWriter) WriteHeader(status int) {
// do nothing // do nothing
} }
func (m *mockTCPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
panic("Hijack not implemented")
}
func createSingleIngressConfig(t *testing.T, service string) ingress.Ingress { func createSingleIngressConfig(t *testing.T, service string) ingress.Ingress {
ingressConfig := &config.Configuration{ ingressConfig := &config.Configuration{
Ingress: []config.UnvalidatedIngressRule{ Ingress: []config.UnvalidatedIngressRule{