From d26a8c5d44efdf44f7f96037964611d7b63aca60 Mon Sep 17 00:00:00 2001 From: Chung-Ting Huang Date: Wed, 5 Jun 2019 10:08:55 -0500 Subject: [PATCH] TUN-1893: Proxy requests to the origin based on tunnel hostname --- connection/connection.go | 18 +--- connection/supervisor.go | 21 ++++- h2mux/h2mux.go | 43 +++++++++ h2mux/muxedstream.go | 22 ++++- h2mux/muxedstream_test.go | 27 ++++++ origin/tunnel.go | 48 +--------- originservice/originservice.go | 56 ++++++++---- streamhandler/request.go | 69 ++++++++++++++ streamhandler/stream_handler.go | 91 +++++++++++++++++++ tunnelhostnamemapper/tunnelhostnamemapper.go | 49 ++++++++++ .../tunnelhostnamemapper_test.go | 69 ++++++++++++++ 11 files changed, 431 insertions(+), 82 deletions(-) create mode 100644 streamhandler/request.go create mode 100644 streamhandler/stream_handler.go create mode 100644 tunnelhostnamemapper/tunnelhostnamemapper.go create mode 100644 tunnelhostnamemapper/tunnelhostnamemapper_test.go diff --git a/connection/connection.go b/connection/connection.go index f9dd6125..0e830e91 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -8,6 +8,7 @@ import ( "time" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/streamhandler" "github.com/cloudflare/cloudflared/tunnelrpc" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/pkg/errors" @@ -53,14 +54,6 @@ type h2muxHandler struct { logger *logrus.Entry } -type muxedStreamHandler struct { -} - -// Implements MuxedStreamHandler interface -func (h *muxedStreamHandler) ServeStream(stream *h2mux.MuxedStream) error { - return nil -} - func (h *h2muxHandler) serve(ctx context.Context) error { // Serve doesn't return until h2mux is shutdown if err := h.muxer.Serve(ctx); err != nil { @@ -87,11 +80,7 @@ func (h *h2muxHandler) shutdown() { } func (h *h2muxHandler) newRPConn(ctx context.Context) (*rpc.Conn, error) { - stream, err := h.muxer.OpenStream(ctx, []h2mux.Header{ - {Name: ":method", Value: "RPC"}, - {Name: ":scheme", Value: "capnp"}, - {Name: ":path", Value: "*"}, - }, nil) + stream, err := h.muxer.OpenRPCStream(ctx) if err != nil { return nil, err } @@ -103,6 +92,7 @@ func (h *h2muxHandler) newRPConn(ctx context.Context) (*rpc.Conn, error) { // NewConnectionHandler returns a connectionHandler, wrapping h2mux to make RPC calls func newH2MuxHandler(ctx context.Context, + streamHandler *streamhandler.StreamHandler, config *ConnectionConfig, edgeIP *net.TCPAddr, ) (connectionHandler, error) { @@ -126,7 +116,7 @@ func newH2MuxHandler(ctx context.Context, // Client mux handshake with agent server muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{ Timeout: dialTimeout, - Handler: &muxedStreamHandler{}, + Handler: streamHandler, IsClient: true, HeartbeatInterval: config.HeartbeatInterval, MaxHeartbeats: config.MaxHeartbeats, diff --git a/connection/supervisor.go b/connection/supervisor.go index 50855d3f..ba39043a 100644 --- a/connection/supervisor.go +++ b/connection/supervisor.go @@ -5,6 +5,9 @@ import ( "net" "time" + "github.com/cloudflare/cloudflared/streamhandler" + + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/google/uuid" "github.com/pkg/errors" @@ -32,9 +35,12 @@ type CloudflaredConfig struct { // Supervisor is a stateful object that manages connections with the edge type Supervisor struct { - config *CloudflaredConfig - state *supervisorState - connErrors chan error + streamHandler *streamhandler.StreamHandler + newConfigChan chan<- *pogs.ClientConfig + useConfigResultChan <-chan *pogs.UseConfigurationResult + config *CloudflaredConfig + state *supervisorState + connErrors chan error } type supervisorState struct { @@ -57,8 +63,13 @@ func (s *supervisorState) getNextEdgeIP() *net.TCPAddr { } func NewSupervisor(config *CloudflaredConfig) *Supervisor { + newConfigChan := make(chan *pogs.ClientConfig) + useConfigResultChan := make(chan *pogs.UseConfigurationResult) return &Supervisor{ - config: config, + streamHandler: streamhandler.NewStreamHandler(newConfigChan, useConfigResultChan, config.Logger), + newConfigChan: newConfigChan, + useConfigResultChan: useConfigResultChan, + config: config, state: &supervisorState{ connectionPool: &connectionPool{}, }, @@ -91,7 +102,7 @@ func (s *Supervisor) Run(ctx context.Context) error { time.Sleep(5 * time.Second) } if currentConnectionCount < expectedConnectionCount { - h, err := newH2MuxHandler(ctx, s.config.ConnectionConfig, s.state.getNextEdgeIP()) + h, err := newH2MuxHandler(ctx, s.streamHandler, s.config.ConnectionConfig, s.state.getNextEdgeIP()) if err != nil { logger.WithError(err).Error("Failed to create new connection handler") continue diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index 2ac07b62..6e0905c2 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -94,6 +94,14 @@ type Header struct { Name, Value string } +func RPCHeaders() []Header { + return []Header{ + {Name: ":method", Value: "RPC"}, + {Name: ":scheme", Value: "capnp"}, + {Name: ":path", Value: "*"}, + } +} + // Handshake establishes a muxed connection with the peer. // After the handshake completes, it is possible to open and accept streams. func Handshake( @@ -414,6 +422,41 @@ func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader } } +func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) { + stream := &MuxedStream{ + responseHeadersReceived: make(chan struct{}), + readBuffer: NewSharedBuffer(), + writeBuffer: &bytes.Buffer{}, + writeBufferMaxLen: m.config.StreamWriteBufferMaxLen, + writeBufferHasSpace: make(chan struct{}, 1), + receiveWindow: m.config.DefaultWindowSize, + receiveWindowCurrentMax: m.config.DefaultWindowSize, + receiveWindowMax: m.config.MaxWindowSize, + sendWindow: m.config.DefaultWindowSize, + readyList: m.readyList, + writeHeaders: RPCHeaders(), + dictionaries: m.muxReader.dictionaries, + } + + select { + // Will be received by mux writer + case <-ctx.Done(): + return nil, ErrOpenStreamTimeout + case <-m.abortChan: + return nil, ErrConnectionClosed + case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: nil}: + } + + select { + case <-ctx.Done(): + return nil, ErrResponseHeadersTimeout + case <-m.abortChan: + return nil, ErrConnectionClosed + case <-stream.responseHeadersReceived: + return stream, nil + } +} + func (m *Muxer) Metrics() *MuxerMetrics { return m.muxMetricsUpdater.metrics() } diff --git a/h2mux/muxedstream.go b/h2mux/muxedstream.go index 8fb94817..44d6f1e2 100644 --- a/h2mux/muxedstream.go +++ b/h2mux/muxedstream.go @@ -68,7 +68,8 @@ type MuxedStream struct { sentEOF bool // true if the peer sent us an EOF receivedEOF bool - + // If valid, tunnelHostname is used to identify which origin service is the intended recipient of the request + tunnelHostname TunnelHostname // Compression-related fields receivedUseDict bool method string @@ -195,6 +196,25 @@ func (s *MuxedStream) WriteHeaders(headers []Header) error { return nil } +// IsRPCStream returns if the stream is used to transport RPC. +func (s *MuxedStream) IsRPCStream() bool { + rpcHeaders := RPCHeaders() + if len(s.Headers) != len(rpcHeaders) { + return false + } + // The headers order matters, so RPC stream should be opened with OpenRPCStream method and let MuxWriter serializes the headers. + for i, rpcHeader := range rpcHeaders { + if s.Headers[i] != rpcHeader { + return false + } + } + return true +} + +func (s *MuxedStream) TunnelHostname() TunnelHostname { + return s.tunnelHostname +} + func (s *MuxedStream) getReceiveWindow() uint32 { s.writeLock.Lock() defer s.writeLock.Unlock() diff --git a/h2mux/muxedstream_test.go b/h2mux/muxedstream_test.go index 3672b531..b0e0ac13 100644 --- a/h2mux/muxedstream_test.go +++ b/h2mux/muxedstream_test.go @@ -98,3 +98,30 @@ func TestMuxedStreamEOF(t *testing.T) { assert.Equal(t, 0, n) } } + +func TestIsRPCStream(t *testing.T) { + tests := []struct { + stream *MuxedStream + isRPCStream bool + }{ + { + stream: &MuxedStream{}, + isRPCStream: false, + }, + { + stream: &MuxedStream{Headers: RPCHeaders()}, + isRPCStream: true, + }, + { + stream: &MuxedStream{Headers: []Header{ + {Name: ":method", Value: "rpc"}, + {Name: ":scheme", Value: "Capnp"}, + {Name: ":path", Value: "/"}, + }}, + isRPCStream: false, + }, + } + for _, test := range tests { + assert.Equal(t, test.isRPCStream, test.stream.IsRPCStream()) + } +} diff --git a/origin/tunnel.go b/origin/tunnel.go index bbc80046..64539048 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -17,6 +17,7 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/signal" + "github.com/cloudflare/cloudflared/streamhandler" "github.com/cloudflare/cloudflared/tunnelrpc" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/validation" @@ -471,39 +472,6 @@ func LogServerInfo( metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName) } -func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error { - for _, header := range h2 { - switch header.Name { - case ":method": - h1.Method = header.Value - case ":scheme": - case ":authority": - // Otherwise the host header will be based on the origin URL - h1.Host = header.Value - case ":path": - u, err := url.Parse(header.Value) - if err != nil { - return fmt.Errorf("unparseable path") - } - resolved := h1.URL.ResolveReference(u) - // prevent escaping base URL - if !strings.HasPrefix(resolved.String(), h1.URL.String()) { - return fmt.Errorf("invalid path") - } - h1.URL = resolved - case "content-length": - contentLength, err := strconv.ParseInt(header.Value, 10, 64) - if err != nil { - return fmt.Errorf("unparseable content length") - } - h1.ContentLength = contentLength - default: - h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value) - } - } - return nil -} - func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) { h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}} for headerName, headerValues := range h1.Header { @@ -514,10 +482,6 @@ func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) { return } -func FindCfRayHeader(h1 *http.Request) string { - return h1.Header.Get("Cf-Ray") -} - type TunnelHandler struct { originUrl string muxer *h2mux.Muxer @@ -605,8 +569,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { return reqErr } - cfRay := FindCfRayHeader(req) - lbProbe := isLBProbeRequest(req) + cfRay := streamhandler.FindCfRayHeader(req) + lbProbe := streamhandler.IsLBProbeRequest(req) h.logRequest(req, cfRay, lbProbe) var resp *http.Response @@ -629,7 +593,7 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, if err != nil { return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") } - err = H2RequestHeadersToH1Request(stream.Headers, req) + err = streamhandler.H2RequestHeadersToH1Request(stream.Headers, req) if err != nil { return nil, errors.Wrap(err, "invalid request received") } @@ -759,10 +723,6 @@ func uint8ToString(input uint8) string { return strconv.FormatUint(uint64(input), 10) } -func isLBProbeRequest(req *http.Request) bool { - return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix) -} - // Print out the given lines in a nice ASCII box. func asciiBox(lines []string, padding int) (box []string) { maxLen := maxLen(lines) diff --git a/originservice/originservice.go b/originservice/originservice.go index 34aadb8f..9bbf4e02 100644 --- a/originservice/originservice.go +++ b/originservice/originservice.go @@ -22,6 +22,7 @@ import ( // OriginService is an interface to proxy requests to different type of origins type OriginService interface { Proxy(stream *h2mux.MuxedStream, req *http.Request) (resp *http.Response, err error) + OriginAddr() string Shutdown() } @@ -55,13 +56,13 @@ func (hc *HTTPService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*htt resp, err := hc.client.RoundTrip(req) if err != nil { - return nil, errors.Wrap(err, "Error proxying request to HTTP origin") + return nil, errors.Wrap(err, "error proxying request to HTTP origin") } defer resp.Body.Close() err = stream.WriteHeaders(h1ResponseToH2Response(resp)) if err != nil { - return nil, errors.Wrap(err, "Error writing response header to HTTP origin") + return nil, errors.Wrap(err, "error writing response header to HTTP origin") } if isEventStream(resp) { writeEventStream(stream, resp.Body) @@ -73,30 +74,39 @@ func (hc *HTTPService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*htt return resp, nil } +func (hc *HTTPService) OriginAddr() string { + return hc.originAddr +} + func (hc *HTTPService) Shutdown() {} // WebsocketService talks to origin using WS/WSS type WebsocketService struct { - tlsConfig *tls.Config - shutdownC chan struct{} + tlsConfig *tls.Config + originAddr string + shutdownC chan struct{} } func NewWebSocketService(tlsConfig *tls.Config, url string) (OriginService, error) { listener, err := net.Listen("tcp", "127.0.0.1:") if err != nil { - return nil, errors.Wrap(err, "Cannot start Websocket Proxy Server") + return nil, errors.Wrap(err, "cannot start Websocket Proxy Server") } shutdownC := make(chan struct{}) go func() { websocket.StartProxyServer(log.CreateLogger(), listener, url, shutdownC) }() return &WebsocketService{ - tlsConfig: tlsConfig, - shutdownC: shutdownC, + tlsConfig: tlsConfig, + originAddr: url, + shutdownC: shutdownC, }, nil } -func (wsc *WebsocketService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (response *http.Response, err error) { +func (wsc *WebsocketService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { + if !websocket.IsWebSocketUpgrade(req) { + return nil, fmt.Errorf("request is not a websocket connection") + } conn, response, err := websocket.ClientConnect(req, wsc.tlsConfig) if err != nil { return nil, err @@ -104,7 +114,7 @@ func (wsc *WebsocketService) Proxy(stream *h2mux.MuxedStream, req *http.Request) defer conn.Close() err = stream.WriteHeaders(h1ResponseToH2Response(response)) if err != nil { - return nil, errors.Wrap(err, "Error writing response header to websocket origin") + return nil, errors.Wrap(err, "error writing response header to websocket origin") } // Copy to/from stream to the undelying connection. Use the underlying // connection because cloudflared doesn't operate on the message themselves @@ -112,30 +122,36 @@ func (wsc *WebsocketService) Proxy(stream *h2mux.MuxedStream, req *http.Request) return response, nil } +func (wsc *WebsocketService) OriginAddr() string { + return wsc.originAddr +} + func (wsc *WebsocketService) Shutdown() { close(wsc.shutdownC) } // HelloWorldService talks to the hello world example origin type HelloWorldService struct { - client http.RoundTripper - listener net.Listener - shutdownC chan struct{} + client http.RoundTripper + listener net.Listener + originAddr string + shutdownC chan struct{} } func NewHelloWorldService(transport http.RoundTripper) (OriginService, error) { listener, err := hello.CreateTLSListener("127.0.0.1:") if err != nil { - return nil, errors.Wrap(err, "Cannot start Hello World Server") + return nil, errors.Wrap(err, "cannot start Hello World Server") } shutdownC := make(chan struct{}) go func() { hello.StartHelloWorldServer(log.CreateLogger(), listener, shutdownC) }() return &HelloWorldService{ - client: transport, - listener: listener, - shutdownC: shutdownC, + client: transport, + listener: listener, + originAddr: listener.Addr().String(), + shutdownC: shutdownC, }, nil } @@ -145,13 +161,13 @@ func (hwc *HelloWorldService) Proxy(stream *h2mux.MuxedStream, req *http.Request resp, err := hwc.client.RoundTrip(req) if err != nil { - return nil, errors.Wrap(err, "Error proxying request to Hello World origin") + return nil, errors.Wrap(err, "error proxying request to Hello World origin") } defer resp.Body.Close() err = stream.WriteHeaders(h1ResponseToH2Response(resp)) if err != nil { - return nil, errors.Wrap(err, "Error writing response header to Hello World origin") + return nil, errors.Wrap(err, "error writing response header to Hello World origin") } // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream @@ -161,6 +177,10 @@ func (hwc *HelloWorldService) Proxy(stream *h2mux.MuxedStream, req *http.Request return resp, nil } +func (hwc *HelloWorldService) OriginAddr() string { + return hwc.originAddr +} + func (hwc *HelloWorldService) Shutdown() { hwc.listener.Close() } diff --git a/streamhandler/request.go b/streamhandler/request.go new file mode 100644 index 00000000..6d2004bd --- /dev/null +++ b/streamhandler/request.go @@ -0,0 +1,69 @@ +package streamhandler + +import ( + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/pkg/errors" +) + +const ( + lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" +) + +func FindCfRayHeader(h1 *http.Request) string { + return h1.Header.Get("Cf-Ray") +} + +func IsLBProbeRequest(req *http.Request) bool { + return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix) +} + +func CreateRequest(stream *h2mux.MuxedStream, originAddr string) (*http.Request, error) { + req, err := http.NewRequest(http.MethodGet, originAddr, h2mux.MuxedStreamReader{MuxedStream: stream}) + if err != nil { + return nil, errors.Wrap(err, "unexpected error from http.NewRequest") + } + err = H2RequestHeadersToH1Request(stream.Headers, req) + if err != nil { + return nil, errors.Wrap(err, "invalid request received") + } + return req, nil +} + +func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error { + for _, header := range h2 { + switch header.Name { + case ":method": + h1.Method = header.Value + case ":scheme": + case ":authority": + // Otherwise the host header will be based on the origin URL + h1.Host = header.Value + case ":path": + u, err := url.Parse(header.Value) + if err != nil { + return fmt.Errorf("unparseable path") + } + resolved := h1.URL.ResolveReference(u) + // prevent escaping base URL + if !strings.HasPrefix(resolved.String(), h1.URL.String()) { + return fmt.Errorf("invalid path") + } + h1.URL = resolved + case "content-length": + contentLength, err := strconv.ParseInt(header.Value, 10, 64) + if err != nil { + return fmt.Errorf("unparseable content length") + } + h1.ContentLength = contentLength + default: + h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value) + } + } + return nil +} diff --git a/streamhandler/stream_handler.go b/streamhandler/stream_handler.go new file mode 100644 index 00000000..d350bffb --- /dev/null +++ b/streamhandler/stream_handler.go @@ -0,0 +1,91 @@ +package streamhandler + +import ( + "fmt" + "net/http" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/tunnelhostnamemapper" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/sirupsen/logrus" +) + +// StreamHandler handles new stream opened by the edge. The streams can be used to proxy requests or make RPC. +type StreamHandler struct { + // newConfigChan is a send-only channel to notify Supervisor of a new ClientConfig + newConfigChan chan<- *pogs.ClientConfig + // useConfigResultChan is a receive-only channel for Supervisor to communicate the result of applying a new ClientConfig + useConfigResultChan <-chan *pogs.UseConfigurationResult + // originMapper maps tunnel hostname to origin service + tunnelHostnameMapper *tunnelhostnamemapper.TunnelHostnameMapper + logger *logrus.Entry +} + +// NewStreamHandler creates a new StreamHandler +func NewStreamHandler(newConfigChan chan<- *pogs.ClientConfig, + useConfigResultChan <-chan *pogs.UseConfigurationResult, + logger *logrus.Logger, +) *StreamHandler { + return &StreamHandler{ + newConfigChan: newConfigChan, + useConfigResultChan: useConfigResultChan, + tunnelHostnameMapper: tunnelhostnamemapper.NewTunnelHostnameMapper(), + logger: logger.WithField("subsystem", "streamHandler"), + } +} + +// ServeStream implements MuxedStreamHandler interface +func (s *StreamHandler) ServeStream(stream *h2mux.MuxedStream) error { + if stream.IsRPCStream() { + return fmt.Errorf("serveRPC not implemented") + } + return s.serveRequest(stream) +} + +func (s *StreamHandler) serveRequest(stream *h2mux.MuxedStream) error { + tunnelHostname := stream.TunnelHostname() + if !tunnelHostname.IsSet() { + err := fmt.Errorf("stream doesn't have tunnelHostname") + s.logger.Error(err) + return err + } + + originService, ok := s.tunnelHostnameMapper.Get(tunnelHostname) + if !ok { + err := fmt.Errorf("cannot map tunnel hostname %s to origin", tunnelHostname) + s.logger.Error(err) + return err + } + + req, err := CreateRequest(stream, originService.OriginAddr()) + if err != nil { + return err + } + + logger := s.requestLogger(req, tunnelHostname) + logger.Debugf("Request Headers %+v", req.Header) + + resp, err := originService.Proxy(stream, req) + if err != nil { + logger.WithError(err).Error("Request error") + return err + } + + logger.WithField("status", resp.Status).Debugf("Response Headers %+v", resp.Header) + return nil +} + +func (s *StreamHandler) requestLogger(req *http.Request, tunnelHostname h2mux.TunnelHostname) *logrus.Entry { + cfRay := FindCfRayHeader(req) + lbProbe := IsLBProbeRequest(req) + logger := s.logger.WithField("tunnelHostname", tunnelHostname) + if cfRay != "" { + logger = logger.WithField("CF-RAY", cfRay) + logger.Debugf("%s %s %s", req.Method, req.URL, req.Proto) + } else if lbProbe { + logger.Debugf("Load Balancer health check %s %s %s", req.Method, req.URL, req.Proto) + } else { + logger.Warnf("Requests %v does not have CF-RAY header. Please open a support ticket with Cloudflare.", req) + } + return logger +} diff --git a/tunnelhostnamemapper/tunnelhostnamemapper.go b/tunnelhostnamemapper/tunnelhostnamemapper.go new file mode 100644 index 00000000..bb8f70f1 --- /dev/null +++ b/tunnelhostnamemapper/tunnelhostnamemapper.go @@ -0,0 +1,49 @@ +package tunnelhostnamemapper + +import ( + "sync" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/originservice" +) + +// TunnelHostnameMapper maps TunnelHostname to an OriginService +type TunnelHostnameMapper struct { + sync.RWMutex + tunnelHostnameToOrigin map[h2mux.TunnelHostname]originservice.OriginService +} + +func NewTunnelHostnameMapper() *TunnelHostnameMapper { + return &TunnelHostnameMapper{ + tunnelHostnameToOrigin: make(map[h2mux.TunnelHostname]originservice.OriginService), + } +} + +// Get an OriginService given a TunnelHostname +func (om *TunnelHostnameMapper) Get(key h2mux.TunnelHostname) (originservice.OriginService, bool) { + om.RLock() + defer om.RUnlock() + originService, ok := om.tunnelHostnameToOrigin[key] + return originService, ok +} + +// Add a mapping. If there is already an OriginService with this key, shutdown the old origin service and replace it +// with the new one +func (om *TunnelHostnameMapper) Add(key h2mux.TunnelHostname, os originservice.OriginService) { + om.Lock() + defer om.Unlock() + if oldOS, ok := om.tunnelHostnameToOrigin[key]; ok { + oldOS.Shutdown() + } + om.tunnelHostnameToOrigin[key] = os +} + +// DeleteAll mappings, and shutdown all OriginService +func (om *TunnelHostnameMapper) DeleteAll() { + om.Lock() + defer om.Unlock() + for key, os := range om.tunnelHostnameToOrigin { + os.Shutdown() + delete(om.tunnelHostnameToOrigin, key) + } +} diff --git a/tunnelhostnamemapper/tunnelhostnamemapper_test.go b/tunnelhostnamemapper/tunnelhostnamemapper_test.go new file mode 100644 index 00000000..4c7fd0d8 --- /dev/null +++ b/tunnelhostnamemapper/tunnelhostnamemapper_test.go @@ -0,0 +1,69 @@ +package tunnelhostnamemapper + +import ( + "fmt" + "net/http" + "sync" + "testing" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/originservice" + "github.com/stretchr/testify/assert" +) + +const ( + routines = 1000 +) + +func TestTunnelHostnameMapperConcurrentAccess(t *testing.T) { + thm := NewTunnelHostnameMapper() + + concurrentOps(t, func(i int) { + // om is empty + os, ok := thm.Get(tunnelHostname(i)) + assert.False(t, ok) + assert.Nil(t, os) + }) + + httpOS := originservice.NewHTTPService(http.DefaultTransport, "127.0.0.1:8080", false) + concurrentOps(t, func(i int) { + thm.Add(tunnelHostname(i), httpOS) + }) + + concurrentOps(t, func(i int) { + os, ok := thm.Get(tunnelHostname(i)) + assert.True(t, ok) + assert.Equal(t, httpOS, os) + }) + + secondHTTPOS := originservice.NewHTTPService(http.DefaultTransport, "127.0.0.1:8090", true) + concurrentOps(t, func(i int) { + // Add should httpOS with secondHTTPOS + thm.Add(tunnelHostname(i), secondHTTPOS) + }) + + concurrentOps(t, func(i int) { + os, ok := thm.Get(tunnelHostname(i)) + assert.True(t, ok) + assert.Equal(t, secondHTTPOS, os) + }) + + thm.DeleteAll() + assert.Empty(t, thm.tunnelHostnameToOrigin) +} + +func concurrentOps(t *testing.T, f func(i int)) { + var wg sync.WaitGroup + wg.Add(routines) + for i := 0; i < routines; i++ { + go func(i int) { + f(i) + wg.Done() + }(i) + } + wg.Wait() +} + +func tunnelHostname(i int) h2mux.TunnelHostname { + return h2mux.TunnelHostname(fmt.Sprintf("%d.cftunnel.com", i)) +}