TUN-3019: Remove declarative tunnel entry code

pull/173/head^2
cthuang 3 years ago
parent be0514c5c9
commit fb82b2ced5

@ -21,7 +21,6 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/dbconnect"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/hello"
@ -32,10 +31,8 @@ import (
"github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/sshlog"
"github.com/cloudflare/cloudflared/sshserver"
"github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
"github.com/coreos/go-systemd/daemon"
@ -93,8 +90,6 @@ const (
// bastionFlag is to enable bastion, or jump host, operation
bastionFlag = "bastion"
noIntentMsg = "The --intent argument is required. Cloudflared looks up an Intent to determine what configuration to use (i.e. which tunnels to start). If you don't have any Intents yet, you can use a placeholder Intent Label for now. Then, when you make an Intent with that label, cloudflared will get notified and open the tunnels you specified in that Intent."
debugLevelWarning = "At debug level, request URL, method, protocol, content legnth and header will be logged. " +
"Response status, content length and header will also be logged in debug level."
)
@ -322,10 +317,6 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
cancel()
}()
if c.IsSet("use-declarative-tunnels") {
return startDeclarativeTunnel(ctx, c, cloudflaredID, buildInfo, &listeners, logger)
}
// update needs to be after DNS proxy is up to resolve equinox server address
if updater.IsAutoupdateEnabled(c, logger) {
logger.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
@ -500,124 +491,6 @@ func isProxyDestinationConfigured(staticHost string, c *cli.Context) bool {
return staticHost != "" || c.IsSet(bastionFlag)
}
func startDeclarativeTunnel(ctx context.Context,
c *cli.Context,
cloudflaredID uuid.UUID,
buildInfo *buildinfo.BuildInfo,
listeners *gracenet.Net,
logger logger.Service,
) error {
reverseProxyOrigin, err := defaultOriginConfig(c)
if err != nil {
logger.Errorf("%s", err)
return err
}
reverseProxyConfig, err := pogs.NewReverseProxyConfig(
c.String("hostname"),
reverseProxyOrigin,
c.Uint64("retries"),
c.Duration("proxy-connection-timeout"),
c.Uint64("compression-quality"),
)
if err != nil {
logger.Errorf("Cannot initialize default client config because reverse proxy config is invalid: %s", err)
return err
}
defaultClientConfig := &pogs.ClientConfig{
Version: pogs.InitVersion(),
SupervisorConfig: &pogs.SupervisorConfig{
AutoUpdateFrequency: c.Duration("autoupdate-freq"),
MetricsUpdateFrequency: c.Duration("metrics-update-freq"),
GracePeriod: c.Duration("grace-period"),
},
EdgeConnectionConfig: &pogs.EdgeConnectionConfig{
NumHAConnections: uint8(c.Int("ha-connections")),
HeartbeatInterval: c.Duration("heartbeat-interval"),
Timeout: c.Duration("dial-edge-timeout"),
MaxFailedHeartbeats: c.Uint64("heartbeat-count"),
},
DoHProxyConfigs: []*pogs.DoHProxyConfig{},
ReverseProxyConfigs: []*pogs.ReverseProxyConfig{reverseProxyConfig},
}
autoupdater := updater.NewAutoUpdater(defaultClientConfig.SupervisorConfig.AutoUpdateFrequency, listeners, logger)
originCert, err := getOriginCert(c, logger)
if err != nil {
logger.Errorf("error getting origin cert: %s", err)
return err
}
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c)
if err != nil {
logger.Errorf("unable to create TLS config to connect with edge: %s", err)
return err
}
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil {
logger.Errorf("unable to parse tag: %s", err)
return err
}
intentLabel := c.String("intent")
if intentLabel == "" {
logger.Error("--intent was empty")
return fmt.Errorf(noIntentMsg)
}
cloudflaredConfig := &connection.CloudflaredConfig{
BuildInfo: buildInfo,
CloudflaredID: cloudflaredID,
IntentLabel: intentLabel,
Tags: tags,
}
serviceDiscoverer, err := serviceDiscoverer(c, logger)
if err != nil {
logger.Errorf("unable to create service discoverer: %s", err)
return err
}
supervisor, err := supervisor.NewSupervisor(defaultClientConfig, originCert, toEdgeTLSConfig,
serviceDiscoverer, cloudflaredConfig, autoupdater, updater.SupportAutoUpdate(logger), logger)
if err != nil {
logger.Errorf("unable to create Supervisor: %s", err)
return err
}
return supervisor.Run(ctx)
}
func defaultOriginConfig(c *cli.Context) (pogs.OriginConfig, error) {
if c.IsSet("hello-world") {
return &pogs.HelloWorldOriginConfig{}, nil
}
originConfig := &pogs.HTTPOriginConfig{
TCPKeepAlive: c.Duration("proxy-tcp-keepalive"),
DialDualStack: !c.Bool("proxy-no-happy-eyeballs"),
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
TLSVerify: !c.Bool("no-tls-verify"),
OriginCAPool: c.String("origin-ca-pool"),
OriginServerName: c.String("origin-server-name"),
MaxIdleConnections: c.Uint64("proxy-keepalive-connections"),
IdleConnectionTimeout: c.Duration("proxy-keepalive-timeout"),
ProxyConnectionTimeout: c.Duration("proxy-connection-timeout"),
ExpectContinueTimeout: c.Duration("proxy-expect-continue-timeout"),
ChunkedEncoding: c.Bool("no-chunked-encoding"),
}
if c.IsSet("unix-socket") {
unixSocket, err := config.ValidateUnixSocket(c)
if err != nil {
return nil, errors.Wrap(err, "error validating --unix-socket")
}
originConfig.URLString = unixSocket
}
originAddr, err := config.ValidateUrl(c)
if err != nil {
return nil, errors.Wrap(err, "error validating origin URL")
}
originConfig.URLString = originAddr
return originConfig, nil
}
func waitToShutdown(wg *sync.WaitGroup,
errC chan error,
shutdownC, graceShutdownC chan struct{},
@ -1064,18 +937,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
EnvVars: []string{"TUNNEL_TRACE_OUTPUT"},
Hidden: shouldHide,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "use-declarative-tunnels",
Usage: "Test establishing connections with declarative tunnel methods.",
EnvVars: []string{"TUNNEL_USE_DECLARATIVE"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "intent",
Usage: "The label of an Intent from which `cloudflared` should gets its tunnels from. Intents can be created in the Origin Registry UI.",
EnvVars: []string{"TUNNEL_INTENT"},
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "use-reconnect-token",
Usage: "Test reestablishing connections with the new 'reconnect token' flow.",

@ -14,7 +14,6 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/tlsconfig"
@ -290,15 +289,6 @@ func prepareTunnelConfig(
}, nil
}
func serviceDiscoverer(c *cli.Context, logger logger.Service) (*edgediscovery.Edge, error) {
// If --edge is specfied, resolve edge server addresses
if len(c.StringSlice("edge")) > 0 {
return edgediscovery.StaticEdge(logger, c.StringSlice("edge"))
}
// Otherwise lookup edge server addresses through service discovery
return edgediscovery.ResolveEdge(logger)
}
func isRunningFromTerminal() bool {
return terminal.IsTerminal(int(os.Stdout.Fd()))
}

@ -1,57 +0,0 @@
package connection
import (
"context"
"net"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
openStreamTimeout = 30 * time.Second
)
type Connection struct {
id uuid.UUID
muxer *h2mux.Muxer
addr *net.TCPAddr
isLongLived bool
longLivedID int
}
func newConnection(muxer *h2mux.Muxer, addr *net.TCPAddr) (*Connection, error) {
id, err := uuid.NewRandom()
if err != nil {
return nil, err
}
return &Connection{
id: id,
muxer: muxer,
addr: addr,
}, nil
}
func (c *Connection) Serve(ctx context.Context) error {
// Serve doesn't return until h2mux is shutdown
return c.muxer.Serve(ctx)
}
// Connect is used to establish connections with cloudflare's edge network
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger logger.Service) (tunnelpogs.ConnectResult, error) {
tsClient, err := NewRPCClient(ctx, c.muxer, logger, openStreamTimeout)
if err != nil {
return nil, errors.Wrap(err, "cannot create new RPC connection")
}
defer tsClient.Close()
return tsClient.Connect(ctx, parameters)
}
func (c *Connection) Shutdown() {
c.muxer.Shutdown()
}

@ -1,302 +0,0 @@
package connection
import (
"context"
"crypto/tls"
"fmt"
"sync"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/streamhandler"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
quickStartLink = "https://developers.cloudflare.com/argo-tunnel/quickstart/"
faqLink = "https://developers.cloudflare.com/argo-tunnel/faq/"
defaultRetryAfter = time.Second * 5
packageNamespace = "connection"
edgeManagerSubsystem = "edgemanager"
)
// EdgeManager manages connections with the edge
type EdgeManager struct {
// streamHandler handles stream opened by the edge
streamHandler *streamhandler.StreamHandler
// TLSConfig is the TLS configuration to connect with edge
tlsConfig *tls.Config
// cloudflaredConfig is the cloudflared configuration that is determined when the process first starts
cloudflaredConfig *CloudflaredConfig
// serviceDiscoverer returns the next edge addr to connect to
serviceDiscoverer *edgediscovery.Edge
// state is attributes of ConnectionManager that can change during runtime.
state *edgeManagerState
logger logger.Service
metrics *metrics
}
type metrics struct {
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
activeStreams prometheus.Gauge
}
func newMetrics(namespace, subsystem string) *metrics {
return &metrics{
activeStreams: h2mux.NewActiveStreamsMetrics(namespace, subsystem),
}
}
// EdgeManagerConfigurable is the configurable attributes of a EdgeConnectionManager
type EdgeManagerConfigurable struct {
TunnelHostnames []h2mux.TunnelHostname
*tunnelpogs.EdgeConnectionConfig
}
type CloudflaredConfig struct {
CloudflaredID uuid.UUID
Tags []tunnelpogs.Tag
BuildInfo *buildinfo.BuildInfo
IntentLabel string
}
func NewEdgeManager(
streamHandler *streamhandler.StreamHandler,
edgeConnMgrConfigurable *EdgeManagerConfigurable,
userCredential []byte,
tlsConfig *tls.Config,
serviceDiscoverer *edgediscovery.Edge,
cloudflaredConfig *CloudflaredConfig,
logger logger.Service,
) *EdgeManager {
return &EdgeManager{
streamHandler: streamHandler,
tlsConfig: tlsConfig,
cloudflaredConfig: cloudflaredConfig,
serviceDiscoverer: serviceDiscoverer,
state: newEdgeConnectionManagerState(edgeConnMgrConfigurable, userCredential),
logger: logger,
metrics: newMetrics(packageNamespace, edgeManagerSubsystem),
}
}
func (em *EdgeManager) Run(ctx context.Context) error {
defer em.shutdown()
// Currently, declarative tunnels don't have any concept of a stable connection
// Each edge connection is transient and when it dies, it is replaced by a different one,
// not restarted.
// So in the future we should really change this so that n connections are stored individually
connIndex := 0
for {
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "EdgeConnectionManager terminated")
default:
time.Sleep(1 * time.Second)
}
// Create/delete connection one at a time, so we don't need to adjust for connections that are being created/deleted
// in shouldCreateConnection or shouldReduceConnection calculation
if em.state.shouldCreateConnection(em.serviceDiscoverer.AvailableAddrs()) {
if connErr := em.newConnection(ctx, connIndex); connErr != nil {
if !connErr.ShouldRetry {
em.logger.Errorf("connectionManager: %s with error: %s", em.noRetryMessage(), connErr)
return connErr
}
em.logger.Errorf("connectionManager: cannot create new connection: %s", connErr)
} else {
connIndex++
}
} else if em.state.shouldReduceConnection() {
if err := em.closeConnection(ctx); err != nil {
em.logger.Errorf("connectionManager: cannot close connection: %s", err)
}
}
}
}
func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurable) {
em.logger.Infof("New edge connection manager configuration %+v", newConfigurable)
em.state.updateConfigurable(newConfigurable)
}
func (em *EdgeManager) newConnection(ctx context.Context, index int) *tunnelpogs.ConnectError {
edgeTCPAddr, err := em.serviceDiscoverer.GetAddr(index)
if err != nil {
return retryConnection(fmt.Sprintf("edge address discovery error: %v", err))
}
configurable := em.state.getConfigurable()
edgeConn, err := DialEdge(ctx, configurable.Timeout, em.tlsConfig, edgeTCPAddr)
if err != nil {
return retryConnection(fmt.Sprintf("dial edge error: %v", err))
}
// Establish a muxed connection with the edge
// Client mux handshake with agent server
muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
Timeout: configurable.Timeout,
Handler: em.streamHandler,
IsClient: true,
HeartbeatInterval: configurable.HeartbeatInterval,
MaxHeartbeats: configurable.MaxFailedHeartbeats,
Logger: em.logger,
}, em.metrics.activeStreams)
if err != nil {
retryConnection(fmt.Sprintf("couldn't perform handshake with edge: %v", err))
}
h2muxConn, err := newConnection(muxer, edgeTCPAddr)
if err != nil {
return retryConnection(fmt.Sprintf("couldn't create h2mux connection: %v", err))
}
go em.serveConn(ctx, h2muxConn)
connResult, err := h2muxConn.Connect(ctx, &tunnelpogs.ConnectParameters{
CloudflaredID: em.cloudflaredConfig.CloudflaredID,
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
NumPreviousAttempts: 0,
OriginCert: em.state.getUserCredential(),
IntentLabel: em.cloudflaredConfig.IntentLabel,
Tags: em.cloudflaredConfig.Tags,
}, em.logger)
if err != nil {
h2muxConn.Shutdown()
return retryConnection(fmt.Sprintf("couldn't connect to edge: %v", err))
}
if connErr := connResult.ConnectError(); connErr != nil {
return connErr
}
em.state.newConnection(h2muxConn)
em.logger.Infof("connectionManager: connected to %s", connResult.ConnectedTo())
if connResult.ClientConfig() != nil {
em.streamHandler.UseConfiguration(ctx, connResult.ClientConfig())
}
return nil
}
func (em *EdgeManager) closeConnection(ctx context.Context) error {
conn := em.state.getFirstConnection()
if conn == nil {
return fmt.Errorf("no connection to close")
}
conn.Shutdown()
// teardown will be handled by EdgeManager.serveConn in another goroutine
return nil
}
func (em *EdgeManager) serveConn(ctx context.Context, conn *Connection) {
err := conn.Serve(ctx)
em.logger.Errorf("connectionManager: Connection closed: %s", err)
em.state.closeConnection(conn)
em.serviceDiscoverer.GiveBack(conn.addr)
}
func (em *EdgeManager) noRetryMessage() string {
messageTemplate := "cloudflared could not register an Argo Tunnel on your account. Please confirm the following before trying again:" +
"1. You have Argo Smart Routing enabled in your account, See Enable Argo section of %s." +
"2. Your credential at %s is still valid. See %s."
return fmt.Sprintf(messageTemplate, quickStartLink, em.state.getConfigurable().UserCredentialPath, faqLink)
}
func (em *EdgeManager) shutdown() {
em.state.shutdown()
}
type edgeManagerState struct {
sync.RWMutex
configurable *EdgeManagerConfigurable
userCredential []byte
conns map[uuid.UUID]*Connection
}
func newEdgeConnectionManagerState(configurable *EdgeManagerConfigurable, userCredential []byte) *edgeManagerState {
return &edgeManagerState{
configurable: configurable,
userCredential: userCredential,
conns: make(map[uuid.UUID]*Connection),
}
}
func (ems *edgeManagerState) shouldCreateConnection(availableEdgeAddrs int) bool {
ems.RLock()
defer ems.RUnlock()
expectedHAConns := int(ems.configurable.NumHAConnections)
if availableEdgeAddrs < expectedHAConns {
expectedHAConns = availableEdgeAddrs
}
return len(ems.conns) < expectedHAConns
}
func (ems *edgeManagerState) shouldReduceConnection() bool {
ems.RLock()
defer ems.RUnlock()
return uint8(len(ems.conns)) > ems.configurable.NumHAConnections
}
func (ems *edgeManagerState) newConnection(conn *Connection) {
ems.Lock()
defer ems.Unlock()
ems.conns[conn.id] = conn
}
func (ems *edgeManagerState) closeConnection(conn *Connection) {
ems.Lock()
defer ems.Unlock()
delete(ems.conns, conn.id)
}
func (ems *edgeManagerState) getFirstConnection() *Connection {
ems.RLock()
defer ems.RUnlock()
for _, conn := range ems.conns {
return conn
}
return nil
}
func (ems *edgeManagerState) shutdown() {
ems.Lock()
defer ems.Unlock()
for _, conn := range ems.conns {
conn.Shutdown()
}
}
func (ems *edgeManagerState) getConfigurable() *EdgeManagerConfigurable {
ems.Lock()
defer ems.Unlock()
return ems.configurable
}
func (ems *edgeManagerState) updateConfigurable(newConfigurable *EdgeManagerConfigurable) {
ems.Lock()
defer ems.Unlock()
ems.configurable = newConfigurable
}
func (ems *edgeManagerState) getUserCredential() []byte {
ems.RLock()
defer ems.RUnlock()
return ems.userCredential
}
func retryConnection(cause string) *tunnelpogs.ConnectError {
return &tunnelpogs.ConnectError{
Cause: cause,
RetryAfter: defaultRetryAfter,
ShouldRetry: true,
}
}

@ -1,77 +0,0 @@
package connection
import (
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
var (
configurable = &EdgeManagerConfigurable{
[]h2mux.TunnelHostname{
"http.example.com",
"ws.example.com",
"hello.example.com",
},
&pogs.EdgeConnectionConfig{
NumHAConnections: 1,
HeartbeatInterval: 1 * time.Second,
Timeout: 5 * time.Second,
MaxFailedHeartbeats: 3,
UserCredentialPath: "/etc/cloudflared/cert.pem",
},
}
cloudflaredConfig = &CloudflaredConfig{
CloudflaredID: uuid.New(),
Tags: []pogs.Tag{
{Name: "pool", Value: "east-6"},
},
BuildInfo: &buildinfo.BuildInfo{
GoOS: "linux",
GoVersion: "1.12",
GoArch: "amd64",
CloudflaredVersion: "2019.6.0",
},
}
)
func mockEdgeManager() *EdgeManager {
newConfigChan := make(chan<- *pogs.ClientConfig)
useConfigResultChan := make(<-chan *pogs.UseConfigurationResult)
logger := logger.NewOutputWriter(logger.NewMockWriteManager())
edge := edgediscovery.MockEdge(logger, []*net.TCPAddr{})
return NewEdgeManager(
streamhandler.NewStreamHandler(newConfigChan, useConfigResultChan, logger),
configurable,
[]byte{},
nil,
edge,
cloudflaredConfig,
logger,
)
}
func TestUpdateConfigurable(t *testing.T) {
m := mockEdgeManager()
newConfigurable := &EdgeManagerConfigurable{
[]h2mux.TunnelHostname{
"second.example.com",
},
&pogs.EdgeConnectionConfig{
NumHAConnections: 2,
},
}
m.UpdateConfigurable(newConfigurable)
assert.Equal(t, newConfigurable, m.state.getConfigurable())
}

@ -70,7 +70,7 @@ require (
gopkg.in/square/go-jose.v2 v2.4.0 // indirect
gopkg.in/urfave/cli.v2 v2.0.0-20180128181224-d604b6ffeee8
gopkg.in/yaml.v2 v2.2.4
zombiezen.com/go/capnproto2 v0.0.0-20180616160808-7cfd211c19c7
zombiezen.com/go/capnproto2 v2.18.0+incompatible
)
// ../../go/pkg/mod/github.com/coredns/coredns@v1.2.0/plugin/metrics/metrics.go:40:49: too many arguments in call to prometheus.NewProcessCollector

@ -9,9 +9,9 @@ github.com/DATA-DOG/go-sqlmock v1.3.3 h1:CWUqKXe0s8A2z6qCgkP4Kru7wC11YoAnoupUKFD
github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/GeertJohan/go.incremental v1.0.0/go.mod h1:6fAjUhbVuX1KcMD3c8TEgVUqmo4seqhv0i0kdATSkM0=
github.com/GeertJohan/go.rice v1.0.0/go.mod h1:eH6gbSOAUv07dQuZVnBmoDP8mgsM1rtixis4Tib9if0=
github.com/akavel/rsrc v0.8.0/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c=
github.com/acmacalister/skittles v0.0.0-20160609003031-7423546701e1 h1:RKnVV4C7qoN/sToLX2y1dqH7T6kKLMHcwRJlgwb9Ggk=
github.com/acmacalister/skittles v0.0.0-20160609003031-7423546701e1/go.mod h1:gI5CyA/CEnS6eqNV22rqs4dG3aGfaSbXgPORIlwr2r0=
github.com/akavel/rsrc v0.8.0/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4 h1:Hs82Z41s6SdL1CELW+XaDYmOH4hkBN4/N9og/AsOv7E=
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@ -285,3 +285,5 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
zombiezen.com/go/capnproto2 v0.0.0-20180616160808-7cfd211c19c7 h1:CZoOFlTPbKfAShKYrMuUfYbnXexFT1rYRUX1SPnrdE4=
zombiezen.com/go/capnproto2 v0.0.0-20180616160808-7cfd211c19c7/go.mod h1:TMGa8HWGJkXiq4nHe9Zu/JgRF5oUtg4XizFC+Vexbec=
zombiezen.com/go/capnproto2 v2.18.0+incompatible h1:mwfXZniffG5mXokQGHUJWGnqIBggoPfT/CEwon9Yess=
zombiezen.com/go/capnproto2 v2.18.0+incompatible/go.mod h1:XO5Pr2SbXgqZwn0m0Ru54QBqpOf4K5AYBO+8LAOBQEQ=

@ -25,7 +25,6 @@ import (
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
@ -618,8 +617,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
return reqErr
}
cfRay := streamhandler.FindCfRayHeader(req)
lbProbe := streamhandler.IsLBProbeRequest(req)
cfRay := findCfRayHeader(req)
lbProbe := isLBProbeRequest(req)
h.logRequest(req, cfRay, lbProbe)
var resp *http.Response
@ -833,3 +832,11 @@ func activeIncidentsMsg(incidents []Incident) string {
return preamble + " " + strings.Join(incidentStrings, "; ")
}
func findCfRayHeader(h1 *http.Request) string {
return h1.Header.Get("Cf-Ray")
}
func isLBProbeRequest(req *http.Request) bool {
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
}

@ -1,247 +0,0 @@
// Package client defines and implements interface to proxy to HTTP, websocket and hello world origins
package originservice
import (
"bufio"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/cloudflare/cloudflared/buffer"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
)
// OriginService is an interface to proxy requests to different type of origins
type OriginService interface {
Proxy(stream *h2mux.MuxedStream, req *http.Request) (resp *http.Response, err error)
URL() *url.URL
Summary() string
Shutdown()
}
// HTTPService talks to origin using HTTP/HTTPS
type HTTPService struct {
client http.RoundTripper
originURL *url.URL
chunkedEncoding bool
bufferPool *buffer.Pool
}
func NewHTTPService(transport http.RoundTripper, url *url.URL, chunkedEncoding bool) OriginService {
return &HTTPService{
client: transport,
originURL: url,
chunkedEncoding: chunkedEncoding,
bufferPool: buffer.NewPool(512 * 1024),
}
}
func (hc *HTTPService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
const responseSourceOrigin = "origin"
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if !hc.chunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if err == nil {
req.ContentLength = int64(cLength)
}
}
// Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive")
resp, err := hc.client.RoundTrip(req)
if err != nil {
return nil, errors.Wrap(err, "error proxying request to HTTP origin")
}
defer resp.Body.Close()
responseHeaders := h1ResponseToH2Response(resp)
responseHeaders = append(responseHeaders, h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, responseSourceOrigin))
err = stream.WriteHeaders(responseHeaders)
if err != nil {
return nil, errors.Wrap(err, "error writing response header to HTTP origin")
}
if isEventStream(resp) {
writeEventStream(stream, resp.Body)
} else {
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write
buf := hc.bufferPool.Get()
defer hc.bufferPool.Put(buf)
io.CopyBuffer(stream, resp.Body, buf)
}
return resp, nil
}
func (hc *HTTPService) URL() *url.URL {
return hc.originURL
}
func (hc *HTTPService) Summary() string {
return fmt.Sprintf("HTTP service listening on %s", hc.originURL)
}
func (hc *HTTPService) Shutdown() {}
// WebsocketService talks to origin using WS/WSS
type WebsocketService struct {
tlsConfig *tls.Config
originURL *url.URL
shutdownC chan struct{}
}
func NewWebSocketService(tlsConfig *tls.Config, url *url.URL, logger logger.Service) (OriginService, error) {
listener, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
return nil, errors.Wrap(err, "cannot start Websocket Proxy Server")
}
shutdownC := make(chan struct{})
go func() {
websocket.StartProxyServer(logger, listener, url.String(), shutdownC, websocket.DefaultStreamHandler)
}()
return &WebsocketService{
tlsConfig: tlsConfig,
originURL: url,
shutdownC: shutdownC,
}, nil
}
func (wsc *WebsocketService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
if !websocket.IsWebSocketUpgrade(req) {
return nil, fmt.Errorf("request is not a websocket connection")
}
conn, response, err := websocket.ClientConnect(req, wsc.tlsConfig)
if err != nil {
return nil, err
}
defer conn.Close()
err = stream.WriteHeaders(h1ResponseToH2Response(response))
if err != nil {
return nil, errors.Wrap(err, "error writing response header to websocket origin")
}
// Copy to/from stream to the undelying connection. Use the underlying
// connection because cloudflared doesn't operate on the message themselves
websocket.Stream(conn.UnderlyingConn(), stream)
return response, nil
}
func (wsc *WebsocketService) URL() *url.URL {
return wsc.originURL
}
func (wsc *WebsocketService) Summary() string {
return fmt.Sprintf("Websocket listening on %s", wsc.originURL)
}
func (wsc *WebsocketService) Shutdown() {
close(wsc.shutdownC)
}
// HelloWorldService talks to the hello world example origin
type HelloWorldService struct {
client http.RoundTripper
listener net.Listener
originURL *url.URL
shutdownC chan struct{}
bufferPool *buffer.Pool
}
func NewHelloWorldService(transport http.RoundTripper, logger logger.Service) (OriginService, error) {
listener, err := hello.CreateTLSListener("127.0.0.1:")
if err != nil {
return nil, errors.Wrap(err, "cannot start Hello World Server")
}
shutdownC := make(chan struct{})
go func() {
hello.StartHelloWorldServer(logger, listener, shutdownC)
}()
return &HelloWorldService{
client: transport,
listener: listener,
originURL: &url.URL{
Scheme: "https",
Host: listener.Addr().String(),
},
shutdownC: shutdownC,
bufferPool: buffer.NewPool(512 * 1024),
}, nil
}
func (hwc *HelloWorldService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
// Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive")
resp, err := hwc.client.RoundTrip(req)
if err != nil {
return nil, errors.Wrap(err, "error proxying request to Hello World origin")
}
defer resp.Body.Close()
err = stream.WriteHeaders(h1ResponseToH2Response(resp))
if err != nil {
return nil, errors.Wrap(err, "error writing response header to Hello World origin")
}
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write
buf := hwc.bufferPool.Get()
defer hwc.bufferPool.Put(buf)
io.CopyBuffer(stream, resp.Body, buf)
return resp, nil
}
func (hwc *HelloWorldService) URL() *url.URL {
return hwc.originURL
}
func (hwc *HelloWorldService) Summary() string {
return fmt.Sprintf("Hello World service listening on %s", hwc.originURL)
}
func (hwc *HelloWorldService) Shutdown() {
hwc.listener.Close()
}
func isEventStream(resp *http.Response) bool {
// Check if content-type is text/event-stream. We need to check if the header value starts with text/event-stream
// because text/event-stream; charset=UTF-8 is also valid
// Ref: https://tools.ietf.org/html/rfc7231#section-3.1.1.1
for _, contentType := range resp.Header["content-type"] {
if strings.HasPrefix(strings.ToLower(contentType), "text/event-stream") {
return true
}
}
return false
}
func writeEventStream(stream *h2mux.MuxedStream, respBody io.ReadCloser) {
reader := bufio.NewReader(respBody)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
break
}
stream.Write(line)
}
}
func h1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}}
for headerName, headerValues := range h1.Header {
for _, headerValue := range headerValues {
h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue})
}
}
return
}

@ -1,60 +0,0 @@
package originservice
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsEventStream(t *testing.T) {
tests := []struct {
resp *http.Response
isEventStream bool
}{
{
resp: &http.Response{},
isEventStream: false,
},
{
// isEventStream checks all headers
resp: &http.Response{
Header: http.Header{
"accept": []string{"text/html"},
"content-type": []string{"text/event-stream"},
},
},
isEventStream: true,
},
{
// Content-Type and text/event-stream are case-insensitive. text/event-stream can be followed by OWS parameter
resp: &http.Response{
Header: http.Header{
"content-type": []string{"Text/event-stream;charset=utf-8"},
},
},
isEventStream: true,
},
{
// Content-Type and text/event-stream are case-insensitive. text/event-stream can be followed by OWS parameter
resp: &http.Response{
Header: http.Header{
"content-type": []string{"appication/json", "text/html", "Text/event-stream;charset=utf-8"},
},
},
isEventStream: true,
},
{
// Not an event stream because the content-type value doesn't start with text/event-stream
resp: &http.Response{
Header: http.Header{
"content-type": []string{" text/event-stream"},
},
},
isEventStream: false,
},
}
for _, test := range tests {
assert.Equal(t, test.isEventStream, isEventStream(test.resp), "Header: %v", test.resp.Header)
}
}

