Initial WIP refactoring logic out of main package
This commit is contained in:
		
							parent
							
								
									29c42adaa1
								
							
						
					
					
						commit
						6ab87d01de
					
				|  | @ -1,13 +1,7 @@ | ||||||
| package main | package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"crypto/tls" |  | ||||||
| 	"encoding/hex" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" |  | ||||||
| 	"math/rand" |  | ||||||
| 	"net" |  | ||||||
| 	"net/http" |  | ||||||
| 	"os" | 	"os" | ||||||
| 	"os/signal" | 	"os/signal" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
|  | @ -18,10 +12,9 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/cloudflare/cloudflare-warp/metrics" | 	"github.com/cloudflare/cloudflare-warp/metrics" | ||||||
| 	"github.com/cloudflare/cloudflare-warp/origin" |  | ||||||
| 	"github.com/cloudflare/cloudflare-warp/tlsconfig" | 	"github.com/cloudflare/cloudflare-warp/tlsconfig" | ||||||
| 	tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" |  | ||||||
| 	"github.com/cloudflare/cloudflare-warp/validation" | 	"github.com/cloudflare/cloudflare-warp/validation" | ||||||
|  | 	"github.com/cloudflare/cloudflare-warp/warp" | ||||||
| 
 | 
 | ||||||
| 	"github.com/facebookgo/grace/gracenet" | 	"github.com/facebookgo/grace/gracenet" | ||||||
| 	"github.com/getsentry/raven-go" | 	"github.com/getsentry/raven-go" | ||||||
|  | @ -316,8 +309,8 @@ WARNING: | ||||||
| 
 | 
 | ||||||
