TUN-1913: Define OriginService for each type of origin
This commit is contained in:
parent
acd17f6ab6
commit
d32fb8e82c
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/cloudflare/cloudflared/metrics"
|
||||
"github.com/cloudflare/cloudflared/origin"
|
||||
"github.com/cloudflare/cloudflared/signal"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
"github.com/cloudflare/cloudflared/tunneldns"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
"github.com/coreos/go-systemd/daemon"
|
||||
|
@ -444,7 +445,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "cacert",
|
||||
Name: tlsconfig.CaCertFlag,
|
||||
Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.",
|
||||
EnvVars: []string{"TUNNEL_CACERT"},
|
||||
Hidden: true,
|
||||
|
@ -463,7 +464,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "origin-ca-pool",
|
||||
Name: tlsconfig.OriginCAPoolFlag,
|
||||
Usage: "Path to the CA for the certificate of your origin. This option should be used only if your certificate is not signed by Cloudflare.",
|
||||
EnvVars: []string{"TUNNEL_ORIGIN_CA_POOL"},
|
||||
Hidden: shouldHide,
|
||||
|
|
|
@ -3,14 +3,12 @@ package tunnel
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -187,7 +185,7 @@ func prepareTunnelConfig(
|
|||
}
|
||||
}
|
||||
|
||||
originCertPool, err := loadCertPool(c, logger)
|
||||
originCertPool, err := tlsconfig.LoadOriginCA(c, logger)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Error loading cert pool")
|
||||
return nil, errors.Wrap(err, "Error loading cert pool")
|
||||
|
@ -236,7 +234,7 @@ func prepareTunnelConfig(
|
|||
return nil, errors.Wrap(err, "unable to connect to the origin")
|
||||
}
|
||||
|
||||
toEdgeTLSConfig, err := createTunnelConfig(c)
|
||||
toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("unable to create TLS config to connect with edge")
|
||||
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
|
||||
|
@ -274,112 +272,6 @@ func prepareTunnelConfig(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func loadCertPool(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) {
|
||||
const originCAPoolFlag = "origin-ca-pool"
|
||||
originCAPoolFilename := c.String(originCAPoolFlag)
|
||||
var originCustomCAPool []byte
|
||||
|
||||
if originCAPoolFilename != "" {
|
||||
var err error
|
||||
originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, originCAPoolFlag))
|
||||
}
|
||||
}
|
||||
|
||||
originCertPool, err := loadOriginCertPool(originCustomCAPool)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error loading the certificate pool")
|
||||
}
|
||||
|
||||
// Windows users should be notified that they can use the flag
|
||||
if runtime.GOOS == "windows" && originCAPoolFilename == "" {
|
||||
logger.Infof("cloudflared does not support loading the system root certificate pool on Windows. Please use the --%s to specify it", originCAPoolFlag)
|
||||
}
|
||||
|
||||
return originCertPool, nil
|
||||
}
|
||||
|
||||
func loadOriginCertPool(originCAPoolPEM []byte) (*x509.CertPool, error) {
|
||||
// Get the global pool
|
||||
certPool, err := loadGlobalCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Then, add any custom origin CA pool the user may have passed
|
||||
if originCAPoolPEM != nil {
|
||||
if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
|
||||
logger.Warn("could not append the provided origin CA to the cloudflared certificate pool")
|
||||
}
|
||||
}
|
||||
|
||||
return certPool, nil
|
||||
}
|
||||
|
||||
func loadGlobalCertPool() (*x509.CertPool, error) {
|
||||
// First, obtain the system certificate pool
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
if runtime.GOOS != "windows" {
|
||||
logger.WithError(err).Warn("error obtaining the system certificates")
|
||||
}
|
||||
certPool = x509.NewCertPool()
|
||||
}
|
||||
|
||||
// Next, append the Cloudflare CAs into the system pool
|
||||
cfRootCA, err := tlsconfig.GetCloudflareRootCA()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
|
||||
}
|
||||
for _, cert := range cfRootCA {
|
||||
certPool.AddCert(cert)
|
||||
}
|
||||
|
||||
// Finally, add the Hello certificate into the pool (since it's self-signed)
|
||||
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool")
|
||||
}
|
||||
certPool.AddCert(helloCert)
|
||||
|
||||
return certPool, nil
|
||||
}
|
||||
|
||||
func createTunnelConfig(c *cli.Context) (*tls.Config, error) {
|
||||
var rootCAs []string
|
||||
if c.String("cacert") != "" {
|
||||
rootCAs = append(rootCAs, c.String("cacert"))
|
||||
}
|
||||
edgeAddrs := c.StringSlice("edge")
|
||||
|
||||
userConfig := &tlsconfig.TLSParameters{RootCAs: rootCAs}
|
||||
tlsConfig, err := tlsconfig.GetConfig(userConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tlsConfig.RootCAs == nil {
|
||||
rootCAPool := x509.NewCertPool()
|
||||
cfRootCA, err := tlsconfig.GetCloudflareRootCA()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
|
||||
}
|
||||
for _, cert := range cfRootCA {
|
||||
rootCAPool.AddCert(cert)
|
||||
}
|
||||
tlsConfig.RootCAs = rootCAPool
|
||||
tlsConfig.ServerName = "cftunnel.com"
|
||||
} else if len(edgeAddrs) > 0 {
|
||||
// Set for development environments and for testing specific origintunneld instances
|
||||
tlsConfig.ServerName, _, _ = net.SplitHostPort(edgeAddrs[0])
|
||||
}
|
||||
|
||||
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
|
||||
return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config")
|
||||
}
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func isRunningFromTerminal() bool {
|
||||
return terminal.IsTerminal(int(os.Stdout.Fd()))
|
||||
}
|
||||
|
|
|
@ -0,0 +1,199 @@
|
|||
// Package client defines and implements interface to proxy to HTTP, websocket and hello world origins
|
||||
package originservice
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/log"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// OriginService is an interface to proxy requests to different type of origins
|
||||
type OriginService interface {
|
||||
Proxy(stream *h2mux.MuxedStream, req *http.Request) (resp *http.Response, err error)
|
||||
Shutdown()
|
||||
}
|
||||
|
||||
// HTTPService talks to origin using HTTP/HTTPS
|
||||
type HTTPService struct {
|
||||
client http.RoundTripper
|
||||
originAddr string
|
||||
chunkedEncoding bool
|
||||
}
|
||||
|
||||
func NewHTTPService(transport http.RoundTripper, originAddr string, chunkedEncoding bool) OriginService {
|
||||
return &HTTPService{
|
||||
client: transport,
|
||||
originAddr: originAddr,
|
||||
chunkedEncoding: chunkedEncoding,
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *HTTPService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||
if !hc.chunkedEncoding {
|
||||
req.TransferEncoding = []string{"gzip", "deflate"}
|
||||
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
||||
if err == nil {
|
||||
req.ContentLength = int64(cLength)
|
||||
}
|
||||
}
|
||||
|
||||
// Request origin to keep connection alive to improve performance
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
resp, err := hc.client.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error proxying request to HTTP origin")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = stream.WriteHeaders(h1ResponseToH2Response(resp))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error writing response header to HTTP origin")
|
||||
}
|
||||
if isEventStream(resp) {
|
||||
writeEventStream(stream, resp.Body)
|
||||
} else {
|
||||
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||
// compression generates dictionary on first write
|
||||
io.CopyBuffer(stream, resp.Body, make([]byte, 512*1024))
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (hc *HTTPService) Shutdown() {}
|
||||
|
||||
// WebsocketService talks to origin using WS/WSS
|
||||
type WebsocketService struct {
|
||||
tlsConfig *tls.Config
|
||||
shutdownC chan struct{}
|
||||
}
|
||||
|
||||
func NewWebSocketService(tlsConfig *tls.Config, url string) (OriginService, error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Cannot start Websocket Proxy Server")
|
||||
}
|
||||
shutdownC := make(chan struct{})
|
||||
go func() {
|
||||
websocket.StartProxyServer(log.CreateLogger(), listener, url, shutdownC)
|
||||
}()
|
||||
return &WebsocketService{
|
||||
tlsConfig: tlsConfig,
|
||||
shutdownC: shutdownC,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (wsc *WebsocketService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (response *http.Response, err error) {
|
||||
conn, response, err := websocket.ClientConnect(req, wsc.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
err = stream.WriteHeaders(h1ResponseToH2Response(response))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error writing response header to websocket origin")
|
||||
}
|
||||
// Copy to/from stream to the undelying connection. Use the underlying
|
||||
// connection because cloudflared doesn't operate on the message themselves
|
||||
websocket.Stream(conn.UnderlyingConn(), stream)
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (wsc *WebsocketService) Shutdown() {
|
||||
close(wsc.shutdownC)
|
||||
}
|
||||
|
||||
// HelloWorldService talks to the hello world example origin
|
||||
type HelloWorldService struct {
|
||||
client http.RoundTripper
|
||||
listener net.Listener
|
||||
shutdownC chan struct{}
|
||||
}
|
||||
|
||||
func NewHelloWorldService(transport http.RoundTripper) (OriginService, error) {
|
||||
listener, err := hello.CreateTLSListener("127.0.0.1:")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Cannot start Hello World Server")
|
||||
}
|
||||
shutdownC := make(chan struct{})
|
||||
go func() {
|
||||
hello.StartHelloWorldServer(log.CreateLogger(), listener, shutdownC)
|
||||
}()
|
||||
return &HelloWorldService{
|
||||
client: transport,
|
||||
listener: listener,
|
||||
shutdownC: shutdownC,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (hwc *HelloWorldService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||
// Request origin to keep connection alive to improve performance
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
resp, err := hwc.client.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error proxying request to Hello World origin")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = stream.WriteHeaders(h1ResponseToH2Response(resp))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error writing response header to Hello World origin")
|
||||
}
|
||||
|
||||
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||
// compression generates dictionary on first write
|
||||
io.CopyBuffer(stream, resp.Body, make([]byte, 512*1024))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (hwc *HelloWorldService) Shutdown() {
|
||||
hwc.listener.Close()
|
||||
}
|
||||
|
||||
func isEventStream(resp *http.Response) bool {
|
||||
// Check if content-type is text/event-stream. We need to check if the header value starts with text/event-stream
|
||||
// because text/event-stream; charset=UTF-8 is also valid
|
||||
// Ref: https://tools.ietf.org/html/rfc7231#section-3.1.1.1
|
||||
for _, contentType := range resp.Header["content-type"] {
|
||||
if strings.HasPrefix(strings.ToLower(contentType), "text/event-stream") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func writeEventStream(stream *h2mux.MuxedStream, respBody io.ReadCloser) {
|
||||
reader := bufio.NewReader(respBody)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
stream.Write(line)
|
||||
}
|
||||
}
|
||||
|
||||
func h1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
|
||||
h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}}
|
||||
for headerName, headerValues := range h1.Header {
|
||||
for _, headerValue := range headerValues {
|
||||
h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package originservice
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsEventStream(t *testing.T) {
|
||||
tests := []struct {
|
||||
resp *http.Response
|
||||
isEventStream bool
|
||||
}{
|
||||
{
|
||||
resp: &http.Response{},
|
||||
isEventStream: false,
|
||||
},
|
||||
{
|
||||
// isEventStream checks all headers
|
||||
resp: &http.Response{
|
||||
Header: http.Header{
|
||||
"accept": []string{"text/html"},
|
||||
"content-type": []string{"text/event-stream"},
|
||||
},
|
||||
},
|
||||
isEventStream: true,
|
||||
},
|
||||
{
|
||||
// Content-Type and text/event-stream are case-insensitive. text/event-stream can be followed by OWS parameter
|
||||
resp: &http.Response{
|
||||
Header: http.Header{
|
||||
"content-type": []string{"Text/event-stream;charset=utf-8"},
|
||||
},
|
||||
},
|
||||
isEventStream: true,
|
||||
},
|
||||
{
|
||||
// Content-Type and text/event-stream are case-insensitive. text/event-stream can be followed by OWS parameter
|
||||
resp: &http.Response{
|
||||
Header: http.Header{
|
||||
"content-type": []string{"appication/json", "text/html", "Text/event-stream;charset=utf-8"},
|
||||
},
|
||||
},
|
||||
isEventStream: true,
|
||||
},
|
||||
{
|
||||
// Not an event stream because the content-type value doesn't start with text/event-stream
|
||||
resp: &http.Response{
|
||||
Header: http.Header{
|
||||
"content-type": []string{" text/event-stream"},
|
||||
},
|
||||
},
|
||||
isEventStream: false,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
assert.Equal(t, test.isEventStream, isEventStream(test.resp), "Header: %v", test.resp.Header)
|
||||
}
|
||||
}
|
|
@ -2,10 +2,22 @@ package tlsconfig
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/getsentry/raven-go"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
OriginCAPoolFlag = "origin-ca-pool"
|
||||
CaCertFlag = "cacert"
|
||||
)
|
||||
|
||||
// CertReloader can load and reload a TLS certificate from a particular filepath.
|
||||
|
@ -51,3 +63,120 @@ func (cr *CertReloader) LoadCert() error {
|
|||
cr.certificate = &cert
|
||||
return nil
|
||||
}
|
||||
|
||||
func LoadOriginCA(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) {
|
||||
var originCustomCAPool []byte
|
||||
|
||||
originCAPoolFilename := c.String(OriginCAPoolFlag)
|
||||
if originCAPoolFilename != "" {
|
||||
var err error
|
||||
originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, OriginCAPoolFlag))
|
||||
}
|
||||
}
|
||||
|
||||
originCertPool, err := loadOriginCertPool(originCustomCAPool, logger)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error loading the certificate pool")
|
||||
}
|
||||
|
||||
// Windows users should be notified that they can use the flag
|
||||
if runtime.GOOS == "windows" && originCAPoolFilename == "" {
|
||||
logger.Infof("cloudflared does not support loading the system root certificate pool on Windows. Please use the --%s to specify it", OriginCAPoolFlag)
|
||||
}
|
||||
|
||||
return originCertPool, nil
|
||||
}
|
||||
|
||||
func LoadCustomCertPool(customCertFilename string) (*x509.CertPool, error) {
|
||||
pool := x509.NewCertPool()
|
||||
customCAPoolPEM, err := ioutil.ReadFile(customCertFilename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s", customCertFilename))
|
||||
}
|
||||
if !pool.AppendCertsFromPEM(customCAPoolPEM) {
|
||||
return nil, fmt.Errorf("error appending custom CA to cert pool")
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func CreateTunnelConfig(c *cli.Context) (*tls.Config, error) {
|
||||
var rootCAs []string
|
||||
if c.String(CaCertFlag) != "" {
|
||||
rootCAs = append(rootCAs, c.String(CaCertFlag))
|
||||
}
|
||||
|
||||
userConfig := &TLSParameters{RootCAs: rootCAs}
|
||||
tlsConfig, err := GetConfig(userConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tlsConfig.RootCAs == nil {
|
||||
rootCAPool := x509.NewCertPool()
|
||||
cfRootCA, err := GetCloudflareRootCA()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
|
||||
}
|
||||
for _, cert := range cfRootCA {
|
||||
rootCAPool.AddCert(cert)
|
||||
}
|
||||
tlsConfig.RootCAs = rootCAPool
|
||||
tlsConfig.ServerName = "cftunnel.com"
|
||||
} else if edgeAddrs := c.StringSlice("edge"); len(edgeAddrs) > 0 {
|
||||
// Set for development environments and for testing specific origintunneld instances
|
||||
tlsConfig.ServerName, _, _ = net.SplitHostPort(edgeAddrs[0])
|
||||
}
|
||||
|
||||
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
|
||||
return nil, fmt.Errorf("either ServerName or InsecureSkipVerify must be specified in the tls.Config")
|
||||
}
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func loadOriginCertPool(originCAPoolPEM []byte, logger *logrus.Logger) (*x509.CertPool, error) {
|
||||
// Get the global pool
|
||||
certPool, err := loadGlobalCertPool(logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Then, add any custom origin CA pool the user may have passed
|
||||
if originCAPoolPEM != nil {
|
||||
if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
|
||||
logger.Warn("could not append the provided origin CA to the cloudflared certificate pool")
|
||||
}
|
||||
}
|
||||
|
||||
return certPool, nil
|
||||
}
|
||||
|
||||
func loadGlobalCertPool(logger *logrus.Logger) (*x509.CertPool, error) {
|
||||
// First, obtain the system certificate pool
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
if runtime.GOOS != "windows" { // See https://github.com/golang/go/issues/16736
|
||||
logger.WithError(err).Warn("error obtaining the system certificates")
|
||||
}
|
||||
certPool = x509.NewCertPool()
|
||||
}
|
||||
|
||||
// Next, append the Cloudflare CAs into the system pool
|
||||
cfRootCA, err := GetCloudflareRootCA()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not append Cloudflare Root CAs to cloudflared certificate pool")
|
||||
}
|
||||
for _, cert := range cfRootCA {
|
||||
certPool.AddCert(cert)
|
||||
}
|
||||
|
||||
// Finally, add the Hello certificate into the pool (since it's self-signed)
|
||||
helloCert, err := GetHelloCertificateX509()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not append Hello server certificate to cloudflared certificate pool")
|
||||
}
|
||||
certPool.AddCert(helloCert)
|
||||
|
||||
return certPool, nil
|
||||
}
|
||||
|
|
|
@ -2,11 +2,18 @@ package pogs
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/originservice"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"github.com/pkg/errors"
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/pogs"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
|
@ -68,6 +75,9 @@ func NewReverseProxyConfig(
|
|||
|
||||
//go-sumtype:decl OriginConfig
|
||||
type OriginConfig interface {
|
||||
// Service returns a OriginService used to proxy to the origin
|
||||
Service() (originservice.OriginService, error)
|
||||
// go-sumtype requires at least one unexported method, otherwise it will complain that interface is not sealed
|
||||
isOriginConfig()
|
||||
}
|
||||
|
||||
|
@ -86,8 +96,6 @@ type HTTPOriginConfig struct {
|
|||
ChunkedEncoding bool
|
||||
}
|
||||
|
||||
func (_ *HTTPOriginConfig) isOriginConfig() {}
|
||||
|
||||
type OriginAddr interface {
|
||||
Addr() string
|
||||
}
|
||||
|
@ -119,6 +127,39 @@ func (up *UnixPath) Addr() string {
|
|||
return up.Path
|
||||
}
|
||||
|
||||
func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
|
||||
rootCAs, err := tlsconfig.LoadCustomCertPool(hc.OriginCAPool)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dialContext := (&net.Dialer{
|
||||
Timeout: hc.ProxyConnectTimeout,
|
||||
KeepAlive: hc.TCPKeepAlive,
|
||||
DualStack: hc.DialDualStack,
|
||||
}).DialContext
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: dialContext,
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
ServerName: hc.OriginServerName,
|
||||
InsecureSkipVerify: hc.TLSVerify,
|
||||
},
|
||||
TLSHandshakeTimeout: hc.TLSHandshakeTimeout,
|
||||
MaxIdleConns: int(hc.MaxIdleConnections),
|
||||
IdleConnTimeout: hc.IdleConnectionTimeout,
|
||||
ExpectContinueTimeout: hc.ExpectContinueTimeout,
|
||||
}
|
||||
if unixPath, ok := hc.URL.(*UnixPath); ok {
|
||||
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialContext(ctx, "unix", unixPath.Addr())
|
||||
}
|
||||
}
|
||||
return originservice.NewHTTPService(transport, hc.URL.Addr(), hc.ChunkedEncoding), nil
|
||||
}
|
||||
|
||||
func (_ *HTTPOriginConfig) isOriginConfig() {}
|
||||
|
||||
type WebSocketOriginConfig struct {
|
||||
URL string `capnp:"url"`
|
||||
TLSVerify bool `capnp:"tlsVerify"`
|
||||
|
@ -126,10 +167,48 @@ type WebSocketOriginConfig struct {
|
|||
OriginServerName string
|
||||
}
|
||||
|
||||
func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error) {
|
||||
rootCAs, err := tlsconfig.LoadCustomCertPool(wsc.OriginCAPool)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
ServerName: wsc.OriginServerName,
|
||||
InsecureSkipVerify: wsc.TLSVerify,
|
||||
}
|
||||
return originservice.NewWebSocketService(tlsConfig, wsc.URL)
|
||||
}
|
||||
|
||||
func (_ *WebSocketOriginConfig) isOriginConfig() {}
|
||||
|
||||
type HelloWorldOriginConfig struct{}
|
||||
|
||||
func (_ *HelloWorldOriginConfig) Service() (originservice.OriginService, error) {
|
||||
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Cannot get Hello World server certificate")
|
||||
}
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(helloCert)
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
},
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
return originservice.NewHelloWorldService(transport)
|
||||
}
|
||||
|
||||
func (_ *HelloWorldOriginConfig) isOriginConfig() {}
|
||||
|
||||
/*
|
||||
|
|
|
@ -3,8 +3,9 @@
|
|||
package tunnelrpc
|
||||
|
||||
import (
|
||||
context "golang.org/x/net/context"
|
||||
strconv "strconv"
|
||||
|
||||
context "golang.org/x/net/context"
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
text "zombiezen.com/go/capnproto2/encoding/text"
|
||||
schemas "zombiezen.com/go/capnproto2/schemas"
|
||||
|
|
Loading…
Reference in New Issue