TUN-1913: Define OriginService for each type of origin

This commit is contained in:
Chung-Ting Huang 2019-05-28 15:53:35 -05:00
parent acd17f6ab6
commit d32fb8e82c
7 changed files with 476 additions and 115 deletions

View File

@ -21,6 +21,7 @@ import (
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns"
"github.com/cloudflare/cloudflared/websocket"
"github.com/coreos/go-systemd/daemon"
@ -444,7 +445,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "cacert",
Name: tlsconfig.CaCertFlag,
Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.",
EnvVars: []string{"TUNNEL_CACERT"},
Hidden: true,
@ -463,7 +464,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "origin-ca-pool",
Name: tlsconfig.OriginCAPoolFlag,
Usage: "Path to the CA for the certificate of your origin. This option should be used only if your certificate is not signed by Cloudflare.",
EnvVars: []string{"TUNNEL_ORIGIN_CA_POOL"},
Hidden: shouldHide,

View File

@ -3,14 +3,12 @@ package tunnel
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"time"
@ -187,7 +185,7 @@ func prepareTunnelConfig(
}
}
originCertPool, err := loadCertPool(c, logger)
originCertPool, err := tlsconfig.LoadOriginCA(c, logger)
if err != nil {
logger.WithError(err).Error("Error loading cert pool")
return nil, errors.Wrap(err, "Error loading cert pool")
@ -236,7 +234,7 @@ func prepareTunnelConfig(
return nil, errors.Wrap(err, "unable to connect to the origin")
}
toEdgeTLSConfig, err := createTunnelConfig(c)
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c)
if err != nil {
logger.WithError(err).Error("unable to create TLS config to connect with edge")
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
@ -274,112 +272,6 @@ func prepareTunnelConfig(
}, nil
}
func loadCertPool(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) {
const originCAPoolFlag = "origin-ca-pool"
originCAPoolFilename := c.String(originCAPoolFlag)
var originCustomCAPool []byte
if originCAPoolFilename != "" {
var err error
originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, originCAPoolFlag))
}
}
originCertPool, err := loadOriginCertPool(originCustomCAPool)
if err != nil {
return nil, errors.Wrap(err, "error loading the certificate pool")
}
// Windows users should be notified that they can use the flag
if runtime.GOOS == "windows" && originCAPoolFilename == "" {
logger.Infof("cloudflared does not support loading the system root certificate pool on Windows. Please use the --%s to specify it", originCAPoolFlag)
}
return originCertPool, nil
}
func loadOriginCertPool(originCAPoolPEM []byte) (*x509.CertPool, error) {
// Get the global pool
certPool, err := loadGlobalCertPool()
if err != nil {
return nil, err
}
// Then, add any custom origin CA pool the user may have passed
if originCAPoolPEM != nil {
if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
logger.Warn("could not append the provided origin CA to the cloudflared certificate pool")
}
}
return certPool, nil
}
func loadGlobalCertPool() (*x509.CertPool, error) {
// First, obtain the system certificate pool
certPool, err := x509.SystemCertPool()
if err != nil {
if runtime.GOOS != "windows" {
logger.WithError(err).Warn("error obtaining the system certificates")
}
certPool = x509.NewCertPool()
}
// Next, append the Cloudflare CAs into the system pool
cfRootCA, err := tlsconfig.GetCloudflareRootCA()
if err != nil {
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
}
for _, cert := range cfRootCA {
certPool.AddCert(cert)
}
// Finally, add the Hello certificate into the pool (since it's self-signed)
helloCert, err := tlsconfig.GetHelloCertificateX509()
if err != nil {
return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool")
}
certPool.AddCert(helloCert)
return certPool, nil
}
func createTunnelConfig(c *cli.Context) (*tls.Config, error) {
var rootCAs []string
if c.String("cacert") != "" {
rootCAs = append(rootCAs, c.String("cacert"))
}
edgeAddrs := c.StringSlice("edge")
userConfig := &tlsconfig.TLSParameters{RootCAs: rootCAs}
tlsConfig, err := tlsconfig.GetConfig(userConfig)
if err != nil {
return nil, err
}
if tlsConfig.RootCAs == nil {
rootCAPool := x509.NewCertPool()
cfRootCA, err := tlsconfig.GetCloudflareRootCA()
if err != nil {
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
}
for _, cert := range cfRootCA {
rootCAPool.AddCert(cert)
}
tlsConfig.RootCAs = rootCAPool
tlsConfig.ServerName = "cftunnel.com"
} else if len(edgeAddrs) > 0 {
// Set for development environments and for testing specific origintunneld instances
tlsConfig.ServerName, _, _ = net.SplitHostPort(edgeAddrs[0])
}
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config")
}
return tlsConfig, nil
}
func isRunningFromTerminal() bool {
return terminal.IsTerminal(int(os.Stdout.Fd()))
}

View File

@ -0,0 +1,199 @@
// Package client defines and implements interface to proxy to HTTP, websocket and hello world origins
package originservice
import (
"bufio"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/log"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
)
// OriginService is an interface to proxy requests to different type of origins
type OriginService interface {
Proxy(stream *h2mux.MuxedStream, req *http.Request) (resp *http.Response, err error)
Shutdown()
}
// HTTPService talks to origin using HTTP/HTTPS
type HTTPService struct {
client http.RoundTripper
originAddr string
chunkedEncoding bool
}
func NewHTTPService(transport http.RoundTripper, originAddr string, chunkedEncoding bool) OriginService {
return &HTTPService{
client: transport,
originAddr: originAddr,
chunkedEncoding: chunkedEncoding,
}
}
func (hc *HTTPService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if !hc.chunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if err == nil {
req.ContentLength = int64(cLength)
}
}
// Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive")
resp, err := hc.client.RoundTrip(req)
if err != nil {
return nil, errors.Wrap(err, "Error proxying request to HTTP origin")
}
defer resp.Body.Close()
err = stream.WriteHeaders(h1ResponseToH2Response(resp))
if err != nil {
return nil, errors.Wrap(err, "Error writing response header to HTTP origin")
}
if isEventStream(resp) {
writeEventStream(stream, resp.Body)
} else {
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write
io.CopyBuffer(stream, resp.Body, make([]byte, 512*1024))
}
return resp, nil
}
func (hc *HTTPService) Shutdown() {}
// WebsocketService talks to origin using WS/WSS
type WebsocketService struct {
tlsConfig *tls.Config
shutdownC chan struct{}
}
func NewWebSocketService(tlsConfig *tls.Config, url string) (OriginService, error) {
listener, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
return nil, errors.Wrap(err, "Cannot start Websocket Proxy Server")
}
shutdownC := make(chan struct{})
go func() {
websocket.StartProxyServer(log.CreateLogger(), listener, url, shutdownC)
}()
return &WebsocketService{
tlsConfig: tlsConfig,
shutdownC: shutdownC,
}, nil
}
func (wsc *WebsocketService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (response *http.Response, err error) {
conn, response, err := websocket.ClientConnect(req, wsc.tlsConfig)
if err != nil {
return nil, err
}
defer conn.Close()
err = stream.WriteHeaders(h1ResponseToH2Response(response))
if err != nil {
return nil, errors.Wrap(err, "Error writing response header to websocket origin")
}
// Copy to/from stream to the undelying connection. Use the underlying
// connection because cloudflared doesn't operate on the message themselves
websocket.Stream(conn.UnderlyingConn(), stream)
return response, nil
}
func (wsc *WebsocketService) Shutdown() {
close(wsc.shutdownC)
}
// HelloWorldService talks to the hello world example origin
type HelloWorldService struct {
client http.RoundTripper
listener net.Listener
shutdownC chan struct{}
}
func NewHelloWorldService(transport http.RoundTripper) (OriginService, error) {
listener, err := hello.CreateTLSListener("127.0.0.1:")
if err != nil {
return nil, errors.Wrap(err, "Cannot start Hello World Server")
}
shutdownC := make(chan struct{})
go func() {
hello.StartHelloWorldServer(log.CreateLogger(), listener, shutdownC)
}()
return &HelloWorldService{
client: transport,
listener: listener,
shutdownC: shutdownC,
}, nil
}
func (hwc *HelloWorldService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
// Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive")
resp, err := hwc.client.RoundTrip(req)
if err != nil {
return nil, errors.Wrap(err, "Error proxying request to Hello World origin")
}
defer resp.Body.Close()
err = stream.WriteHeaders(h1ResponseToH2Response(resp))
if err != nil {
return nil, errors.Wrap(err, "Error writing response header to Hello World origin")
}
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write
io.CopyBuffer(stream, resp.Body, make([]byte, 512*1024))
return resp, nil
}
func (hwc *HelloWorldService) Shutdown() {
hwc.listener.Close()
}
func isEventStream(resp *http.Response) bool {
// Check if content-type is text/event-stream. We need to check if the header value starts with text/event-stream
// because text/event-stream; charset=UTF-8 is also valid
// Ref: https://tools.ietf.org/html/rfc7231#section-3.1.1.1
for _, contentType := range resp.Header["content-type"] {
if strings.HasPrefix(strings.ToLower(contentType), "text/event-stream") {
return true
}
}
return false
}
func writeEventStream(stream *h2mux.MuxedStream, respBody io.ReadCloser) {
reader := bufio.NewReader(respBody)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
break
}
stream.Write(line)
}
}
func h1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}}
for headerName, headerValues := range h1.Header {
for _, headerValue := range headerValues {
h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue})
}
}
return
}

