Merge branch 'master' of ssh://stash.cfops.it:7999/tun/cloudflared
This commit is contained in:
		
						commit
						25a04e0c69
					
				|  | @ -21,6 +21,7 @@ import ( | |||
| 	"github.com/cloudflare/cloudflared/metrics" | ||||
| 	"github.com/cloudflare/cloudflared/origin" | ||||
| 	"github.com/cloudflare/cloudflared/signal" | ||||
| 	"github.com/cloudflare/cloudflared/tlsconfig" | ||||
| 	"github.com/cloudflare/cloudflared/tunneldns" | ||||
| 	"github.com/cloudflare/cloudflared/websocket" | ||||
| 	"github.com/coreos/go-systemd/daemon" | ||||
|  | @ -444,7 +445,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag { | |||
| 			Hidden:  true, | ||||
| 		}), | ||||
| 		altsrc.NewStringFlag(&cli.StringFlag{ | ||||
| 			Name:    "cacert", | ||||
| 			Name:    tlsconfig.CaCertFlag, | ||||
| 			Usage:   "Certificate Authority authenticating connections with Cloudflare's edge network.", | ||||
| 			EnvVars: []string{"TUNNEL_CACERT"}, | ||||
| 			Hidden:  true, | ||||
|  | @ -463,7 +464,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag { | |||
| 			Hidden:  shouldHide, | ||||
| 		}), | ||||
| 		altsrc.NewStringFlag(&cli.StringFlag{ | ||||
| 			Name:    "origin-ca-pool", | ||||
| 			Name:    tlsconfig.OriginCAPoolFlag, | ||||
| 			Usage:   "Path to the CA for the certificate of your origin. This option should be used only if your certificate is not signed by Cloudflare.", | ||||
| 			EnvVars: []string{"TUNNEL_ORIGIN_CA_POOL"}, | ||||
| 			Hidden:  shouldHide, | ||||
|  |  | |||
|  | @ -3,14 +3,12 @@ package tunnel | |||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"crypto/x509" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"runtime" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
|  | @ -187,7 +185,7 @@ func prepareTunnelConfig( | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	originCertPool, err := loadCertPool(c, logger) | ||||
| 	originCertPool, err := tlsconfig.LoadOriginCA(c, logger) | ||||
| 	if err != nil { | ||||
| 		logger.WithError(err).Error("Error loading cert pool") | ||||
| 		return nil, errors.Wrap(err, "Error loading cert pool") | ||||
|  | @ -236,7 +234,7 @@ func prepareTunnelConfig( | |||
| 		return nil, errors.Wrap(err, "unable to connect to the origin") | ||||
| 	} | ||||
| 
 | ||||
| 	toEdgeTLSConfig, err := createTunnelConfig(c) | ||||
| 	toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c) | ||||
| 	if err != nil { | ||||
| 		logger.WithError(err).Error("unable to create TLS config to connect with edge") | ||||
| 		return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") | ||||
|  | @ -274,112 +272,6 @@ func prepareTunnelConfig( | |||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func loadCertPool(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) { | ||||
| 	const originCAPoolFlag = "origin-ca-pool" | ||||
| 	originCAPoolFilename := c.String(originCAPoolFlag) | ||||
| 	var originCustomCAPool []byte | ||||
| 
 | ||||
| 	if originCAPoolFilename != "" { | ||||
| 		var err error | ||||
| 		originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename) | ||||
| 		if err != nil { | ||||
| 			return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, originCAPoolFlag)) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	originCertPool, err := loadOriginCertPool(originCustomCAPool) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "error loading the certificate pool") | ||||
| 	} | ||||
| 
 | ||||
| 	// Windows users should be notified that they can use the flag
 | ||||
| 	if runtime.GOOS == "windows" && originCAPoolFilename == "" { | ||||
| 		logger.Infof("cloudflared does not support loading the system root certificate pool on Windows. Please use the --%s to specify it", originCAPoolFlag) | ||||
| 	} | ||||
| 
 | ||||
| 	return originCertPool, nil | ||||
| } | ||||
| 
 | ||||
| func loadOriginCertPool(originCAPoolPEM []byte) (*x509.CertPool, error) { | ||||
| 	// Get the global pool
 | ||||
| 	certPool, err := loadGlobalCertPool() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// Then, add any custom origin CA pool the user may have passed
 | ||||
| 	if originCAPoolPEM != nil { | ||||
| 		if !certPool.AppendCertsFromPEM(originCAPoolPEM) { | ||||
| 			logger.Warn("could not append the provided origin CA to the cloudflared certificate pool") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return certPool, nil | ||||
| } | ||||
| 
 | ||||
| func loadGlobalCertPool() (*x509.CertPool, error) { | ||||
| 	// First, obtain the system certificate pool
 | ||||
| 	certPool, err := x509.SystemCertPool() | ||||
| 	if err != nil { | ||||
| 		if runtime.GOOS != "windows" { | ||||
| 			logger.WithError(err).Warn("error obtaining the system certificates") | ||||
| 		} | ||||
| 		certPool = x509.NewCertPool() | ||||
| 	} | ||||
| 
 | ||||
| 	// Next, append the Cloudflare CAs into the system pool
 | ||||
| 	cfRootCA, err := tlsconfig.GetCloudflareRootCA() | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool") | ||||
| 	} | ||||
| 	for _, cert := range cfRootCA { | ||||
| 		certPool.AddCert(cert) | ||||
| 	} | ||||
| 
 | ||||
| 	// Finally, add the Hello certificate into the pool (since it's self-signed)
 | ||||
| 	helloCert, err := tlsconfig.GetHelloCertificateX509() | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool") | ||||
| 	} | ||||
| 	certPool.AddCert(helloCert) | ||||
| 
 | ||||
| 	return certPool, nil | ||||
| } | ||||
| 
 | ||||
| func createTunnelConfig(c *cli.Context) (*tls.Config, error) { | ||||
| 	var rootCAs []string | ||||
| 	if c.String("cacert") != "" { | ||||
| 		rootCAs = append(rootCAs, c.String("cacert")) | ||||
| 	} | ||||
| 	edgeAddrs := c.StringSlice("edge") | ||||
| 
 | ||||
| 	userConfig := &tlsconfig.TLSParameters{RootCAs: rootCAs} | ||||
| 	tlsConfig, err := tlsconfig.GetConfig(userConfig) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if tlsConfig.RootCAs == nil { | ||||
| 		rootCAPool := x509.NewCertPool() | ||||
| 		cfRootCA, err := tlsconfig.GetCloudflareRootCA() | ||||
| 		if err != nil { | ||||
| 			return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool") | ||||
| 		} | ||||
| 		for _, cert := range cfRootCA { | ||||
| 			rootCAPool.AddCert(cert) | ||||
| 		} | ||||
| 		tlsConfig.RootCAs = rootCAPool | ||||
| 		tlsConfig.ServerName = "cftunnel.com" | ||||
| 	} else if len(edgeAddrs) > 0 { | ||||
| 		// Set for development environments and for testing specific origintunneld instances
 | ||||
| 		tlsConfig.ServerName, _, _ = net.SplitHostPort(edgeAddrs[0]) | ||||
| 	} | ||||
| 
 | ||||
| 	if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify { | ||||
| 		return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config") | ||||
| 	} | ||||
| 	return tlsConfig, nil | ||||
| } | ||||
| 
 | ||||
| func isRunningFromTerminal() bool { | ||||
| 	return terminal.IsTerminal(int(os.Stdout.Fd())) | ||||
| } | ||||
|  |  | |||
|  | @ -0,0 +1,199 @@ | |||
| // Package client defines and implements interface to proxy to HTTP, websocket and hello world origins
 | ||||
| package originservice | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"crypto/tls" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/h2mux" | ||||
| 	"github.com/cloudflare/cloudflared/hello" | ||||
| 	"github.com/cloudflare/cloudflared/log" | ||||
| 	"github.com/cloudflare/cloudflared/websocket" | ||||
| 
 | ||||
| 	"github.com/pkg/errors" | ||||
| ) | ||||
| 
 | ||||
| // OriginService is an interface to proxy requests to different type of origins
 | ||||
| type OriginService interface { | ||||
| 	Proxy(stream *h2mux.MuxedStream, req *http.Request) (resp *http.Response, err error) | ||||
| 	Shutdown() | ||||
| } | ||||
| 
 | ||||
| // HTTPService talks to origin using HTTP/HTTPS
 | ||||
| type HTTPService struct { | ||||
| 	client          http.RoundTripper | ||||
| 	originAddr      string | ||||
| 	chunkedEncoding bool | ||||
| } | ||||
| 
 | ||||
| func NewHTTPService(transport http.RoundTripper, originAddr string, chunkedEncoding bool) OriginService { | ||||
| 	return &HTTPService{ | ||||
| 		client:          transport, | ||||
| 		originAddr:      originAddr, | ||||
| 		chunkedEncoding: chunkedEncoding, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (hc *HTTPService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { | ||||
| 	// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
 | ||||
| 	if !hc.chunkedEncoding { | ||||
| 		req.TransferEncoding = []string{"gzip", "deflate"} | ||||
| 		cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) | ||||
| 		if err == nil { | ||||
| 			req.ContentLength = int64(cLength) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Request origin to keep connection alive to improve performance
 | ||||
| 	req.Header.Set("Connection", "keep-alive") | ||||
| 
 | ||||
| 	resp, err := hc.client.RoundTrip(req) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error proxying request to HTTP origin") | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 
 | ||||
| 	err = stream.WriteHeaders(h1ResponseToH2Response(resp)) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error writing response header to HTTP origin") | ||||
| 	} | ||||
| 	if isEventStream(resp) { | ||||
| 		writeEventStream(stream, resp.Body) | ||||
| 	} else { | ||||
| 		// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
 | ||||
| 		// compression generates dictionary on first write
 | ||||
| 		io.CopyBuffer(stream, resp.Body, make([]byte, 512*1024)) | ||||
| 	} | ||||
| 	return resp, nil | ||||
| } | ||||
| 
 | ||||
| func (hc *HTTPService) Shutdown() {} | ||||
| 
 | ||||
| // WebsocketService talks to origin using WS/WSS
 | ||||
| type WebsocketService struct { | ||||
| 	tlsConfig *tls.Config | ||||
| 	shutdownC chan struct{} | ||||
| } | ||||
| 
 | ||||
| func NewWebSocketService(tlsConfig *tls.Config, url string) (OriginService, error) { | ||||
| 	listener, err := net.Listen("tcp", "127.0.0.1:") | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Cannot start Websocket Proxy Server") | ||||
| 	} | ||||
| 	shutdownC := make(chan struct{}) | ||||
| 	go func() { | ||||
| 		websocket.StartProxyServer(log.CreateLogger(), listener, url, shutdownC) | ||||
| 	}() | ||||
| 	return &WebsocketService{ | ||||
| 		tlsConfig: tlsConfig, | ||||
| 		shutdownC: shutdownC, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func (wsc *WebsocketService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (response *http.Response, err error) { | ||||
| 	conn, response, err := websocket.ClientConnect(req, wsc.tlsConfig) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
| 	err = stream.WriteHeaders(h1ResponseToH2Response(response)) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error writing response header to websocket origin") | ||||
| 	} | ||||
| 	// Copy to/from stream to the undelying connection. Use the underlying
 | ||||
| 	// connection because cloudflared doesn't operate on the message themselves
 | ||||
| 	websocket.Stream(conn.UnderlyingConn(), stream) | ||||
| 	return response, nil | ||||
| } | ||||
| 
 | ||||
| func (wsc *WebsocketService) Shutdown() { | ||||
| 	close(wsc.shutdownC) | ||||
| } | ||||
| 
 | ||||
| // HelloWorldService talks to the hello world example origin
 | ||||
| type HelloWorldService struct { | ||||
| 	client    http.RoundTripper | ||||
| 	listener  net.Listener | ||||
| 	shutdownC chan struct{} | ||||
| } | ||||
| 
 | ||||
| func NewHelloWorldService(transport http.RoundTripper) (OriginService, error) { | ||||
| 	listener, err := hello.CreateTLSListener("127.0.0.1:") | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Cannot start Hello World Server") | ||||
| 	} | ||||
| 	shutdownC := make(chan struct{}) | ||||
| 	go func() { | ||||
| 		hello.StartHelloWorldServer(log.CreateLogger(), listener, shutdownC) | ||||
| 	}() | ||||
| 	return &HelloWorldService{ | ||||
| 		client:    transport, | ||||
| 		listener:  listener, | ||||
| 		shutdownC: shutdownC, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func (hwc *HelloWorldService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) { | ||||
| 	// Request origin to keep connection alive to improve performance
 | ||||
| 	req.Header.Set("Connection", "keep-alive") | ||||
| 
 | ||||
| 	resp, err := hwc.client.RoundTrip(req) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error proxying request to Hello World origin") | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 
 | ||||
| 	err = stream.WriteHeaders(h1ResponseToH2Response(resp)) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Error writing response header to Hello World origin") | ||||
| 	} | ||||
| 
 | ||||
| 	// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
 | ||||
| 	// compression generates dictionary on first write
 | ||||
| 	io.CopyBuffer(stream, resp.Body, make([]byte, 512*1024)) | ||||
| 
 | ||||
| 	return resp, nil | ||||
| } | ||||
| 
 | ||||
| func (hwc *HelloWorldService) Shutdown() { | ||||
| 	hwc.listener.Close() | ||||
| } | ||||
| 
 | ||||
| func isEventStream(resp *http.Response) bool { | ||||
| 	// Check if content-type is text/event-stream. We need to check if the header value starts with text/event-stream
 | ||||
| 	// because text/event-stream; charset=UTF-8 is also valid
 | ||||
| 	// Ref: https://tools.ietf.org/html/rfc7231#section-3.1.1.1
 | ||||
| 	for _, contentType := range resp.Header["content-type"] { | ||||
| 		if strings.HasPrefix(strings.ToLower(contentType), "text/event-stream") { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func writeEventStream(stream *h2mux.MuxedStream, respBody io.ReadCloser) { | ||||
| 	reader := bufio.NewReader(respBody) | ||||
| 	for { | ||||
| 		line, err := reader.ReadBytes('\n') | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		stream.Write(line) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func h1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) { | ||||
| 	h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}} | ||||
| 	for headerName, headerValues := range h1.Header { | ||||
| 		for _, headerValue := range headerValues { | ||||
| 			h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue}) | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | @ -0,0 +1,60 @@ | |||
| package originservice | ||||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
| func TestIsEventStream(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		resp          *http.Response | ||||
| 		isEventStream bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			resp:          &http.Response{}, | ||||
| 			isEventStream: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			// isEventStream checks all headers
 | ||||
| 			resp: &http.Response{ | ||||
| 				Header: http.Header{ | ||||
| 					"accept":       []string{"text/html"}, | ||||
| 					"content-type": []string{"text/event-stream"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			isEventStream: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			// Content-Type and text/event-stream are case-insensitive. text/event-stream can be followed by OWS parameter
 | ||||
| 			resp: &http.Response{ | ||||
| 				Header: http.Header{ | ||||
| 					"content-type": []string{"Text/event-stream;charset=utf-8"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			isEventStream: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			// Content-Type and text/event-stream are case-insensitive. text/event-stream can be followed by OWS parameter
 | ||||
| 			resp: &http.Response{ | ||||
| 				Header: http.Header{ | ||||
| 					"content-type": []string{"appication/json", "text/html", "Text/event-stream;charset=utf-8"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			isEventStream: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			// Not an event stream because the content-type value doesn't start with text/event-stream
 | ||||
| 			resp: &http.Response{ | ||||
| 				Header: http.Header{ | ||||
| 					"content-type": []string{" text/event-stream"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			isEventStream: false, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, test := range tests { | ||||
| 		assert.Equal(t, test.isEventStream, isEventStream(test.resp), "Header: %v", test.resp.Header) | ||||
| 	} | ||||
| } | ||||
|  | @ -2,10 +2,22 @@ package tlsconfig | |||
| 
 | ||||
| import ( | ||||
| 	"crypto/tls" | ||||
| 	"crypto/x509" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net" | ||||
| 	"runtime" | ||||
| 	"sync" | ||||
| 
 | ||||
| 	"github.com/getsentry/raven-go" | ||||
| 	"github.com/pkg/errors" | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"gopkg.in/urfave/cli.v2" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	OriginCAPoolFlag = "origin-ca-pool" | ||||
| 	CaCertFlag       = "cacert" | ||||
| ) | ||||
| 
 | ||||
| // CertReloader can load and reload a TLS certificate from a particular filepath.
 | ||||
|  | @ -51,3 +63,120 @@ func (cr *CertReloader) LoadCert() error { | |||
| 	cr.certificate = &cert | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func LoadOriginCA(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) { | ||||
| 	var originCustomCAPool []byte | ||||
| 
 | ||||
| 	originCAPoolFilename := c.String(OriginCAPoolFlag) | ||||
| 	if originCAPoolFilename != "" { | ||||
| 		var err error | ||||
| 		originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename) | ||||
| 		if err != nil { | ||||
| 			return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, OriginCAPoolFlag)) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	originCertPool, err := loadOriginCertPool(originCustomCAPool, logger) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "error loading the certificate pool") | ||||
| 	} | ||||
| 
 | ||||
| 	// Windows users should be notified that they can use the flag
 | ||||
| 	if runtime.GOOS == "windows" && originCAPoolFilename == "" { | ||||
| 		logger.Infof("cloudflared does not support loading the system root certificate pool on Windows. Please use the --%s to specify it", OriginCAPoolFlag) | ||||
| 	} | ||||
| 
 | ||||
| 	return originCertPool, nil | ||||
| } | ||||
| 
 | ||||
| func LoadCustomCertPool(customCertFilename string) (*x509.CertPool, error) { | ||||
| 	pool := x509.NewCertPool() | ||||
| 	customCAPoolPEM, err := ioutil.ReadFile(customCertFilename) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s", customCertFilename)) | ||||
| 	} | ||||
| 	if !pool.AppendCertsFromPEM(customCAPoolPEM) { | ||||
| 		return nil, fmt.Errorf("error appending custom CA to cert pool") | ||||
| 	} | ||||
| 	return pool, nil | ||||
| } | ||||
| 
 | ||||
| func CreateTunnelConfig(c *cli.Context) (*tls.Config, error) { | ||||
| 	var rootCAs []string | ||||
| 	if c.String(CaCertFlag) != "" { | ||||
| 		rootCAs = append(rootCAs, c.String(CaCertFlag)) | ||||
| 	} | ||||
| 
 | ||||
| 	userConfig := &TLSParameters{RootCAs: rootCAs} | ||||
| 	tlsConfig, err := GetConfig(userConfig) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if tlsConfig.RootCAs == nil { | ||||
| 		rootCAPool := x509.NewCertPool() | ||||
| 		cfRootCA, err := GetCloudflareRootCA() | ||||
| 		if err != nil { | ||||
| 			return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool") | ||||
| 		} | ||||
| 		for _, cert := range cfRootCA { | ||||
| 			rootCAPool.AddCert(cert) | ||||
| 		} | ||||
| 		tlsConfig.RootCAs = rootCAPool | ||||
| 		tlsConfig.ServerName = "cftunnel.com" | ||||
| 	} else if edgeAddrs := c.StringSlice("edge"); len(edgeAddrs) > 0 { | ||||
| 		// Set for development environments and for testing specific origintunneld instances
 | ||||
| 		tlsConfig.ServerName, _, _ = net.SplitHostPort(edgeAddrs[0]) | ||||
| 	} | ||||
| 
 | ||||
| 	if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify { | ||||
| 		return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config") | ||||
| 	} | ||||
| 	return tlsConfig, nil | ||||
| } | ||||
| 
 | ||||
| func loadOriginCertPool(originCAPoolPEM []byte, logger *logrus.Logger) (*x509.CertPool, error) { | ||||
| 	// Get the global pool
 | ||||
| 	certPool, err := loadGlobalCertPool(logger) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// Then, add any custom origin CA pool the user may have passed
 | ||||
| 	if originCAPoolPEM != nil { | ||||
| 		if !certPool.AppendCertsFromPEM(originCAPoolPEM) { | ||||
| 			logger.Warn("could not append the provided origin CA to the cloudflared certificate pool") | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return certPool, nil | ||||
| } | ||||
| 
 | ||||
| func loadGlobalCertPool(logger *logrus.Logger) (*x509.CertPool, error) { | ||||
| 	// First, obtain the system certificate pool
 | ||||
| 	certPool, err := x509.SystemCertPool() | ||||
| 	if err != nil { | ||||
| 		if runtime.GOOS != "windows" { // See https://github.com/golang/go/issues/16736
 | ||||
| 			logger.WithError(err).Warn("error obtaining the system certificates") | ||||
| 		} | ||||
| 		certPool = x509.NewCertPool() | ||||
| 	} | ||||
| 
 | ||||
| 	// Next, append the Cloudflare CAs into the system pool
 | ||||
| 	cfRootCA, err := GetCloudflareRootCA() | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool") | ||||
| 	} | ||||
| 	for _, cert := range cfRootCA { | ||||
| 		certPool.AddCert(cert) | ||||
| 	} | ||||
| 
 | ||||
| 	// Finally, add the Hello certificate into the pool (since it's self-signed)
 | ||||
| 	helloCert, err := GetHelloCertificateX509() | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool") | ||||
| 	} | ||||
| 	certPool.AddCert(helloCert) | ||||
| 
 | ||||
| 	return certPool, nil | ||||
| } | ||||
|  |  | |||
|  | @ -2,11 +2,18 @@ package pogs | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"crypto/x509" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/cloudflare/cloudflared/originservice" | ||||
| 	"github.com/cloudflare/cloudflared/tlsconfig" | ||||
| 	"github.com/cloudflare/cloudflared/tunnelrpc" | ||||
| 	"github.com/pkg/errors" | ||||
| 	capnp "zombiezen.com/go/capnproto2" | ||||
| 	"zombiezen.com/go/capnproto2/pogs" | ||||
| 	"zombiezen.com/go/capnproto2/rpc" | ||||
|  | @ -68,6 +75,9 @@ func NewReverseProxyConfig( | |||
| 
 | ||||
| //go-sumtype:decl OriginConfig
 | ||||
| type OriginConfig interface { | ||||
| 	// Service returns a OriginService used to proxy to the origin
 | ||||
| 	Service() (originservice.OriginService, error) | ||||
| 	// go-sumtype requires at least one unexported method, otherwise it will complain that interface is not sealed
 | ||||
| 	isOriginConfig() | ||||
| } | ||||
| 
 | ||||
|  | @ -86,8 +96,6 @@ type HTTPOriginConfig struct { | |||
| 	ChunkedEncoding       bool | ||||
| } | ||||
| 
 | ||||
| func (_ *HTTPOriginConfig) isOriginConfig() {} | ||||
| 
 | ||||
| type OriginAddr interface { | ||||
| 	Addr() string | ||||
| } | ||||
|  | @ -119,6 +127,39 @@ func (up *UnixPath) Addr() string { | |||
| 	return up.Path | ||||
| } | ||||
| 
 | ||||
| func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) { | ||||
| 	rootCAs, err := tlsconfig.LoadCustomCertPool(hc.OriginCAPool) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	dialContext := (&net.Dialer{ | ||||
| 		Timeout:   hc.ProxyConnectTimeout, | ||||
| 		KeepAlive: hc.TCPKeepAlive, | ||||
| 		DualStack: hc.DialDualStack, | ||||
| 	}).DialContext | ||||
| 	transport := &http.Transport{ | ||||
| 		Proxy:       http.ProxyFromEnvironment, | ||||
| 		DialContext: dialContext, | ||||
| 		TLSClientConfig: &tls.Config{ | ||||
| 			RootCAs:            rootCAs, | ||||
| 			ServerName:         hc.OriginServerName, | ||||
| 			InsecureSkipVerify: hc.TLSVerify, | ||||
| 		}, | ||||
| 		TLSHandshakeTimeout:   hc.TLSHandshakeTimeout, | ||||
| 		MaxIdleConns:          int(hc.MaxIdleConnections), | ||||
| 		IdleConnTimeout:       hc.IdleConnectionTimeout, | ||||
| 		ExpectContinueTimeout: hc.ExpectContinueTimeout, | ||||
| 	} | ||||
| 	if unixPath, ok := hc.URL.(*UnixPath); ok { | ||||
| 		transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { | ||||
| 			return dialContext(ctx, "unix", unixPath.Addr()) | ||||
| 		} | ||||
| 	} | ||||
| 	return originservice.NewHTTPService(transport, hc.URL.Addr(), hc.ChunkedEncoding), nil | ||||
| } | ||||
| 
 | ||||
| func (_ *HTTPOriginConfig) isOriginConfig() {} | ||||
| 
 | ||||
| type WebSocketOriginConfig struct { | ||||
| 	URL              string `capnp:"url"` | ||||
| 	TLSVerify        bool   `capnp:"tlsVerify"` | ||||
|  | @ -126,10 +167,48 @@ type WebSocketOriginConfig struct { | |||
| 	OriginServerName string | ||||
| } | ||||
| 
 | ||||
| func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error) { | ||||
| 	rootCAs, err := tlsconfig.LoadCustomCertPool(wsc.OriginCAPool) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	tlsConfig := &tls.Config{ | ||||
| 		RootCAs:            rootCAs, | ||||
| 		ServerName:         wsc.OriginServerName, | ||||
| 		InsecureSkipVerify: wsc.TLSVerify, | ||||
| 	} | ||||
| 	return originservice.NewWebSocketService(tlsConfig, wsc.URL) | ||||
| } | ||||
| 
 | ||||
| func (_ *WebSocketOriginConfig) isOriginConfig() {} | ||||
| 
 | ||||
| type HelloWorldOriginConfig struct{} | ||||
| 
 | ||||
| func (_ *HelloWorldOriginConfig) Service() (originservice.OriginService, error) { | ||||
| 	helloCert, err := tlsconfig.GetHelloCertificateX509() | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrap(err, "Cannot get Hello World server certificate") | ||||
| 	} | ||||
| 	rootCAs := x509.NewCertPool() | ||||
| 	rootCAs.AddCert(helloCert) | ||||
| 	transport := &http.Transport{ | ||||
| 		Proxy: http.ProxyFromEnvironment, | ||||
| 		DialContext: (&net.Dialer{ | ||||
| 			Timeout:   30 * time.Second, | ||||
| 			KeepAlive: 30 * time.Second, | ||||
| 			DualStack: true, | ||||
| 		}).DialContext, | ||||
| 		TLSClientConfig: &tls.Config{ | ||||
| 			RootCAs: rootCAs, | ||||
| 		}, | ||||
| 		MaxIdleConns:          100, | ||||
| 		IdleConnTimeout:       90 * time.Second, | ||||
| 		TLSHandshakeTimeout:   10 * time.Second, | ||||
| 		ExpectContinueTimeout: 1 * time.Second, | ||||
| 	} | ||||
| 	return originservice.NewHelloWorldService(transport) | ||||
| } | ||||
| 
 | ||||
| func (_ *HelloWorldOriginConfig) isOriginConfig() {} | ||||
| 
 | ||||
| /* | ||||
|  |  | |||
|  | @ -3,8 +3,9 @@ | |||
| package tunnelrpc | ||||
| 
 | ||||
| import ( | ||||
| 	context "golang.org/x/net/context" | ||||
| 	strconv "strconv" | ||||
| 
 | ||||
| 	context "golang.org/x/net/context" | ||||
| 	capnp "zombiezen.com/go/capnproto2" | ||||
| 	text "zombiezen.com/go/capnproto2/encoding/text" | ||||
| 	schemas "zombiezen.com/go/capnproto2/schemas" | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue