Merge branch 'master' of ssh://stash.cfops.it:7999/tun/cloudflared
This commit is contained in:
commit
25a04e0c69
|
@ -21,6 +21,7 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/metrics"
|
"github.com/cloudflare/cloudflared/metrics"
|
||||||
"github.com/cloudflare/cloudflared/origin"
|
"github.com/cloudflare/cloudflared/origin"
|
||||||
"github.com/cloudflare/cloudflared/signal"
|
"github.com/cloudflare/cloudflared/signal"
|
||||||
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
"github.com/cloudflare/cloudflared/tunneldns"
|
"github.com/cloudflare/cloudflared/tunneldns"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
"github.com/coreos/go-systemd/daemon"
|
"github.com/coreos/go-systemd/daemon"
|
||||||
|
@ -444,7 +445,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
}),
|
}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "cacert",
|
Name: tlsconfig.CaCertFlag,
|
||||||
Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.",
|
Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.",
|
||||||
EnvVars: []string{"TUNNEL_CACERT"},
|
EnvVars: []string{"TUNNEL_CACERT"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
|
@ -463,7 +464,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
Hidden: shouldHide,
|
Hidden: shouldHide,
|
||||||
}),
|
}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "origin-ca-pool",
|
Name: tlsconfig.OriginCAPoolFlag,
|
||||||
Usage: "Path to the CA for the certificate of your origin. This option should be used only if your certificate is not signed by Cloudflare.",
|
Usage: "Path to the CA for the certificate of your origin. This option should be used only if your certificate is not signed by Cloudflare.",
|
||||||
EnvVars: []string{"TUNNEL_ORIGIN_CA_POOL"},
|
EnvVars: []string{"TUNNEL_ORIGIN_CA_POOL"},
|
||||||
Hidden: shouldHide,
|
Hidden: shouldHide,
|
||||||
|
|
|
@ -3,14 +3,12 @@ package tunnel
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -187,7 +185,7 @@ func prepareTunnelConfig(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
originCertPool, err := loadCertPool(c, logger)
|
originCertPool, err := tlsconfig.LoadOriginCA(c, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Error("Error loading cert pool")
|
logger.WithError(err).Error("Error loading cert pool")
|
||||||
return nil, errors.Wrap(err, "Error loading cert pool")
|
return nil, errors.Wrap(err, "Error loading cert pool")
|
||||||
|
@ -236,7 +234,7 @@ func prepareTunnelConfig(
|
||||||
return nil, errors.Wrap(err, "unable to connect to the origin")
|
return nil, errors.Wrap(err, "unable to connect to the origin")
|
||||||
}
|
}
|
||||||
|
|
||||||
toEdgeTLSConfig, err := createTunnelConfig(c)
|
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Error("unable to create TLS config to connect with edge")
|
logger.WithError(err).Error("unable to create TLS config to connect with edge")
|
||||||
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
|
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
|
||||||
|
@ -274,112 +272,6 @@ func prepareTunnelConfig(
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadCertPool(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) {
|
|
||||||
const originCAPoolFlag = "origin-ca-pool"
|
|
||||||
originCAPoolFilename := c.String(originCAPoolFlag)
|
|
||||||
var originCustomCAPool []byte
|
|
||||||
|
|
||||||
if originCAPoolFilename != "" {
|
|
||||||
var err error
|
|
||||||
originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, originCAPoolFlag))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
originCertPool, err := loadOriginCertPool(originCustomCAPool)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "error loading the certificate pool")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Windows users should be notified that they can use the flag
|
|
||||||
if runtime.GOOS == "windows" && originCAPoolFilename == "" {
|
|
||||||
logger.Infof("cloudflared does not support loading the system root certificate pool on Windows. Please use the --%s to specify it", originCAPoolFlag)
|
|
||||||
}
|
|
||||||
|
|
||||||
return originCertPool, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadOriginCertPool(originCAPoolPEM []byte) (*x509.CertPool, error) {
|
|
||||||
// Get the global pool
|
|
||||||
certPool, err := loadGlobalCertPool()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then, add any custom origin CA pool the user may have passed
|
|
||||||
if originCAPoolPEM != nil {
|
|
||||||
if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
|
|
||||||
logger.Warn("could not append the provided origin CA to the cloudflared certificate pool")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return certPool, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadGlobalCertPool() (*x509.CertPool, error) {
|
|
||||||
// First, obtain the system certificate pool
|
|
||||||
certPool, err := x509.SystemCertPool()
|
|
||||||
if err != nil {
|
|
||||||
if runtime.GOOS != "windows" {
|
|
||||||
logger.WithError(err).Warn("error obtaining the system certificates")
|
|
||||||
}
|
|
||||||
certPool = x509.NewCertPool()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, append the Cloudflare CAs into the system pool
|
|
||||||
cfRootCA, err := tlsconfig.GetCloudflareRootCA()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
|
|
||||||
}
|
|
||||||
for _, cert := range cfRootCA {
|
|
||||||
certPool.AddCert(cert)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, add the Hello certificate into the pool (since it's self-signed)
|
|
||||||
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool")
|
|
||||||
}
|
|
||||||
certPool.AddCert(helloCert)
|
|
||||||
|
|
||||||
return certPool, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func createTunnelConfig(c *cli.Context) (*tls.Config, error) {
|
|
||||||
var rootCAs []string
|
|
||||||
if c.String("cacert") != "" {
|
|
||||||
rootCAs = append(rootCAs, c.String("cacert"))
|
|
||||||
}
|
|
||||||
edgeAddrs := c.StringSlice("edge")
|
|
||||||
|
|
||||||
userConfig := &tlsconfig.TLSParameters{RootCAs: rootCAs}
|
|
||||||
tlsConfig, err := tlsconfig.GetConfig(userConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if tlsConfig.RootCAs == nil {
|
|
||||||
rootCAPool := x509.NewCertPool()
|
|
||||||
cfRootCA, err := tlsconfig.GetCloudflareRootCA()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
|
|
||||||
}
|
|
||||||
for _, cert := range cfRootCA {
|
|
||||||
rootCAPool.AddCert(cert)
|
|
||||||
}
|
|
||||||
tlsConfig.RootCAs = rootCAPool
|
|
||||||
tlsConfig.ServerName = "cftunnel.com"
|
|
||||||
} else if len(edgeAddrs) > 0 {
|
|
||||||
// Set for development environments and for testing specific origintunneld instances
|
|
||||||
tlsConfig.ServerName, _, _ = net.SplitHostPort(edgeAddrs[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
|
|
||||||
return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config")
|
|
||||||
}
|
|
||||||
return tlsConfig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRunningFromTerminal() bool {
|
func isRunningFromTerminal() bool {
|
||||||
return terminal.IsTerminal(int(os.Stdout.Fd()))
|
return terminal.IsTerminal(int(os.Stdout.Fd()))
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,199 @@
|
||||||
|
// Package client defines and implements interface to proxy to HTTP, websocket and hello world origins
|
||||||
|
package originservice
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflared/hello"
|
||||||
|
"github.com/cloudflare/cloudflared/log"
|
||||||
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OriginService is an interface to proxy requests to different type of origins
|
||||||
|
type OriginService interface {
|
||||||
|
Proxy(stream *h2mux.MuxedStream, req *http.Request) (resp *http.Response, err error)
|
||||||
|
Shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPService talks to origin using HTTP/HTTPS
|
||||||
|
type HTTPService struct {
|
||||||
|
client http.RoundTripper
|
||||||
|
originAddr string
|
||||||
|
chunkedEncoding bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHTTPService(transport http.RoundTripper, originAddr string, chunkedEncoding bool) OriginService {
|
||||||
|
return &HTTPService{
|
||||||
|
client: transport,
|
||||||
|
originAddr: originAddr,
|
||||||
|
chunkedEncoding: chunkedEncoding,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HTTPService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||||
|
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||||
|
if !hc.chunkedEncoding {
|
||||||
|
req.TransferEncoding = []string{"gzip", "deflate"}
|
||||||
|
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
||||||
|
if err == nil {
|
||||||
|
req.ContentLength = int64(cLength)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request origin to keep connection alive to improve performance
|
||||||
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
|
||||||
|
resp, err := hc.client.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error proxying request to HTTP origin")
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
err = stream.WriteHeaders(h1ResponseToH2Response(resp))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error writing response header to HTTP origin")
|
||||||
|
}
|
||||||
|
if isEventStream(resp) {
|
||||||
|
writeEventStream(stream, resp.Body)
|
||||||
|
} else {
|
||||||
|
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||||
|
// compression generates dictionary on first write
|
||||||
|
io.CopyBuffer(stream, resp.Body, make([]byte, 512*1024))
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hc *HTTPService) Shutdown() {}
|
||||||
|
|
||||||
|
// WebsocketService talks to origin using WS/WSS
|
||||||
|
type WebsocketService struct {
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
shutdownC chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWebSocketService(tlsConfig *tls.Config, url string) (OriginService, error) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:")
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Cannot start Websocket Proxy Server")
|
||||||
|
}
|
||||||
|
shutdownC := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
websocket.StartProxyServer(log.CreateLogger(), listener, url, shutdownC)
|
||||||
|
}()
|
||||||
|
return &WebsocketService{
|
||||||
|
tlsConfig: tlsConfig,
|
||||||
|
shutdownC: shutdownC,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wsc *WebsocketService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (response *http.Response, err error) {
|
||||||
|
conn, response, err := websocket.ClientConnect(req, wsc.tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
err = stream.WriteHeaders(h1ResponseToH2Response(response))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error writing response header to websocket origin")
|
||||||
|
}
|
||||||
|
// Copy to/from stream to the undelying connection. Use the underlying
|
||||||
|
// connection because cloudflared doesn't operate on the message themselves
|
||||||
|
websocket.Stream(conn.UnderlyingConn(), stream)
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wsc *WebsocketService) Shutdown() {
|
||||||
|
close(wsc.shutdownC)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HelloWorldService talks to the hello world example origin
|
||||||
|
type HelloWorldService struct {
|
||||||
|
client http.RoundTripper
|
||||||
|
listener net.Listener
|
||||||
|
shutdownC chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHelloWorldService(transport http.RoundTripper) (OriginService, error) {
|
||||||
|
listener, err := hello.CreateTLSListener("127.0.0.1:")
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Cannot start Hello World Server")
|
||||||
|
}
|
||||||
|
shutdownC := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
hello.StartHelloWorldServer(log.CreateLogger(), listener, shutdownC)
|
||||||
|
}()
|
||||||
|
return &HelloWorldService{
|
||||||
|
client: transport,
|
||||||
|
listener: listener,
|
||||||
|
shutdownC: shutdownC,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hwc *HelloWorldService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||||
|
// Request origin to keep connection alive to improve performance
|
||||||
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
|
||||||
|
resp, err := hwc.client.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error proxying request to Hello World origin")
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
err = stream.WriteHeaders(h1ResponseToH2Response(resp))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error writing response header to Hello World origin")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||||
|
// compression generates dictionary on first write
|
||||||
|
io.CopyBuffer(stream, resp.Body, make([]byte, 512*1024))
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hwc *HelloWorldService) Shutdown() {
|
||||||
|
hwc.listener.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isEventStream(resp *http.Response) bool {
|
||||||
|
// Check if content-type is text/event-stream. We need to check if the header value starts with text/event-stream
|
||||||
|
// because text/event-stream; charset=UTF-8 is also valid
|
||||||
|
// Ref: https://tools.ietf.org/html/rfc7231#section-3.1.1.1
|
||||||
|
for _, contentType := range resp.Header["content-type"] {
|
||||||
|
if strings.HasPrefix(strings.ToLower(contentType), "text/event-stream") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeEventStream(stream *h2mux.MuxedStream, respBody io.ReadCloser) {
|
||||||
|
reader := bufio.NewReader(respBody)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadBytes('\n')
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
stream.Write(line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func h1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
|
||||||
|
h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}}
|
||||||
|
for headerName, headerValues := range h1.Header {
|
||||||
|
for _, headerValue := range headerValues {
|
||||||
|
h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
package originservice
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsEventStream(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
resp *http.Response
|
||||||
|
isEventStream bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
resp: &http.Response{},
|
||||||
|
isEventStream: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// isEventStream checks all headers
|
||||||
|
resp: &http.Response{
|
||||||
|
Header: http.Header{
|
||||||
|
"accept": []string{"text/html"},
|
||||||
|
"content-type": []string{"text/event-stream"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
isEventStream: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Content-Type and text/event-stream are case-insensitive. text/event-stream can be followed by OWS parameter
|
||||||
|
resp: &http.Response{
|
||||||
|
Header: http.Header{
|
||||||
|
"content-type": []string{"Text/event-stream;charset=utf-8"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
isEventStream: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Content-Type and text/event-stream are case-insensitive. text/event-stream can be followed by OWS parameter
|
||||||
|
resp: &http.Response{
|
||||||
|
Header: http.Header{
|
||||||
|
"content-type": []string{"appication/json", "text/html", "Text/event-stream;charset=utf-8"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
isEventStream: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Not an event stream because the content-type value doesn't start with text/event-stream
|
||||||
|
resp: &http.Response{
|
||||||
|
Header: http.Header{
|
||||||
|
"content-type": []string{" text/event-stream"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
isEventStream: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
assert.Equal(t, test.isEventStream, isEventStream(test.resp), "Header: %v", test.resp.Header)
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,10 +2,22 @@ package tlsconfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/getsentry/raven-go"
|
"github.com/getsentry/raven-go"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"gopkg.in/urfave/cli.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
OriginCAPoolFlag = "origin-ca-pool"
|
||||||
|
CaCertFlag = "cacert"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CertReloader can load and reload a TLS certificate from a particular filepath.
|
// CertReloader can load and reload a TLS certificate from a particular filepath.
|
||||||
|
@ -51,3 +63,120 @@ func (cr *CertReloader) LoadCert() error {
|
||||||
cr.certificate = &cert
|
cr.certificate = &cert
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LoadOriginCA(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) {
|
||||||
|
var originCustomCAPool []byte
|
||||||
|
|
||||||
|
originCAPoolFilename := c.String(OriginCAPoolFlag)
|
||||||
|
if originCAPoolFilename != "" {
|
||||||
|
var err error
|
||||||
|
originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, OriginCAPoolFlag))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
originCertPool, err := loadOriginCertPool(originCustomCAPool, logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error loading the certificate pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Windows users should be notified that they can use the flag
|
||||||
|
if runtime.GOOS == "windows" && originCAPoolFilename == "" {
|
||||||
|
logger.Infof("cloudflared does not support loading the system root certificate pool on Windows. Please use the --%s to specify it", OriginCAPoolFlag)
|
||||||
|
}
|
||||||
|
|
||||||
|
return originCertPool, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadCustomCertPool(customCertFilename string) (*x509.CertPool, error) {
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
customCAPoolPEM, err := ioutil.ReadFile(customCertFilename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s", customCertFilename))
|
||||||
|
}
|
||||||
|
if !pool.AppendCertsFromPEM(customCAPoolPEM) {
|
||||||
|
return nil, fmt.Errorf("error appending custom CA to cert pool")
|
||||||
|
}
|
||||||
|
return pool, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateTunnelConfig(c *cli.Context) (*tls.Config, error) {
|
||||||
|
var rootCAs []string
|
||||||
|
if c.String(CaCertFlag) != "" {
|
||||||
|
rootCAs = append(rootCAs, c.String(CaCertFlag))
|
||||||
|
}
|
||||||
|
|
||||||
|
userConfig := &TLSParameters{RootCAs: rootCAs}
|
||||||
|
tlsConfig, err := GetConfig(userConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConfig.RootCAs == nil {
|
||||||
|
rootCAPool := x509.NewCertPool()
|
||||||
|
cfRootCA, err := GetCloudflareRootCA()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
|
||||||
|
}
|
||||||
|
for _, cert := range cfRootCA {
|
||||||
|
rootCAPool.AddCert(cert)
|
||||||
|
}
|
||||||
|
tlsConfig.RootCAs = rootCAPool
|
||||||
|
tlsConfig.ServerName = "cftunnel.com"
|
||||||
|
} else if edgeAddrs := c.StringSlice("edge"); len(edgeAddrs) > 0 {
|
||||||
|
// Set for development environments and for testing specific origintunneld instances
|
||||||
|
tlsConfig.ServerName, _, _ = net.SplitHostPort(edgeAddrs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
|
||||||
|
return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config")
|
||||||
|
}
|
||||||
|
return tlsConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadOriginCertPool(originCAPoolPEM []byte, logger *logrus.Logger) (*x509.CertPool, error) {
|
||||||
|
// Get the global pool
|
||||||
|
certPool, err := loadGlobalCertPool(logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then, add any custom origin CA pool the user may have passed
|
||||||
|
if originCAPoolPEM != nil {
|
||||||
|
if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
|
||||||
|
logger.Warn("could not append the provided origin CA to the cloudflared certificate pool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return certPool, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadGlobalCertPool(logger *logrus.Logger) (*x509.CertPool, error) {
|
||||||
|
// First, obtain the system certificate pool
|
||||||
|
certPool, err := x509.SystemCertPool()
|
||||||
|
if err != nil {
|
||||||
|
if runtime.GOOS != "windows" { // See https://github.com/golang/go/issues/16736
|
||||||
|
logger.WithError(err).Warn("error obtaining the system certificates")
|
||||||
|
}
|
||||||
|
certPool = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, append the Cloudflare CAs into the system pool
|
||||||
|
cfRootCA, err := GetCloudflareRootCA()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
|
||||||
|
}
|
||||||
|
for _, cert := range cfRootCA {
|
||||||
|
certPool.AddCert(cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, add the Hello certificate into the pool (since it's self-signed)
|
||||||
|
helloCert, err := GetHelloCertificateX509()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool")
|
||||||
|
}
|
||||||
|
certPool.AddCert(helloCert)
|
||||||
|
|
||||||
|
return certPool, nil
|
||||||
|
}
|
||||||
|
|
|
@ -2,11 +2,18 @@ package pogs
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/originservice"
|
||||||
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
|
"github.com/pkg/errors"
|
||||||
capnp "zombiezen.com/go/capnproto2"
|
capnp "zombiezen.com/go/capnproto2"
|
||||||
"zombiezen.com/go/capnproto2/pogs"
|
"zombiezen.com/go/capnproto2/pogs"
|
||||||
"zombiezen.com/go/capnproto2/rpc"
|
"zombiezen.com/go/capnproto2/rpc"
|
||||||
|
@ -68,6 +75,9 @@ func NewReverseProxyConfig(
|
||||||
|
|
||||||
//go-sumtype:decl OriginConfig
|
//go-sumtype:decl OriginConfig
|
||||||
type OriginConfig interface {
|
type OriginConfig interface {
|
||||||
|
// Service returns a OriginService used to proxy to the origin
|
||||||
|
Service() (originservice.OriginService, error)
|
||||||
|
// go-sumtype requires at least one unexported method, otherwise it will complain that interface is not sealed
|
||||||
isOriginConfig()
|
isOriginConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,8 +96,6 @@ type HTTPOriginConfig struct {
|
||||||
ChunkedEncoding bool
|
ChunkedEncoding bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (_ *HTTPOriginConfig) isOriginConfig() {}
|
|
||||||
|
|
||||||
type OriginAddr interface {
|
type OriginAddr interface {
|
||||||
Addr() string
|
Addr() string
|
||||||
}
|
}
|
||||||
|
@ -119,6 +127,39 @@ func (up *UnixPath) Addr() string {
|
||||||
return up.Path
|
return up.Path
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
|
||||||
|
rootCAs, err := tlsconfig.LoadCustomCertPool(hc.OriginCAPool)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dialContext := (&net.Dialer{
|
||||||
|
Timeout: hc.ProxyConnectTimeout,
|
||||||
|
KeepAlive: hc.TCPKeepAlive,
|
||||||
|
DualStack: hc.DialDualStack,
|
||||||
|
}).DialContext
|
||||||
|
transport := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: dialContext,
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: rootCAs,
|
||||||
|
ServerName: hc.OriginServerName,
|
||||||
|
InsecureSkipVerify: hc.TLSVerify,
|
||||||
|
},
|
||||||
|
TLSHandshakeTimeout: hc.TLSHandshakeTimeout,
|
||||||
|
MaxIdleConns: int(hc.MaxIdleConnections),
|
||||||
|
IdleConnTimeout: hc.IdleConnectionTimeout,
|
||||||
|
ExpectContinueTimeout: hc.ExpectContinueTimeout,
|
||||||
|
}
|
||||||
|
if unixPath, ok := hc.URL.(*UnixPath); ok {
|
||||||
|
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
|
return dialContext(ctx, "unix", unixPath.Addr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return originservice.NewHTTPService(transport, hc.URL.Addr(), hc.ChunkedEncoding), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_ *HTTPOriginConfig) isOriginConfig() {}
|
||||||
|
|
||||||
type WebSocketOriginConfig struct {
|
type WebSocketOriginConfig struct {
|
||||||
URL string `capnp:"url"`
|
URL string `capnp:"url"`
|
||||||
TLSVerify bool `capnp:"tlsVerify"`
|
TLSVerify bool `capnp:"tlsVerify"`
|
||||||
|
@ -126,10 +167,48 @@ type WebSocketOriginConfig struct {
|
||||||
OriginServerName string
|
OriginServerName string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error) {
|
||||||
|
rootCAs, err := tlsconfig.LoadCustomCertPool(wsc.OriginCAPool)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
RootCAs: rootCAs,
|
||||||
|
ServerName: wsc.OriginServerName,
|
||||||
|
InsecureSkipVerify: wsc.TLSVerify,
|
||||||
|
}
|
||||||
|
return originservice.NewWebSocketService(tlsConfig, wsc.URL)
|
||||||
|
}
|
||||||
|
|
||||||
func (_ *WebSocketOriginConfig) isOriginConfig() {}
|
func (_ *WebSocketOriginConfig) isOriginConfig() {}
|
||||||
|
|
||||||
type HelloWorldOriginConfig struct{}
|
type HelloWorldOriginConfig struct{}
|
||||||
|
|
||||||
|
func (_ *HelloWorldOriginConfig) Service() (originservice.OriginService, error) {
|
||||||
|
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Cannot get Hello World server certificate")
|
||||||
|
}
|
||||||
|
rootCAs := x509.NewCertPool()
|
||||||
|
rootCAs.AddCert(helloCert)
|
||||||
|
transport := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
DualStack: true,
|
||||||
|
}).DialContext,
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: rootCAs,
|
||||||
|
},
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
return originservice.NewHelloWorldService(transport)
|
||||||
|
}
|
||||||
|
|
||||||
func (_ *HelloWorldOriginConfig) isOriginConfig() {}
|
func (_ *HelloWorldOriginConfig) isOriginConfig() {}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -3,8 +3,9 @@
|
||||||
package tunnelrpc
|
package tunnelrpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
context "golang.org/x/net/context"
|
|
||||||
strconv "strconv"
|
strconv "strconv"
|
||||||
|
|
||||||
|
context "golang.org/x/net/context"
|
||||||
capnp "zombiezen.com/go/capnproto2"
|
capnp "zombiezen.com/go/capnproto2"
|
||||||
text "zombiezen.com/go/capnproto2/encoding/text"
|
text "zombiezen.com/go/capnproto2/encoding/text"
|
||||||
schemas "zombiezen.com/go/capnproto2/schemas"
|
schemas "zombiezen.com/go/capnproto2/schemas"
|
||||||
|
|
Loading…
Reference in New Issue