codies/main.go

287 lines
6.6 KiB
Go
Raw Normal View History

2020-05-23 21:51:11 +00:00
package main
import (
"context"
"encoding/json"
"fmt"
2020-05-23 21:51:11 +00:00
"log"
"math/rand"
"net/http"
"os"
"reflect"
"time"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/gofrs/uuid"
"github.com/jessevdk/go-flags"
"github.com/posener/ctxutil"
"github.com/tomwright/queryparam/v4"
"github.com/zikaeroh/codies/internal/protocol"
"github.com/zikaeroh/codies/internal/server"
2020-05-23 22:14:58 +00:00
"github.com/zikaeroh/codies/internal/version"
2020-05-23 21:51:11 +00:00
"golang.org/x/sync/errgroup"
"nhooyr.io/websocket"
)
var args = struct {
Addr string `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
Prod bool `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
2020-05-23 21:51:11 +00:00
Debug bool `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
}{
Addr: ":5000",
}
var wsOpts *websocket.AcceptOptions
2020-05-23 21:51:11 +00:00
func main() {
2020-05-23 22:14:58 +00:00
rand.Seed(time.Now().Unix())
log.SetFlags(log.LstdFlags | log.Lshortfile)
2020-05-23 21:51:11 +00:00
if _, err := flags.Parse(&args); err != nil {
// Default flag parser prints messages, so just exit.
os.Exit(1)
}
if !args.Prod && !args.Debug {
log.Fatal("missing required option --prod or --debug")
} else if args.Prod && args.Debug {
log.Fatal("must specify either --prod or --debug")
}
2020-05-23 22:14:58 +00:00
log.Printf("starting codies server, version %s", version.Version())
wsOpts = &websocket.AcceptOptions{
2020-05-23 21:51:11 +00:00
OriginPatterns: args.Origins,
}
if args.Debug {
log.Println("starting in debug mode, allowing any WebSocket origin host")
wsOpts.OriginPatterns = []string{"*"}
} else {
if !version.VersionSet() {
log.Fatal("running production build without version set")
}
2020-05-23 21:51:11 +00:00
}
g, ctx := errgroup.WithContext(ctxutil.Interrupt())
srv := server.NewServer()
r := chi.NewMux()
r.Use(middleware.Heartbeat("/ping"))
r.Use(middleware.Recoverer)
r.NotFound(staticRouter().ServeHTTP)
2020-05-23 21:51:11 +00:00
r.Group(func(r chi.Router) {
r.Use(middleware.NoCache)
r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(&protocol.TimeResponse{Time: time.Now()})
})
r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
rooms, clients := srv.Stats()
2020-05-23 21:51:11 +00:00
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
_ = enc.Encode(&protocol.StatsResponse{
Rooms: rooms,
Clients: clients,
})
2020-05-23 21:51:11 +00:00
})
r.Group(func(r chi.Router) {
if !args.Debug {
r.Use(checkVersion)
2020-05-23 21:51:11 +00:00
}
r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
query := &protocol.ExistsQuery{}
if err := queryparam.Parse(r.URL.Query(), query); err != nil {
httpErr(w, http.StatusBadRequest)
return
2020-05-23 21:51:11 +00:00
}
room := srv.FindRoomByID(query.RoomID)
if room == nil {
2020-05-23 21:51:11 +00:00
w.WriteHeader(http.StatusNotFound)
} else {
w.WriteHeader(http.StatusOK)
}
_, _ = w.Write([]byte("."))
})
2020-05-23 21:51:11 +00:00
r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
2020-05-23 21:51:11 +00:00
req := &protocol.RoomRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
httpErr(w, http.StatusBadRequest)
return
}
2020-05-23 21:51:11 +00:00
if !req.Valid() {
httpErr(w, http.StatusBadRequest)
return
}
2020-05-23 21:51:11 +00:00
resp := &protocol.RoomResponse{}
w.Header().Add("Content-Type", "application/json")
if req.Create {
room, err := srv.CreateRoom(req.RoomName, req.RoomPass)
if err != nil {
switch err {
case server.ErrRoomExists:
resp.Error = stringPtr("Room already exists.")
w.WriteHeader(http.StatusBadRequest)
case server.ErrTooManyRooms:
resp.Error = stringPtr("Too many rooms.")
w.WriteHeader(http.StatusServiceUnavailable)
default:
resp.Error = stringPtr("An unknown error occurred.")
w.WriteHeader(http.StatusInternalServerError)
}
} else {
resp.ID = &room.ID
w.WriteHeader(http.StatusOK)
}
} else {
room := srv.FindRoom(req.RoomName)
if room == nil || room.Password != req.RoomPass {
resp.Error = stringPtr("Room not found or password does not match.")
w.WriteHeader(http.StatusNotFound)
} else {
resp.ID = &room.ID
w.WriteHeader(http.StatusOK)
}
}
2020-05-23 21:51:11 +00:00
_ = json.NewEncoder(w).Encode(resp)
2020-05-23 21:51:11 +00:00
})
r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
query := &protocol.WSQuery{}
if err := queryparam.Parse(r.URL.Query(), query); err != nil {
httpErr(w, http.StatusBadRequest)
return
}
2020-05-23 21:51:11 +00:00
if !query.Valid() {
httpErr(w, http.StatusBadRequest)
return
}
room := srv.FindRoomByID(query.RoomID)
if room == nil {
httpErr(w, http.StatusNotFound)
return
}
c, err := websocket.Accept(w, r, wsOpts)
if err != nil {
log.Println(err)
return
}
g.Go(func() error {
room.HandleConn(query.PlayerID, query.Nickname, c)
return nil
})
2020-05-23 21:51:11 +00:00
})
})
})
g.Go(func() error {
return srv.Run(ctx)
})
httpSrv := http.Server{Addr: args.Addr, Handler: r}
g.Go(func() error {
<-ctx.Done()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return httpSrv.Shutdown(ctx)
})
g.Go(func() error {
return httpSrv.ListenAndServe()
})
log.Fatal(g.Wait())
}
func staticRouter() http.Handler {
fs := http.Dir("./frontend/build")
fsh := http.FileServer(fs)
r := chi.NewMux()
r.Use(middleware.Compress(5))
r.Handle("/static/*", fsh)
r.Handle("/favicon/*", fsh)
r.Group(func(r chi.Router) {
r.Use(middleware.NoCache)
r.Handle("/*", fsh)
})
return r
}
func checkVersion(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
want := version.Version()
toCheck := []string{
r.Header.Get("X-CODIES-VERSION"),
r.URL.Query().Get("codiesVersion"),
}
for _, got := range toCheck {
if got == want {
next.ServeHTTP(w, r)
return
}
}
reason := fmt.Sprintf("client version too old, please reload to get %s", want)
if r.Header.Get("Upgrade") == "websocket" {
c, err := websocket.Accept(w, r, wsOpts)
if err != nil {
log.Println(err)
return
}
c.Close(4418, reason)
return
}
w.WriteHeader(http.StatusTeapot)
fmt.Fprint(w, reason)
})
}
2020-05-23 21:51:11 +00:00
func httpErr(w http.ResponseWriter, code int) {
http.Error(w, http.StatusText(code), code)
}
func stringPtr(s string) *string {
return &s
}
func init() {
queryparam.DefaultParser.ValueParsers[reflect.TypeOf(uuid.UUID{})] = func(value string, _ string) (reflect.Value, error) {
id, err := uuid.FromString(value)
return reflect.ValueOf(id), err
}
}