Use ServeMux router instead of switch statement

oauth2
maru 2024-04-24 20:05:57 -04:00
parent da572ebdd9
commit fa5dcb0052
No known key found for this signature in database
GPG Key ID: 37689350E9CD0F0D
3 changed files with 302 additions and 311 deletions

View File

@ -10,9 +10,31 @@ import (
"github.com/pagefaultgames/pokerogue-server/db" "github.com/pagefaultgames/pokerogue-server/db"
) )
func Init() { func Init(mux *http.ServeMux) {
scheduleStatRefresh() scheduleStatRefresh()
daily.Init() daily.Init()
// account
mux.HandleFunc("/account/info", handleAccountInfo)
mux.HandleFunc("/account/register", handleAccountRegister)
mux.HandleFunc("/account/login", handleAccountLogin)
mux.HandleFunc("/account/logout", handleAccountLogout)
// game
mux.HandleFunc("/game/playercount", handleGamePlayerCount)
mux.HandleFunc("/game/titlestats", handleGameTitleStats)
mux.HandleFunc("/game/classicsessioncount", handleGameClassicSessionCount)
// savedata
mux.HandleFunc("/savedata/get", handleSaveData)
mux.HandleFunc("/savedata/update", handleSaveData)
mux.HandleFunc("/savedata/delete", handleSaveData)
mux.HandleFunc("/savedata/clear", handleSaveData)
// daily
mux.HandleFunc("/daily/seed", handleDailySeed)
mux.HandleFunc("/daily/rankings", handleDailyRankings)
mux.HandleFunc("/daily/rankingpagecount", handleDailyRankingPageCount)
} }
func getTokenFromRequest(r *http.Request) ([]byte, error) { func getTokenFromRequest(r *http.Request) ([]byte, error) {

View File

@ -7,7 +7,6 @@ import (
"log" "log"
"net/http" "net/http"
"strconv" "strconv"
"sync"
"github.com/pagefaultgames/pokerogue-server/api/account" "github.com/pagefaultgames/pokerogue-server/api/account"
"github.com/pagefaultgames/pokerogue-server/api/daily" "github.com/pagefaultgames/pokerogue-server/api/daily"
@ -16,36 +15,20 @@ import (
"github.com/pagefaultgames/pokerogue-server/defs" "github.com/pagefaultgames/pokerogue-server/defs"
) )
type Server struct {
Debug bool
Exit *sync.RWMutex
}
/* /*
The caller of endpoint handler functions are responsible for extracting the necessary data from the request. The caller of endpoint handler functions are responsible for extracting the necessary data from the request.
Handler functions are responsible for checking the validity of this data and returning a result or error. Handler functions are responsible for checking the validity of this data and returning a result or error.
Handlers should not return serialized JSON, instead return the struct itself. Handlers should not return serialized JSON, instead return the struct itself.
*/ */
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func httpError(w http.ResponseWriter, r *http.Request, err error, code int) {
// kind of misusing the RWMutex but it doesn't matter log.Printf("%s: %s\n", r.URL.Path, err)
s.Exit.RLock() http.Error(w, err.Error(), code)
defer s.Exit.RUnlock() }
if s.Debug { // account
w.Header().Set("Access-Control-Allow-Headers", "*")
w.Header().Set("Access-Control-Allow-Methods", "*")
w.Header().Set("Access-Control-Allow-Origin", "*")
if r.Method == "OPTIONS" { func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
return
}
}
switch r.URL.Path {
// /account
case "/account/info":
username, err := getUsernameFromRequest(r) username, err := getUsernameFromRequest(r)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusBadRequest) httpError(w, r, err, http.StatusBadRequest)
@ -71,7 +54,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
case "/account/register": }
func handleAccountRegister(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest) httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest)
@ -85,7 +70,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
case "/account/login": }
func handleAccountLogin(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest) httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest)
@ -105,7 +92,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
case "/account/logout": }
func handleAccountLogout(w http.ResponseWriter, r *http.Request) {
token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization")) token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("failed to decode token: %s", err), http.StatusBadRequest) httpError(w, r, fmt.Errorf("failed to decode token: %s", err), http.StatusBadRequest)
@ -119,11 +108,15 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}
// /game // game
case "/game/playercount":
func handleGamePlayerCount(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(strconv.Itoa(playerCount))) w.Write([]byte(strconv.Itoa(playerCount)))
case "/game/titlestats": }
func handleGameTitleStats(w http.ResponseWriter, r *http.Request) {
err := json.NewEncoder(w).Encode(defs.TitleStats{ err := json.NewEncoder(w).Encode(defs.TitleStats{
PlayerCount: playerCount, PlayerCount: playerCount,
BattleCount: battleCount, BattleCount: battleCount,
@ -134,11 +127,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
case "/game/classicsessioncount": }
w.Write([]byte(strconv.Itoa(classicSessionCount)))
// /savedata func handleGameClassicSessionCount(w http.ResponseWriter, r *http.Request) {
case "/savedata/get", "/savedata/update", "/savedata/delete", "/savedata/clear": w.Write([]byte(strconv.Itoa(classicSessionCount)))
}
func handleSaveData(w http.ResponseWriter, r *http.Request) {
uuid, err := getUUIDFromRequest(r) uuid, err := getUUIDFromRequest(r)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusBadRequest) httpError(w, r, err, http.StatusBadRequest)
@ -256,11 +251,15 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
}
// /daily // daily
case "/daily/seed":
func handleDailySeed(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(daily.Seed())) w.Write([]byte(daily.Seed()))
case "/daily/rankings": }
func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
var err error var err error
var category int var category int
@ -294,7 +293,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
case "/daily/rankingpagecount": }
func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) {
var category int var category int
if r.URL.Query().Has("category") { if r.URL.Query().Has("category") {
var err error var err error
@ -311,13 +312,4 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
w.Write([]byte(strconv.Itoa(count))) w.Write([]byte(strconv.Itoa(count)))
default:
httpError(w, r, fmt.Errorf("unknown endpoint"), http.StatusNotFound)
return
}
}
func httpError(w http.ResponseWriter, r *http.Request, err error, code int) {
log.Printf("%s: %s\n", r.URL.Path, err)
http.Error(w, err.Error(), code)
} }

View File

@ -7,9 +7,6 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"os/signal"
"sync"
"syscall"
"github.com/pagefaultgames/pokerogue-server/api" "github.com/pagefaultgames/pokerogue-server/api"
"github.com/pagefaultgames/pokerogue-server/db" "github.com/pagefaultgames/pokerogue-server/db"
@ -17,8 +14,6 @@ import (
func main() { func main() {
// flag stuff // flag stuff
debug := flag.Bool("debug", false, "debug mode")
proto := flag.String("proto", "tcp", "protocol for api to use (tcp, unix)") proto := flag.String("proto", "tcp", "protocol for api to use (tcp, unix)")
addr := flag.String("addr", "0.0.0.0", "network address for api to listen on") addr := flag.String("addr", "0.0.0.0", "network address for api to listen on")
@ -46,15 +41,13 @@ func main() {
log.Fatalf("failed to create net listener: %s", err) log.Fatalf("failed to create net listener: %s", err)
} }
// create exit handler mux := http.NewServeMux()
var exit sync.RWMutex
createExitHandler(&exit)
// init api // init api
api.Init() api.Init(mux)
// start web server // start web server
err = http.Serve(listener, &api.Server{Debug: *debug, Exit: &exit}) err = http.Serve(listener, mux)
if err != nil { if err != nil {
log.Fatalf("failed to create http server or server errored: %s", err) log.Fatalf("failed to create http server or server errored: %s", err)
} }
@ -76,19 +69,3 @@ func createListener(proto, addr string) (net.Listener, error) {
return listener, nil return listener, nil
} }
func createExitHandler(mtx *sync.RWMutex) {
s := make(chan os.Signal, 1)
signal.Notify(s, syscall.SIGINT, syscall.SIGTERM)
go func() {
// wait for exit signal of some kind
<-s
// block new requests and wait for existing ones to finish
mtx.Lock()
// bail
os.Exit(0)
}()
}