package main import ( "context" "encoding/json" "fmt" "log" "math/rand" "net/http" "os" "time" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" "github.com/jessevdk/go-flags" "github.com/posener/ctxutil" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/tomwright/queryparam/v4" "github.com/zikaeroh/codies/internal/pkger" "github.com/zikaeroh/codies/internal/protocol" "github.com/zikaeroh/codies/internal/responder" "github.com/zikaeroh/codies/internal/server" "github.com/zikaeroh/codies/internal/version" "github.com/zikaeroh/ctxlog" "go.uber.org/zap" "golang.org/x/sync/errgroup" "nhooyr.io/websocket" ) var args = struct { Addr string `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"` Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"` Prod bool `long:"prod" env:"CODIES_PROD" description:"Enables production mode"` Debug bool `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"` }{ Addr: ":5000", } var wsOpts *websocket.AcceptOptions func main() { if argv := os.Args[1:]; len(argv) > 0 && argv[0] == "version" { fmt.Println(version.Version()) return } rand.Seed(time.Now().Unix()) if _, err := flags.Parse(&args); err != nil { // Default flag parser prints messages, so just exit. os.Exit(1) } if !args.Prod && !args.Debug { log.Fatal("missing required option --prod or --debug") } else if args.Prod && args.Debug { log.Fatal("must specify either --prod or --debug") } ctx := ctxutil.Interrupt() logger := ctxlog.New(args.Debug) defer zap.RedirectStdLog(logger)() ctx = ctxlog.WithLogger(ctx, logger) ctxlog.Info(ctx, "starting", zap.String("version", version.Version())) wsOpts = &websocket.AcceptOptions{ OriginPatterns: args.Origins, CompressionMode: websocket.CompressionContextTakeover, } if args.Debug { ctxlog.Info(ctx, "starting in debug mode, allowing any WebSocket origin host") wsOpts.InsecureSkipVerify = true } else { if !version.VersionSet() { ctxlog.Fatal(ctx, "running production build without version set") } } g, ctx := errgroup.WithContext(ctx) srv := server.NewServer() r := chi.NewMux() r.Use(func(next http.Handler) http.Handler { return promhttp.InstrumentHandlerCounter(metricRequest, next) }) r.Use(middleware.Heartbeat("/ping")) r.Use(middleware.Recoverer) r.NotFound(staticHandler().ServeHTTP) r.Group(func(r chi.Router) { r.Use(middleware.NoCache) r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) { responder.Respond(w, responder.Body(&protocol.TimeResponse{Time: time.Now()})) }) r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) { rooms, clients := srv.Stats() responder.Respond(w, responder.Body(&protocol.StatsResponse{ Rooms: rooms, Clients: clients, }), responder.Pretty(true), ) }) r.Group(func(r chi.Router) { if !args.Debug { r.Use(checkVersion) } r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) { query := &protocol.ExistsQuery{} if err := queryparam.Parse(r.URL.Query(), query); err != nil { responder.Respond(w, responder.Status(http.StatusBadRequest)) return } room := srv.FindRoomByID(query.RoomID) if room == nil { responder.Respond(w, responder.Status(http.StatusNotFound)) } else { responder.Respond(w, responder.Status(http.StatusOK)) } }) r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() req := &protocol.RoomRequest{} if err := json.NewDecoder(r.Body).Decode(req); err != nil { responder.Respond(w, responder.Status(http.StatusBadRequest)) return } if msg, valid := req.Valid(); !valid { responder.Respond(w, responder.Status(http.StatusBadRequest), responder.Body(&protocol.RoomResponse{ Error: stringPtr(msg), }), ) return } var room *server.Room if req.Create { var err error room, err = srv.CreateRoom(ctx, req.RoomName, req.RoomPass) if err != nil { switch err { case server.ErrRoomExists: responder.Respond(w, responder.Status(http.StatusBadRequest), responder.Body(&protocol.RoomResponse{ Error: stringPtr("Room already exists."), }), ) case server.ErrTooManyRooms: responder.Respond(w, responder.Status(http.StatusServiceUnavailable), responder.Body(&protocol.RoomResponse{ Error: stringPtr("Too many rooms."), }), ) default: responder.Respond(w, responder.Status(http.StatusInternalServerError), responder.Body(&protocol.RoomResponse{ Error: stringPtr("An unknown error occurred."), }), ) } return } } else { room = srv.FindRoom(req.RoomName) if room == nil || room.Password != req.RoomPass { responder.Respond(w, responder.Status(http.StatusNotFound), responder.Body(&protocol.RoomResponse{ Error: stringPtr("Room not found or password does not match."), }), ) return } } responder.Respond(w, responder.Body(&protocol.RoomResponse{ ID: &room.ID, })) }) r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) { query := &protocol.WSQuery{} if err := queryparam.Parse(r.URL.Query(), query); err != nil { responder.Respond(w, responder.Status(http.StatusBadRequest)) return } if _, valid := query.Valid(); !valid { responder.Respond(w, responder.Status(http.StatusBadRequest)) return } room := srv.FindRoomByID(query.RoomID) if room == nil { responder.Respond(w, responder.Status(http.StatusBadRequest)) return } c, err := websocket.Accept(w, r, wsOpts) if err != nil { return } g.Go(func() error { room.HandleConn(ctx, query.Nickname, c) return nil }) }) }) }) g.Go(func() error { return srv.Run(ctx) }) runServer(ctx, g, args.Addr, r) if args.Prod { runServer(ctx, g, ":2112", prometheusHandler()) } exitErr := g.Wait() ctxlog.Fatal(ctx, "exited", zap.Error(exitErr)) } func staticHandler() http.Handler { fs := pkger.Dir("/frontend/build") fsh := http.FileServer(fs) r := chi.NewMux() r.Use(middleware.Compress(5)) r.Handle("/static/*", fsh) r.Handle("/favicon/*", fsh) r.Group(func(r chi.Router) { r.Use(middleware.NoCache) r.Handle("/*", fsh) }) return r } func checkVersion(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { want := version.Version() toCheck := []string{ r.Header.Get("X-CODIES-VERSION"), r.URL.Query().Get("codiesVersion"), } for _, got := range toCheck { if got == want { next.ServeHTTP(w, r) return } } reason := fmt.Sprintf("client version too old, please reload to get %s", want) if r.Header.Get("Upgrade") == "websocket" { c, err := websocket.Accept(w, r, wsOpts) if err != nil { return } c.Close(4418, reason) return } w.WriteHeader(http.StatusTeapot) fmt.Fprint(w, reason) }) } func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) { httpSrv := http.Server{Addr: addr, Handler: handler} g.Go(func() error { <-ctx.Done() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() return httpSrv.Shutdown(ctx) }) g.Go(func() error { return httpSrv.ListenAndServe() }) } func prometheusHandler() http.Handler { mux := http.NewServeMux() mux.Handle("/metrics", promhttp.Handler()) return mux }