package hello import ( "bytes" "crypto/tls" "encoding/json" "fmt" "html/template" "io/ioutil" "net" "net/http" "os" "time" "github.com/gorilla/websocket" "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/tlsconfig" ) const ( UptimeRoute = "/uptime" WSRoute = "/ws" SSERoute = "/sse" HealthRoute = "/_health" defaultSSEFreq = time.Second * 10 ) type templateData struct { ServerName string Request *http.Request Body string } type OriginUpTime struct { StartTime time.Time `json:"startTime"` UpTime string `json:"uptime"` } const defaultServerName = "the Cloudflare Tunnel test server" const indexTemplate = ` Cloudflare Tunnel Connection

Congrats! You created a tunnel!

Cloudflare Tunnel exposes locally running applications to the internet by running an encrypted, virtual tunnel from your laptop or server to Cloudflare's edge network.

Ready for the next step?

Get started here

Request

Method: {{.Request.Method}}
Protocol: {{.Request.Proto}}
Request URL: {{.Request.URL}}
Transfer encoding: {{.Request.TransferEncoding}}
Host: {{.Request.Host}}
Remote address: {{.Request.RemoteAddr}}
Request URI: {{.Request.RequestURI}}
{{range $key, $value := .Request.Header}}
Header: {{$key}}, Value: {{$value}}
{{end}}
Body: {{.Body}}
` func StartHelloWorldServer(log *zerolog.Logger, listener net.Listener, shutdownC <-chan struct{}) error { log.Info().Msgf("Starting Hello World server at %s", listener.Addr()) serverName := defaultServerName if hostname, err := os.Hostname(); err == nil { serverName = hostname } upgrader := websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, } muxer := http.NewServeMux() muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now())) muxer.HandleFunc(WSRoute, websocketHandler(log, upgrader)) muxer.HandleFunc(SSERoute, sseHandler(log)) muxer.HandleFunc(HealthRoute, healthHandler()) muxer.HandleFunc("/", rootHandler(serverName)) httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer} go func() { <-shutdownC _ = httpServer.Close() }() err := httpServer.Serve(listener) return err } func CreateTLSListener(address string) (net.Listener, error) { certificate, err := tlsconfig.GetHelloCertificate() if err != nil { return nil, err } // If the port in address is empty, a port number is automatically chosen listener, err := tls.Listen( "tcp", address, &tls.Config{Certificates: []tls.Certificate{certificate}}) return listener, err } func uptimeHandler(startTime time.Time) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Note that if autoupdate is enabled, the uptime is reset when a new client // release is available resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()} respJson, err := json.Marshal(resp) if err != nil { w.WriteHeader(http.StatusInternalServerError) } else { w.Header().Set("Content-Type", "application/json") _, _ = w.Write(respJson) } } } // This handler will echo message func websocketHandler(log *zerolog.Logger, upgrader websocket.Upgrader) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // This addresses the issue of r.Host includes port but origin header doesn't host, _, err := net.SplitHostPort(r.Host) if err == nil { r.Host = host } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Err(err).Msg("failed to upgrade to websocket connection") return } defer conn.Close() for { mt, message, err := conn.ReadMessage() if err != nil { log.Err(err).Msg("websocket read message error") break } if err := conn.WriteMessage(mt, message); err != nil { log.Err(err).Msg("websocket write message error") break } } } } func sseHandler(log *zerolog.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") flusher, ok := w.(http.Flusher) if !ok { w.WriteHeader(http.StatusInternalServerError) log.Error().Msgf("Can't support SSE. ResponseWriter %T doesn't implement http.Flusher interface", w) return } freq := defaultSSEFreq if requestedFreq := r.URL.Query()["freq"]; len(requestedFreq) > 0 { parsedFreq, err := time.ParseDuration(requestedFreq[0]) if err == nil { freq = parsedFreq } } log.Info().Msgf("Server Sent Events every %s", freq) ticker := time.NewTicker(freq) counter := 0 for { select { case <-r.Context().Done(): return case <-ticker.C: } _, err := fmt.Fprintf(w, "%d\n\n", counter) if err != nil { return } flusher.Flush() counter++ } } } func healthHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) } } func rootHandler(serverName string) http.HandlerFunc { responseTemplate := template.Must(template.New("index").Parse(indexTemplate)) return func(w http.ResponseWriter, r *http.Request) { var buffer bytes.Buffer var body string rawBody, err := ioutil.ReadAll(r.Body) if err == nil { body = string(rawBody) } else { body = "" } err = responseTemplate.Execute(&buffer, &templateData{ ServerName: serverName, Request: r, Body: body, }) if err != nil { w.WriteHeader(http.StatusInternalServerError) _, _ = fmt.Fprintf(w, "error: %v", err) } else { _, _ = buffer.WriteTo(w) } } }