TUN-4701: Split Proxy into ProxyHTTP and ProxyTCP
http.Request now is only used by ProxyHTTP and not required if the proxying is TCP. The dest conversion is handled by the transport layer.
This commit is contained in:
parent
81dff44bb9
commit
8f3526289a
|
@ -1,6 +1,7 @@
|
||||||
package connection
|
package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -11,9 +12,15 @@ import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
const LogFieldConnIndex = "connIndex"
|
const (
|
||||||
|
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
||||||
|
LogFieldConnIndex = "connIndex"
|
||||||
|
)
|
||||||
|
|
||||||
|
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
OriginProxy OriginProxy
|
OriginProxy OriginProxy
|
||||||
|
@ -87,9 +94,64 @@ func (t Type) String() string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OriginProxy is how data flows from cloudflared to the origin services running behind it.
|
||||||
type OriginProxy interface {
|
type OriginProxy interface {
|
||||||
// If Proxy returns an error, the caller is responsible for writing the error status to ResponseWriter
|
ProxyHTTP(w ResponseWriter, req *http.Request, isWebsocket bool) error
|
||||||
Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error
|
ProxyTCP(ctx context.Context, rwa ReadWriteAcker, req *TCPRequest) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCPRequest defines the input format needed to perform a TCP proxy.
|
||||||
|
type TCPRequest struct {
|
||||||
|
Dest string
|
||||||
|
CFRay string
|
||||||
|
LBProbe bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadWriteAcker is a readwriter with the ability to Acknowledge to the downstream (edge) that the origin has
|
||||||
|
// accepted the connection.
|
||||||
|
type ReadWriteAcker interface {
|
||||||
|
io.ReadWriter
|
||||||
|
AckConnection() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPResponseReadWriteAcker is an HTTP implementation of ReadWriteAcker.
|
||||||
|
type HTTPResponseReadWriteAcker struct {
|
||||||
|
r io.Reader
|
||||||
|
w ResponseWriter
|
||||||
|
req *http.Request
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHTTPResponseReadWriterAcker returns a new instance of HTTPResponseReadWriteAcker.
|
||||||
|
func NewHTTPResponseReadWriterAcker(w ResponseWriter, req *http.Request) *HTTPResponseReadWriteAcker {
|
||||||
|
return &HTTPResponseReadWriteAcker{
|
||||||
|
r: req.Body,
|
||||||
|
w: w,
|
||||||
|
req: req,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HTTPResponseReadWriteAcker) Read(p []byte) (int, error) {
|
||||||
|
return h.r.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HTTPResponseReadWriteAcker) Write(p []byte) (int, error) {
|
||||||
|
return h.w.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AckConnection acks an HTTP connection by sending a switch protocols status code that enables the caller to
|
||||||
|
// upgrade to streams.
|
||||||
|
func (h *HTTPResponseReadWriteAcker) AckConnection() error {
|
||||||
|
resp := &http.Response{
|
||||||
|
Status: switchingProtocolText,
|
||||||
|
StatusCode: http.StatusSwitchingProtocols,
|
||||||
|
ContentLength: -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if secWebsocketKey := h.req.Header.Get("Sec-WebSocket-Key"); secWebsocketKey != "" {
|
||||||
|
resp.Header = websocket.NewResponseHeader(h.req)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.w.WriteRespHeaders(resp.StatusCode, resp.Header)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponseWriter interface {
|
type ResponseWriter interface {
|
||||||
|
@ -112,3 +174,11 @@ func IsServerSentEvent(headers http.Header) bool {
|
||||||
func uint8ToString(input uint8) string {
|
func uint8ToString(input uint8) string {
|
||||||
return strconv.FormatUint(uint64(input), 10)
|
return strconv.FormatUint(uint64(input), 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func FindCfRayHeader(req *http.Request) string {
|
||||||
|
return req.Header.Get("Cf-Ray")
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsLBProbeRequest(req *http.Request) bool {
|
||||||
|
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package connection
|
package connection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -11,6 +12,8 @@ import (
|
||||||
"github.com/gobwas/ws/wsutil"
|
"github.com/gobwas/ws/wsutil"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -18,6 +21,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil)
|
||||||
testConfig = &Config{
|
testConfig = &Config{
|
||||||
OriginProxy: &mockOriginProxy{},
|
OriginProxy: &mockOriginProxy{},
|
||||||
GracePeriod: time.Millisecond * 100,
|
GracePeriod: time.Millisecond * 100,
|
||||||
|
@ -38,14 +42,17 @@ type testRequest struct {
|
||||||
isProxyError bool
|
isProxyError bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockOriginProxy struct {
|
type mockOriginProxy struct{}
|
||||||
}
|
|
||||||
|
|
||||||
func (moc *mockOriginProxy) Proxy(w ResponseWriter, r *http.Request, sourceConnectionType Type) error {
|
func (moc *mockOriginProxy) ProxyHTTP(
|
||||||
if sourceConnectionType == TypeWebsocket {
|
w ResponseWriter,
|
||||||
return wsEndpoint(w, r)
|
req *http.Request,
|
||||||
|
isWebsocket bool,
|
||||||
|
) error {
|
||||||
|
if isWebsocket {
|
||||||
|
return wsEndpoint(w, req)
|
||||||
}
|
}
|
||||||
switch r.URL.Path {
|
switch req.URL.Path {
|
||||||
case "/ok":
|
case "/ok":
|
||||||
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
|
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
|
||||||
case "/large_file":
|
case "/large_file":
|
||||||
|
@ -60,6 +67,15 @@ func (moc *mockOriginProxy) Proxy(w ResponseWriter, r *http.Request, sourceConne
|
||||||
originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
|
originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (moc *mockOriginProxy) ProxyTCP(
|
||||||
|
ctx context.Context,
|
||||||
|
rwa ReadWriteAcker,
|
||||||
|
r *TCPRequest,
|
||||||
|
) error {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type nowriter struct {
|
type nowriter struct {
|
||||||
|
|
|
@ -33,6 +33,8 @@ type h2muxConnection struct {
|
||||||
gracefulShutdownC <-chan struct{}
|
gracefulShutdownC <-chan struct{}
|
||||||
stoppedGracefully bool
|
stoppedGracefully bool
|
||||||
|
|
||||||
|
log *zerolog.Logger
|
||||||
|
|
||||||
// newRPCClientFunc allows us to mock RPCs during testing
|
// newRPCClientFunc allows us to mock RPCs during testing
|
||||||
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
|
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
|
||||||
}
|
}
|
||||||
|
@ -222,12 +224,11 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
sourceConnectionType = TypeWebsocket
|
sourceConnectionType = TypeWebsocket
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.config.OriginProxy.Proxy(respWriter, req, sourceConnectionType)
|
err := h.config.OriginProxy.ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respWriter.WriteErrorResponse()
|
respWriter.WriteErrorResponse()
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
|
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
|
|
||||||
|
@ -26,7 +27,9 @@ const (
|
||||||
|
|
||||||
var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
|
var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
|
||||||
|
|
||||||
type http2Connection struct {
|
// HTTP2Connection represents a net.Conn that uses HTTP2 frames to proxy traffic from the edge to cloudflared on the
|
||||||
|
// origin.
|
||||||
|
type HTTP2Connection struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
server *http2.Server
|
server *http2.Server
|
||||||
config *Config
|
config *Config
|
||||||
|
@ -38,6 +41,7 @@ type http2Connection struct {
|
||||||
// newRPCClientFunc allows us to mock RPCs during testing
|
// newRPCClientFunc allows us to mock RPCs during testing
|
||||||
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
|
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
|
||||||
|
|
||||||
|
log *zerolog.Logger
|
||||||
activeRequestsWG sync.WaitGroup
|
activeRequestsWG sync.WaitGroup
|
||||||
connectedFuse ConnectedFuse
|
connectedFuse ConnectedFuse
|
||||||
gracefulShutdownC <-chan struct{}
|
gracefulShutdownC <-chan struct{}
|
||||||
|
@ -45,6 +49,7 @@ type http2Connection struct {
|
||||||
controlStreamErr error // result of running control stream handler
|
controlStreamErr error // result of running control stream handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewHTTP2Connection returns a new instance of HTTP2Connection.
|
||||||
func NewHTTP2Connection(
|
func NewHTTP2Connection(
|
||||||
conn net.Conn,
|
conn net.Conn,
|
||||||
config *Config,
|
config *Config,
|
||||||
|
@ -53,9 +58,10 @@ func NewHTTP2Connection(
|
||||||
observer *Observer,
|
observer *Observer,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
connectedFuse ConnectedFuse,
|
connectedFuse ConnectedFuse,
|
||||||
|
log *zerolog.Logger,
|
||||||
gracefulShutdownC <-chan struct{},
|
gracefulShutdownC <-chan struct{},
|
||||||
) *http2Connection {
|
) *HTTP2Connection {
|
||||||
return &http2Connection{
|
return &HTTP2Connection{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
server: &http2.Server{
|
server: &http2.Server{
|
||||||
MaxConcurrentStreams: math.MaxUint32,
|
MaxConcurrentStreams: math.MaxUint32,
|
||||||
|
@ -68,11 +74,13 @@ func NewHTTP2Connection(
|
||||||
connIndex: connIndex,
|
connIndex: connIndex,
|
||||||
newRPCClientFunc: newRegistrationRPCClient,
|
newRPCClientFunc: newRegistrationRPCClient,
|
||||||
connectedFuse: connectedFuse,
|
connectedFuse: connectedFuse,
|
||||||
|
log: log,
|
||||||
gracefulShutdownC: gracefulShutdownC,
|
gracefulShutdownC: gracefulShutdownC,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *http2Connection) Serve(ctx context.Context) error {
|
// Serve serves an HTTP2 server that the edge can talk to.
|
||||||
|
func (c *HTTP2Connection) Serve(ctx context.Context) error {
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
c.close()
|
c.close()
|
||||||
|
@ -93,7 +101,7 @@ func (c *http2Connection) Serve(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
c.activeRequestsWG.Add(1)
|
c.activeRequestsWG.Add(1)
|
||||||
defer c.activeRequestsWG.Done()
|
defer c.activeRequestsWG.Done()
|
||||||
|
|
||||||
|
@ -106,23 +114,47 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var proxyErr error
|
|
||||||
switch connType {
|
switch connType {
|
||||||
case TypeControlStream:
|
case TypeControlStream:
|
||||||
proxyErr = c.serveControlStream(r.Context(), respWriter)
|
if err := c.serveControlStream(r.Context(), respWriter); err != nil {
|
||||||
c.controlStreamErr = proxyErr
|
c.controlStreamErr = err
|
||||||
case TypeWebsocket:
|
c.log.Error().Err(err)
|
||||||
stripWebsocketUpgradeHeader(r)
|
respWriter.WriteErrorResponse()
|
||||||
proxyErr = c.config.OriginProxy.Proxy(respWriter, r, TypeWebsocket)
|
|
||||||
default:
|
|
||||||
proxyErr = c.config.OriginProxy.Proxy(respWriter, r, connType)
|
|
||||||
}
|
}
|
||||||
if proxyErr != nil {
|
|
||||||
|
case TypeWebsocket, TypeHTTP:
|
||||||
|
stripWebsocketUpgradeHeader(r)
|
||||||
|
if err := c.config.OriginProxy.ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil {
|
||||||
|
err := fmt.Errorf("Failed to proxy HTTP: %w", err)
|
||||||
|
c.log.Error().Err(err)
|
||||||
|
respWriter.WriteErrorResponse()
|
||||||
|
}
|
||||||
|
|
||||||
|
case TypeTCP:
|
||||||
|
host, err := getRequestHost(r)
|
||||||
|
if err != nil {
|
||||||
|
err := fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %w`, err)
|
||||||
|
c.log.Error().Err(err)
|
||||||
|
respWriter.WriteErrorResponse()
|
||||||
|
}
|
||||||
|
|
||||||
|
rws := NewHTTPResponseReadWriterAcker(respWriter, r)
|
||||||
|
if err := c.config.OriginProxy.ProxyTCP(r.Context(), rws, &TCPRequest{
|
||||||
|
Dest: host,
|
||||||
|
CFRay: FindCfRayHeader(r),
|
||||||
|
LBProbe: IsLBProbeRequest(r),
|
||||||
|
}); err != nil {
|
||||||
|
respWriter.WriteErrorResponse()
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
err := fmt.Errorf("Received unknown connection type: %s", connType)
|
||||||
|
c.log.Error().Err(err)
|
||||||
respWriter.WriteErrorResponse()
|
respWriter.WriteErrorResponse()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
|
func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
|
||||||
rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log)
|
rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log)
|
||||||
defer rpcClient.Close()
|
defer rpcClient.Close()
|
||||||
|
|
||||||
|
@ -145,7 +177,7 @@ func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *ht
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *http2Connection) close() {
|
func (c *HTTP2Connection) close() {
|
||||||
// Wait for all serve HTTP handlers to return
|
// Wait for all serve HTTP handlers to return
|
||||||
c.activeRequestsWG.Wait()
|
c.activeRequestsWG.Wait()
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
|
@ -287,3 +319,14 @@ func IsTCPStream(r *http.Request) bool {
|
||||||
func stripWebsocketUpgradeHeader(r *http.Request) {
|
func stripWebsocketUpgradeHeader(r *http.Request) {
|
||||||
r.Header.Del(InternalUpgradeHeader)
|
r.Header.Del(InternalUpgradeHeader)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getRequestHost returns the host of the http.Request.
|
||||||
|
func getRequestHost(r *http.Request) (string, error) {
|
||||||
|
if r.Host != "" {
|
||||||
|
return r.Host, nil
|
||||||
|
}
|
||||||
|
if r.URL != nil {
|
||||||
|
return r.URL.Host, nil
|
||||||
|
}
|
||||||
|
return "", errors.New("host not set in incoming request")
|
||||||
|
}
|
||||||
|
|
|
@ -26,9 +26,10 @@ var (
|
||||||
testTransport = http2.Transport{}
|
testTransport = http2.Transport{}
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestHTTP2Connection() (*http2Connection, net.Conn) {
|
func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
|
||||||
edgeConn, originConn := net.Pipe()
|
edgeConn, originConn := net.Pipe()
|
||||||
var connIndex = uint8(0)
|
var connIndex = uint8(0)
|
||||||
|
log := zerolog.Nop()
|
||||||
return NewHTTP2Connection(
|
return NewHTTP2Connection(
|
||||||
originConn,
|
originConn,
|
||||||
testConfig,
|
testConfig,
|
||||||
|
@ -37,6 +38,7 @@ func newTestHTTP2Connection() (*http2Connection, net.Conn) {
|
||||||
NewObserver(&log, &log, false),
|
NewObserver(&log, &log, false),
|
||||||
connIndex,
|
connIndex,
|
||||||
mockConnectedFuse{},
|
mockConnectedFuse{},
|
||||||
|
&log,
|
||||||
nil,
|
nil,
|
||||||
), edgeConn
|
), edgeConn
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package ingress
|
package ingress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
@ -9,7 +8,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
|
||||||
errUnsupportedConnectionType = errors.New("internal error: unsupported connection type")
|
errUnsupportedConnectionType = errors.New("internal error: unsupported connection type")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
169
origin/proxy.go
169
origin/proxy.go
|
@ -7,7 +7,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
@ -20,13 +19,15 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// TagHeaderNamePrefix indicates a Cloudflared Warp Tag prefix that gets appended for warp traffic stream headers.
|
||||||
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
||||||
LogFieldCFRay = "cfRay"
|
LogFieldCFRay = "cfRay"
|
||||||
LogFieldRule = "ingressRule"
|
LogFieldRule = "ingressRule"
|
||||||
LogFieldOriginService = "originService"
|
LogFieldOriginService = "originService"
|
||||||
)
|
)
|
||||||
|
|
||||||
type proxy struct {
|
// Proxy represents a means to Proxy between cloudflared and the origin services.
|
||||||
|
type Proxy struct {
|
||||||
ingressRules ingress.Ingress
|
ingressRules ingress.Ingress
|
||||||
warpRouting *ingress.WarpRoutingService
|
warpRouting *ingress.WarpRoutingService
|
||||||
tags []tunnelpogs.Tag
|
tags []tunnelpogs.Tag
|
||||||
|
@ -34,15 +35,14 @@ type proxy struct {
|
||||||
bufferPool *bufferPool
|
bufferPool *bufferPool
|
||||||
}
|
}
|
||||||
|
|
||||||
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
// NewOriginProxy returns a new instance of the Proxy struct.
|
||||||
|
|
||||||
func NewOriginProxy(
|
func NewOriginProxy(
|
||||||
ingressRules ingress.Ingress,
|
ingressRules ingress.Ingress,
|
||||||
warpRouting *ingress.WarpRoutingService,
|
warpRouting *ingress.WarpRoutingService,
|
||||||
tags []tunnelpogs.Tag,
|
tags []tunnelpogs.Tag,
|
||||||
log *zerolog.Logger) connection.OriginProxy {
|
log *zerolog.Logger,
|
||||||
|
) *Proxy {
|
||||||
return &proxy{
|
return &Proxy{
|
||||||
ingressRules: ingressRules,
|
ingressRules: ingressRules,
|
||||||
warpRouting: warpRouting,
|
warpRouting: warpRouting,
|
||||||
tags: tags,
|
tags: tags,
|
||||||
|
@ -51,41 +51,18 @@ func NewOriginProxy(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Caller is responsible for writing any error to ResponseWriter
|
// ProxyHTTP further depends on ingress rules to establish a connection with the origin service. This may be
|
||||||
func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConnectionType connection.Type) error {
|
// a simple roundtrip or a tcp/websocket dial depending on ingres rule setup.
|
||||||
|
func (p *Proxy) ProxyHTTP(
|
||||||
|
w connection.ResponseWriter,
|
||||||
|
req *http.Request,
|
||||||
|
isWebsocket bool,
|
||||||
|
) error {
|
||||||
incrementRequests()
|
incrementRequests()
|
||||||
defer decrementConcurrentRequests()
|
defer decrementConcurrentRequests()
|
||||||
|
|
||||||
cfRay := findCfRayHeader(req)
|
cfRay := connection.FindCfRayHeader(req)
|
||||||
lbProbe := isLBProbeRequest(req)
|
lbProbe := connection.IsLBProbeRequest(req)
|
||||||
|
|
||||||
serveCtx, cancel := context.WithCancel(req.Context())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
p.appendTagHeaders(req)
|
|
||||||
if sourceConnectionType == connection.TypeTCP {
|
|
||||||
if p.warpRouting == nil {
|
|
||||||
err := errors.New(`cloudflared received a request from WARP client, but your configuration has disabled ingress from WARP clients. To enable this, set "warp-routing:\n\t enabled: true" in your config.yaml`)
|
|
||||||
p.log.Error().Msg(err.Error())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
logFields := logFields{
|
|
||||||
cfRay: cfRay,
|
|
||||||
lbProbe: lbProbe,
|
|
||||||
rule: ingress.ServiceWarpRouting,
|
|
||||||
}
|
|
||||||
|
|
||||||
host, err := getRequestHost(req)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %v`, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := p.proxyStreamRequest(serveCtx, w, host, req, p.warpRouting.Proxy, logFields); err != nil {
|
|
||||||
p.logRequestError(err, cfRay, "", ingress.ServiceWarpRouting)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
|
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
|
||||||
logFields := logFields{
|
logFields := logFields{
|
||||||
|
@ -97,8 +74,14 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
|
|
||||||
switch originProxy := rule.Service.(type) {
|
switch originProxy := rule.Service.(type) {
|
||||||
case ingress.HTTPOriginProxy:
|
case ingress.HTTPOriginProxy:
|
||||||
if err := p.proxyHTTPRequest(w, req, originProxy, sourceConnectionType == connection.TypeWebsocket,
|
if err := p.proxyHTTPRequest(
|
||||||
rule.Config.DisableChunkedEncoding, logFields); err != nil {
|
w,
|
||||||
|
req,
|
||||||
|
originProxy,
|
||||||
|
isWebsocket,
|
||||||
|
rule.Config.DisableChunkedEncoding,
|
||||||
|
logFields,
|
||||||
|
); err != nil {
|
||||||
rule, srv := ruleField(p.ingressRules, ruleNum)
|
rule, srv := ruleField(p.ingressRules, ruleNum)
|
||||||
p.logRequestError(err, cfRay, rule, srv)
|
p.logRequestError(err, cfRay, rule, srv)
|
||||||
return err
|
return err
|
||||||
|
@ -110,7 +93,9 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := p.proxyStreamRequest(serveCtx, w, dest, req, originProxy, logFields); err != nil {
|
|
||||||
|
rws := connection.NewHTTPResponseReadWriterAcker(w, req)
|
||||||
|
if err := p.proxyStream(req.Context(), rws, dest, originProxy, logFields); err != nil {
|
||||||
rule, srv := ruleField(p.ingressRules, ruleNum)
|
rule, srv := ruleField(p.ingressRules, ruleNum)
|
||||||
p.logRequestError(err, cfRay, rule, srv)
|
p.logRequestError(err, cfRay, rule, srv)
|
||||||
return err
|
return err
|
||||||
|
@ -121,24 +106,36 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
|
// ProxyTCP proxies to a TCP connection between the origin service and cloudflared.
|
||||||
switch rule.Service.String() {
|
func (p *Proxy) ProxyTCP(
|
||||||
case ingress.ServiceBastion:
|
ctx context.Context,
|
||||||
return carrier.ResolveBastionDest(req)
|
rwa connection.ReadWriteAcker,
|
||||||
default:
|
req *connection.TCPRequest,
|
||||||
return rule.Service.String(), nil
|
) error {
|
||||||
}
|
incrementRequests()
|
||||||
|
defer decrementConcurrentRequests()
|
||||||
|
|
||||||
|
if p.warpRouting == nil {
|
||||||
|
err := errors.New(`cloudflared received a request from WARP client, but your configuration has disabled ingress from WARP clients. To enable this, set "warp-routing:\n\t enabled: true" in your config.yaml`)
|
||||||
|
p.log.Error().Msg(err.Error())
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRequestHost returns the host of the http.Request.
|
serveCtx, cancel := context.WithCancel(ctx)
|
||||||
func getRequestHost(r *http.Request) (string, error) {
|
defer cancel()
|
||||||
if r.Host != "" {
|
|
||||||
return r.Host, nil
|
logFields := logFields{
|
||||||
|
cfRay: req.CFRay,
|
||||||
|
lbProbe: req.LBProbe,
|
||||||
|
rule: ingress.ServiceWarpRouting,
|
||||||
}
|
}
|
||||||
if r.URL != nil {
|
|
||||||
return r.URL.Host, nil
|
if err := p.proxyStream(serveCtx, rwa, req.Dest, p.warpRouting.Proxy, logFields); err != nil {
|
||||||
|
p.logRequestError(err, req.CFRay, "", ingress.ServiceWarpRouting)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return "", errors.New("host not set in incoming request")
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
|
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
|
||||||
|
@ -149,13 +146,15 @@ func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
|
||||||
return fmt.Sprintf("%d", ruleNum), srv
|
return fmt.Sprintf("%d", ruleNum), srv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *proxy) proxyHTTPRequest(
|
// ProxyHTTPRequest proxies requests of underlying type http and websocket to the origin service.
|
||||||
|
func (p *Proxy) proxyHTTPRequest(
|
||||||
w connection.ResponseWriter,
|
w connection.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
httpService ingress.HTTPOriginProxy,
|
httpService ingress.HTTPOriginProxy,
|
||||||
isWebsocket bool,
|
isWebsocket bool,
|
||||||
disableChunkedEncoding bool,
|
disableChunkedEncoding bool,
|
||||||
fields logFields) error {
|
fields logFields,
|
||||||
|
) error {
|
||||||
roundTripReq := req
|
roundTripReq := req
|
||||||
if isWebsocket {
|
if isWebsocket {
|
||||||
roundTripReq = req.Clone(req.Context())
|
roundTripReq = req.Clone(req.Context())
|
||||||
|
@ -214,17 +213,17 @@ func (p *proxy) proxyHTTPRequest(
|
||||||
defer p.bufferPool.Put(buf)
|
defer p.bufferPool.Put(buf)
|
||||||
_, _ = io.CopyBuffer(w, resp.Body, buf)
|
_, _ = io.CopyBuffer(w, resp.Body, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.logOriginResponse(resp, fields)
|
p.logOriginResponse(resp, fields)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// proxyStreamRequest first establish a connection with origin, then it writes the status code and headers, and finally it streams data between
|
// proxyStream proxies type TCP and other underlying types if the connection is defined as a stream oriented
|
||||||
// eyeball and origin.
|
// ingress rule.
|
||||||
func (p *proxy) proxyStreamRequest(
|
func (p *Proxy) proxyStream(
|
||||||
serveCtx context.Context,
|
ctx context.Context,
|
||||||
w connection.ResponseWriter,
|
rwa connection.ReadWriteAcker,
|
||||||
dest string,
|
dest string,
|
||||||
req *http.Request,
|
|
||||||
connectionProxy ingress.StreamBasedOriginProxy,
|
connectionProxy ingress.StreamBasedOriginProxy,
|
||||||
fields logFields,
|
fields logFields,
|
||||||
) error {
|
) error {
|
||||||
|
@ -233,21 +232,11 @@ func (p *proxy) proxyStreamRequest(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := &http.Response{
|
if err := rwa.AckConnection(); err != nil {
|
||||||
Status: switchingProtocolText,
|
|
||||||
StatusCode: http.StatusSwitchingProtocols,
|
|
||||||
ContentLength: -1,
|
|
||||||
}
|
|
||||||
|
|
||||||
if secWebsocketKey := req.Header.Get("Sec-WebSocket-Key"); secWebsocketKey != "" {
|
|
||||||
resp.Header = websocket.NewResponseHeader(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
streamCtx, cancel := context.WithCancel(serveCtx)
|
streamCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -256,12 +245,7 @@ func (p *proxy) proxyStreamRequest(
|
||||||
originConn.Close()
|
originConn.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
eyeballStream := &bidirectionalStream{
|
originConn.Stream(ctx, rwa, p.log)
|
||||||
writer: w,
|
|
||||||
reader: req.Body,
|
|
||||||
}
|
|
||||||
originConn.Stream(serveCtx, eyeballStream, p.log)
|
|
||||||
p.logOriginResponse(resp, fields)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -278,7 +262,7 @@ func (wr *bidirectionalStream) Write(p []byte) (n int, err error) {
|
||||||
return wr.writer.Write(p)
|
return wr.writer.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
|
func (p *Proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
|
||||||
reader := bufio.NewReader(respBody)
|
reader := bufio.NewReader(respBody)
|
||||||
for {
|
for {
|
||||||
line, err := reader.ReadBytes('\n')
|
line, err := reader.ReadBytes('\n')
|
||||||
|
@ -289,7 +273,7 @@ func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *proxy) appendTagHeaders(r *http.Request) {
|
func (p *Proxy) appendTagHeaders(r *http.Request) {
|
||||||
for _, tag := range p.tags {
|
for _, tag := range p.tags {
|
||||||
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
|
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
|
||||||
}
|
}
|
||||||
|
@ -301,7 +285,7 @@ type logFields struct {
|
||||||
rule interface{}
|
rule interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *proxy) logRequest(r *http.Request, fields logFields) {
|
func (p *Proxy) logRequest(r *http.Request, fields logFields) {
|
||||||
if fields.cfRay != "" {
|
if fields.cfRay != "" {
|
||||||
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto)
|
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto)
|
||||||
} else if fields.lbProbe {
|
} else if fields.lbProbe {
|
||||||
|
@ -324,7 +308,7 @@ func (p *proxy) logRequest(r *http.Request, fields logFields) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *proxy) logOriginResponse(resp *http.Response, fields logFields) {
|
func (p *Proxy) logOriginResponse(resp *http.Response, fields logFields) {
|
||||||
responseByCode.WithLabelValues(strconv.Itoa(resp.StatusCode)).Inc()
|
responseByCode.WithLabelValues(strconv.Itoa(resp.StatusCode)).Inc()
|
||||||
if fields.cfRay != "" {
|
if fields.cfRay != "" {
|
||||||
p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", fields.cfRay, resp.Status, fields.rule)
|
p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", fields.cfRay, resp.Status, fields.rule)
|
||||||
|
@ -342,7 +326,7 @@ func (p *proxy) logOriginResponse(resp *http.Response, fields logFields) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *proxy) logRequestError(err error, cfRay string, rule, service string) {
|
func (p *Proxy) logRequestError(err error, cfRay string, rule, service string) {
|
||||||
requestErrors.Inc()
|
requestErrors.Inc()
|
||||||
log := p.log.Error().Err(err)
|
log := p.log.Error().Err(err)
|
||||||
if cfRay != "" {
|
if cfRay != "" {
|
||||||
|
@ -357,10 +341,11 @@ func (p *proxy) logRequestError(err error, cfRay string, rule, service string) {
|
||||||
log.Msg("")
|
log.Msg("")
|
||||||
}
|
}
|
||||||
|
|
||||||
func findCfRayHeader(req *http.Request) string {
|
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
|
||||||
return req.Header.Get("Cf-Ray")
|
switch rule.Service.String() {
|
||||||
|
case ingress.ServiceBastion:
|
||||||
|
return carrier.ResolveBastionDest(req)
|
||||||
|
default:
|
||||||
|
return rule.Service.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isLBProbeRequest(req *http.Request) bool {
|
|
||||||
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,6 +46,10 @@ func newMockHTTPRespWriter() *mockHTTPRespWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *mockHTTPRespWriter) WriteResponse() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
w.WriteHeader(status)
|
w.WriteHeader(status)
|
||||||
for header, val := range header {
|
for header, val := range header {
|
||||||
|
@ -146,7 +150,7 @@ func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
|
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = proxy.Proxy(responseWriter, req, connection.TypeHTTP)
|
err = proxy.ProxyHTTP(responseWriter, req, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, responseWriter.Code)
|
assert.Equal(t, http.StatusOK, responseWriter.Code)
|
||||||
|
@ -170,7 +174,7 @@ func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
|
|
||||||
errGroup, ctx := errgroup.WithContext(ctx)
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
err = proxy.Proxy(responseWriter, req, connection.TypeWebsocket)
|
err = proxy.ProxyHTTP(responseWriter, req, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code)
|
require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code)
|
||||||
|
@ -231,7 +235,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err = proxy.Proxy(responseWriter, req, connection.TypeHTTP)
|
err = proxy.ProxyHTTP(responseWriter, req, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, http.StatusOK, responseWriter.Code)
|
require.Equal(t, http.StatusOK, responseWriter.Code)
|
||||||
|
@ -330,7 +334,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
|
||||||
req, err := http.NewRequest(http.MethodGet, test.url, nil)
|
req, err := http.NewRequest(http.MethodGet, test.url, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = proxy.Proxy(responseWriter, req, connection.TypeHTTP)
|
err = proxy.ProxyHTTP(responseWriter, req, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, test.expectedStatus, responseWriter.Code)
|
assert.Equal(t, test.expectedStatus, responseWriter.Code)
|
||||||
|
@ -358,7 +362,7 @@ func (errorOriginTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyError(t *testing.T) {
|
func TestProxyError(t *testing.T) {
|
||||||
ingress := ingress.Ingress{
|
ing := ingress.Ingress{
|
||||||
Rules: []ingress.Rule{
|
Rules: []ingress.Rule{
|
||||||
{
|
{
|
||||||
Hostname: "*",
|
Hostname: "*",
|
||||||
|
@ -372,13 +376,13 @@ func TestProxyError(t *testing.T) {
|
||||||
|
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
|
|
||||||
proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log)
|
proxy := NewOriginProxy(ing, unusedWarpRoutingService, testTags, &log)
|
||||||
|
|
||||||
responseWriter := newMockHTTPRespWriter()
|
responseWriter := newMockHTTPRespWriter()
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Error(t, proxy.Proxy(responseWriter, req, connection.TypeHTTP))
|
assert.Error(t, proxy.ProxyHTTP(responseWriter, req, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
type replayer struct {
|
type replayer struct {
|
||||||
|
@ -617,6 +621,7 @@ func TestConnections(t *testing.T) {
|
||||||
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
|
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
|
||||||
proxy := NewOriginProxy(ingressRule, test.args.warpRoutingService, testTags, logger)
|
proxy := NewOriginProxy(ingressRule, test.args.warpRoutingService, testTags, logger)
|
||||||
|
|
||||||
|
dest := ln.Addr().String()
|
||||||
req, err := http.NewRequest(
|
req, err := http.NewRequest(
|
||||||
http.MethodGet,
|
http.MethodGet,
|
||||||
test.args.ingressServiceScheme+ln.Addr().String(),
|
test.args.ingressServiceScheme+ln.Addr().String(),
|
||||||
|
@ -634,8 +639,12 @@ func TestConnections(t *testing.T) {
|
||||||
replayer.Write(resp)
|
replayer.Write(resp)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
if test.args.connectionType == connection.TypeTCP {
|
||||||
err = proxy.Proxy(respWriter, req, test.args.connectionType)
|
rws := connection.NewHTTPResponseReadWriterAcker(respWriter, req)
|
||||||
|
err = proxy.ProxyTCP(ctx, rws, &connection.TCPRequest{Dest: dest})
|
||||||
|
} else {
|
||||||
|
err = proxy.ProxyHTTP(respWriter, req, test.args.connectionType == connection.TypeWebsocket)
|
||||||
|
}
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
assert.Equal(t, test.want.err, err != nil)
|
assert.Equal(t, test.want.err, err != nil)
|
||||||
|
@ -829,6 +838,10 @@ func newTCPRespWriter(w io.Writer) *mockTCPRespWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockTCPRespWriter) Read(p []byte) (n int, err error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
|
func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
|
||||||
return m.w.Write(p)
|
return m.w.Write(p)
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,6 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dialTimeout = 15 * time.Second
|
dialTimeout = 15 * time.Second
|
||||||
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
|
||||||
FeatureSerializedHeaders = "serialized_headers"
|
FeatureSerializedHeaders = "serialized_headers"
|
||||||
FeatureQuickReconnects = "quick_reconnects"
|
FeatureQuickReconnects = "quick_reconnects"
|
||||||
)
|
)
|
||||||
|
@ -417,6 +416,7 @@ func ServeHTTP2(
|
||||||
config.Observer,
|
config.Observer,
|
||||||
connIndex,
|
connIndex,
|
||||||
connectedFuse,
|
connectedFuse,
|
||||||
|
config.Log,
|
||||||
gracefulShutdownC,
|
gracefulShutdownC,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue