Initial WIP refactoring logic out of main package
This commit is contained in:
parent
29c42adaa1
commit
6ab87d01de
|
@ -1,13 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -18,10 +12,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflare-warp/metrics"
|
"github.com/cloudflare/cloudflare-warp/metrics"
|
||||||
"github.com/cloudflare/cloudflare-warp/origin"
|
|
||||||
"github.com/cloudflare/cloudflare-warp/tlsconfig"
|
"github.com/cloudflare/cloudflare-warp/tlsconfig"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
|
||||||
"github.com/cloudflare/cloudflare-warp/validation"
|
"github.com/cloudflare/cloudflare-warp/validation"
|
||||||
|
"github.com/cloudflare/cloudflare-warp/warp"
|
||||||
|
|
||||||
"github.com/facebookgo/grace/gracenet"
|
"github.com/facebookgo/grace/gracenet"
|
||||||
"github.com/getsentry/raven-go"
|
"github.com/getsentry/raven-go"
|
||||||
|
@ -316,8 +309,8 @@ WARNING:
|
||||||
|
|
||||||
func startServer(c *cli.Context) {
|
func startServer(c *cli.Context) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
errC := make(chan error)
|
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
|
errC := make(chan error)
|
||||||
|
|
||||||
// If the user choose to supply all options through env variables,
|
// If the user choose to supply all options through env variables,
|
||||||
// c.NumFlags() == 0 && c.NArg() == 0. For warp to work, the user needs to at
|
// c.NumFlags() == 0 && c.NArg() == 0. For warp to work, the user needs to at
|
||||||
|
@ -326,6 +319,7 @@ func startServer(c *cli.Context) {
|
||||||
cli.ShowAppHelp(c)
|
cli.ShowAppHelp(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logLevel, err := logrus.ParseLevel(c.String("loglevel"))
|
logLevel, err := logrus.ParseLevel(c.String("loglevel"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Log.WithError(err).Fatal("Unknown logging level specified")
|
Log.WithError(err).Fatal("Unknown logging level specified")
|
||||||
|
@ -353,22 +347,16 @@ func startServer(c *cli.Context) {
|
||||||
go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
|
go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
|
||||||
}
|
}
|
||||||
|
|
||||||
hostname, err := validation.ValidateHostname(c.String("hostname"))
|
|
||||||
if err != nil {
|
|
||||||
Log.WithError(err).Fatal("Invalid hostname")
|
|
||||||
|
|
||||||
}
|
|
||||||
clientID := c.String("id")
|
|
||||||
if !c.IsSet("id") {
|
|
||||||
clientID = generateRandomClientID()
|
|
||||||
}
|
|
||||||
|
|
||||||
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
|
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Log.WithError(err).Fatal("Tag parse failure")
|
Log.WithError(err).Fatal("Tag parse failure")
|
||||||
}
|
}
|
||||||
|
|
||||||
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
|
validURL, err := validateUrl(c)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Fatal("Error validating url")
|
||||||
|
}
|
||||||
|
|
||||||
if c.IsSet("hello-world") {
|
if c.IsSet("hello-world") {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
listener, err := createListener("127.0.0.1:")
|
listener, err := createListener("127.0.0.1:")
|
||||||
|
@ -381,86 +369,18 @@ func startServer(c *cli.Context) {
|
||||||
wg.Done()
|
wg.Done()
|
||||||
listener.Close()
|
listener.Close()
|
||||||
}()
|
}()
|
||||||
c.Set("url", "https://"+listener.Addr().String())
|
validURL = "https://" + listener.Addr().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
url, err := validateUrl(c)
|
Log.Infof("Proxying tunnel requests to %s", validURL)
|
||||||
if err != nil {
|
|
||||||
Log.WithError(err).Fatal("Error validating url")
|
|
||||||
}
|
|
||||||
Log.Infof("Proxying tunnel requests to %s", url)
|
|
||||||
|
|
||||||
// Fail if the user provided an old authentication method
|
// Fail if the user provided an old authentication method
|
||||||
if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
|
if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
|
||||||
Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login")
|
Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the user has acquired a certificate using the log in command
|
|
||||||
originCertPath, err := homedir.Expand(c.String("origincert"))
|
|
||||||
if err != nil {
|
|
||||||
Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
|
|
||||||
}
|
|
||||||
ok, err := fileExists(originCertPath)
|
|
||||||
if !ok {
|
|
||||||
Log.Fatalf(`Cannot find a valid certificate for your origin at the path:
|
|
||||||
|
|
||||||
%s
|
|
||||||
|
|
||||||
If the path above is wrong, specify the path with the -origincert option.
|
|
||||||
If you don't have a certificate signed by Cloudflare, run the command:
|
|
||||||
|
|
||||||
%s login
|
|
||||||
`, originCertPath, os.Args[0])
|
|
||||||
}
|
|
||||||
// Easier to send the certificate as []byte via RPC than decoding it at this point
|
|
||||||
originCert, err := ioutil.ReadFile(originCertPath)
|
|
||||||
if err != nil {
|
|
||||||
Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
tunnelMetrics := origin.NewTunnelMetrics()
|
|
||||||
httpTransport := &http.Transport{
|
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
DialContext: (&net.Dialer{
|
|
||||||
Timeout: c.Duration("proxy-connect-timeout"),
|
|
||||||
KeepAlive: c.Duration("proxy-tcp-keepalive"),
|
|
||||||
DualStack: !c.Bool("proxy-no-happy-eyeballs"),
|
|
||||||
}).DialContext,
|
|
||||||
MaxIdleConns: c.Int("proxy-keepalive-connections"),
|
|
||||||
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
|
|
||||||
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
|
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
|
||||||
TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
|
|
||||||
}
|
|
||||||
tunnelConfig := &origin.TunnelConfig{
|
|
||||||
EdgeAddrs: c.StringSlice("edge"),
|
|
||||||
OriginUrl: url,
|
|
||||||
Hostname: hostname,
|
|
||||||
OriginCert: originCert,
|
|
||||||
TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
|
|
||||||
ClientTlsConfig: httpTransport.TLSClientConfig,
|
|
||||||
Retries: c.Uint("retries"),
|
|
||||||
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
|
||||||
MaxHeartbeats: c.Uint64("heartbeat-count"),
|
|
||||||
ClientID: clientID,
|
|
||||||
ReportedVersion: Version,
|
|
||||||
LBPool: c.String("lb-pool"),
|
|
||||||
Tags: tags,
|
|
||||||
HAConnections: c.Int("ha-connections"),
|
|
||||||
HTTPTransport: httpTransport,
|
|
||||||
Metrics: tunnelMetrics,
|
|
||||||
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
|
|
||||||
ProtocolLogger: protoLogger,
|
|
||||||
Logger: Log,
|
|
||||||
IsAutoupdated: c.Bool("is-autoupdated"),
|
|
||||||
}
|
|
||||||
connectedSignal := make(chan struct{})
|
connectedSignal := make(chan struct{})
|
||||||
|
|
||||||
go writePidFile(connectedSignal, c.String("pidfile"))
|
go writePidFile(connectedSignal, c.String("pidfile"))
|
||||||
go func() {
|
|
||||||
errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
|
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -471,6 +391,44 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
tlsConfig := tlsconfig.CLIFlags{RootCA: "cacert"}.GetConfig(c)
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
go func() {
|
||||||
|
errC <- warp.StartServer(warp.ServerConfig{
|
||||||
|
Hostname: c.String("hostname"),
|
||||||
|
ServerURL: validURL,
|
||||||
|
HelloWorld: c.IsSet("hello-world"),
|
||||||
|
Tags: tags,
|
||||||
|
OriginCert: c.String("origincert"),
|
||||||
|
|
||||||
|
ConnectedChan: connectedSignal,
|
||||||
|
ShutdownChan: shutdownC,
|
||||||
|
|
||||||
|
Timeout: c.Duration("proxy-connect-timeout"),
|
||||||
|
KeepAlive: c.Duration("proxy-tcp-keepalive"),
|
||||||
|
DualStack: !c.Bool("proxy-no-happy-eyeballs"),
|
||||||
|
|
||||||
|
MaxIdleConns: c.Int("proxy-keepalive-connections"),
|
||||||
|
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
|
||||||
|
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
|
||||||
|
|
||||||
|
EdgeAddrs: c.StringSlice("edge"),
|
||||||
|
Retries: c.Uint("retries"),
|
||||||
|
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
||||||
|
MaxHeartbeats: c.Uint64("heartbeat-count"),
|
||||||
|
LBPool: c.String("lb-pool"),
|
||||||
|
HAConnections: c.Int("ha-connections"),
|
||||||
|
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
|
||||||
|
IsAutoupdated: c.Bool("is-autoupdated"),
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
ReportedVersion: Version,
|
||||||
|
ProtoLogger: protoLogger,
|
||||||
|
Logger: Log,
|
||||||
|
})
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
var errCode int
|
var errCode int
|
||||||
err = WaitForSignal(errC, shutdownC)
|
err = WaitForSignal(errC, shutdownC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -504,6 +462,14 @@ func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func login(c *cli.Context) error {
|
||||||
|
err := warp.Login(defaultConfigDir, credentialFile, c.String("url"))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func update(c *cli.Context) error {
|
func update(c *cli.Context) error {
|
||||||
if updateApplied() {
|
if updateApplied() {
|
||||||
os.Exit(64)
|
os.Exit(64)
|
||||||
|
@ -584,13 +550,6 @@ func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, er
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateRandomClientID() string {
|
|
||||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
||||||
id := make([]byte, 32)
|
|
||||||
r.Read(id)
|
|
||||||
return hex.EncodeToString(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func writePidFile(waitForSignal chan struct{}, pidFile string) {
|
func writePidFile(waitForSignal chan struct{}, pidFile string) {
|
||||||
<-waitForSignal
|
<-waitForSignal
|
||||||
daemon.SdNotify(false, "READY=1")
|
daemon.SdNotify(false, "READY=1")
|
||||||
|
|
|
@ -93,11 +93,11 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
||||||
Version: c.ReportedVersion,
|
Version: c.ReportedVersion,
|
||||||
OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH),
|
OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH),
|
||||||
ExistingTunnelPolicy: policy,
|
ExistingTunnelPolicy: policy,
|
||||||
PoolName: c.LBPool,
|
// PoolName: c.LBPool, // TODO - see issue #2
|
||||||
Tags: c.Tags,
|
Tags: c.Tags,
|
||||||
ConnectionID: connectionID,
|
ConnectionID: connectionID,
|
||||||
OriginLocalIP: OriginLocalIP,
|
OriginLocalIP: OriginLocalIP,
|
||||||
IsAutoupdated: c.IsAutoupdated,
|
// IsAutoupdated: c.IsAutoupdated, // TODO - see issue #2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -90,8 +90,10 @@ func LoadOriginCertsPool() *x509.CertPool {
|
||||||
return certPool
|
return certPool
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config {
|
func CreateTunnelConfig(tlsConfig *tls.Config, addrs []string) *tls.Config {
|
||||||
tlsConfig := CLIFlags{RootCA: "cacert"}.GetConfig(c)
|
if tlsConfig == nil {
|
||||||
|
tlsConfig = new(tls.Config)
|
||||||
|
}
|
||||||
if tlsConfig.RootCAs == nil {
|
if tlsConfig.RootCAs == nil {
|
||||||
tlsConfig.RootCAs = GetCloudflareRootCA()
|
tlsConfig.RootCAs = GetCloudflareRootCA()
|
||||||
tlsConfig.ServerName = "cftunnel.com"
|
tlsConfig.ServerName = "cftunnel.com"
|
||||||
|
|
|
@ -47,11 +47,11 @@ type RegistrationOptions struct {
|
||||||
Version string
|
Version string
|
||||||
OS string `capnp:"os"`
|
OS string `capnp:"os"`
|
||||||
ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy
|
ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy
|
||||||
PoolName string `capnp:"poolName"`
|
// PoolName string `capnp:"poolName"` // TODO - see issue #2
|
||||||
Tags []Tag
|
Tags []Tag
|
||||||
ConnectionID uint8 `capnp:"connectionId"`
|
ConnectionID uint8 `capnp:"connectionId"`
|
||||||
OriginLocalIP string `capnp:"originLocalIp"`
|
OriginLocalIP string `capnp:"originLocalIp"`
|
||||||
IsAutoupdated bool `capnp:"isAutoupdated"`
|
// IsAutoupdated bool `capnp:"isAutoupdated"` // TODO - see issue #2
|
||||||
}
|
}
|
||||||
|
|
||||||
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {
|
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package warp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
@ -15,15 +15,18 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
homedir "github.com/mitchellh/go-homedir"
|
homedir "github.com/mitchellh/go-homedir"
|
||||||
cli "gopkg.in/urfave/cli.v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const baseLoginURL = "https://www.cloudflare.com/a/warp"
|
const baseLoginURL = "https://www.cloudflare.com/a/warp"
|
||||||
const baseCertStoreURL = "https://login.cloudflarewarp.com"
|
const baseCertStoreURL = "https://login.cloudflarewarp.com"
|
||||||
const clientTimeout = time.Minute * 20
|
const clientTimeout = time.Minute * 20
|
||||||
|
|
||||||
func login(c *cli.Context) error {
|
// Login obtains credentials from Cloudflare to enable
|
||||||
configPath, err := homedir.Expand(defaultConfigDir)
|
// the creation of tunnels with the Warp service.
|
||||||
|
// baseURL is the base URL from which to login to warp;
|
||||||
|
// leave empty to use default.
|
||||||
|
func Login(configDir, credentialFile, baseURL string) error {
|
||||||
|
configPath, err := homedir.Expand(configDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -38,21 +41,20 @@ func login(c *cli.Context) error {
|
||||||
path := filepath.Join(configPath, credentialFile)
|
path := filepath.Join(configPath, credentialFile)
|
||||||
fileInfo, err := os.Stat(path)
|
fileInfo, err := os.Stat(path)
|
||||||
if err == nil && fileInfo.Size() > 0 {
|
if err == nil && fileInfo.Size() > 0 {
|
||||||
fmt.Fprintf(os.Stderr, `You have an existing certificate at %s which login would overwrite.
|
return fmt.Errorf(`You have an existing certificate at %s which login would overwrite.
|
||||||
If this is intentional, please move or delete that file then run this command again.
|
If this is intentional, please move or delete that file then run this command again.
|
||||||
`, path)
|
`, path)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
if err != nil && err.(*os.PathError).Err != syscall.ENOENT {
|
if err != nil && err.(*os.PathError).Err != syscall.ENOENT {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// for local debugging
|
// for local debugging
|
||||||
baseURL := baseCertStoreURL
|
if baseURL == "" {
|
||||||
if c.IsSet("url") {
|
baseURL = baseCertStoreURL
|
||||||
baseURL = c.String("url")
|
|
||||||
}
|
}
|
||||||
// Generate a random post URL
|
|
||||||
|
// generate a random post URL
|
||||||
certURL := baseURL + generateRandomPath()
|
certURL := baseURL + generateRandomPath()
|
||||||
loginURL, err := url.Parse(baseLoginURL)
|
loginURL, err := url.Parse(baseLoginURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -67,7 +69,7 @@ If this is intentional, please move or delete that file then run this command ag
|
||||||
|
|
||||||
%s
|
%s
|
||||||
|
|
||||||
Leave cloudflare-warp running to install the certificate automatically.
|
Leave the program running to install the certificate automatically.
|
||||||
`, loginURL.String())
|
`, loginURL.String())
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(os.Stderr, `A browser window should have opened at the following URL:
|
fmt.Fprintf(os.Stderr, `A browser window should have opened at the following URL:
|
||||||
|
@ -75,11 +77,10 @@ Leave cloudflare-warp running to install the certificate automatically.
|
||||||
%s
|
%s
|
||||||
|
|
||||||
If the browser failed to open, open it yourself and visit the URL above.
|
If the browser failed to open, open it yourself and visit the URL above.
|
||||||
|
|
||||||
`, loginURL.String())
|
`, loginURL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if download(certURL, path) {
|
if ok, err := download(certURL, path); ok && err == nil {
|
||||||
fmt.Fprintf(os.Stderr, `You have successfully logged in.
|
fmt.Fprintf(os.Stderr, `You have successfully logged in.
|
||||||
If you wish to copy your credentials to a server, they have been saved to:
|
If you wish to copy your credentials to a server, they have been saved to:
|
||||||
%s
|
%s
|
||||||
|
@ -126,21 +127,24 @@ func open(url string) error {
|
||||||
return exec.Command(cmd, args...).Start()
|
return exec.Command(cmd, args...).Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
func download(certURL, filePath string) bool {
|
// download downloads a certificate at certURL to filePath.
|
||||||
|
// It returns true if the certificate was successfully
|
||||||
|
// downloaded; false otherwise, with any applicable error.
|
||||||
|
// An error may be returned even if the certificate was
|
||||||
|
// downloaded successfully.
|
||||||
|
func download(certURL, filePath string) (bool, error) {
|
||||||
client := &http.Client{Timeout: clientTimeout}
|
client := &http.Client{Timeout: clientTimeout}
|
||||||
// attempt a (long-running) certificate get
|
// attempt a (long-running) certificate get
|
||||||
for i := 0; i < 20; i++ {
|
for i := 0; i < 20; i++ {
|
||||||
ok, err := tryDownload(client, certURL, filePath)
|
ok, err := tryDownload(client, certURL, filePath)
|
||||||
if ok {
|
if ok {
|
||||||
putSuccess(client, certURL)
|
return true, putSuccess(client, certURL)
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Log.WithError(err).Error("Error fetching certificate")
|
return false, fmt.Errorf("fetching certificate: %v", err)
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err error) {
|
func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err error) {
|
||||||
|
@ -175,20 +179,19 @@ func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err er
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func putSuccess(client *http.Client, certURL string) {
|
func putSuccess(client *http.Client, certURL string) error {
|
||||||
// indicate success to the relay server
|
// indicate success to the relay server
|
||||||
req, err := http.NewRequest("PUT", certURL+"/ok", nil)
|
req, err := http.NewRequest("PUT", certURL+"/ok", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Log.WithError(err).Error("HTTP request error")
|
return fmt.Errorf("HTTP request error: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Log.WithError(err).Error("HTTP error")
|
return fmt.Errorf("HTTP error: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
|
return fmt.Errorf("unexpected HTTP status code %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
|
@ -0,0 +1,186 @@
|
||||||
|
package warp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflare-warp/origin"
|
||||||
|
"github.com/cloudflare/cloudflare-warp/tlsconfig"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||||
|
"github.com/cloudflare/cloudflare-warp/validation"
|
||||||
|
homedir "github.com/mitchellh/go-homedir"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StartServer starts a warp proxy server with the given configuration.
|
||||||
|
// It blocks indefinitely.
|
||||||
|
func StartServer(cfg ServerConfig) error {
|
||||||
|
hostname, err := validation.ValidateHostname(cfg.Hostname)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if cfg.ClientID == "" {
|
||||||
|
cfg.ClientID = generateRandomClientID()
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Tags = append(cfg.Tags, tunnelpogs.Tag{Name: "ID", Value: cfg.ClientID})
|
||||||
|
|
||||||
|
cfg.ServerURL, err = validation.ValidateUrl(cfg.ServerURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("validating server URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the user has acquired a certificate using the log in command
|
||||||
|
originCertPath, err := homedir.Expand(cfg.OriginCert)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot resolve path %s: %v", cfg.OriginCert, err)
|
||||||
|
}
|
||||||
|
ok, err := fileExists(originCertPath)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf(`Cannot find a valid certificate for your origin at the path:
|
||||||
|
|
||||||
|
%s
|
||||||
|
|
||||||
|
If the path above is wrong, specify the path with the -origincert option.
|
||||||
|
If you don't have a certificate signed by Cloudflare, run the command:
|
||||||
|
|
||||||
|
%s login
|
||||||
|
`, originCertPath, os.Args[0]) // TODO - we need to improve how this is handled
|
||||||
|
}
|
||||||
|
// Easier to send the certificate as []byte via RPC than decoding it at this point
|
||||||
|
originCert, err := ioutil.ReadFile(originCertPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot read %s to load origin certificate: %v", originCertPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelMetrics := origin.NewTunnelMetrics()
|
||||||
|
|
||||||
|
httpTransport := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: cfg.Timeout,
|
||||||
|
KeepAlive: cfg.KeepAlive,
|
||||||
|
DualStack: cfg.DualStack,
|
||||||
|
}).DialContext,
|
||||||
|
MaxIdleConns: cfg.MaxIdleConns,
|
||||||
|
IdleConnTimeout: cfg.IdleConnTimeout,
|
||||||
|
TLSHandshakeTimeout: cfg.TLSHandshakeTimeout,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelConfig := &origin.TunnelConfig{
|
||||||
|
EdgeAddrs: cfg.EdgeAddrs,
|
||||||
|
OriginUrl: cfg.ServerURL,
|
||||||
|
Hostname: hostname,
|
||||||
|
OriginCert: originCert,
|
||||||
|
TlsConfig: tlsconfig.CreateTunnelConfig(cfg.TLSConfig, cfg.EdgeAddrs),
|
||||||
|
ClientTlsConfig: httpTransport.TLSClientConfig,
|
||||||
|
Retries: cfg.Retries,
|
||||||
|
HeartbeatInterval: cfg.HeartbeatInterval,
|
||||||
|
MaxHeartbeats: cfg.MaxHeartbeats,
|
||||||
|
ClientID: cfg.ClientID,
|
||||||
|
ReportedVersion: cfg.ReportedVersion,
|
||||||
|
LBPool: cfg.LBPool,
|
||||||
|
Tags: cfg.Tags,
|
||||||
|
HAConnections: cfg.HAConnections,
|
||||||
|
HTTPTransport: httpTransport,
|
||||||
|
Metrics: tunnelMetrics,
|
||||||
|
MetricsUpdateFreq: cfg.MetricsUpdateFreq,
|
||||||
|
ProtocolLogger: cfg.ProtoLogger,
|
||||||
|
Logger: cfg.Logger,
|
||||||
|
IsAutoupdated: cfg.IsAutoupdated,
|
||||||
|
}
|
||||||
|
|
||||||
|
// blocking
|
||||||
|
return origin.StartTunnelDaemon(tunnelConfig, cfg.ShutdownChan, cfg.ConnectedChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileExists(path string) (bool, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
// ignore missing files
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateRandomClientID() string {
|
||||||
|
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
id := make([]byte, 32)
|
||||||
|
r.Read(id)
|
||||||
|
return hex.EncodeToString(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerConfig specifies a warp proxy-server configuration.
|
||||||
|
type ServerConfig struct {
|
||||||
|
// The hostname on a Cloudflare zone with which route
|
||||||
|
// traffic through this tunnel.
|
||||||
|
// Required.
|
||||||
|
Hostname string
|
||||||
|
|
||||||
|
// The URL of the local web server. If empty (if there
|
||||||
|
// is no server), set HelloWorld to true for a demo.
|
||||||
|
// Required.
|
||||||
|
ServerURL string
|
||||||
|
|
||||||
|
// If true, use the established tunnel to expose a
|
||||||
|
// test HTTP server. If false, ServerURL must be set.
|
||||||
|
HelloWorld bool
|
||||||
|
|
||||||
|
// The tunnel ID; leave blank to use a random ID.
|
||||||
|
ClientID string
|
||||||
|
|
||||||
|
// Custom tags to identify this tunnel
|
||||||
|
Tags []tunnelpogs.Tag
|
||||||
|
|
||||||
|
// Specifies the Warp certificate for one of your zones,
|
||||||
|
// authorizing the client to serve as an origin for that zone.
|
||||||
|
// A certificate is required to use Warp. You can obtain a
|
||||||
|
// certificate by using the login command or by visiting
|
||||||
|
// https://www.cloudflare.com/a/warp.
|
||||||
|
OriginCert string
|
||||||
|
|
||||||
|
// The channel to close when the tunnel is connected.
|
||||||
|
ConnectedChan chan struct{}
|
||||||
|
|
||||||
|
// The channel to close when shutting down.
|
||||||
|
ShutdownChan chan struct{}
|
||||||
|
|
||||||
|
Timeout time.Duration // proxy-connect-timeout
|
||||||
|
KeepAlive time.Duration // proxy-tcp-keepalive
|
||||||
|
DualStack bool // proxy-no-happy-eyeballs
|
||||||
|
|
||||||
|
MaxIdleConns int // proxy-keepalive-connections
|
||||||
|
IdleConnTimeout time.Duration // proxy-keepalive-timeout
|
||||||
|
TLSHandshakeTimeout time.Duration // proxy-tls-timeout
|
||||||
|
|
||||||
|
EdgeAddrs []string // edge
|
||||||
|
Retries uint // retries
|
||||||
|
HeartbeatInterval time.Duration // heartbeat-interval
|
||||||
|
MaxHeartbeats uint64 // heartbeat-count
|
||||||
|
LBPool string // lb-pool
|
||||||
|
HAConnections int // ha-connections
|
||||||
|
MetricsUpdateFreq time.Duration // metrics-update-freq
|
||||||
|
IsAutoupdated bool // is-autoupdated
|
||||||
|
|
||||||
|
// The TLS client config used when making the tunnel.
|
||||||
|
TLSConfig *tls.Config
|
||||||
|
|
||||||
|
// The version of the client to report
|
||||||
|
ReportedVersion string
|
||||||
|
|
||||||
|
ProtoLogger *logrus.Logger
|
||||||
|
Logger *logrus.Logger
|
||||||
|
}
|
Loading…
Reference in New Issue