From 4153b708a40bcab48b8eabd3536d48450df5fc77 Mon Sep 17 00:00:00 2001 From: cloudflare-warp-bot Date: Mon, 2 Apr 2018 20:43:11 +0000 Subject: [PATCH] Release Argo Tunnel Client 2018.4.0 --- cmd/cloudflared/main.go | 58 +++++++++++++++++++++++++++++++++++---- origin/supervisor.go | 61 +++++++++++++++++++++++++++++++++++++++-- origin/tunnel.go | 24 +++++++++++++++- tunneldns/tunnel.go | 3 +- 4 files changed, 135 insertions(+), 11 deletions(-) diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go index cc57c185..46905e4c 100644 --- a/cmd/cloudflared/main.go +++ b/cmd/cloudflared/main.go @@ -43,6 +43,8 @@ const ( quickStartUrl = "https://developers.cloudflare.com/argo-tunnel/quickstart/quickstart/" noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/" licenseUrl = "https://developers.cloudflare.com/argo-tunnel/licence/" + minDNSInitWait = time.Second * 15 + minPingFreq = time.Second * 2 ) var listeners = gracenet.Net{} @@ -288,9 +290,34 @@ func main() { altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ Name: "proxy-dns-upstream", Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.", - Value: cli.NewStringSlice("https://cloudflare-dns.com/.well-known/dns-query"), + Value: cli.NewStringSlice("https://cloudflare-dns.com/dns-query"), EnvVars: []string{"TUNNEL_DNS_UPSTREAM"}, }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: "skip-hostname-propagation-check", + Usage: "Flag to instruct cloudflared to skip checking whether DNS record for the hostname has been propagated.", + EnvVars: []string{"TUNNEL_SKIP_HOSTNAME_PROPAGATION_CHECK"}, + }), + altsrc.NewUintFlag(&cli.UintFlag{ + Name: "hostname-propagated-retries", + Usage: "How many pings to test whether send DNS record has been propagated before reregistering tunnel", + Value: 25, + EnvVars: []string{"TUNNEL_HOSTNAME_PROPAGATED_RETRIES"}, + }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "init-wait-time", + Usage: "Initial waiting time to checking whether DNS record has propagated", + Value: minDNSInitWait, + EnvVars: []string{"TUNNEL_INIT_WAIT_TIME"}, + Hidden: true, + }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "ping-freq", + Usage: "Ping frequency for checking DNS record has propagated", + Value: minPingFreq, + EnvVars: []string{"TUNNEL_PING_FREQ"}, + Hidden: true, + }), } app.Action = func(c *cli.Context) error { raven.CapturePanic(func() { startServer(c) }, nil) @@ -375,7 +402,7 @@ func main() { &cli.StringSliceFlag{ Name: "upstream", Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.", - Value: cli.NewStringSlice("https://cloudflare-dns.com/.well-known/dns-query"), + Value: cli.NewStringSlice("https://cloudflare-dns.com/dns-query"), EnvVars: []string{"TUNNEL_DNS_UPSTREAM"}, }, }, @@ -460,11 +487,15 @@ func startServer(c *cli.Context) { listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(c.Uint("proxy-dns-port")), c.StringSlice("proxy-dns-upstream")) if err != nil { listener.Stop() - Log.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server") + Log.WithError(err).Fatal("Cannot create the DNS over HTTPS proxy server") } go func() { - listener.Start() - <-shutdownC + err := listener.Start() + if err != nil { + Log.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server") + } else { + <-shutdownC + } listener.Stop() wg.Done() }() @@ -547,6 +578,7 @@ If you don't have a certificate signed by Cloudflare, run the command: ProtocolLogger: protoLogger, Logger: Log, IsAutoupdated: c.Bool("is-autoupdated"), + DNSValidationConfig: getDNSValidationConfig(c), } connectedSignal := make(chan struct{}) @@ -792,3 +824,19 @@ func isAutoupdateEnabled(c *cli.Context) bool { return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 } + +func getDNSValidationConfig(c *cli.Context) *origin.DNSValidationConfig { + dnsValidationConfig := &origin.DNSValidationConfig{ + VerifyDNSPropagated: !c.Bool("skip-hostname-propagation-check"), + DNSPingRetries: c.Uint("hostname-propagated-retries"), + DNSInitWaitTime: c.Duration("init-wait-time"), + PingFreq: c.Duration("ping-freq"), + } + if dnsValidationConfig.DNSInitWaitTime < minDNSInitWait { + dnsValidationConfig.DNSInitWaitTime = minDNSInitWait + } + if dnsValidationConfig.PingFreq < minPingFreq { + dnsValidationConfig.PingFreq = minPingFreq + } + return dnsValidationConfig +} diff --git a/origin/supervisor.go b/origin/supervisor.go index f9e75bd5..5856d427 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -3,9 +3,11 @@ package origin import ( "fmt" "net" + "net/http" "time" "golang.org/x/net/context" + "github.com/pkg/errors" ) const ( @@ -13,6 +15,8 @@ const ( tunnelRetryDuration = time.Second * 10 // SRV record resolution TTL resolveTTL = time.Hour + // Interval between registering new tunnels + registrationInterval = time.Second ) type Supervisor struct { @@ -137,12 +141,16 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct return tunnelError.err case <-connectedSignal: } + if s.config.VerifyDNSPropagated { + err = s.verifyDNSPropagated(ctx) + if err != nil { + return errors.Wrap(err, "Failed to register tunnel") + } + } // At least one successful connection, so start the rest for i := 1; i < s.config.HAConnections; i++ { go s.startTunnel(ctx, i, make(chan struct{})) - // TODO: Add artificial delay between HA connections to make sure all origins - // are registered in LB pool. Temporary fix until we fix LB - time.Sleep(time.Millisecond * 500) + time.Sleep(registrationInterval) } return nil } @@ -175,6 +183,53 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan } } +func (s *Supervisor) verifyDNSPropagated(ctx context.Context) (err error) { + Log.Infof("Waiting for %s DNS record to propagate...", s.config.Hostname) + time.Sleep(s.config.DNSInitWaitTime) + var lastResponseStatus string + tickC := time.Tick(s.config.PingFreq) + req, client, err := s.createPingRequestAndClient() + if err != nil { + return fmt.Errorf("Cannot create GET request to %s", s.config.Hostname) + } + for i := 0; i < int(s.config.DNSPingRetries); i++ { + select { + case <-ctx.Done(): + return fmt.Errorf("Context was canceled") + case <-tickC: + resp, err := client.Do(req) + if err != nil { + continue + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + Log.Infof("Tunnel created and available at %s", s.config.Hostname) + return nil + } + if i == 0 { + Log.Infof("First ping to origin through Argo Tunnel returned %s", resp.Status) + } + lastResponseStatus = resp.Status + } + } + Log.Infof("Last ping to origin through Argo Tunnel returned %s", lastResponseStatus) + return fmt.Errorf("Exceed DNS record validation retry limit") +} + +func (s *Supervisor) createPingRequestAndClient() (*http.Request, *http.Client, error) { + url := fmt.Sprintf("https://%s",s.config.Hostname) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, nil, err + } + req.Header.Add(CloudflaredPingHeader, s.config.ClientID) + transport := s.config.HTTPTransport + if transport == nil { + transport = http.DefaultTransport + } + return req, &http.Client{Transport: transport}, nil +} + // startTunnel starts a new tunnel connection. The resulting error will be sent on // s.tunnelErrors. func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}) { diff --git a/origin/tunnel.go b/origin/tunnel.go index e75988b4..b68b5002 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -35,8 +35,16 @@ const ( TagHeaderNamePrefix = "Cf-Warp-Tag-" DuplicateConnectionError = "EDUPCONN" + CloudflaredPingHeader = "Cloudflard-Ping" ) +type DNSValidationConfig struct { + VerifyDNSPropagated bool + DNSPingRetries uint + DNSInitWaitTime time.Duration + PingFreq time.Duration +} + type TunnelConfig struct { EdgeAddrs []string OriginUrl string @@ -58,6 +66,7 @@ type TunnelConfig struct { ProtocolLogger *logrus.Logger Logger *logrus.Logger IsAutoupdated bool + *DNSValidationConfig } type dialError struct { @@ -297,7 +306,6 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi } } - Log.Infof("Registered at %s", registration.Url) return nil } @@ -361,6 +369,7 @@ func FindCfRayHeader(h1 *http.Request) string { return h1.Header.Get("Cf-Ray") } + type TunnelHandler struct { originUrl string muxer *h2mux.Muxer @@ -370,6 +379,7 @@ type TunnelHandler struct { metrics *tunnelMetrics // connectionID is only used by metrics, and prometheus requires labels to be string connectionID string + clientID string } var dialer = net.Dialer{DualStack: true} @@ -387,6 +397,7 @@ func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, co tags: config.Tags, metrics: config.Metrics, connectionID: uint8ToString(connectionID), + clientID: config.ClientID, } if h.httpClient == nil { h.httpClient = http.DefaultTransport @@ -442,6 +453,10 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { h.AppendTagHeaders(req) cfRay := FindCfRayHeader(req) h.logRequest(req, cfRay) + if h.isCloudflaredPing(req) { + stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", http.StatusOK)}}) + return nil + } if websocket.IsWebSocketUpgrade(req) { conn, response, err := websocket.ClientConnect(req, h.tlsConfig) if err != nil { @@ -469,6 +484,13 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { return nil } +func (h *TunnelHandler) isCloudflaredPing(h1 *http.Request) bool { + if h1.Header.Get(CloudflaredPingHeader) == h.clientID { + return true + } + return false +} + func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) { Log.WithError(err).Error("HTTP request error") stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}}) diff --git a/tunneldns/tunnel.go b/tunneldns/tunnel.go index f4b5eed7..145f1285 100644 --- a/tunneldns/tunnel.go +++ b/tunneldns/tunnel.go @@ -1,7 +1,6 @@ package tunneldns import ( - "fmt" "net" "os" "os/signal" @@ -132,7 +131,7 @@ func CreateListener(address string, port uint16, upstreams []string) (*Listener, } // Format an endpoint - endpoint := fmt.Sprintf("dns://%s:%d", address, port) + endpoint := "dns://" + net.JoinHostPort(address, strconv.FormatUint(uint64(port), 10)) // Create the actual middleware server server, err := dnsserver.NewServer(endpoint, []*dnsserver.Config{createConfig(address, port, NewMetricsPlugin(chain))})