parent
be0514c5c9
commit
fb82b2ced5
@ -1,57 +0,0 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
const (
|
||||
openStreamTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
id uuid.UUID
|
||||
muxer *h2mux.Muxer
|
||||
addr *net.TCPAddr
|
||||
isLongLived bool
|
||||
longLivedID int
|
||||
}
|
||||
|
||||
func newConnection(muxer *h2mux.Muxer, addr *net.TCPAddr) (*Connection, error) {
|
||||
id, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Connection{
|
||||
id: id,
|
||||
muxer: muxer,
|
||||
addr: addr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Connection) Serve(ctx context.Context) error {
|
||||
// Serve doesn't return until h2mux is shutdown
|
||||
return c.muxer.Serve(ctx)
|
||||
}
|
||||
|
||||
// Connect is used to establish connections with cloudflare's edge network
|
||||
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger logger.Service) (tunnelpogs.ConnectResult, error) {
|
||||
tsClient, err := NewRPCClient(ctx, c.muxer, logger, openStreamTimeout)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot create new RPC connection")
|
||||
}
|
||||
defer tsClient.Close()
|
||||
return tsClient.Connect(ctx, parameters)
|
||||
}
|
||||
|
||||
func (c *Connection) Shutdown() {
|
||||
c.muxer.Shutdown()
|
||||
}
|
@ -1,302 +0,0 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/streamhandler"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
const (
|
||||
quickStartLink = "https://developers.cloudflare.com/argo-tunnel/quickstart/"
|
||||
faqLink = "https://developers.cloudflare.com/argo-tunnel/faq/"
|
||||
defaultRetryAfter = time.Second * 5
|
||||
packageNamespace = "connection"
|
||||
edgeManagerSubsystem = "edgemanager"
|
||||
)
|
||||
|
||||
// EdgeManager manages connections with the edge
|
||||
type EdgeManager struct {
|
||||
// streamHandler handles stream opened by the edge
|
||||
streamHandler *streamhandler.StreamHandler
|
||||
// TLSConfig is the TLS configuration to connect with edge
|
||||
tlsConfig *tls.Config
|
||||
// cloudflaredConfig is the cloudflared configuration that is determined when the process first starts
|
||||
cloudflaredConfig *CloudflaredConfig
|
||||
// serviceDiscoverer returns the next edge addr to connect to
|
||||
serviceDiscoverer *edgediscovery.Edge
|
||||
// state is attributes of ConnectionManager that can change during runtime.
|
||||
state *edgeManagerState
|
||||
|
||||
logger logger.Service
|
||||
|
||||
metrics *metrics
|
||||
}
|
||||
|
||||
type metrics struct {
|
||||
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
|
||||
activeStreams prometheus.Gauge
|
||||
}
|
||||
|
||||
func newMetrics(namespace, subsystem string) *metrics {
|
||||
return &metrics{
|
||||
activeStreams: h2mux.NewActiveStreamsMetrics(namespace, subsystem),
|
||||
}
|
||||
}
|
||||
|
||||
// EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager
|
||||
type EdgeManagerConfigurable struct {
|
||||
TunnelHostnames []h2mux.TunnelHostname
|
||||
*tunnelpogs.EdgeConnectionConfig
|
||||
}
|
||||
|
||||
type CloudflaredConfig struct {
|
||||
CloudflaredID uuid.UUID
|
||||
Tags []tunnelpogs.Tag
|
||||
BuildInfo *buildinfo.BuildInfo
|
||||
IntentLabel string
|
||||
}
|
||||
|
||||
func NewEdgeManager(
|
||||
streamHandler *streamhandler.StreamHandler,
|
||||
edgeConnMgrConfigurable *EdgeManagerConfigurable,
|
||||
userCredential []byte,
|
||||
tlsConfig *tls.Config,
|
||||
serviceDiscoverer *edgediscovery.Edge,
|
||||
cloudflaredConfig *CloudflaredConfig,
|
||||
logger logger.Service,
|
||||
) *EdgeManager {
|
||||
return &EdgeManager{
|
||||
streamHandler: streamHandler,
|
||||
tlsConfig: tlsConfig,
|
||||
cloudflaredConfig: cloudflaredConfig,
|
||||
serviceDiscoverer: serviceDiscoverer,
|
||||
state: newEdgeConnectionManagerState(edgeConnMgrConfigurable, userCredential),
|
||||
logger: logger,
|
||||
metrics: newMetrics(packageNamespace, edgeManagerSubsystem),
|
||||
}
|
||||
}
|
||||
|
||||
func (em *EdgeManager) Run(ctx context.Context) error {
|
||||
defer em.shutdown()
|
||||
|
||||
// Currently, declarative tunnels don't have any concept of a stable connection
|
||||
// Each edge connection is transient and when it dies, it is replaced by a different one,
|
||||
// not restarted.
|
||||
// So in the future we should really change this so that n connections are stored individually
|
||||
connIndex := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "EdgeConnectionManager terminated")
|
||||
default:
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
// Create/delete connection one at a time, so we don't need to adjust for connections that are being created/deleted
|
||||
// in shouldCreateConnection or shouldReduceConnection calculation
|
||||
if em.state.shouldCreateConnection(em.serviceDiscoverer.AvailableAddrs()) {
|
||||
if connErr := em.newConnection(ctx, connIndex); connErr != nil {
|
||||
if !connErr.ShouldRetry {
|
||||
em.logger.Errorf("connectionManager: %s with error: %s", em.noRetryMessage(), connErr)
|
||||
return connErr
|
||||
}
|
||||
em.logger.Errorf("connectionManager: cannot create new connection: %s", connErr)
|
||||
} else {
|
||||
connIndex++
|
||||
}
|
||||
} else if em.state.shouldReduceConnection() {
|
||||
if err := em.closeConnection(ctx); err != nil {
|
||||
em.logger.Errorf("connectionManager: cannot close connection: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurable) {
|
||||
em.logger.Infof("New edge connection manager configuration %+v", newConfigurable)
|
||||
em.state.updateConfigurable(newConfigurable)
|
||||
}
|
||||
|
||||
func (em *EdgeManager) newConnection(ctx context.Context, index int) *tunnelpogs.ConnectError {
|
||||
edgeTCPAddr, err := em.serviceDiscoverer.GetAddr(index)
|
||||
if err != nil {
|
||||
return retryConnection(fmt.Sprintf("edge address discovery error: %v", err))
|
||||
}
|
||||
configurable := em.state.getConfigurable()
|
||||
edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr)
|
||||
if err != nil {
|
||||
return retryConnection(fmt.Sprintf("dial edge error: %v", err))
|
||||
}
|
||||
// Establish a muxed connection with the edge
|
||||
// Client mux handshake with agent server
|
||||
muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
|
||||
Timeout: configurable.Timeout,
|
||||
Handler: em.streamHandler,
|
||||
IsClient: true,
|
||||
HeartbeatInterval: configurable.HeartbeatInterval,
|
||||
MaxHeartbeats: configurable.MaxFailedHeartbeats,
|
||||
Logger: em.logger,
|
||||
}, em.metrics.activeStreams)
|
||||
if err != nil {
|
||||
retryConnection(fmt.Sprintf("couldn't perform handshake with edge: %v", err))
|
||||
}
|
||||
|
||||
h2muxConn, err := newConnection(muxer, edgeTCPAddr)
|
||||
if err != nil {
|
||||
return retryConnection(fmt.Sprintf("couldn't create h2mux connection: %v", err))
|
||||
}
|
||||
|
||||
go em.serveConn(ctx, h2muxConn)
|
||||
|
||||
connResult, err := h2muxConn.Connect(ctx, &tunnelpogs.ConnectParameters{
|
||||
CloudflaredID: em.cloudflaredConfig.CloudflaredID,
|
||||
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
|
||||
NumPreviousAttempts: 0,
|
||||
OriginCert: em.state.getUserCredential(),
|
||||
IntentLabel: em.cloudflaredConfig.IntentLabel,
|
||||
Tags: em.cloudflaredConfig.Tags,
|
||||
}, em.logger)
|
||||
if err != nil {
|
||||
h2muxConn.Shutdown()
|
||||
return retryConnection(fmt.Sprintf("couldn't connect to edge: %v", err))
|
||||
}
|
||||
|
||||
if connErr := connResult.ConnectError(); connErr != nil {
|
||||
return connErr
|
||||
}
|
||||
|
||||
em.state.newConnection(h2muxConn)
|
||||
em.logger.Infof("connectionManager: connected to %s", connResult.ConnectedTo())
|
||||
|
||||
if connResult.ClientConfig() != nil {
|
||||
em.streamHandler.UseConfiguration(ctx, connResult.ClientConfig())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (em *EdgeManager) closeConnection(ctx context.Context) error {
|
||||
conn := em.state.getFirstConnection()
|
||||
if conn == nil {
|
||||
return fmt.Errorf("no connection to close")
|
||||
}
|
||||
conn.Shutdown()
|
||||
// teardown will be handled by EdgeManager.serveConn in another goroutine
|
||||
return nil
|
||||
}
|
||||
|
||||
func (em *EdgeManager) serveConn(ctx context.Context, conn *Connection) {
|
||||
err := conn.Serve(ctx)
|
||||
em.logger.Errorf("connectionManager: Connection closed: %s", err)
|
||||
em.state.closeConnection(conn)
|
||||
em.serviceDiscoverer.GiveBack(conn.addr)
|
||||
}
|
||||
|
||||
func (em *EdgeManager) noRetryMessage() string {
|
||||
messageTemplate := "cloudflared could not register an Argo Tunnel on your account. Please confirm the following before trying again:" +
|
||||
"1. You have Argo Smart Routing enabled in your account, See Enable Argo section of %s." +
|
||||
"2. Your credential at %s is still valid. See %s."
|
||||
return fmt.Sprintf(messageTemplate, quickStartLink, em.state.getConfigurable().UserCredentialPath, faqLink)
|
||||
}
|
||||
|
||||
func (em *EdgeManager) shutdown() {
|
||||
em.state.shutdown()
|
||||
}
|
||||
|
||||
type edgeManagerState struct {
|
||||
sync.RWMutex
|
||||
configurable *EdgeManagerConfigurable
|
||||
userCredential []byte
|
||||
conns map[uuid.UUID]*Connection
|
||||
}
|
||||
|
||||
func newEdgeConnectionManagerState(configurable *EdgeManagerConfigurable, userCredential []byte) *edgeManagerState {
|
||||
return &edgeManagerState{
|
||||
configurable: configurable,
|
||||
userCredential: userCredential,
|
||||
conns: make(map[uuid.UUID]*Connection),
|
||||
}
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) shouldCreateConnection(availableEdgeAddrs int) bool {
|
||||
ems.RLock()
|
||||
defer ems.RUnlock()
|
||||
expectedHAConns := int(ems.configurable.NumHAConnections)
|
||||
if availableEdgeAddrs < expectedHAConns {
|
||||
expectedHAConns = availableEdgeAddrs
|
||||
}
|
||||
return len(ems.conns) < expectedHAConns
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) shouldReduceConnection() bool {
|
||||
ems.RLock()
|
||||
defer ems.RUnlock()
|
||||
return uint8(len(ems.conns)) > ems.configurable.NumHAConnections
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) newConnection(conn *Connection) {
|
||||
ems.Lock()
|
||||
defer ems.Unlock()
|
||||
ems.conns[conn.id] = conn
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) closeConnection(conn *Connection) {
|
||||
ems.Lock()
|
||||
defer ems.Unlock()
|
||||
delete(ems.conns, conn.id)
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) getFirstConnection() *Connection {
|
||||
ems.RLock()
|
||||
defer ems.RUnlock()
|
||||
|
||||
for _, conn := range ems.conns {
|
||||
return conn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) shutdown() {
|
||||
ems.Lock()
|
||||
defer ems.Unlock()
|
||||
for _, conn := range ems.conns {
|
||||
conn.Shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) getConfigurable() *EdgeManagerConfigurable {
|
||||
ems.Lock()
|
||||
defer ems.Unlock()
|
||||
return ems.configurable
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) updateConfigurable(newConfigurable *EdgeManagerConfigurable) {
|
||||
ems.Lock()
|
||||
defer ems.Unlock()
|
||||
ems.configurable = newConfigurable
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) getUserCredential() []byte {
|
||||
ems.RLock()
|
||||
defer ems.RUnlock()
|
||||
return ems.userCredential
|
||||
}
|
||||
|
||||
func retryConnection(cause string) *tunnelpogs.ConnectError {
|
||||
return &tunnelpogs.ConnectError{
|
||||
Cause: cause,
|
||||
RetryAfter: defaultRetryAfter,
|
||||
ShouldRetry: true,
|
||||
}
|
||||
}
|
@ -1,77 +0,0 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/streamhandler"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
var (
|
||||
configurable = &EdgeManagerConfigurable{
|
||||
[]h2mux.TunnelHostname{
|
||||
"http.example.com",
|
||||
"ws.example.com",
|
||||
"hello.example.com",
|
||||
},
|
||||
&pogs.EdgeConnectionConfig{
|
||||
NumHAConnections: 1,
|
||||
HeartbeatInterval: 1 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
MaxFailedHeartbeats: 3,
|
||||
UserCredentialPath: "/etc/cloudflared/cert.pem",
|
||||
},
|
||||
}
|
||||
cloudflaredConfig = &CloudflaredConfig{
|
||||
CloudflaredID: uuid.New(),
|
||||
Tags: []pogs.Tag{
|
||||
{Name: "pool", Value: "east-6"},
|
||||
},
|
||||
BuildInfo: &buildinfo.BuildInfo{
|
||||
GoOS: "linux",
|
||||
GoVersion: "1.12",
|
||||
GoArch: "amd64",
|
||||
CloudflaredVersion: "2019.6.0",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func mockEdgeManager() *EdgeManager {
|
||||
newConfigChan := make(chan<- *pogs.ClientConfig)
|
||||
useConfigResultChan := make(<-chan *pogs.UseConfigurationResult)
|
||||
logger := logger.NewOutputWriter(logger.NewMockWriteManager())
|
||||
edge := edgediscovery.MockEdge(logger, []*net.TCPAddr{})
|
||||
return NewEdgeManager(
|
||||
streamhandler.NewStreamHandler(newConfigChan, useConfigResultChan, logger),
|
||||
configurable,
|
||||
[]byte{},
|
||||
nil,
|
||||
edge,
|
||||
cloudflaredConfig,
|
||||
logger,
|
||||
)
|
||||
}
|
||||
|
||||
func TestUpdateConfigurable(t *testing.T) {
|
||||
m := mockEdgeManager()
|
||||
newConfigurable := &EdgeManagerConfigurable{
|
||||
[]h2mux.TunnelHostname{
|
||||
"second.example.com",
|
||||
},
|
||||
&pogs.EdgeConnectionConfig{
|
||||
NumHAConnections: 2,
|
||||
},
|
||||
}
|
||||
m.UpdateConfigurable(newConfigurable)
|
||||
|
||||
assert.Equal(t, newConfigurable, m.state.getConfigurable())
|
||||
}
|
@ -1,247 +0,0 @@
|
||||
// Package client defines and implements interface to proxy to HTTP, websocket and hello world origins
|
||||
package originservice
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/buffer"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// OriginService is an interface to proxy requests to different type of origins
|
||||
type OriginService interface {
|
||||
Proxy(stream *h2mux.MuxedStream, req *http.Request) (resp *http.Response, err error)
|
||||
URL() *url.URL
|
||||
Summary() string
|
||||
Shutdown()
|
||||
}
|
||||
|
||||
// HTTPService talks to origin using HTTP/HTTPS
|
||||
type HTTPService struct {
|
||||
client http.RoundTripper
|
||||
originURL *url.URL
|
||||
chunkedEncoding bool
|
||||
bufferPool *buffer.Pool
|
||||
}
|
||||
|
||||
func NewHTTPService(transport http.RoundTripper, url *url.URL, chunkedEncoding bool) OriginService {
|
||||
return &HTTPService{
|
||||
client: transport,
|
||||
originURL: url,
|
||||
chunkedEncoding: chunkedEncoding,
|
||||
bufferPool: buffer.NewPool(512 * 1024),
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *HTTPService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||
const responseSourceOrigin = "origin"
|
||||
|
||||
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||
if !hc.chunkedEncoding {
|
||||
req.TransferEncoding = []string{"gzip", "deflate"}
|
||||
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
||||
if err == nil {
|
||||
req.ContentLength = int64(cLength)
|
||||
}
|
||||
}
|
||||
|
||||
// Request origin to keep connection alive to improve performance
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
resp, err := hc.client.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error proxying request to HTTP origin")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
responseHeaders := h1ResponseToH2Response(resp)
|
||||
responseHeaders = append(responseHeaders, h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, responseSourceOrigin))
|
||||
err = stream.WriteHeaders(responseHeaders)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error writing response header to HTTP origin")
|
||||
}
|
||||
if isEventStream(resp) {
|
||||
writeEventStream(stream, resp.Body)
|
||||
} else {
|
||||
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||
// compression generates dictionary on first write
|
||||
buf := hc.bufferPool.Get()
|
||||
defer hc.bufferPool.Put(buf)
|
||||
io.CopyBuffer(stream, resp.Body, buf)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (hc *HTTPService) URL() *url.URL {
|
||||
return hc.originURL
|
||||
}
|
||||
|
||||
func (hc *HTTPService) Summary() string {
|
||||
return fmt.Sprintf("HTTP service listening on %s", hc.originURL)
|
||||
}
|
||||
|
||||
func (hc *HTTPService) Shutdown() {}
|
||||
|
||||
// WebsocketService talks to origin using WS/WSS
|
||||
type WebsocketService struct {
|
||||
tlsConfig *tls.Config
|
||||
originURL *url.URL
|
||||
shutdownC chan struct{}
|
||||
}
|
||||
|
||||
func NewWebSocketService(tlsConfig *tls.Config, url *url.URL, logger logger.Service) (OriginService, error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot start Websocket Proxy Server")
|
||||
}
|
||||
shutdownC := make(chan struct{})
|
||||
go func() {
|
||||
websocket.StartProxyServer(logger, listener, url.String(), shutdownC, websocket.DefaultStreamHandler)
|
||||
}()
|
||||
return &WebsocketService{
|
||||
tlsConfig: tlsConfig,
|
||||
originURL: url,
|
||||
shutdownC: shutdownC,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (wsc *WebsocketService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||
if !websocket.IsWebSocketUpgrade(req) {
|
||||
return nil, fmt.Errorf("request is not a websocket connection")
|
||||
}
|
||||
conn, response, err := websocket.ClientConnect(req, wsc.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
err = stream.WriteHeaders(h1ResponseToH2Response(response))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error writing response header to websocket origin")
|
||||
}
|
||||
// Copy to/from stream to the undelying connection. Use the underlying
|
||||
// connection because cloudflared doesn't operate on the message themselves
|
||||
websocket.Stream(conn.UnderlyingConn(), stream)
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (wsc *WebsocketService) URL() *url.URL {
|
||||
return wsc.originURL
|
||||
}
|
||||
|
||||
func (wsc *WebsocketService) Summary() string {
|
||||
return fmt.Sprintf("Websocket listening on %s", wsc.originURL)
|
||||
}
|
||||
|
||||
func (wsc *WebsocketService) Shutdown() {
|
||||
close(wsc.shutdownC)
|
||||
}
|
||||
|
||||
// HelloWorldService talks to the hello world example origin
|
||||
type HelloWorldService struct {
|
||||
client http.RoundTripper
|
||||
listener net.Listener
|
||||
originURL *url.URL
|
||||
shutdownC chan struct{}
|
||||
bufferPool *buffer.Pool
|
||||
}
|
||||
|
||||
func NewHelloWorldService(transport http.RoundTripper, logger logger.Service) (OriginService, error) {
|
||||
listener, err := hello.CreateTLSListener("127.0.0.1:")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot start Hello World Server")
|
||||
}
|
||||
shutdownC := make(chan struct{})
|
||||
go func() {
|
||||
hello.StartHelloWorldServer(logger, listener, shutdownC)
|
||||
}()
|
||||
return &HelloWorldService{
|
||||
client: transport,
|
||||
listener: listener,
|
||||
originURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: listener.Addr().String(),
|
||||
},
|
||||
shutdownC: shutdownC,
|
||||
bufferPool: buffer.NewPool(512 * 1024),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (hwc *HelloWorldService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||
// Request origin to keep connection alive to improve performance
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
resp, err := hwc.client.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error proxying request to Hello World origin")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = stream.WriteHeaders(h1ResponseToH2Response(resp))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error writing response header to Hello World origin")
|
||||
}
|
||||
|
||||
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||
// compression generates dictionary on first write
|
||||
buf := hwc.bufferPool.Get()
|
||||
defer hwc.bufferPool.Put(buf)
|
||||
io.CopyBuffer(stream, resp.Body, buf)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (hwc *HelloWorldService) URL() *url.URL {
|
||||
return hwc.originURL
|
||||
}
|
||||
|
||||
func (hwc *HelloWorldService) Summary() string {
|
||||
return fmt.Sprintf("Hello World service listening on %s", hwc.originURL)
|
||||
}
|
||||
|
||||
func (hwc *HelloWorldService) Shutdown() {
|
||||
hwc.listener.Close()
|
||||
}
|
||||
|
||||
func isEventStream(resp *http.Response) bool {
|
||||
// Check if content-type is text/event-stream. We need to check if the header value starts with text/event-stream
|
||||
// because text/event-stream; charset=UTF-8 is also valid
|
||||
// Ref: https://tools.ietf.org/html/rfc7231#section-3.1.1.1
|
||||
for _, contentType := range resp.Header["content-type"] {
|
||||
if strings.HasPrefix(strings.ToLower(contentType), "text/event-stream") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func writeEventStream(stream *h2mux.MuxedStream, respBody io.ReadCloser) {
|
||||
reader := bufio.NewReader(respBody)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
stream.Write(line)
|
||||
}
|
||||
}
|
||||
|
||||
func h1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
|
||||
h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}}
|
||||
for headerName, headerValues := range h1.Header {
|
||||
for _, headerValue := range headerValues {
|
||||
h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
@ -1,60 +0,0 @@
|
||||
package originservice
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsEventStream(t *testing.T) {
|
||||
tests := []struct {
|
||||
resp *http.Response
|
||||
isEventStream bool
|
||||
}{
|
||||
{
|
||||
resp: &http.Response{},
|
||||
isEventStream: false,
|
||||
},
|
||||
{
|
||||
// isEventStream checks all headers
|
||||
resp: &http.Response{
|
||||
Header: http.Header{
|
||||
"accept": []string{"text/html"},
|
||||
"content-type": []string{"text/event-stream"},
|
||||
},
|
||||
},
|
||||
isEventStream: true,
|
||||
},
|
||||
{
|
||||
// Content-Type and text/event-stream are case-insensitive. text/event-stream can be followed by OWS parameter
|
||||
resp: &http.Response{
|
||||
Header: http.Header{
|
||||
"content-type": []string{"Text/event-stream;charset=utf-8"},
|
||||
},
|
||||
},
|
||||
isEventStream: true,
|
||||
},
|
||||
{
|
||||
// Content-Type and text/event-stream are case-insensitive. text/event-stream can be followed by OWS parameter
|
||||
resp: &http.Response{
|
||||
Header: http.Header{
|
||||
"content-type": []string{"appication/json", "text/html", "Text/event-stream;charset=utf-8"},
|
||||
},
|
||||
},
|
||||
isEventStream: true,
|
||||
},
|
||||
{
|
||||
// Not an event stream because the content-type value doesn't start with text/event-stream
|
||||
resp: &http.Response{
|
||||
Header: http.Header{
|
||||
"content-type": []string{" text/event-stream"},
|
||||
},
|
||||
},
|
||||
isEventStream: false,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
assert.Equal(t, test.isEventStream, isEventStream(test.resp), "Header: %v", test.resp.Header)
|
||||
}
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
package streamhandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
||||
)
|
||||
|
||||
func FindCfRayHeader(h1 *http.Request) string {
|
||||
return h1.Header.Get("Cf-Ray")
|
||||
}
|
||||
|
||||
func IsLBProbeRequest(req *http.Request) bool {
|
||||
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
|
||||
}
|
||||
|
||||
func createRequest(stream *h2mux.MuxedStream, url *url.URL) (*http.Request, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url.String(), h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unexpected error from http.NewRequest")
|
||||
}
|
||||
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid request received")
|
||||
}
|
||||
return req, nil
|
||||
}
|
@ -1,189 +0,0 @@
|
||||
package streamhandler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/tunnelhostnamemapper"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/pkg/errors"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
)
|
||||
|
||||
const (
|
||||
statusPseudoHeader = ":status"
|
||||
)
|
||||
|
||||
type httpErrorStatus struct {
|
||||
status string
|
||||
text []byte
|
||||
}
|
||||
|
||||
var (
|
||||
statusBadRequest = newHTTPErrorStatus(http.StatusBadRequest)
|
||||
statusNotFound = newHTTPErrorStatus(http.StatusNotFound)
|
||||
statusBadGateway = newHTTPErrorStatus(http.StatusBadGateway)
|
||||
)
|
||||
|
||||
func newHTTPErrorStatus(status int) *httpErrorStatus {
|
||||
return &httpErrorStatus{
|
||||
status: strconv.Itoa(status),
|
||||
text: []byte(http.StatusText(status)),
|
||||
}
|
||||
}
|
||||
|
||||
// StreamHandler handles new stream opened by the edge. The streams can be used to proxy requests or make RPC.
|
||||
type StreamHandler struct {
|
||||
// newConfigChan is a send-only channel to notify Supervisor of a new ClientConfig
|
||||
newConfigChan chan<- *pogs.ClientConfig
|
||||
// useConfigResultChan is a receive-only channel for Supervisor to communicate the result of applying a new ClientConfig
|
||||
useConfigResultChan <-chan *pogs.UseConfigurationResult
|
||||
// originMapper maps tunnel hostname to origin service
|
||||
tunnelHostnameMapper *tunnelhostnamemapper.TunnelHostnameMapper
|
||||
logger logger.Service
|
||||
}
|
||||
|
||||
// NewStreamHandler creates a new StreamHandler
|
||||
func NewStreamHandler(newConfigChan chan<- *pogs.ClientConfig,
|
||||
useConfigResultChan <-chan *pogs.UseConfigurationResult,
|
||||
logger logger.Service,
|
||||
) *StreamHandler {
|
||||
return &StreamHandler{
|
||||
newConfigChan: newConfigChan,
|
||||
useConfigResultChan: useConfigResultChan,
|
||||
tunnelHostnameMapper: tunnelhostnamemapper.NewTunnelHostnameMapper(),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// UseConfiguration implements ClientService
|
||||
func (s *StreamHandler) UseConfiguration(ctx context.Context, config *pogs.ClientConfig) (*pogs.UseConfigurationResult, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := fmt.Errorf("Timeout while sending new config to Supervisor")
|
||||
s.logger.Errorf("streamHandler: %s", err)
|
||||
return nil, err
|
||||
case s.newConfigChan <- config:
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := fmt.Errorf("Timeout applying new configuration")
|
||||
s.logger.Errorf("streamHandler: %s", err)
|
||||
return nil, err
|
||||
case result := <-s.useConfigResultChan:
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig replaces current originmapper mapping with mappings from newConfig
|
||||
func (s *StreamHandler) UpdateConfig(newConfig []*pogs.ReverseProxyConfig) (failedConfigs []*pogs.FailedConfig) {
|
||||
|
||||
// Delete old configs that aren't in the `newConfig`
|
||||
toRemove := s.tunnelHostnameMapper.ToRemove(newConfig)
|
||||
for _, hostnameToRemove := range toRemove {
|
||||
s.tunnelHostnameMapper.Delete(hostnameToRemove)
|
||||
}
|
||||
|
||||
// Add new configs that weren't in the old mapper
|
||||
toAdd := s.tunnelHostnameMapper.ToAdd(newConfig)
|
||||
for _, tunnelConfig := range toAdd {
|
||||
tunnelHostname := tunnelConfig.TunnelHostname
|
||||
originSerice, err := tunnelConfig.OriginConfig.Service(s.logger)
|
||||
if err != nil {
|
||||
s.logger.Errorf("streamHandler: tunnelHostname: %s Invalid origin service config: %s", tunnelHostname, err)
|
||||
failedConfigs = append(failedConfigs, &pogs.FailedConfig{
|
||||
Config: tunnelConfig,
|
||||
Reason: tunnelConfig.FailReason(err),
|
||||
})
|
||||
continue
|
||||
}
|
||||
s.tunnelHostnameMapper.Add(tunnelConfig.TunnelHostname, originSerice)
|
||||
s.logger.Infof("streamHandler: tunnelHostname: %s New origin service config: %v", tunnelHostname, originSerice.Summary())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ServeStream implements MuxedStreamHandler interface
|
||||
func (s *StreamHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||
if stream.IsRPCStream() {
|
||||
return s.serveRPC(stream)
|
||||
}
|
||||
if err := s.serveRequest(stream); err != nil {
|
||||
s.logger.Errorf("streamHandler: %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StreamHandler) serveRPC(stream *h2mux.MuxedStream) error {
|
||||
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "200"}})
|
||||
main := pogs.ClientService_ServerToClient(s)
|
||||
rpcConn := rpc.NewConn(
|
||||
tunnelrpc.NewTransportLogger(s.logger, rpc.StreamTransport(stream)),
|
||||
rpc.MainInterface(main.Client),
|
||||
tunnelrpc.ConnLog(s.logger),
|
||||
)
|
||||
return rpcConn.Wait()
|
||||
}
|
||||
|
||||
func (s *StreamHandler) serveRequest(stream *h2mux.MuxedStream) error {
|
||||
tunnelHostname := stream.TunnelHostname()
|
||||
if !tunnelHostname.IsSet() {
|
||||
s.writeErrorStatus(stream, statusBadRequest)
|
||||
return fmt.Errorf("stream doesn't have tunnelHostname")
|
||||
}
|
||||
|
||||
originService, ok := s.tunnelHostnameMapper.Get(tunnelHostname)
|
||||
if !ok {
|
||||
s.writeErrorStatus(stream, statusNotFound)
|
||||
return fmt.Errorf("cannot map tunnel hostname %s to origin", tunnelHostname)
|
||||
}
|
||||
|
||||
req, err := createRequest(stream, originService.URL())
|
||||
if err != nil {
|
||||
s.writeErrorStatus(stream, statusBadRequest)
|
||||
return errors.Wrap(err, "cannot create request")
|
||||
}
|
||||
|
||||
cfRay := s.logRequest(req, tunnelHostname)
|
||||
s.logger.Debugf("streamHandler: tunnelHostname: %s CF-RAY: %s Request Headers %+v", tunnelHostname, cfRay, req.Header)
|
||||
|
||||
resp, err := originService.Proxy(stream, req)
|
||||
if err != nil {
|
||||
s.writeErrorStatus(stream, statusBadGateway)
|
||||
return errors.Wrap(err, "cannot proxy request")
|
||||
}
|
||||
|
||||
s.logger.Debugf("streamHandler: tunnelHostname: %s CF-RAY: %s status: %s Response Headers %+v", tunnelHostname, cfRay, resp.Status, resp.Header)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StreamHandler) logRequest(req *http.Request, tunnelHostname h2mux.TunnelHostname) string {
|
||||
cfRay := FindCfRayHeader(req)
|
||||
lbProbe := IsLBProbeRequest(req)
|
||||
logger := s.logger
|
||||
if cfRay != "" {
|
||||
logger.Debugf("streamHandler: tunnelHostname: %s CF-RAY: %s %s %s %s", tunnelHostname, cfRay, req.Method, req.URL, req.Proto)
|
||||
} else if lbProbe {
|
||||
logger.Debugf("streamHandler: tunnelHostname: %s CF-RAY: %s Load Balancer health check %s %s %s", tunnelHostname, cfRay, req.Method, req.URL, req.Proto)
|
||||
} else {
|
||||
logger.Infof("streamHandler: tunnelHostname: %s CF-RAY: %s Requests %v does not have CF-RAY header. Please open a support ticket with Cloudflare.", tunnelHostname, cfRay, req)
|
||||
}
|
||||
return cfRay
|
||||
}
|
||||
|
||||
func (s *StreamHandler) writeErrorStatus(stream *h2mux.MuxedStream, status *httpErrorStatus) {
|
||||
_ = stream.WriteHeaders([]h2mux.Header{
|
||||
{
|
||||
Name: statusPseudoHeader,
|
||||
Value: status.status,
|
||||
},
|
||||
h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceCloudflared),
|
||||
})
|
||||
_, _ = stream.Write(status.text)
|
||||
}
|
@ -1,261 +0,0 @@
|
||||
package streamhandler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
testOpenStreamTimeout = time.Millisecond * 5000
|
||||
testHandshakeTimeout = time.Millisecond * 1000
|
||||
)
|
||||
|
||||
var (
|
||||
testTunnelHostname = h2mux.TunnelHostname("123.cftunnel.com")
|
||||
baseHeaders = []h2mux.Header{
|
||||
{Name: ":method", Value: "GET"},
|
||||
{Name: ":scheme", Value: "http"},
|
||||
{Name: ":authority", Value: "example.com"},
|
||||
{Name: ":path", Value: "/"},
|
||||
|
||||
// Regular headers must always come after the pseudoheaders
|
||||
{Name: h2mux.RequestUserHeadersField, Value: ""},
|
||||
}
|
||||
tunnelHostnameHeader = h2mux.Header{Name: h2mux.CloudflaredProxyTunnelHostnameHeader, Value: testTunnelHostname.String()}
|
||||
)
|
||||
|
||||
func TestServeRequest(t *testing.T) {
|
||||
l := logger.NewOutputWriter(logger.NewMockWriteManager())
|
||||
configChan := make(chan *pogs.ClientConfig)
|
||||
useConfigResultChan := make(chan *pogs.UseConfigurationResult)
|
||||
streamHandler := NewStreamHandler(configChan, useConfigResultChan, l)
|
||||
|
||||
message := []byte("Hello cloudflared")
|
||||
httpServer := httptest.NewServer(&mockHTTPHandler{message})
|
||||
|
||||
reverseProxyConfigs := []*pogs.ReverseProxyConfig{
|
||||
{
|
||||
TunnelHostname: testTunnelHostname,
|
||||
OriginConfig: &pogs.HTTPOriginConfig{
|
||||
URLString: httpServer.URL,
|
||||
},
|
||||
},
|
||||
}
|
||||
streamHandler.UpdateConfig(reverseProxyConfigs)
|
||||
|
||||
muxPair := NewDefaultMuxerPair(t, streamHandler)
|
||||
muxPair.Serve(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
headers := append(baseHeaders, tunnelHostnameHeader)
|
||||
stream, err := muxPair.EdgeMux.OpenStream(ctx, headers, nil)
|
||||
assert.NoError(t, err)
|
||||
assertStatusHeader(t, http.StatusOK, stream.Headers)
|
||||
assertRespBody(t, message, stream)
|
||||
}
|
||||
|
||||
func createStreamHandler() *StreamHandler {
|
||||
configChan := make(chan *pogs.ClientConfig)
|
||||
useConfigResultChan := make(chan *pogs.UseConfigurationResult)
|
||||
l := logger.NewOutputWriter(logger.NewMockWriteManager())
|
||||
|
||||
return NewStreamHandler(configChan, useConfigResultChan, l)
|
||||
}
|
||||
|
||||
func createRequestMuxPair(t *testing.T, streamHandler *StreamHandler) *DefaultMuxerPair {
|
||||
muxPair := NewDefaultMuxerPair(t, streamHandler)
|
||||
muxPair.Serve(t)
|
||||
|
||||
return muxPair
|
||||
}
|
||||
|
||||
func TestServeStatusBadRequest(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
// No tunnel hostname header, expect to get 400 Bad Request
|
||||
stream, err := createRequestMuxPair(t, createStreamHandler()).EdgeMux.OpenStream(ctx, baseHeaders, nil)
|
||||
assert.NoError(t, err)
|
||||
assertStatusHeader(t, http.StatusBadRequest, stream.Headers)
|
||||
assertRespBody(t, statusBadRequest.text, stream)
|
||||
}
|
||||
|
||||
func TestServeInvalidContentLength(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Invalid content-length, wouldn't be able to create a request
|
||||
// Expect to get 400 Bad Request
|
||||
headers := append(baseHeaders, tunnelHostnameHeader)
|
||||
headers = append(headers, h2mux.Header{
|
||||
Name: "content-length",
|
||||
Value: "x",
|
||||
})
|
||||
streamHandler := createStreamHandler()
|
||||
streamHandler.UpdateConfig([]*pogs.ReverseProxyConfig{
|
||||
{
|
||||
TunnelHostname: testTunnelHostname,
|
||||
OriginConfig: &pogs.HTTPOriginConfig{
|
||||
URLString: "",
|
||||
},
|
||||
},
|
||||
})
|
||||
mux := createRequestMuxPair(t, streamHandler).EdgeMux
|
||||
stream, err := mux.OpenStream(ctx, headers, nil)
|
||||
assert.NoError(t, err)
|
||||
assertStatusHeader(t, http.StatusBadRequest, stream.Headers)
|
||||
assertRespBody(t, statusBadRequest.text, stream)
|
||||
}
|
||||
|
||||
func TestServeStatusNotFound(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
// No mapping for the tunnel hostname, expect to get 404 Not Found
|
||||
headers := append(baseHeaders, tunnelHostnameHeader)
|
||||
stream, err := createRequestMuxPair(t, createStreamHandler()).EdgeMux.OpenStream(ctx, headers, nil)
|
||||
assert.NoError(t, err)
|
||||
assertStatusHeader(t, http.StatusNotFound, stream.Headers)
|
||||
assertRespBody(t, statusNotFound.text, stream)
|
||||
}
|
||||
|
||||
func TestServeStatusBadGateway(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Nothing listening on empty url, so proxy would fail. Expect to get 502 Bad Gateway
|
||||
reverseProxyConfigs := []*pogs.ReverseProxyConfig{
|
||||
{
|
||||
TunnelHostname: testTunnelHostname,
|
||||
OriginConfig: &pogs.HTTPOriginConfig{
|
||||
URLString: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
streamHandler := createStreamHandler()
|
||||
streamHandler.UpdateConfig(reverseProxyConfigs)
|
||||
headers := append(baseHeaders, tunnelHostnameHeader)
|
||||
stream, err := createRequestMuxPair(t, streamHandler).EdgeMux.OpenStream(ctx, headers, nil)
|
||||
assert.NoError(t, err)
|
||||
assertStatusHeader(t, http.StatusBadGateway, stream.Headers)
|
||||
assertRespBody(t, statusBadGateway.text, stream)
|
||||
}
|
||||
|
||||
func assertStatusHeader(t *testing.T, expectedStatus int, headers []h2mux.Header) {
|
||||
assert.Equal(t, statusPseudoHeader, headers[0].Name)
|
||||
assert.Equal(t, strconv.Itoa(expectedStatus), headers[0].Value)
|
||||
}
|
||||
|
||||
func assertRespBody(t *testing.T, expectedRespBody []byte, stream *h2mux.MuxedStream) {
|
||||
respBody := make([]byte, len(expectedRespBody))
|
||||
_, err := stream.Read(respBody)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedRespBody, respBody)
|
||||
}
|
||||
|
||||
type DefaultMuxerPair struct {
|
||||
OriginMuxConfig h2mux.MuxerConfig
|
||||
OriginMux *h2mux.Muxer
|
||||
OriginConn net.Conn
|
||||
EdgeMuxConfig h2mux.MuxerConfig
|
||||
EdgeMux *h2mux.Muxer
|
||||
EdgeConn net.Conn
|
||||
doneC chan struct{}
|
||||
}
|
||||
|
||||
func NewDefaultMuxerPair(t *testing.T, h h2mux.MuxedStreamHandler) *DefaultMuxerPair {
|
||||
origin, edge := net.Pipe()
|
||||
p := &DefaultMuxerPair{
|
||||
OriginMuxConfig: h2mux.MuxerConfig{
|
||||
Timeout: testHandshakeTimeout,
|
||||
Handler: h,
|
||||
IsClient: true,
|
||||
Name: "origin",
|
||||
Logger: logger.NewOutputWriter(logger.NewMockWriteManager()),
|
||||
< |