TUN-3615: added support to proxy tcp streams
added ingress.DefaultStreamHandler and a basic test for tcp stream proxy moved websocket.Stream to ingress cloudflared no longer picks tcpstream host from header
This commit is contained in:
parent
e2262085e5
commit
368066a966
|
@ -8,6 +8,7 @@ import (
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/token"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/token"
|
||||||
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
"github.com/cloudflare/cloudflared/socks"
|
"github.com/cloudflare/cloudflared/socks"
|
||||||
cfwebsocket "github.com/cloudflare/cloudflared/websocket"
|
cfwebsocket "github.com/cloudflare/cloudflared/websocket"
|
||||||
|
|
||||||
|
@ -61,7 +62,7 @@ func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) erro
|
||||||
|
|
||||||
_ = socksServer.Serve(conn)
|
_ = socksServer.Serve(conn)
|
||||||
} else {
|
} else {
|
||||||
cfwebsocket.Stream(wsConn, conn)
|
ingress.Stream(wsConn, conn)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -69,7 +70,7 @@ func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) erro
|
||||||
// StartServer creates a Websocket server to listen for connections.
|
// StartServer creates a Websocket server to listen for connections.
|
||||||
// This is used on the origin (tunnel) side to take data from the muxer and send it to the origin
|
// This is used on the origin (tunnel) side to take data from the muxer and send it to the origin
|
||||||
func (ws *Websocket) StartServer(listener net.Listener, remote string, shutdownC <-chan struct{}) error {
|
func (ws *Websocket) StartServer(listener net.Listener, remote string, shutdownC <-chan struct{}) error {
|
||||||
return cfwebsocket.StartProxyServer(ws.log, listener, remote, shutdownC, cfwebsocket.DefaultStreamHandler)
|
return cfwebsocket.StartProxyServer(ws.log, listener, remote, shutdownC, ingress.DefaultStreamHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createWebsocketStream will create a WebSocket connection to stream data over
|
// createWebsocketStream will create a WebSocket connection to stream data over
|
||||||
|
|
|
@ -50,8 +50,17 @@ func (c *ClassicTunnelConfig) IsTrialZone() bool {
|
||||||
return c.Hostname == ""
|
return c.Hostname == ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Type indicates the connection type of the connection.
|
||||||
|
type Type int
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeWebsocket Type = iota
|
||||||
|
TypeTCP
|
||||||
|
TypeHTTP
|
||||||
|
)
|
||||||
|
|
||||||
type OriginProxy interface {
|
type OriginProxy interface {
|
||||||
Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error
|
Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponseWriter interface {
|
type ResponseWriter interface {
|
||||||
|
|
|
@ -41,8 +41,8 @@ type testRequest struct {
|
||||||
type mockOriginProxy struct {
|
type mockOriginProxy struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (moc *mockOriginProxy) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error {
|
func (moc *mockOriginProxy) Proxy(w ResponseWriter, r *http.Request, sourceConnectionType Type) error {
|
||||||
if isWebsocket {
|
if sourceConnectionType == TypeWebsocket {
|
||||||
return wsEndpoint(w, r)
|
return wsEndpoint(w, r)
|
||||||
}
|
}
|
||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
|
|
|
@ -216,7 +216,12 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
return reqErr
|
return reqErr
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.config.OriginProxy.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
|
var sourceConnectionType = TypeHTTP
|
||||||
|
if websocket.IsWebSocketUpgrade(req) {
|
||||||
|
sourceConnectionType = TypeWebsocket
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.config.OriginProxy.Proxy(respWriter, req, sourceConnectionType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respWriter.WriteErrorResponse()
|
respWriter.WriteErrorResponse()
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade"
|
internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade"
|
||||||
|
tcpStreamHeader = "Cf-Cloudflared-Proxy-Src"
|
||||||
websocketUpgrade = "websocket"
|
websocketUpgrade = "websocket"
|
||||||
controlStreamUpgrade = "control-stream"
|
controlStreamUpgrade = "control-stream"
|
||||||
)
|
)
|
||||||
|
@ -107,21 +108,33 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
respWriter.flusher = flusher
|
respWriter.flusher = flusher
|
||||||
var err error
|
|
||||||
if isControlStreamUpgrade(r) {
|
switch {
|
||||||
|
case isControlStreamUpgrade(r):
|
||||||
respWriter.shouldFlush = true
|
respWriter.shouldFlush = true
|
||||||
err = c.serveControlStream(r.Context(), respWriter)
|
if err := c.serveControlStream(r.Context(), respWriter); err != nil {
|
||||||
c.controlStreamErr = err
|
respWriter.WriteErrorResponse()
|
||||||
} else if isWebsocketUpgrade(r) {
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
case isWebsocketUpgrade(r):
|
||||||
respWriter.shouldFlush = true
|
respWriter.shouldFlush = true
|
||||||
stripWebsocketUpgradeHeader(r)
|
stripWebsocketUpgradeHeader(r)
|
||||||
err = c.config.OriginProxy.Proxy(respWriter, r, true)
|
if err := c.config.OriginProxy.Proxy(respWriter, r, TypeWebsocket); err != nil {
|
||||||
} else {
|
respWriter.WriteErrorResponse()
|
||||||
err = c.config.OriginProxy.Proxy(respWriter, r, false)
|
}
|
||||||
}
|
return
|
||||||
|
|
||||||
if err != nil {
|
case IsTCPStream(r):
|
||||||
respWriter.WriteErrorResponse()
|
if err := c.config.OriginProxy.Proxy(respWriter, r, TypeTCP); err != nil {
|
||||||
|
respWriter.WriteErrorResponse()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
default:
|
||||||
|
if err := c.config.OriginProxy.Proxy(respWriter, r, TypeHTTP); err != nil {
|
||||||
|
respWriter.WriteErrorResponse()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -231,11 +244,16 @@ func (rp *http2RespWriter) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func isControlStreamUpgrade(r *http.Request) bool {
|
func isControlStreamUpgrade(r *http.Request) bool {
|
||||||
return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == controlStreamUpgrade
|
return r.Header.Get(internalUpgradeHeader) == controlStreamUpgrade
|
||||||
}
|
}
|
||||||
|
|
||||||
func isWebsocketUpgrade(r *http.Request) bool {
|
func isWebsocketUpgrade(r *http.Request) bool {
|
||||||
return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == websocketUpgrade
|
return r.Header.Get(internalUpgradeHeader) == websocketUpgrade
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTCPStream discerns if the connection request needs a tcp stream proxy.
|
||||||
|
func IsTCPStream(r *http.Request) bool {
|
||||||
|
return r.Header.Get(tcpStreamHeader) != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func stripWebsocketUpgradeHeader(r *http.Request) {
|
func stripWebsocketUpgradeHeader(r *http.Request) {
|
||||||
|
|
|
@ -24,6 +24,11 @@ var (
|
||||||
ErrURLIncompatibleWithIngress = errors.New("You can't set the --url flag (or $TUNNEL_URL) when using multiple-origin ingress rules")
|
ErrURLIncompatibleWithIngress = errors.New("You can't set the --url flag (or $TUNNEL_URL) when using multiple-origin ingress rules")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ServiceBastion = "bastion"
|
||||||
|
ServiceTeamnet = "teamnet-proxy"
|
||||||
|
)
|
||||||
|
|
||||||
// FindMatchingRule returns the index of the Ingress Rule which matches the given
|
// FindMatchingRule returns the index of the Ingress Rule which matches the given
|
||||||
// hostname and path. This function assumes the last rule matches everything,
|
// hostname and path. This function assumes the last rule matches everything,
|
||||||
// which is the case if the rules were instantiated via the ingress#Validate method
|
// which is the case if the rules were instantiated via the ingress#Validate method
|
||||||
|
@ -90,7 +95,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ
|
||||||
return new(helloWorld), nil
|
return new(helloWorld), nil
|
||||||
}
|
}
|
||||||
if c.IsSet(config.BastionFlag) {
|
if c.IsSet(config.BastionFlag) {
|
||||||
return newBridgeService(), nil
|
return newBridgeService(nil), nil
|
||||||
}
|
}
|
||||||
if c.IsSet("url") {
|
if c.IsSet("url") {
|
||||||
originURL, err := config.ValidateUrl(c, allowURLFromArgs)
|
originURL, err := config.ValidateUrl(c, allowURLFromArgs)
|
||||||
|
@ -159,12 +164,14 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||||
service = &srv
|
service = &srv
|
||||||
} else if r.Service == "hello_world" || r.Service == "hello-world" || r.Service == "helloworld" {
|
} else if r.Service == "hello_world" || r.Service == "hello-world" || r.Service == "helloworld" {
|
||||||
service = new(helloWorld)
|
service = new(helloWorld)
|
||||||
} else if r.Service == "bastion" || cfg.BastionMode {
|
} else if r.Service == ServiceBastion || cfg.BastionMode {
|
||||||
// Bastion mode will always start a Websocket proxy server, which will
|
// Bastion mode will always start a Websocket proxy server, which will
|
||||||
// overwrite the localService.URL field when `start` is called. So,
|
// overwrite the localService.URL field when `start` is called. So,
|
||||||
// leave the URL field empty for now.
|
// leave the URL field empty for now.
|
||||||
cfg.BastionMode = true
|
cfg.BastionMode = true
|
||||||
service = newBridgeService()
|
service = newBridgeService(nil)
|
||||||
|
} else if r.Service == ServiceTeamnet {
|
||||||
|
service = newBridgeService(DefaultStreamHandler)
|
||||||
} else {
|
} else {
|
||||||
// Validate URL services
|
// Validate URL services
|
||||||
u, err := url.Parse(r.Service)
|
u, err := url.Parse(r.Service)
|
||||||
|
|
|
@ -315,7 +315,7 @@ ingress:
|
||||||
want: []Rule{
|
want: []Rule{
|
||||||
{
|
{
|
||||||
Hostname: "bastion.foo.com",
|
Hostname: "bastion.foo.com",
|
||||||
Service: newBridgeService(),
|
Service: newBridgeService(nil),
|
||||||
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -335,7 +335,7 @@ ingress:
|
||||||
want: []Rule{
|
want: []Rule{
|
||||||
{
|
{
|
||||||
Hostname: "bastion.foo.com",
|
Hostname: "bastion.foo.com",
|
||||||
Service: newBridgeService(),
|
Service: newBridgeService(nil),
|
||||||
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -17,10 +17,36 @@ type OriginConnection interface {
|
||||||
Close()
|
Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn)
|
||||||
|
|
||||||
|
// Stream copies copy data to & from provided io.ReadWriters.
|
||||||
|
func Stream(conn, backendConn io.ReadWriter) {
|
||||||
|
proxyDone := make(chan struct{}, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
io.Copy(conn, backendConn)
|
||||||
|
proxyDone <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
io.Copy(backendConn, conn)
|
||||||
|
proxyDone <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// If one side is done, we are done.
|
||||||
|
<-proxyDone
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultStreamHandler is an implementation of streamHandlerFunc that
|
||||||
|
// performs a two way io.Copy between originConn and remoteConn.
|
||||||
|
func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn) {
|
||||||
|
Stream(originConn, remoteConn)
|
||||||
|
}
|
||||||
|
|
||||||
// tcpConnection is an OriginConnection that directly streams to raw TCP.
|
// tcpConnection is an OriginConnection that directly streams to raw TCP.
|
||||||
type tcpConnection struct {
|
type tcpConnection struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
streamHandler func(tunnelConn io.ReadWriter, originConn net.Conn)
|
streamHandler streamHandlerFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter) {
|
func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter) {
|
||||||
|
@ -39,7 +65,7 @@ type wsConnection struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter) {
|
func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter) {
|
||||||
websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn())
|
Stream(tunnelConn, wsc.wsConn.UnderlyingConn())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wsc *wsConnection) Close() {
|
func (wsc *wsConnection) Close() {
|
||||||
|
|
|
@ -2,13 +2,14 @@ package ingress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
||||||
|
@ -63,7 +64,21 @@ func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection,
|
||||||
return o.client.connect(r, dest)
|
return o.client.connect(r, dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getRequestHost returns the host of the http.Request.
|
||||||
|
func getRequestHost(r *http.Request) (string, error) {
|
||||||
|
if r.Host != "" {
|
||||||
|
return r.Host, nil
|
||||||
|
}
|
||||||
|
if r.URL != nil {
|
||||||
|
return r.URL.Host, nil
|
||||||
|
}
|
||||||
|
return "", errors.New("host not found")
|
||||||
|
}
|
||||||
|
|
||||||
func (o *bridgeService) destination(r *http.Request) (string, error) {
|
func (o *bridgeService) destination(r *http.Request) (string, error) {
|
||||||
|
if connection.IsTCPStream(r) {
|
||||||
|
return getRequestHost(r)
|
||||||
|
}
|
||||||
jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader)
|
jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader)
|
||||||
if jumpDestination == "" {
|
if jumpDestination == "" {
|
||||||
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
|
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
|
||||||
|
@ -85,7 +100,7 @@ func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnectio
|
||||||
}
|
}
|
||||||
|
|
||||||
type tcpClient struct {
|
type tcpClient struct {
|
||||||
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn)
|
streamHandler streamHandlerFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) {
|
func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) {
|
||||||
|
|
|
@ -91,7 +91,7 @@ func TestBridgeServiceDestination(t *testing.T) {
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
s := newBridgeService()
|
s := newBridgeService(nil)
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
r := &http.Request{
|
r := &http.Request{
|
||||||
Header: test.header,
|
Header: test.header,
|
||||||
|
|
|
@ -81,9 +81,12 @@ type bridgeService struct {
|
||||||
client *tcpClient
|
client *tcpClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func newBridgeService() *bridgeService {
|
// if streamHandler is nil, a default one is set.
|
||||||
|
func newBridgeService(streamHandler streamHandlerFunc) *bridgeService {
|
||||||
return &bridgeService{
|
return &bridgeService{
|
||||||
client: &tcpClient{},
|
client: &tcpClient{
|
||||||
|
streamHandler: streamHandler,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,10 +95,15 @@ func (o *bridgeService) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||||
|
// streamHandler is already set by the constructor.
|
||||||
|
if o.client.streamHandler != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if cfg.ProxyType == socksProxy {
|
if cfg.ProxyType == socksProxy {
|
||||||
o.client.streamHandler = socks.StreamHandler
|
o.client.streamHandler = socks.StreamHandler
|
||||||
} else {
|
} else {
|
||||||
o.client.streamHandler = websocket.DefaultStreamHandler
|
o.client.streamHandler = DefaultStreamHandler
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -136,7 +144,7 @@ func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdo
|
||||||
if cfg.ProxyType == socksProxy {
|
if cfg.ProxyType == socksProxy {
|
||||||
o.client.streamHandler = socks.StreamHandler
|
o.client.streamHandler = socks.StreamHandler
|
||||||
} else {
|
} else {
|
||||||
o.client.streamHandler = websocket.DefaultStreamHandler
|
o.client.streamHandler = DefaultStreamHandler
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ func NewOriginProxy(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, log *ze
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error {
|
func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConnectionType connection.Type) error {
|
||||||
incrementRequests()
|
incrementRequests()
|
||||||
defer decrementConcurrentRequests()
|
defer decrementConcurrentRequests()
|
||||||
|
|
||||||
|
@ -49,43 +49,50 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocke
|
||||||
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
|
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
|
||||||
p.logRequest(req, cfRay, lbProbe, ruleNum)
|
p.logRequest(req, cfRay, lbProbe, ruleNum)
|
||||||
|
|
||||||
var (
|
if sourceConnectionType == connection.TypeHTTP {
|
||||||
resp *http.Response
|
resp, err := p.proxyHTTP(w, req, rule)
|
||||||
err error
|
if err != nil {
|
||||||
)
|
p.logErrorAndWriteResponse(w, err, cfRay, ruleNum)
|
||||||
|
return err
|
||||||
if isWebsocket {
|
|
||||||
go websocket.NewConn(w, p.log).Pinger(req.Context())
|
|
||||||
|
|
||||||
connClosedChan := make(chan struct{})
|
|
||||||
err = p.proxyConnection(connClosedChan, w, req, rule)
|
|
||||||
if err == nil {
|
|
||||||
respHeader := websocket.NewResponseHeader(req)
|
|
||||||
status := http.StatusSwitchingProtocols
|
|
||||||
resp = &http.Response{
|
|
||||||
Status: http.StatusText(status),
|
|
||||||
StatusCode: status,
|
|
||||||
Header: respHeader,
|
|
||||||
ContentLength: -1,
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteRespHeaders(http.StatusSwitchingProtocols, respHeader)
|
|
||||||
<-connClosedChan
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
resp, err = p.proxyHTTP(w, req, rule)
|
p.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
respHeader := http.Header{}
|
||||||
|
if sourceConnectionType == connection.TypeWebsocket {
|
||||||
|
go websocket.NewConn(w, p.log).Pinger(req.Context())
|
||||||
|
respHeader = websocket.NewResponseHeader(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
connClosedChan := make(chan struct{})
|
||||||
|
err := p.proxyConnection(connClosedChan, w, req, rule)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.logRequestError(err, cfRay, ruleNum)
|
p.logErrorAndWriteResponse(w, err, cfRay, ruleNum)
|
||||||
w.WriteErrorResponse()
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
|
status := http.StatusSwitchingProtocols
|
||||||
|
resp := &http.Response{
|
||||||
|
Status: http.StatusText(status),
|
||||||
|
StatusCode: status,
|
||||||
|
Header: respHeader,
|
||||||
|
ContentLength: -1,
|
||||||
|
}
|
||||||
|
w.WriteRespHeaders(http.StatusSwitchingProtocols, nil)
|
||||||
|
|
||||||
|
<-connClosedChan
|
||||||
|
|
||||||
|
p.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *proxy) logErrorAndWriteResponse(w connection.ResponseWriter, err error, cfRay string, ruleNum int) {
|
||||||
|
p.logRequestError(err, cfRay, ruleNum)
|
||||||
|
w.WriteErrorResponse()
|
||||||
|
}
|
||||||
|
|
||||||
func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
|
func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
|
||||||
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||||
if rule.Config.DisableChunkedEncoding {
|
if rule.Config.DisableChunkedEncoding {
|
||||||
|
|
|
@ -143,7 +143,7 @@ func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
|
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = proxy.Proxy(respWriter, req, false)
|
err = proxy.Proxy(respWriter, req, connection.TypeHTTP)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, respWriter.Code)
|
assert.Equal(t, http.StatusOK, respWriter.Code)
|
||||||
|
@ -163,7 +163,7 @@ func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *test
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err = proxy.Proxy(respWriter, req, true)
|
err = proxy.Proxy(respWriter, req, connection.TypeWebsocket)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
|
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
|
||||||
|
@ -205,7 +205,7 @@ func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T)
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err = proxy.Proxy(respWriter, req, false)
|
err = proxy.Proxy(respWriter, req, connection.TypeHTTP)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, http.StatusOK, respWriter.Code)
|
require.Equal(t, http.StatusOK, respWriter.Code)
|
||||||
|
@ -298,7 +298,7 @@ func TestProxyMultipleOrigins(t *testing.T) {
|
||||||
req, err := http.NewRequest(http.MethodGet, test.url, nil)
|
req, err := http.NewRequest(http.MethodGet, test.url, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = proxy.Proxy(respWriter, req, false)
|
err = proxy.Proxy(respWriter, req, connection.TypeHTTP)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, test.expectedStatus, respWriter.Code)
|
assert.Equal(t, test.expectedStatus, respWriter.Code)
|
||||||
|
@ -346,7 +346,7 @@ func TestProxyError(t *testing.T) {
|
||||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = proxy.Proxy(respWriter, req, false)
|
err = proxy.Proxy(respWriter, req, connection.TypeHTTP)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Equal(t, http.StatusBadGateway, respWriter.Code)
|
assert.Equal(t, http.StatusBadGateway, respWriter.Code)
|
||||||
assert.Equal(t, "http response error", respWriter.Body.String())
|
assert.Equal(t, "http response error", respWriter.Body.String())
|
||||||
|
@ -376,12 +376,10 @@ func TestProxyBastionMode(t *testing.T) {
|
||||||
|
|
||||||
t.Run("testBastionWebsocket", testBastionWebsocket(proxy))
|
t.Run("testBastionWebsocket", testBastionWebsocket(proxy))
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
return func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
// WSRoute is a websocket echo handler
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
readPipe, _ := io.Pipe()
|
readPipe, _ := io.Pipe()
|
||||||
respWriter := newMockWSRespWriter(readPipe)
|
respWriter := newMockWSRespWriter(readPipe)
|
||||||
|
@ -389,14 +387,15 @@ func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
msgFromConn := []byte("data from websocket proxy")
|
msgFromConn := []byte("data from websocket proxy")
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
server, err := ln.Accept()
|
conn, err := ln.Accept()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
conn := websocket.NewConn(server, nil)
|
wsConn := websocket.NewConn(conn, nil)
|
||||||
conn.Write(msgFromConn)
|
wsConn.Write(msgFromConn)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://dummy", nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://dummy", nil)
|
||||||
|
@ -405,7 +404,7 @@ func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err = proxy.Proxy(respWriter, req, true)
|
err = proxy.Proxy(respWriter, req, connection.TypeWebsocket)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
|
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
|
||||||
|
@ -422,3 +421,92 @@ func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTCPStream(t *testing.T) {
|
||||||
|
logger := logger.Create(nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
ingressConfig := &config.Configuration{
|
||||||
|
Ingress: []config.UnvalidatedIngressRule{
|
||||||
|
config.UnvalidatedIngressRule{
|
||||||
|
Hostname: "*",
|
||||||
|
Service: ingress.ServiceTeamnet,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ingressRule, err := ingress.ParseIngress(ingressConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errC := make(chan error)
|
||||||
|
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
|
||||||
|
|
||||||
|
proxy := NewOriginProxy(ingressRule, testTags, logger)
|
||||||
|
|
||||||
|
t.Run("testTCPStream", testTCPStreamProxy(proxy))
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockTCPRespWriter struct {
|
||||||
|
w io.Writer
|
||||||
|
code int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTCPRespWriter) Read(p []byte) (n int, err error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
|
||||||
|
return m.w.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTCPRespWriter) WriteErrorResponse() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
|
m.code = status
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func testTCPStreamProxy(proxy connection.OriginProxy) func(t *testing.T) {
|
||||||
|
return func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
readPipe, writePipe := io.Pipe()
|
||||||
|
respWriter := &mockTCPRespWriter{
|
||||||
|
w: writePipe,
|
||||||
|
}
|
||||||
|
msgFromConn := []byte("data from tcp proxy")
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
go func() {
|
||||||
|
defer ln.Close()
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
_, err = conn.Write(msgFromConn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://dummy", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("Cf-Cloudflared-Proxy-Src", "non-blank-value")
|
||||||
|
req.Host = ln.Addr().String()
|
||||||
|
err = proxy.Proxy(respWriter, req, connection.TypeTCP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusSwitchingProtocols, respWriter.code)
|
||||||
|
|
||||||
|
returnedMsg := make([]byte, len(msgFromConn))
|
||||||
|
|
||||||
|
_, err = readPipe.Read(returnedMsg)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, msgFromConn, returnedMsg)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -47,30 +47,6 @@ func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Con
|
||||||
return conn, response, nil
|
return conn, response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream copies copy data to & from provided io.ReadWriters.
|
|
||||||
func Stream(conn, backendConn io.ReadWriter) {
|
|
||||||
proxyDone := make(chan struct{}, 2)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_, _ = io.Copy(conn, backendConn)
|
|
||||||
proxyDone <- struct{}{}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_, _ = io.Copy(backendConn, conn)
|
|
||||||
proxyDone <- struct{}{}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// If one side is done, we are done.
|
|
||||||
<-proxyDone
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultStreamHandler is provided to the the standard websocket to origin stream
|
|
||||||
// This exist to allow SOCKS to deframe data before it gets to the origin
|
|
||||||
func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn) {
|
|
||||||
Stream(originConn, remoteConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartProxyServer will start a websocket server that will decode
|
// StartProxyServer will start a websocket server that will decode
|
||||||
// the websocket data and write the resulting data to the provided
|
// the websocket data and write the resulting data to the provided
|
||||||
func StartProxyServer(
|
func StartProxyServer(
|
||||||
|
|
Loading…
Reference in New Issue