| func startServer(c *cli.Context) { | func startServer(c *cli.Context) { | ||||||
| 	var wg sync.WaitGroup | 	var wg sync.WaitGroup | ||||||
| 	errC := make(chan error) |  | ||||||
| 	wg.Add(2) | 	wg.Add(2) | ||||||
|  | 	errC := make(chan error) | ||||||
| 
 | 
 | ||||||
| 	// If the user choose to supply all options through env variables,
 | 	// If the user choose to supply all options through env variables,
 | ||||||
| 	// c.NumFlags() == 0 && c.NArg() == 0. For warp to work, the user needs to at
 | 	// c.NumFlags() == 0 && c.NArg() == 0. For warp to work, the user needs to at
 | ||||||
|  | @ -326,6 +319,7 @@ func startServer(c *cli.Context) { | ||||||
| 		cli.ShowAppHelp(c) | 		cli.ShowAppHelp(c) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	logLevel, err := logrus.ParseLevel(c.String("loglevel")) | 	logLevel, err := logrus.ParseLevel(c.String("loglevel")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		Log.WithError(err).Fatal("Unknown logging level specified") | 		Log.WithError(err).Fatal("Unknown logging level specified") | ||||||
|  | @ -353,22 +347,16 @@ func startServer(c *cli.Context) { | ||||||
| 		go autoupdate(c.Duration("autoupdate-freq"), shutdownC) | 		go autoupdate(c.Duration("autoupdate-freq"), shutdownC) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	hostname, err := validation.ValidateHostname(c.String("hostname")) |  | ||||||
| 	if err != nil { |  | ||||||
| 		Log.WithError(err).Fatal("Invalid hostname") |  | ||||||
| 
 |  | ||||||
| 	} |  | ||||||
| 	clientID := c.String("id") |  | ||||||
| 	if !c.IsSet("id") { |  | ||||||
| 		clientID = generateRandomClientID() |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) | 	tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		Log.WithError(err).Fatal("Tag parse failure") | 		Log.WithError(err).Fatal("Tag parse failure") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) | 	validURL, err := validateUrl(c) | ||||||
|  | 	if err != nil { | ||||||
|  | 		Log.WithError(err).Fatal("Error validating url") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if c.IsSet("hello-world") { | 	if c.IsSet("hello-world") { | ||||||
| 		wg.Add(1) | 		wg.Add(1) | ||||||
| 		listener, err := createListener("127.0.0.1:") | 		listener, err := createListener("127.0.0.1:") | ||||||
|  | @ -381,86 +369,18 @@ func startServer(c *cli.Context) { | ||||||
| 			wg.Done() | 			wg.Done() | ||||||
| 			listener.Close() | 			listener.Close() | ||||||
| 		}() | 		}() | ||||||
| 		c.Set("url", "https://"+listener.Addr().String()) | 		validURL = "https://" + listener.Addr().String() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	url, err := validateUrl(c) | 	Log.Infof("Proxying tunnel requests to %s", validURL) | ||||||
| 	if err != nil { |  | ||||||
| 		Log.WithError(err).Fatal("Error validating url") |  | ||||||
| 	} |  | ||||||
| 	Log.Infof("Proxying tunnel requests to %s", url) |  | ||||||
| 
 | 
 | ||||||
| 	// Fail if the user provided an old authentication method
 | 	// Fail if the user provided an old authentication method
 | ||||||
| 	if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") { | 	if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") { | ||||||
| 		Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login") | 		Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Check that the user has acquired a certificate using the log in command
 |  | ||||||
| 	originCertPath, err := homedir.Expand(c.String("origincert")) |  | ||||||
| 	if err != nil { |  | ||||||
| 		Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert")) |  | ||||||
| 	} |  | ||||||
| 	ok, err := fileExists(originCertPath) |  | ||||||
| 	if !ok { |  | ||||||
| 		Log.Fatalf(`Cannot find a valid certificate for your origin at the path: |  | ||||||
| 
 |  | ||||||
|     %s |  | ||||||
| 
 |  | ||||||
| If the path above is wrong, specify the path with the -origincert option. |  | ||||||
| If you don't have a certificate signed by Cloudflare, run the command: |  | ||||||
| 
 |  | ||||||
|     %s login |  | ||||||
| `, originCertPath, os.Args[0]) |  | ||||||
| 	} |  | ||||||
| 	// Easier to send the certificate as []byte via RPC than decoding it at this point
 |  | ||||||
| 	originCert, err := ioutil.ReadFile(originCertPath) |  | ||||||
| 	if err != nil { |  | ||||||
| 		Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	tunnelMetrics := origin.NewTunnelMetrics() |  | ||||||
| 	httpTransport := &http.Transport{ |  | ||||||
| 		Proxy: http.ProxyFromEnvironment, |  | ||||||
| 		DialContext: (&net.Dialer{ |  | ||||||
| 			Timeout:   c.Duration("proxy-connect-timeout"), |  | ||||||
| 			KeepAlive: c.Duration("proxy-tcp-keepalive"), |  | ||||||
| 			DualStack: !c.Bool("proxy-no-happy-eyeballs"), |  | ||||||
| 		}).DialContext, |  | ||||||
| 		MaxIdleConns:          c.Int("proxy-keepalive-connections"), |  | ||||||
| 		IdleConnTimeout:       c.Duration("proxy-keepalive-timeout"), |  | ||||||
| 		TLSHandshakeTimeout:   c.Duration("proxy-tls-timeout"), |  | ||||||
| 		ExpectContinueTimeout: 1 * time.Second, |  | ||||||
| 		TLSClientConfig:       &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()}, |  | ||||||
| 	} |  | ||||||
| 	tunnelConfig := &origin.TunnelConfig{ |  | ||||||
| 		EdgeAddrs:         c.StringSlice("edge"), |  | ||||||
| 		OriginUrl:         url, |  | ||||||
| 		Hostname:          hostname, |  | ||||||
| 		OriginCert:        originCert, |  | ||||||
| 		TlsConfig:         tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")), |  | ||||||
| 		ClientTlsConfig:   httpTransport.TLSClientConfig, |  | ||||||
| 		Retries:           c.Uint("retries"), |  | ||||||
| 		HeartbeatInterval: c.Duration("heartbeat-interval"), |  | ||||||
| 		MaxHeartbeats:     c.Uint64("heartbeat-count"), |  | ||||||
| 		ClientID:          clientID, |  | ||||||
| 		ReportedVersion:   Version, |  | ||||||
| 		LBPool:            c.String("lb-pool"), |  | ||||||
| 		Tags:              tags, |  | ||||||
| 		HAConnections:     c.Int("ha-connections"), |  | ||||||
| 		HTTPTransport:     httpTransport, |  | ||||||
| 		Metrics:           tunnelMetrics, |  | ||||||
| 		MetricsUpdateFreq: c.Duration("metrics-update-freq"), |  | ||||||
| 		ProtocolLogger:    protoLogger, |  | ||||||
| 		Logger:            Log, |  | ||||||
| 		IsAutoupdated:     c.Bool("is-autoupdated"), |  | ||||||
| 	} |  | ||||||
| 	connectedSignal := make(chan struct{}) | 	connectedSignal := make(chan struct{}) | ||||||
| 
 |  | ||||||
| 	go writePidFile(connectedSignal, c.String("pidfile")) | 	go writePidFile(connectedSignal, c.String("pidfile")) | ||||||
| 	go func() { |  | ||||||
| 		errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal) |  | ||||||
| 		wg.Done() |  | ||||||
| 	}() |  | ||||||
| 
 | 
 | ||||||
| 	metricsListener, err := listeners.Listen("tcp", c.String("metrics")) | 	metricsListener, err := listeners.Listen("tcp", c.String("metrics")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -471,6 +391,44 @@ If you don't have a certificate signed by Cloudflare, run the command: | ||||||
| 		wg.Done() | 		wg.Done() | ||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
|  | 	tlsConfig := tlsconfig.CLIFlags{RootCA: "cacert"}.GetConfig(c) | ||||||
|  | 
 | ||||||
|  | 	// Start the server
 | ||||||
|  | 	go func() { | ||||||
|  | 		errC <- warp.StartServer(warp.ServerConfig{ | ||||||
|  | 			Hostname:   c.String("hostname"), | ||||||
|  | 			ServerURL:  validURL, | ||||||
|  | 			HelloWorld: c.IsSet("hello-world"), | ||||||
|  | 			Tags:       tags, | ||||||
|  | 			OriginCert: c.String("origincert"), | ||||||
|  | 
 | ||||||
|  | 			ConnectedChan: connectedSignal, | ||||||
|  | 			ShutdownChan:  shutdownC, | ||||||
|  | 
 | ||||||
|  | 			Timeout:   c.Duration("proxy-connect-timeout"), | ||||||
|  | 			KeepAlive: c.Duration("proxy-tcp-keepalive"), | ||||||
|  | 			DualStack: !c.Bool("proxy-no-happy-eyeballs"), | ||||||
|  | 
 | ||||||
|  | 			MaxIdleConns:        c.Int("proxy-keepalive-connections"), | ||||||
|  | 			IdleConnTimeout:     c.Duration("proxy-keepalive-timeout"), | ||||||
|  | 			TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"), | ||||||
|  | 
 | ||||||
|  | 			EdgeAddrs:         c.StringSlice("edge"), | ||||||
|  | 			Retries:           c.Uint("retries"), | ||||||
|  | 			HeartbeatInterval: c.Duration("heartbeat-interval"), | ||||||
|  | 			MaxHeartbeats:     c.Uint64("heartbeat-count"), | ||||||
|  | 			LBPool:            c.String("lb-pool"), | ||||||
|  | 			HAConnections:     c.Int("ha-connections"), | ||||||
|  | 			MetricsUpdateFreq: c.Duration("metrics-update-freq"), | ||||||
|  | 			IsAutoupdated:     c.Bool("is-autoupdated"), | ||||||
|  | 			TLSConfig:         tlsConfig, | ||||||
|  | 			ReportedVersion:   Version, | ||||||
|  | 			ProtoLogger:       protoLogger, | ||||||
|  | 			Logger:            Log, | ||||||
|  | 		}) | ||||||
|  | 		wg.Done() | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
| 	var errCode int | 	var errCode int | ||||||
| 	err = WaitForSignal(errC, shutdownC) | 	err = WaitForSignal(errC, shutdownC) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -504,6 +462,14 @@ func WaitForSignal(errC chan error, shutdownC chan struct{}) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func login(c *cli.Context) error { | ||||||
|  | 	err := warp.Login(defaultConfigDir, credentialFile, c.String("url")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Println(err) | ||||||
|  | 	} | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func update(c *cli.Context) error { | func update(c *cli.Context) error { | ||||||
| 	if updateApplied() { | 	if updateApplied() { | ||||||
| 		os.Exit(64) | 		os.Exit(64) | ||||||
|  | @ -584,13 +550,6 @@ func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, er | ||||||
| 	return nil, nil | 	return nil, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func generateRandomClientID() string { |  | ||||||
| 	r := rand.New(rand.NewSource(time.Now().UnixNano())) |  | ||||||
| 	id := make([]byte, 32) |  | ||||||
| 	r.Read(id) |  | ||||||
| 	return hex.EncodeToString(id) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func writePidFile(waitForSignal chan struct{}, pidFile string) { | func writePidFile(waitForSignal chan struct{}, pidFile string) { | ||||||
| 	<-waitForSignal | 	<-waitForSignal | ||||||
| 	daemon.SdNotify(false, "READY=1") | 	daemon.SdNotify(false, "READY=1") | ||||||
|  | @ -619,10 +578,10 @@ func validateUrl(c *cli.Context) (string, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error { | func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error { | ||||||
| 	fileMode := os.O_WRONLY|os.O_APPEND|os.O_CREATE|os.O_TRUNC | 	fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC | ||||||
| 	// do not truncate log file if the client has been autoupdated
 | 	// do not truncate log file if the client has been autoupdated
 | ||||||
| 	if c.Bool("is-autoupdated") { | 	if c.Bool("is-autoupdated") { | ||||||
| 		fileMode = os.O_WRONLY|os.O_APPEND|os.O_CREATE | 		fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE | ||||||
| 	} | 	} | ||||||
| 	f, err := os.OpenFile(c.String("logfile"), fileMode, 0664) | 	f, err := os.OpenFile(c.String("logfile"), fileMode, 0664) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
|  | @ -93,11 +93,11 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str | ||||||
| 		Version:              c.ReportedVersion, | 		Version:              c.ReportedVersion, | ||||||
| 		OS:                   fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH), | 		OS:                   fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH), | ||||||
| 		ExistingTunnelPolicy: policy, | 		ExistingTunnelPolicy: policy, | ||||||
| 		PoolName:             c.LBPool, | 		// PoolName:             c.LBPool,        // TODO - see issue #2
 | ||||||
| 		Tags:          c.Tags, | 		Tags:          c.Tags, | ||||||
| 		ConnectionID:  connectionID, | 		ConnectionID:  connectionID, | ||||||
| 		OriginLocalIP: OriginLocalIP, | 		OriginLocalIP: OriginLocalIP, | ||||||
| 		IsAutoupdated:        c.IsAutoupdated, | 		// IsAutoupdated:        c.IsAutoupdated, // TODO - see issue #2
 | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -90,8 +90,10 @@ func LoadOriginCertsPool() *x509.CertPool { | ||||||
| 	return certPool | 	return certPool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config { | func CreateTunnelConfig(tlsConfig *tls.Config, addrs []string) *tls.Config { | ||||||
| 	tlsConfig := CLIFlags{RootCA: "cacert"}.GetConfig(c) | 	if tlsConfig == nil { | ||||||
|  | 		tlsConfig = new(tls.Config) | ||||||
|  | 	} | ||||||
| 	if tlsConfig.RootCAs == nil { | 	if tlsConfig.RootCAs == nil { | ||||||
| 		tlsConfig.RootCAs = GetCloudflareRootCA() | 		tlsConfig.RootCAs = GetCloudflareRootCA() | ||||||
| 		tlsConfig.ServerName = "cftunnel.com" | 		tlsConfig.ServerName = "cftunnel.com" | ||||||
|  |  | ||||||
|  | @ -47,11 +47,11 @@ type RegistrationOptions struct { | ||||||
| 	Version              string | 	Version              string | ||||||
| 	OS                   string `capnp:"os"` | 	OS                   string `capnp:"os"` | ||||||
| 	ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy | 	ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy | ||||||
| 	PoolName             string `capnp:"poolName"` | 	// PoolName             string `capnp:"poolName"`       // TODO - see issue #2
 | ||||||
| 	Tags          []Tag | 	Tags          []Tag | ||||||
| 	ConnectionID  uint8  `capnp:"connectionId"` | 	ConnectionID  uint8  `capnp:"connectionId"` | ||||||
| 	OriginLocalIP string `capnp:"originLocalIp"` | 	OriginLocalIP string `capnp:"originLocalIp"` | ||||||
| 	IsAutoupdated        bool   `capnp:"isAutoupdated"` | 	// IsAutoupdated        bool   `capnp:"isAutoupdated"`  // TODO - see issue #2
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error { | func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error { | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| package main | package warp | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
|  | @ -15,15 +15,18 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	homedir "github.com/mitchellh/go-homedir" | 	homedir "github.com/mitchellh/go-homedir" | ||||||
| 	cli "gopkg.in/urfave/cli.v2" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const baseLoginURL = "https://www.cloudflare.com/a/warp" | const baseLoginURL = "https://www.cloudflare.com/a/warp" | ||||||
| const baseCertStoreURL = "https://login.cloudflarewarp.com" | const baseCertStoreURL = "https://login.cloudflarewarp.com" | ||||||
| const clientTimeout = time.Minute * 20 | const clientTimeout = time.Minute * 20 | ||||||
| 
 | 
 | ||||||
| func login(c *cli.Context) error { | // Login obtains credentials from Cloudflare to enable
 | ||||||
| 	configPath, err := homedir.Expand(defaultConfigDir) | // the creation of tunnels with the Warp service.
 | ||||||
|  | // baseURL is the base URL from which to login to warp;
 | ||||||
|  | // leave empty to use default.
 | ||||||
|  | func Login(configDir, credentialFile, baseURL string) error { | ||||||
|  | 	configPath, err := homedir.Expand(configDir) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | @ -38,21 +41,20 @@ func login(c *cli.Context) error { | ||||||
| 	path := filepath.Join(configPath, credentialFile) | 	path := filepath.Join(configPath, credentialFile) | ||||||
| 	fileInfo, err := os.Stat(path) | 	fileInfo, err := os.Stat(path) | ||||||
| 	if err == nil && fileInfo.Size() > 0 { | 	if err == nil && fileInfo.Size() > 0 { | ||||||
| 		fmt.Fprintf(os.Stderr, `You have an existing certificate at %s which login would overwrite. | 		return fmt.Errorf(`You have an existing certificate at %s which login would overwrite. | ||||||
| If this is intentional, please move or delete that file then run this command again. | If this is intentional, please move or delete that file then run this command again. | ||||||
| `, path) | `, path) | ||||||
| 		return nil |  | ||||||
| 	} | 	} | ||||||
| 	if err != nil && err.(*os.PathError).Err != syscall.ENOENT { | 	if err != nil && err.(*os.PathError).Err != syscall.ENOENT { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// for local debugging
 | 	// for local debugging
 | ||||||
| 	baseURL := baseCertStoreURL | 	if baseURL == "" { | ||||||
| 	if c.IsSet("url") { | 		baseURL = baseCertStoreURL | ||||||
| 		baseURL = c.String("url") |  | ||||||
| 	} | 	} | ||||||
| 	// Generate a random post URL
 | 
 | ||||||
|  | 	// generate a random post URL
 | ||||||
| 	certURL := baseURL + generateRandomPath() | 	certURL := baseURL + generateRandomPath() | ||||||
| 	loginURL, err := url.Parse(baseLoginURL) | 	loginURL, err := url.Parse(baseLoginURL) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -67,7 +69,7 @@ If this is intentional, please move or delete that file then run this command ag | ||||||
| 
 | 
 | ||||||
| %s | %s | ||||||
| 
 | 
 | ||||||
| Leave cloudflare-warp running to install the certificate automatically. | Leave the program running to install the certificate automatically. | ||||||
| `, loginURL.String()) | `, loginURL.String()) | ||||||
| 	} else { | 	} else { | ||||||
| 		fmt.Fprintf(os.Stderr, `A browser window should have opened at the following URL: | 		fmt.Fprintf(os.Stderr, `A browser window should have opened at the following URL: | ||||||
|  | @ -75,11 +77,10 @@ Leave cloudflare-warp running to install the certificate automatically. | ||||||
| %s | %s | ||||||
| 
 | 
 | ||||||
| If the browser failed to open, open it yourself and visit the URL above. | If the browser failed to open, open it yourself and visit the URL above. | ||||||
| 
 |  | ||||||
| `, loginURL.String()) | `, loginURL.String()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if download(certURL, path) { | 	if ok, err := download(certURL, path); ok && err == nil { | ||||||
| 		fmt.Fprintf(os.Stderr, `You have successfully logged in. | 		fmt.Fprintf(os.Stderr, `You have successfully logged in. | ||||||
| If you wish to copy your credentials to a server, they have been saved to: | If you wish to copy your credentials to a server, they have been saved to: | ||||||
| %s | %s | ||||||
|  | @ -126,21 +127,24 @@ func open(url string) error { | ||||||
| 	return exec.Command(cmd, args...).Start() | 	return exec.Command(cmd, args...).Start() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func download(certURL, filePath string) bool { | // download downloads a certificate at certURL to filePath.
 | ||||||
|  | // It returns true if the certificate was successfully
 | ||||||
|  | // downloaded; false otherwise, with any applicable error.
 | ||||||
|  | // An error may be returned even if the certificate was
 | ||||||
|  | // downloaded successfully.
 | ||||||
|  | func download(certURL, filePath string) (bool, error) { | ||||||
| 	client := &http.Client{Timeout: clientTimeout} | 	client := &http.Client{Timeout: clientTimeout} | ||||||
| 	// attempt a (long-running) certificate get
 | 	// attempt a (long-running) certificate get
 | ||||||
| 	for i := 0; i < 20; i++ { | 	for i := 0; i < 20; i++ { | ||||||
| 		ok, err := tryDownload(client, certURL, filePath) | 		ok, err := tryDownload(client, certURL, filePath) | ||||||
| 		if ok { | 		if ok { | ||||||
| 			putSuccess(client, certURL) | 			return true, putSuccess(client, certURL) | ||||||
| 			return true |  | ||||||
| 		} | 		} | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			Log.WithError(err).Error("Error fetching certificate") | 			return false, fmt.Errorf("fetching certificate: %v", err) | ||||||
| 			return false |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return false | 	return false, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err error) { | func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err error) { | ||||||
|  | @ -175,20 +179,19 @@ func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err er | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func putSuccess(client *http.Client, certURL string) { | func putSuccess(client *http.Client, certURL string) error { | ||||||
| 	// indicate success to the relay server
 | 	// indicate success to the relay server
 | ||||||
| 	req, err := http.NewRequest("PUT", certURL+"/ok", nil) | 	req, err := http.NewRequest("PUT", certURL+"/ok", nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		Log.WithError(err).Error("HTTP request error") | 		return fmt.Errorf("HTTP request error: %v", err) | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
| 	resp, err := client.Do(req) | 	resp, err := client.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		Log.WithError(err).Error("HTTP error") | 		return fmt.Errorf("HTTP error: %v", err) | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
| 	resp.Body.Close() | 	resp.Body.Close() | ||||||
| 	if resp.StatusCode != 200 { | 	if resp.StatusCode != 200 { | ||||||
| 		Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode) | 		return fmt.Errorf("unexpected HTTP status code %d", resp.StatusCode) | ||||||
| 	} | 	} | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
|  | @ -0,0 +1,186 @@ | ||||||
|  | package warp | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"encoding/hex" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"math/rand" | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 	"os" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/cloudflare/cloudflare-warp/origin" | ||||||
|  | 	"github.com/cloudflare/cloudflare-warp/tlsconfig" | ||||||
|  | 	tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" | ||||||
|  | 	"github.com/cloudflare/cloudflare-warp/validation" | ||||||
|  | 	homedir "github.com/mitchellh/go-homedir" | ||||||
|  | 	"github.com/sirupsen/logrus" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // StartServer starts a warp proxy server with the given configuration.
 | ||||||
|  | // It blocks indefinitely.
 | ||||||
|  | func StartServer(cfg ServerConfig) error { | ||||||
|  | 	hostname, err := validation.ValidateHostname(cfg.Hostname) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if cfg.ClientID == "" { | ||||||
|  | 		cfg.ClientID = generateRandomClientID() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	cfg.Tags = append(cfg.Tags, tunnelpogs.Tag{Name: "ID", Value: cfg.ClientID}) | ||||||
|  | 
 | ||||||
|  | 	cfg.ServerURL, err = validation.ValidateUrl(cfg.ServerURL) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("validating server URL: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Check that the user has acquired a certificate using the log in command
 | ||||||
|  | 	originCertPath, err := homedir.Expand(cfg.OriginCert) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("cannot resolve path %s: %v", cfg.OriginCert, err) | ||||||
|  | 	} | ||||||
|  | 	ok, err := fileExists(originCertPath) | ||||||
|  | 	if !ok { | ||||||
|  | 		return fmt.Errorf(`Cannot find a valid certificate for your origin at the path: | ||||||
|  | 
 | ||||||
|  |     %s | ||||||
|  | 
 | ||||||
|  | If the path above is wrong, specify the path with the -origincert option. | ||||||
|  | If you don't have a certificate signed by Cloudflare, run the command: | ||||||
|  | 
 | ||||||
|  |     %s login | ||||||
|  | `, originCertPath, os.Args[0]) // TODO - we need to improve how this is handled
 | ||||||
|  | 	} | ||||||
|  | 	// Easier to send the certificate as []byte via RPC than decoding it at this point
 | ||||||
|  | 	originCert, err := ioutil.ReadFile(originCertPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("cannot read %s to load origin certificate: %v", originCertPath, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tunnelMetrics := origin.NewTunnelMetrics() | ||||||
|  | 
 | ||||||
|  | 	httpTransport := &http.Transport{ | ||||||
|  | 		Proxy: http.ProxyFromEnvironment, | ||||||
|  | 		DialContext: (&net.Dialer{ | ||||||
|  | 			Timeout:   cfg.Timeout, | ||||||
|  | 			KeepAlive: cfg.KeepAlive, | ||||||
|  | 			DualStack: cfg.DualStack, | ||||||
|  | 		}).DialContext, | ||||||
|  | 		MaxIdleConns:          cfg.MaxIdleConns, | ||||||
|  | 		IdleConnTimeout:       cfg.IdleConnTimeout, | ||||||
|  | 		TLSHandshakeTimeout:   cfg.TLSHandshakeTimeout, | ||||||
|  | 		ExpectContinueTimeout: 1 * time.Second, | ||||||
|  | 		TLSClientConfig:       &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	tunnelConfig := &origin.TunnelConfig{ | ||||||
|  | 		EdgeAddrs:         cfg.EdgeAddrs, | ||||||
|  | 		OriginUrl:         cfg.ServerURL, | ||||||
|  | 		Hostname:          hostname, | ||||||
|  | 		OriginCert:        originCert, | ||||||
|  | 		TlsConfig:         tlsconfig.CreateTunnelConfig(cfg.TLSConfig, cfg.EdgeAddrs), | ||||||
|  | 		ClientTlsConfig:   httpTransport.TLSClientConfig, | ||||||
|  | 		Retries:           cfg.Retries, | ||||||
|  | 		HeartbeatInterval: cfg.HeartbeatInterval, | ||||||
|  | 		MaxHeartbeats:     cfg.MaxHeartbeats, | ||||||
|  | 		ClientID:          cfg.ClientID, | ||||||
|  | 		ReportedVersion:   cfg.ReportedVersion, | ||||||
|  | 		LBPool:            cfg.LBPool, | ||||||
|  | 		Tags:              cfg.Tags, | ||||||
|  | 		HAConnections:     cfg.HAConnections, | ||||||
|  | 		HTTPTransport:     httpTransport, | ||||||
|  | 		Metrics:           tunnelMetrics, | ||||||
|  | 		MetricsUpdateFreq: cfg.MetricsUpdateFreq, | ||||||
|  | 		ProtocolLogger:    cfg.ProtoLogger, | ||||||
|  | 		Logger:            cfg.Logger, | ||||||
|  | 		IsAutoupdated:     cfg.IsAutoupdated, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// blocking
 | ||||||
|  | 	return origin.StartTunnelDaemon(tunnelConfig, cfg.ShutdownChan, cfg.ConnectedChan) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func fileExists(path string) (bool, error) { | ||||||
|  | 	f, err := os.Open(path) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if os.IsNotExist(err) { | ||||||
|  | 			// ignore missing files
 | ||||||
|  | 			return false, nil | ||||||
|  | 		} | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 	f.Close() | ||||||
|  | 	return true, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func generateRandomClientID() string { | ||||||
|  | 	r := rand.New(rand.NewSource(time.Now().UnixNano())) | ||||||
|  | 	id := make([]byte, 32) | ||||||
|  | 	r.Read(id) | ||||||
|  | 	return hex.EncodeToString(id) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ServerConfig specifies a warp proxy-server configuration.
 | ||||||
|  | type ServerConfig struct { | ||||||
|  | 	// The hostname on a Cloudflare zone with which route
 | ||||||
|  | 	// traffic through this tunnel.
 | ||||||
|  | 	// Required.
 | ||||||
|  | 	Hostname string | ||||||
|  | 
 | ||||||
|  | 	// The URL of the local web server. If empty (if there
 | ||||||
|  | 	// is no server), set HelloWorld to true for a demo.
 | ||||||
|  | 	// Required.
 | ||||||
|  | 	ServerURL string | ||||||
|  | 
 | ||||||
|  | 	// If true, use the established tunnel to expose a
 | ||||||
|  | 	// test HTTP server. If false, ServerURL must be set.
 | ||||||
|  | 	HelloWorld bool | ||||||
|  | 
 | ||||||
|  | 	// The tunnel ID; leave blank to use a random ID.
 | ||||||
|  | 	ClientID string | ||||||
|  | 
 | ||||||
|  | 	// Custom tags to identify this tunnel
 | ||||||
|  | 	Tags []tunnelpogs.Tag | ||||||
|  | 
 | ||||||
|  | 	// Specifies the Warp certificate for one of your zones,
 | ||||||
|  | 	// authorizing the client to serve as an origin for that zone.
 | ||||||
|  | 	// A certificate is required to use Warp. You can obtain a
 | ||||||
|  | 	// certificate by using the login command or by visiting
 | ||||||
|  | 	// https://www.cloudflare.com/a/warp.
 | ||||||
|  | 	OriginCert string | ||||||
|  | 
 | ||||||
|  | 	// The channel to close when the tunnel is connected.
 | ||||||
|  | 	ConnectedChan chan struct{} | ||||||
|  | 
 | ||||||
|  | 	// The channel to close when shutting down.
 | ||||||
|  | 	ShutdownChan chan struct{} | ||||||
|  | 
 | ||||||
|  | 	Timeout   time.Duration // proxy-connect-timeout
 | ||||||
|  | 	KeepAlive time.Duration // proxy-tcp-keepalive
 | ||||||
|  | 	DualStack bool          // proxy-no-happy-eyeballs
 | ||||||
|  | 
 | ||||||
|  | 	MaxIdleConns        int           // proxy-keepalive-connections
 | ||||||
|  | 	IdleConnTimeout     time.Duration // proxy-keepalive-timeout
 | ||||||
|  | 	TLSHandshakeTimeout time.Duration // proxy-tls-timeout
 | ||||||
|  | 
 | ||||||
|  | 	EdgeAddrs         []string      // edge
 | ||||||
|  | 	Retries           uint          // retries
 | ||||||
|  | 	HeartbeatInterval time.Duration // heartbeat-interval
 | ||||||
|  | 	MaxHeartbeats     uint64        // heartbeat-count
 | ||||||
|  | 	LBPool            string        // lb-pool
 | ||||||
|  | 	HAConnections     int           // ha-connections
 | ||||||
|  | 	MetricsUpdateFreq time.Duration // metrics-update-freq
 | ||||||
|  | 	IsAutoupdated     bool          // is-autoupdated
 | ||||||
|  | 
 | ||||||
|  | 	// The TLS client config used when making the tunnel.
 | ||||||
|  | 	TLSConfig *tls.Config | ||||||
|  | 
 | ||||||
|  | 	// The version of the client to report
 | ||||||
|  | 	ReportedVersion string | ||||||
|  | 
 | ||||||
|  | 	ProtoLogger *logrus.Logger | ||||||
|  | 	Logger      *logrus.Logger | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue