package origin import ( "crypto/tls" "fmt" "io" "net" "net/http" "net/url" "runtime" "strings" "time" "golang.org/x/net/context" "github.com/cloudflare/cloudflare-warp/h2mux" "github.com/cloudflare/cloudflare-warp/tunnelrpc" tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" "github.com/cloudflare/cloudflare-warp/validation" log "github.com/Sirupsen/logrus" raven "github.com/getsentry/raven-go" "github.com/pkg/errors" rpc "zombiezen.com/go/capnproto2/rpc" ) const ( dialTimeout = 15 * time.Second TagHeaderNamePrefix = "Cf-Warp-Tag-" ) type TunnelConfig struct { EdgeAddr string OriginUrl string Hostname string OriginCert []byte TlsConfig *tls.Config Retries uint HeartbeatInterval time.Duration MaxHeartbeats uint64 ClientID string ReportedVersion string LBPool string Tags []tunnelpogs.Tag ConnectedSignal h2mux.Signal } type dialError struct { cause error } func (e dialError) Error() string { return e.cause.Error() } type printableRegisterTunnelError struct { cause error permanent bool } func (e printableRegisterTunnelError) Error() string { return e.cause.Error() } func (c *TunnelConfig) RegistrationOptions() *tunnelpogs.RegistrationOptions { policy := tunnelrpc.ExistingTunnelPolicy_disconnect if c.LBPool != "" { policy = tunnelrpc.ExistingTunnelPolicy_balance } return &tunnelpogs.RegistrationOptions{ ClientID: c.ClientID, Version: c.ReportedVersion, OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH), ExistingTunnelPolicy: policy, PoolID: c.LBPool, Tags: c.Tags, } } func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}) error { ctx, cancel := context.WithCancel(context.Background()) go func() { <-shutdownC cancel() }() backoff := BackoffHandler{MaxRetries: config.Retries} for { err, recoverable := ServeTunnel(ctx, config, &backoff) if recoverable { if duration, ok := backoff.GetBackoffDuration(ctx); ok { log.Infof("Retrying in %s seconds", duration) backoff.Backoff(ctx) continue } } return err } } func ServeTunnel( ctx context.Context, config *TunnelConfig, backoff *BackoffHandler, ) (err error, recoverable bool) { // Returns error from parsing the origin URL or handshake errors handler, err := NewTunnelHandler(ctx, config) if err != nil { errLog := log.WithError(err) switch err.(type) { case dialError: errLog.Error("Unable to dial edge") case h2mux.MuxerHandshakeError: errLog.Error("Handshake failed with edge server") default: errLog.Error("Tunnel creation failure") return err, false } return err, true } serveCtx, serveCancel := context.WithCancel(ctx) registerErrC := make(chan error, 1) go func() { err := RegisterTunnel(serveCtx, handler.muxer, config) if err == nil { config.ConnectedSignal.Signal() backoff.SetGracePeriod() } else { serveCancel() } registerErrC <- err }() go func() { <-serveCtx.Done() handler.muxer.Shutdown() }() err = handler.muxer.Serve() serveCancel() registerErr := <-registerErrC if err != nil { log.WithError(err).Error("Tunnel error") return err, true } if registerErr != nil { // Don't retry on errors like entitlement failure or version too old if e, ok := registerErr.(printableRegisterTunnelError); ok { log.Error(e) if e.permanent { return nil, false } return e.cause, true } // Only log errors to Sentry that may have been caused by the client side, to reduce dupes raven.CaptureError(registerErr, nil) log.Error("Cannot register") return err, true } return nil, false } func IsRPCStreamResponse(headers []h2mux.Header) bool { if len(headers) != 1 { return false } if headers[0].Name != ":status" || headers[0].Value != "200" { return false } return true } func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig) error { logger := log.WithField("subsystem", "rpc") logger.Debug("initiating RPC stream") stream, err := muxer.OpenStream([]h2mux.Header{ {Name: ":method", Value: "RPC"}, {Name: ":scheme", Value: "capnp"}, {Name: ":path", Value: "*"}, }, nil) if err != nil { // RPC stream open error raven.CaptureError(err, nil) return err } if !IsRPCStreamResponse(stream.Headers) { // stream response error raven.CaptureError(err, nil) return err } conn := rpc.NewConn( tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)), tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")), ) defer conn.Close() ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)} // Request server info without blocking tunnel registration; must use capnp library directly. tsClient := tunnelrpc.TunnelServer{Client: ts.Client} serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { return nil }) registration, err := ts.RegisterTunnel( ctx, config.OriginCert, config.Hostname, config.RegistrationOptions(), ) LogServerInfo(logger, serverInfoPromise.Result()) if err != nil { // RegisterTunnel RPC failure return err } for _, logLine := range registration.LogLines { logger.Info(logLine) } if registration.Err != "" { return printableRegisterTunnelError{ cause: fmt.Errorf("Server error: %s", registration.Err), permanent: registration.PermanentFailure, } } log.Infof("Registered at %s", registration.Url) for _, logLine := range registration.LogLines { log.Infof(logLine) } return nil } func LogServerInfo(logger *log.Entry, promise tunnelrpc.ServerInfo_Promise) { serverInfoMessage, err := promise.Struct() if err != nil { logger.WithError(err).Warn("Failed to retrieve server information") return } serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage) if err != nil { logger.WithError(err).Warn("Failed to retrieve server information") return } log.Infof("Connected to %s", serverInfo.LocationName) } type TunnelHandler struct { originUrl string muxer *h2mux.Muxer httpClient *http.Client tags []tunnelpogs.Tag } var dialer = net.Dialer{DualStack: true} func NewTunnelHandler(ctx context.Context, config *TunnelConfig) (*TunnelHandler, error) { url, err := validation.ValidateUrl(config.OriginUrl) if err != nil { return nil, fmt.Errorf("Unable to parse origin url %#v", url) } h := &TunnelHandler{ originUrl: url, httpClient: &http.Client{Timeout: time.Minute}, tags: config.Tags, } // Inherit from parent context so we can cancel (Ctrl-C) while dialing dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout) // TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one) plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", config.EdgeAddr) dialCancel() if err != nil { return nil, dialError{cause: err} } edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig) edgeConn.SetDeadline(time.Now().Add(dialTimeout)) err = edgeConn.Handshake() if err != nil { return nil, dialError{cause: err} } // clear the deadline on the conn; h2mux has its own timeouts edgeConn.SetDeadline(time.Time{}) // Establish a muxed connection with the edge // Client mux handshake with agent server h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{ Timeout: 5 * time.Second, Handler: h, IsClient: true, HeartbeatInterval: config.HeartbeatInterval, MaxHeartbeats: config.MaxHeartbeats, }) if err != nil { return h, errors.New("TLS handshake error") } return h, err } func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error { for _, header := range h2 { switch header.Name { case ":method": h1.Method = header.Value case ":scheme": case ":authority": // Otherwise the host header will be based on the origin URL h1.Host = header.Value case ":path": u, err := url.Parse(header.Value) if err != nil { return fmt.Errorf("unparseable path") } resolved := h1.URL.ResolveReference(u) // prevent escaping base URL if !strings.HasPrefix(resolved.String(), h1.URL.String()) { return fmt.Errorf("invalid path") } h1.URL = resolved default: h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value) } } return nil } func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) { h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}} for headerName, headerValues := range h1.Header { for _, headerValue := range headerValues { h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue}) } } return } func (h *TunnelHandler) AppendTagHeaders(r *http.Request) { for _, tag := range h.tags { r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value) } } func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) if err != nil { log.WithError(err).Panic("Unexpected error from http.NewRequest") } err = H2RequestHeadersToH1Request(stream.Headers, req) if err != nil { log.WithError(err).Error("invalid request received") } h.AppendTagHeaders(req) response, err := h.httpClient.Do(req) if err != nil { log.WithError(err).Error("HTTP request error") stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}}) stream.Write([]byte("502 Bad Gateway")) } else { defer response.Body.Close() stream.WriteHeaders(H1ResponseToH2Response(response)) io.Copy(stream, response.Body) } return nil }