package hello
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"html/template"
"io/ioutil"
"net"
"net/http"
"os"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/tlsconfig"
)
const (
UptimeRoute = "/uptime"
WSRoute = "/ws"
SSERoute = "/sse"
HealthRoute = "/_health"
defaultSSEFreq = time.Second * 10
)
type templateData struct {
ServerName string
Request *http.Request
Body string
}
type OriginUpTime struct {
StartTime time.Time `json:"startTime"`
UpTime string `json:"uptime"`
}
const defaultServerName = "the Argo Tunnel test server"
const indexTemplate = `
Argo Tunnel Connection
Congrats! You created a tunnel!
Argo Tunnel exposes locally running applications to the internet by
running an encrypted, virtual tunnel from your laptop or server to
Cloudflare's edge network.
Ready for the next step?
Get started here
Request
Method: {{.Request.Method}}
Protocol: {{.Request.Proto}}
Request URL: {{.Request.URL}}
Transfer encoding: {{.Request.TransferEncoding}}
Host: {{.Request.Host}}
Remote address: {{.Request.RemoteAddr}}
Request URI: {{.Request.RequestURI}}
{{range $key, $value := .Request.Header}}
Header: {{$key}}, Value: {{$value}}
{{end}}
Body: {{.Body}}
`
func StartHelloWorldServer(log *zerolog.Logger, listener net.Listener, shutdownC <-chan struct{}) error {
log.Info().Msgf("Starting Hello World server at %s", listener.Addr())
serverName := defaultServerName
if hostname, err := os.Hostname(); err == nil {
serverName = hostname
}
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
muxer := http.NewServeMux()
muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now()))
muxer.HandleFunc(WSRoute, websocketHandler(log, upgrader))
muxer.HandleFunc(SSERoute, sseHandler(log))
muxer.HandleFunc(HealthRoute, healthHandler())
muxer.HandleFunc("/", rootHandler(serverName))
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer}
go func() {
<-shutdownC
_ = httpServer.Close()
}()
err := httpServer.Serve(listener)
return err
}
func CreateTLSListener(address string) (net.Listener, error) {
certificate, err := tlsconfig.GetHelloCertificate()
if err != nil {
return nil, err
}
// If the port in address is empty, a port number is automatically chosen
listener, err := tls.Listen(
"tcp",
address,
&tls.Config{Certificates: []tls.Certificate{certificate}})
return listener, err
}
func uptimeHandler(startTime time.Time) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Note that if autoupdate is enabled, the uptime is reset when a new client
// release is available
resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
respJson, err := json.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
} else {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(respJson)
}
}
}
// This handler will echo message
func websocketHandler(log *zerolog.Logger, upgrader websocket.Upgrader) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// This addresses the issue of r.Host includes port but origin header doesn't
host, _, err := net.SplitHostPort(r.Host)
if err == nil {
r.Host = host
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Error().Msgf("failed to upgrade to websocket connection, error: %s", err)
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
log.Error().Msgf("websocket read message error: %s", err)
break
}
if err := conn.WriteMessage(mt, message); err != nil {
log.Error().Msgf("websocket write message error: %s", err)
break
}
}
}
}
func sseHandler(log *zerolog.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
flusher, ok := w.(http.Flusher)
if !ok {
w.WriteHeader(http.StatusInternalServerError)
log.Error().Msgf("Can't support SSE. ResponseWriter %T doesn't implement http.Flusher interface", w)
return
}
freq := defaultSSEFreq
if requestedFreq := r.URL.Query()["freq"]; len(requestedFreq) > 0 {
parsedFreq, err := time.ParseDuration(requestedFreq[0])
if err == nil {
freq = parsedFreq
}
}
log.Info().Msgf("Server Sent Events every %s", freq)
ticker := time.NewTicker(freq)
counter := 0
for {
select {
case <-r.Context().Done():
return
case <-ticker.C:
}
_, err := fmt.Fprintf(w, "%d\n\n", counter)
if err != nil {
return
}
flusher.Flush()
counter++
}
}
}
func healthHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}
}
func rootHandler(serverName string) http.HandlerFunc {
responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
return func(w http.ResponseWriter, r *http.Request) {
var buffer bytes.Buffer
var body string
rawBody, err := ioutil.ReadAll(r.Body)
if err == nil {
body = string(rawBody)
} else {
body = ""
}
err = responseTemplate.Execute(&buffer, &templateData{
ServerName: serverName,
Request: r,
Body: body,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = fmt.Fprintf(w, "error: %v", err)
} else {
_, _ = buffer.WriteTo(w)
}
}
}