View File

@ -0,0 +1,60 @@
package originservice
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsEventStream(t *testing.T) {
tests := []struct {
resp *http.Response
isEventStream bool
}{
{
resp: &http.Response{},
isEventStream: false,
},
{
// isEventStream checks all headers
resp: &http.Response{
Header: http.Header{
"accept": []string{"text/html"},
"content-type": []string{"text/event-stream"},
},
},
isEventStream: true,
},
{
// Content-Type and text/event-stream are case-insensitive. text/event-stream can be followed by OWS parameter
resp: &http.Response{
Header: http.Header{
"content-type": []string{"Text/event-stream;charset=utf-8"},
},
},
isEventStream: true,
},
{
// Content-Type and text/event-stream are case-insensitive. text/event-stream can be followed by OWS parameter
resp: &http.Response{
Header: http.Header{
"content-type": []string{"appication/json", "text/html", "Text/event-stream;charset=utf-8"},
},
},
isEventStream: true,
},
{
// Not an event stream because the content-type value doesn't start with text/event-stream
resp: &http.Response{
Header: http.Header{
"content-type": []string{" text/event-stream"},
},
},
isEventStream: false,
},
}
for _, test := range tests {
assert.Equal(t, test.isEventStream, isEventStream(test.resp), "Header: %v", test.resp.Header)
}
}

View File

@ -2,10 +2,22 @@ package tlsconfig
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"runtime"
"sync"
"github.com/getsentry/raven-go"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"gopkg.in/urfave/cli.v2"
)
const (
OriginCAPoolFlag = "origin-ca-pool"
CaCertFlag = "cacert"
)
// CertReloader can load and reload a TLS certificate from a particular filepath.
@ -51,3 +63,120 @@ func (cr *CertReloader) LoadCert() error {
cr.certificate = &cert
return nil
}
func LoadOriginCA(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) {
var originCustomCAPool []byte
originCAPoolFilename := c.String(OriginCAPoolFlag)
if originCAPoolFilename != "" {
var err error
originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, OriginCAPoolFlag))
}
}
originCertPool, err := loadOriginCertPool(originCustomCAPool, logger)
if err != nil {
return nil, errors.Wrap(err, "error loading the certificate pool")
}
// Windows users should be notified that they can use the flag
if runtime.GOOS == "windows" && originCAPoolFilename == "" {
logger.Infof("cloudflared does not support loading the system root certificate pool on Windows. Please use the --%s to specify it", OriginCAPoolFlag)
}
return originCertPool, nil
}
func LoadCustomCertPool(customCertFilename string) (*x509.CertPool, error) {
pool := x509.NewCertPool()
customCAPoolPEM, err := ioutil.ReadFile(customCertFilename)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s", customCertFilename))
}
if !pool.AppendCertsFromPEM(customCAPoolPEM) {
return nil, fmt.Errorf("error appending custom CA to cert pool")
}
return pool, nil
}
func CreateTunnelConfig(c *cli.Context) (*tls.Config, error) {
var rootCAs []string
if c.String(CaCertFlag) != "" {
rootCAs = append(rootCAs, c.String(CaCertFlag))
}
userConfig := &TLSParameters{RootCAs: rootCAs}
tlsConfig, err := GetConfig(userConfig)
if err != nil {
return nil, err
}
if tlsConfig.RootCAs == nil {
rootCAPool := x509.NewCertPool()
cfRootCA, err := GetCloudflareRootCA()
if err != nil {
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
}
for _, cert := range cfRootCA {
rootCAPool.AddCert(cert)
}
tlsConfig.RootCAs = rootCAPool
tlsConfig.ServerName = "cftunnel.com"
} else if edgeAddrs := c.StringSlice("edge"); len(edgeAddrs) > 0 {
// Set for development environments and for testing specific origintunneld instances
tlsConfig.ServerName, _, _ = net.SplitHostPort(edgeAddrs[0])
}
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config")
}
return tlsConfig, nil
}
func loadOriginCertPool(originCAPoolPEM []byte, logger *logrus.Logger) (*x509.CertPool, error) {
// Get the global pool
certPool, err := loadGlobalCertPool(logger)
if err != nil {
return nil, err
}
// Then, add any custom origin CA pool the user may have passed
if originCAPoolPEM != nil {
if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
logger.Warn("could not append the provided origin CA to the cloudflared certificate pool")
}
}
return certPool, nil
}
func loadGlobalCertPool(logger *logrus.Logger) (*x509.CertPool, error) {
// First, obtain the system certificate pool
certPool, err := x509.SystemCertPool()
if err != nil {
if runtime.GOOS != "windows" { // See https://github.com/golang/go/issues/16736
logger.WithError(err).Warn("error obtaining the system certificates")
}
certPool = x509.NewCertPool()
}
// Next, append the Cloudflare CAs into the system pool
cfRootCA, err := GetCloudflareRootCA()
if err != nil {
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
}
for _, cert := range cfRootCA {
certPool.AddCert(cert)
}
// Finally, add the Hello certificate into the pool (since it's self-signed)
helloCert, err := GetHelloCertificateX509()
if err != nil {
return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool")
}
certPool.AddCert(helloCert)
return certPool, nil
}

View File

@ -2,11 +2,18 @@ package pogs
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
"net/url"
"time"
"github.com/cloudflare/cloudflared/originservice"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/pkg/errors"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs"
"zombiezen.com/go/capnproto2/rpc"
@ -68,6 +75,9 @@ func NewReverseProxyConfig(
//go-sumtype:decl OriginConfig
type OriginConfig interface {
// Service returns a OriginService used to proxy to the origin
Service() (originservice.OriginService, error)
// go-sumtype requires at least one unexported method, otherwise it will complain that interface is not sealed
isOriginConfig()
}
@ -86,8 +96,6 @@ type HTTPOriginConfig struct {
ChunkedEncoding bool
}
func (_ *HTTPOriginConfig) isOriginConfig() {}
type OriginAddr interface {
Addr() string
}
@ -119,6 +127,39 @@ func (up *UnixPath) Addr() string {
return up.Path
}
func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
rootCAs, err := tlsconfig.LoadCustomCertPool(hc.OriginCAPool)
if err != nil {
return nil, err
}
dialContext := (&net.Dialer{
Timeout: hc.ProxyConnectTimeout,
KeepAlive: hc.TCPKeepAlive,
DualStack: hc.DialDualStack,
}).DialContext
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialContext,
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
ServerName: hc.OriginServerName,
InsecureSkipVerify: hc.TLSVerify,
},
TLSHandshakeTimeout: hc.TLSHandshakeTimeout,
MaxIdleConns: int(hc.MaxIdleConnections),
IdleConnTimeout: hc.IdleConnectionTimeout,
ExpectContinueTimeout: hc.ExpectContinueTimeout,
}
if unixPath, ok := hc.URL.(*UnixPath); ok {
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialContext(ctx, "unix", unixPath.Addr())
}
}
return originservice.NewHTTPService(transport, hc.URL.Addr(), hc.ChunkedEncoding), nil
}
func (_ *HTTPOriginConfig) isOriginConfig() {}
type WebSocketOriginConfig struct {
URL string `capnp:"url"`
TLSVerify bool `capnp:"tlsVerify"`
@ -126,10 +167,48 @@ type WebSocketOriginConfig struct {
OriginServerName string
}
func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error) {
rootCAs, err := tlsconfig.LoadCustomCertPool(wsc.OriginCAPool)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{
RootCAs: rootCAs,
ServerName: wsc.OriginServerName,
InsecureSkipVerify: wsc.TLSVerify,
}
return originservice.NewWebSocketService(tlsConfig, wsc.URL)
}
func (_ *WebSocketOriginConfig) isOriginConfig() {}
type HelloWorldOriginConfig struct{}
func (_ *HelloWorldOriginConfig) Service() (originservice.OriginService, error) {
helloCert, err := tlsconfig.GetHelloCertificateX509()
if err != nil {
return nil, errors.Wrap(err, "Cannot get Hello World server certificate")
}
rootCAs := x509.NewCertPool()
rootCAs.AddCert(helloCert)
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
},
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
return originservice.NewHelloWorldService(transport)
}
func (_ *HelloWorldOriginConfig) isOriginConfig() {}
/*

View File

@ -3,8 +3,9 @@
package tunnelrpc
import (
context "golang.org/x/net/context"
strconv "strconv"
context "golang.org/x/net/context"
capnp "zombiezen.com/go/capnproto2"
text "zombiezen.com/go/capnproto2/encoding/text"
schemas "zombiezen.com/go/capnproto2/schemas"