@ -1,34 +0,0 @@
package streamhandler
import (
"net/http"
"net/url"
"strings"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/pkg/errors"
)
const (
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
)
func FindCfRayHeader(h1 *http.Request) string {
return h1.Header.Get("Cf-Ray")
}
func IsLBProbeRequest(req *http.Request) bool {
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
}
func createRequest(stream *h2mux.MuxedStream, url *url.URL) (*http.Request, error) {
req, err := http.NewRequest(http.MethodGet, url.String(), h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil {
return nil, errors.Wrap(err, "unexpected error from http.NewRequest")
}
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
return nil, errors.Wrap(err, "invalid request received")
}
return req, nil
}

@ -1,189 +0,0 @@
package streamhandler
import (
"context"
"fmt"
"net/http"
"strconv"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelhostnamemapper"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/pkg/errors"
"zombiezen.com/go/capnproto2/rpc"
)
const (
statusPseudoHeader = ":status"
)
type httpErrorStatus struct {
status string
text []byte
}
var (
statusBadRequest = newHTTPErrorStatus(http.StatusBadRequest)
statusNotFound = newHTTPErrorStatus(http.StatusNotFound)
statusBadGateway = newHTTPErrorStatus(http.StatusBadGateway)
)
func newHTTPErrorStatus(status int) *httpErrorStatus {
return &httpErrorStatus{
status: strconv.Itoa(status),
text: []byte(http.StatusText(status)),
}
}
// StreamHandler handles new stream opened by the edge. The streams can be used to proxy requests or make RPC.
type StreamHandler struct {
// newConfigChan is a send-only channel to notify Supervisor of a new ClientConfig
newConfigChan chan<- *pogs.ClientConfig
// useConfigResultChan is a receive-only channel for Supervisor to communicate the result of applying a new ClientConfig
useConfigResultChan <-chan *pogs.UseConfigurationResult
// originMapper maps tunnel hostname to origin service
tunnelHostnameMapper *tunnelhostnamemapper.TunnelHostnameMapper
logger logger.Service
}
// NewStreamHandler creates a new StreamHandler
func NewStreamHandler(newConfigChan chan<- *pogs.ClientConfig,
useConfigResultChan <-chan *pogs.UseConfigurationResult,
logger logger.Service,
) *StreamHandler {
return &StreamHandler{
newConfigChan: newConfigChan,
useConfigResultChan: useConfigResultChan,
tunnelHostnameMapper: tunnelhostnamemapper.NewTunnelHostnameMapper(),
logger: logger,
}
}
// UseConfiguration implements ClientService
func (s *StreamHandler) UseConfiguration(ctx context.Context, config *pogs.ClientConfig) (*pogs.UseConfigurationResult, error) {
select {
case <-ctx.Done():
err := fmt.Errorf("Timeout while sending new config to Supervisor")
s.logger.Errorf("streamHandler: %s", err)
return nil, err
case s.newConfigChan <- config:
}
select {
case <-ctx.Done():
err := fmt.Errorf("Timeout applying new configuration")
s.logger.Errorf("streamHandler: %s", err)
return nil, err
case result := <-s.useConfigResultChan:
return result, nil
}
}
// UpdateConfig replaces current originmapper mapping with mappings from newConfig
func (s *StreamHandler) UpdateConfig(newConfig []*pogs.ReverseProxyConfig) (failedConfigs []*pogs.FailedConfig) {
// Delete old configs that aren't in the `newConfig`
toRemove := s.tunnelHostnameMapper.ToRemove(newConfig)
for _, hostnameToRemove := range toRemove {
s.tunnelHostnameMapper.Delete(hostnameToRemove)
}
// Add new configs that weren't in the old mapper
toAdd := s.tunnelHostnameMapper.ToAdd(newConfig)
for _, tunnelConfig := range toAdd {
tunnelHostname := tunnelConfig.TunnelHostname
originSerice, err := tunnelConfig.OriginConfig.Service(s.logger)
if err != nil {
s.logger.Errorf("streamHandler: tunnelHostname: %s Invalid origin service config: %s", tunnelHostname, err)
failedConfigs = append(failedConfigs, &pogs.FailedConfig{
Config: tunnelConfig,
Reason: tunnelConfig.FailReason(err),
})
continue
}
s.tunnelHostnameMapper.Add(tunnelConfig.TunnelHostname, originSerice)
s.logger.Infof("streamHandler: tunnelHostname: %s New origin service config: %v", tunnelHostname, originSerice.Summary())
}
return
}
// ServeStream implements MuxedStreamHandler interface
func (s *StreamHandler) ServeStream(stream *h2mux.MuxedStream) error {
if stream.IsRPCStream() {
return s.serveRPC(stream)
}
if err := s.serveRequest(stream); err != nil {
s.logger.Errorf("streamHandler: %s", err)
return err
}
return nil
}
func (s *StreamHandler) serveRPC(stream *h2mux.MuxedStream) error {
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "200"}})
main := pogs.ClientService_ServerToClient(s)
rpcConn := rpc.NewConn(
tunnelrpc.NewTransportLogger(s.logger, rpc.StreamTransport(stream)),
rpc.MainInterface(main.Client),
tunnelrpc.ConnLog(s.logger),
)
return rpcConn.Wait()
}
func (s *StreamHandler) serveRequest(stream *h2mux.MuxedStream) error {
tunnelHostname := stream.TunnelHostname()
if !tunnelHostname.IsSet() {
s.writeErrorStatus(stream, statusBadRequest)
return fmt.Errorf("stream doesn't have tunnelHostname")
}
originService, ok := s.tunnelHostnameMapper.Get(tunnelHostname)
if !ok {
s.writeErrorStatus(stream, statusNotFound)
return fmt.Errorf("cannot map tunnel hostname %s to origin", tunnelHostname)
}
req, err := createRequest(stream, originService.URL())
if err != nil {
s.writeErrorStatus(stream, statusBadRequest)
return errors.Wrap(err, "cannot create request")
}
cfRay := s.logRequest(req, tunnelHostname)
s.logger.Debugf("streamHandler: tunnelHostname: %s CF-RAY: %s Request Headers %+v", tunnelHostname, cfRay, req.Header)
resp, err := originService.Proxy(stream, req)
if err != nil {
s.writeErrorStatus(stream, statusBadGateway)
return errors.Wrap(err, "cannot proxy request")
}
s.logger.Debugf("streamHandler: tunnelHostname: %s CF-RAY: %s status: %s Response Headers %+v", tunnelHostname, cfRay, resp.Status, resp.Header)
return nil
}
func (s *StreamHandler) logRequest(req *http.Request, tunnelHostname h2mux.TunnelHostname) string {
cfRay := FindCfRayHeader(req)
lbProbe := IsLBProbeRequest(req)
logger := s.logger
if cfRay != "" {
logger.Debugf("streamHandler: tunnelHostname: %s CF-RAY: %s %s %s %s", tunnelHostname, cfRay, req.Method, req.URL, req.Proto)
} else if lbProbe {
logger.Debugf("streamHandler: tunnelHostname: %s CF-RAY: %s Load Balancer health check %s %s %s", tunnelHostname, cfRay, req.Method, req.URL, req.Proto)
} else {
logger.Infof("streamHandler: tunnelHostname: %s CF-RAY: %s Requests %v does not have CF-RAY header. Please open a support ticket with Cloudflare.", tunnelHostname, cfRay, req)
}
return cfRay
}
func (s *StreamHandler) writeErrorStatus(stream *h2mux.MuxedStream, status *httpErrorStatus) {
_ = stream.WriteHeaders([]h2mux.Header{
{
Name: statusPseudoHeader,
Value: status.status,
},
h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceCloudflared),
})
_, _ = stream.Write(status.text)
}

@ -1,261 +0,0 @@
package streamhandler
import (
"context"
"io"
"net"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
)
const (
testOpenStreamTimeout = time.Millisecond * 5000
testHandshakeTimeout = time.Millisecond * 1000
)
var (
testTunnelHostname = h2mux.TunnelHostname("123.cftunnel.com")
baseHeaders = []h2mux.Header{
{Name: ":method", Value: "GET"},
{Name: ":scheme", Value: "http"},
{Name: ":authority", Value: "example.com"},
{Name: ":path", Value: "/"},
// Regular headers must always come after the pseudoheaders
{Name: h2mux.RequestUserHeadersField, Value: ""},
}
tunnelHostnameHeader = h2mux.Header{Name: h2mux.CloudflaredProxyTunnelHostnameHeader, Value: testTunnelHostname.String()}
)
func TestServeRequest(t *testing.T) {
l := logger.NewOutputWriter(logger.NewMockWriteManager())
configChan := make(chan *pogs.ClientConfig)
useConfigResultChan := make(chan *pogs.UseConfigurationResult)
streamHandler := NewStreamHandler(configChan, useConfigResultChan, l)
message := []byte("Hello cloudflared")
httpServer := httptest.NewServer(&mockHTTPHandler{message})
reverseProxyConfigs := []*pogs.ReverseProxyConfig{
{
TunnelHostname: testTunnelHostname,
OriginConfig: &pogs.HTTPOriginConfig{
URLString: httpServer.URL,
},
},
}
streamHandler.UpdateConfig(reverseProxyConfigs)
muxPair := NewDefaultMuxerPair(t, streamHandler)
muxPair.Serve(t)
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
defer cancel()
headers := append(baseHeaders, tunnelHostnameHeader)
stream, err := muxPair.EdgeMux.OpenStream(ctx, headers, nil)
assert.NoError(t, err)
assertStatusHeader(t, http.StatusOK, stream.Headers)
assertRespBody(t, message, stream)
}
func createStreamHandler() *StreamHandler {
configChan := make(chan *pogs.ClientConfig)
useConfigResultChan := make(chan *pogs.UseConfigurationResult)
l := logger.NewOutputWriter(logger.NewMockWriteManager())
return NewStreamHandler(configChan, useConfigResultChan, l)
}
func createRequestMuxPair(t *testing.T, streamHandler *StreamHandler) *DefaultMuxerPair {
muxPair := NewDefaultMuxerPair(t, streamHandler)
muxPair.Serve(t)
return muxPair
}
func TestServeStatusBadRequest(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
defer cancel()
// No tunnel hostname header, expect to get 400 Bad Request
stream, err := createRequestMuxPair(t, createStreamHandler()).EdgeMux.OpenStream(ctx, baseHeaders, nil)
assert.NoError(t, err)
assertStatusHeader(t, http.StatusBadRequest, stream.Headers)
assertRespBody(t, statusBadRequest.text, stream)
}
func TestServeInvalidContentLength(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
defer cancel()
// Invalid content-length, wouldn't be able to create a request
// Expect to get 400 Bad Request
headers := append(baseHeaders, tunnelHostnameHeader)
headers = append(headers, h2mux.Header{
Name: "content-length",
Value: "x",
})
streamHandler := createStreamHandler()
streamHandler.UpdateConfig([]*pogs.ReverseProxyConfig{
{
TunnelHostname: testTunnelHostname,
OriginConfig: &pogs.HTTPOriginConfig{
URLString: "",
},
},
})
mux := createRequestMuxPair(t, streamHandler).EdgeMux
stream, err := mux.OpenStream(ctx, headers, nil)
assert.NoError(t, err)
assertStatusHeader(t, http.StatusBadRequest, stream.Headers)
assertRespBody(t, statusBadRequest.text, stream)
}
func TestServeStatusNotFound(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
defer cancel()
// No mapping for the tunnel hostname, expect to get 404 Not Found
headers := append(baseHeaders, tunnelHostnameHeader)
stream, err := createRequestMuxPair(t, createStreamHandler()).EdgeMux.OpenStream(ctx, headers, nil)
assert.NoError(t, err)
assertStatusHeader(t, http.StatusNotFound, stream.Headers)
assertRespBody(t, statusNotFound.text, stream)
}
func TestServeStatusBadGateway(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
defer cancel()
// Nothing listening on empty url, so proxy would fail. Expect to get 502 Bad Gateway
reverseProxyConfigs := []*pogs.ReverseProxyConfig{
{
TunnelHostname: testTunnelHostname,
OriginConfig: &pogs.HTTPOriginConfig{
URLString: "",
},
},
}
streamHandler := createStreamHandler()
streamHandler.UpdateConfig(reverseProxyConfigs)
headers := append(baseHeaders, tunnelHostnameHeader)
stream, err := createRequestMuxPair(t, streamHandler).EdgeMux.OpenStream(ctx, headers, nil)
assert.NoError(t, err)
assertStatusHeader(t, http.StatusBadGateway, stream.Headers)
assertRespBody(t, statusBadGateway.text, stream)
}
func assertStatusHeader(t *testing.T, expectedStatus int, headers []h2mux.Header) {
assert.Equal(t, statusPseudoHeader, headers[0].Name)
assert.Equal(t, strconv.Itoa(expectedStatus), headers[0].Value)
}
func assertRespBody(t *testing.T, expectedRespBody []byte, stream *h2mux.MuxedStream) {
respBody := make([]byte, len(expectedRespBody))
_, err := stream.Read(respBody)
assert.NoError(t, err)
assert.Equal(t, expectedRespBody, respBody)
}
type DefaultMuxerPair struct {
OriginMuxConfig h2mux.MuxerConfig
OriginMux *h2mux.Muxer
OriginConn net.Conn
EdgeMuxConfig h2mux.MuxerConfig
EdgeMux *h2mux.Muxer
EdgeConn net.Conn
doneC chan struct{}
}
func NewDefaultMuxerPair(t *testing.T, h h2mux.MuxedStreamHandler) *DefaultMuxerPair {
origin, edge := net.Pipe()
p := &DefaultMuxerPair{
OriginMuxConfig: h2mux.MuxerConfig{
Timeout: testHandshakeTimeout,
Handler: h,
IsClient: true,
Name: "origin",
Logger: logger.NewOutputWriter(logger.NewMockWriteManager()),
<