Use ServeMux router instead of switch statement

oauth2
maru 2024-04-24 20:05:57 -04:00
parent da572ebdd9
commit fa5dcb0052
No known key found for this signature in database
GPG Key ID: 37689350E9CD0F0D
3 changed files with 302 additions and 311 deletions

View File

@ -10,9 +10,31 @@ import (
"github.com/pagefaultgames/pokerogue-server/db" "github.com/pagefaultgames/pokerogue-server/db"
) )
func Init() { func Init(mux *http.ServeMux) {
scheduleStatRefresh() scheduleStatRefresh()
daily.Init() daily.Init()
// account
mux.HandleFunc("/account/info", handleAccountInfo)
mux.HandleFunc("/account/register", handleAccountRegister)
mux.HandleFunc("/account/login", handleAccountLogin)
mux.HandleFunc("/account/logout", handleAccountLogout)
// game
mux.HandleFunc("/game/playercount", handleGamePlayerCount)
mux.HandleFunc("/game/titlestats", handleGameTitleStats)
mux.HandleFunc("/game/classicsessioncount", handleGameClassicSessionCount)
// savedata
mux.HandleFunc("/savedata/get", handleSaveData)
mux.HandleFunc("/savedata/update", handleSaveData)
mux.HandleFunc("/savedata/delete", handleSaveData)
mux.HandleFunc("/savedata/clear", handleSaveData)
// daily
mux.HandleFunc("/daily/seed", handleDailySeed)
mux.HandleFunc("/daily/rankings", handleDailyRankings)
mux.HandleFunc("/daily/rankingpagecount", handleDailyRankingPageCount)
} }
func getTokenFromRequest(r *http.Request) ([]byte, error) { func getTokenFromRequest(r *http.Request) ([]byte, error) {

View File

@ -7,7 +7,6 @@ import (
"log" "log"
"net/http" "net/http"
"strconv" "strconv"
"sync"
"github.com/pagefaultgames/pokerogue-server/api/account" "github.com/pagefaultgames/pokerogue-server/api/account"
"github.com/pagefaultgames/pokerogue-server/api/daily" "github.com/pagefaultgames/pokerogue-server/api/daily"
@ -16,308 +15,301 @@ import (
"github.com/pagefaultgames/pokerogue-server/defs" "github.com/pagefaultgames/pokerogue-server/defs"
) )
type Server struct {
Debug bool
Exit *sync.RWMutex
}
/* /*
The caller of endpoint handler functions are responsible for extracting the necessary data from the request. The caller of endpoint handler functions are responsible for extracting the necessary data from the request.
Handler functions are responsible for checking the validity of this data and returning a result or error. Handler functions are responsible for checking the validity of this data and returning a result or error.
Handlers should not return serialized JSON, instead return the struct itself. Handlers should not return serialized JSON, instead return the struct itself.
*/ */
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func httpError(w http.ResponseWriter, r *http.Request, err error, code int) {
// kind of misusing the RWMutex but it doesn't matter log.Printf("%s: %s\n", r.URL.Path, err)
s.Exit.RLock() http.Error(w, err.Error(), code)
defer s.Exit.RUnlock() }
if s.Debug { // account
w.Header().Set("Access-Control-Allow-Headers", "*")
w.Header().Set("Access-Control-Allow-Methods", "*")
w.Header().Set("Access-Control-Allow-Origin", "*")
if r.Method == "OPTIONS" { func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) username, err := getUsernameFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
uuid, err := getUUIDFromRequest(r) // lazy
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
response, err := account.Info(username, uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
err = json.NewEncoder(w).Encode(response)
if err != nil {
httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
}
func handleAccountRegister(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest)
return
}
err = account.Register(r.Form.Get("username"), r.Form.Get("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
func handleAccountLogin(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest)
return
}
response, err := account.Login(r.Form.Get("username"), r.Form.Get("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
err = json.NewEncoder(w).Encode(response)
if err != nil {
httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
}
func handleAccountLogout(w http.ResponseWriter, r *http.Request) {
token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to decode token: %s", err), http.StatusBadRequest)
return
}
err = account.Logout(token)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
// game
func handleGamePlayerCount(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(strconv.Itoa(playerCount)))
}
func handleGameTitleStats(w http.ResponseWriter, r *http.Request) {
err := json.NewEncoder(w).Encode(defs.TitleStats{
PlayerCount: playerCount,
BattleCount: battleCount,
})
if err != nil {
httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
}
func handleGameClassicSessionCount(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(strconv.Itoa(classicSessionCount)))
}
func handleSaveData(w http.ResponseWriter, r *http.Request) {
uuid, err := getUUIDFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
datatype := -1
if r.URL.Query().Has("datatype") {
datatype, err = strconv.Atoi(r.URL.Query().Get("datatype"))
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
}
var slot int
if r.URL.Query().Has("slot") {
slot, err = strconv.Atoi(r.URL.Query().Get("slot"))
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
}
var save any
// /savedata/get and /savedata/delete specify datatype, but don't expect data in body
if r.URL.Path != "/savedata/get" && r.URL.Path != "/savedata/delete" {
if datatype == 0 {
var system defs.SystemSaveData
err = json.NewDecoder(r.Body).Decode(&system)
if err != nil {
httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest)
return
}
save = system
// /savedata/clear doesn't specify datatype, it is assumed to be 1 (session)
} else if datatype == 1 || r.URL.Path == "/savedata/clear" {
var session defs.SessionSaveData
err = json.NewDecoder(r.Body).Decode(&session)
if err != nil {
httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest)
return
}
save = session
}
}
var token []byte
token, err = getTokenFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
var active bool
if r.URL.Path == "/savedata/get" {
err = db.UpdateActiveSession(uuid, token)
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
}
} else {
active, err = db.IsActiveSession(token)
if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
return
}
// TODO: make this not suck
if !active && r.URL.Path != "/savedata/clear" {
httpError(w, r, fmt.Errorf("session out of date"), http.StatusBadRequest)
return return
} }
} }
switch r.URL.Path { switch r.URL.Path {
// /account case "/savedata/get":
case "/account/info": save, err = savedata.Get(uuid, datatype, slot)
username, err := getUsernameFromRequest(r) case "/savedata/update":
if err != nil { err = savedata.Update(uuid, slot, save)
httpError(w, r, err, http.StatusBadRequest) case "/savedata/delete":
return err = savedata.Delete(uuid, datatype, slot)
} case "/savedata/clear":
if !active {
uuid, err := getUUIDFromRequest(r) // lazy
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
response, err := account.Info(username, uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
err = json.NewEncoder(w).Encode(response)
if err != nil {
httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
case "/account/register":
err := r.ParseForm()
if err != nil {
httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest)
return
}
err = account.Register(r.Form.Get("username"), r.Form.Get("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
case "/account/login":
err := r.ParseForm()
if err != nil {
httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest)
return
}
response, err := account.Login(r.Form.Get("username"), r.Form.Get("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
err = json.NewEncoder(w).Encode(response)
if err != nil {
httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
case "/account/logout":
token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to decode token: %s", err), http.StatusBadRequest)
return
}
err = account.Logout(token)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
// /game
case "/game/playercount":
w.Write([]byte(strconv.Itoa(playerCount)))
case "/game/titlestats":
err := json.NewEncoder(w).Encode(defs.TitleStats{
PlayerCount: playerCount,
BattleCount: battleCount,
})
if err != nil {
httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
case "/game/classicsessioncount":
w.Write([]byte(strconv.Itoa(classicSessionCount)))
// /savedata
case "/savedata/get", "/savedata/update", "/savedata/delete", "/savedata/clear":
uuid, err := getUUIDFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
datatype := -1
if r.URL.Query().Has("datatype") {
datatype, err = strconv.Atoi(r.URL.Query().Get("datatype"))
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
}
var slot int
if r.URL.Query().Has("slot") {
slot, err = strconv.Atoi(r.URL.Query().Get("slot"))
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
}
var save any
// /savedata/get and /savedata/delete specify datatype, but don't expect data in body
if r.URL.Path != "/savedata/get" && r.URL.Path != "/savedata/delete" {
if datatype == 0 {
var system defs.SystemSaveData
err = json.NewDecoder(r.Body).Decode(&system)
if err != nil {
httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest)
return
}
save = system
// /savedata/clear doesn't specify datatype, it is assumed to be 1 (session)
} else if datatype == 1 || r.URL.Path == "/savedata/clear" {
var session defs.SessionSaveData
err = json.NewDecoder(r.Body).Decode(&session)
if err != nil {
httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest)
return
}
save = session
}
}
var token []byte
token, err = getTokenFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
var active bool
if r.URL.Path == "/savedata/get" {
err = db.UpdateActiveSession(uuid, token)
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
}
} else {
active, err = db.IsActiveSession(token)
if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
return
}
// TODO: make this not suck // TODO: make this not suck
if !active && r.URL.Path != "/savedata/clear" { save = savedata.ClearResponse{Error: "session out of date"}
httpError(w, r, fmt.Errorf("session out of date"), http.StatusBadRequest) break
return
}
} }
switch r.URL.Path { s, ok := save.(defs.SessionSaveData)
case "/savedata/get": if !ok {
save, err = savedata.Get(uuid, datatype, slot) err = fmt.Errorf("save data is not type SessionSaveData")
case "/savedata/update": break
err = savedata.Update(uuid, slot, save)
case "/savedata/delete":
err = savedata.Delete(uuid, datatype, slot)
case "/savedata/clear":
if !active {
// TODO: make this not suck
save = savedata.ClearResponse{Error: "session out of date"}
break
}
s, ok := save.(defs.SessionSaveData)
if !ok {
err = fmt.Errorf("save data is not type SessionSaveData")
break
}
// doesn't return a save, but it works
save, err = savedata.Clear(uuid, slot, daily.Seed(), s)
}
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
} }
if save == nil || r.URL.Path == "/savedata/update" { // doesn't return a save, but it works
w.WriteHeader(http.StatusOK) save, err = savedata.Clear(uuid, slot, daily.Seed(), s)
return }
} if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
err = json.NewEncoder(w).Encode(save)
if err != nil {
httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
// /daily
case "/daily/seed":
w.Write([]byte(daily.Seed()))
case "/daily/rankings":
var err error
var category int
if r.URL.Query().Has("category") {
category, err = strconv.Atoi(r.URL.Query().Get("category"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest)
return
}
}
page := 1
if r.URL.Query().Has("page") {
page, err = strconv.Atoi(r.URL.Query().Get("page"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to convert page: %s", err), http.StatusBadRequest)
return
}
}
rankings, err := daily.Rankings(category, page)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
err = json.NewEncoder(w).Encode(rankings)
if err != nil {
httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
case "/daily/rankingpagecount":
var category int
if r.URL.Query().Has("category") {
var err error
category, err = strconv.Atoi(r.URL.Query().Get("category"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest)
return
}
}
count, err := daily.RankingPageCount(category)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
}
w.Write([]byte(strconv.Itoa(count)))
default:
httpError(w, r, fmt.Errorf("unknown endpoint"), http.StatusNotFound)
return return
} }
if save == nil || r.URL.Path == "/savedata/update" {
w.WriteHeader(http.StatusOK)
return
}
err = json.NewEncoder(w).Encode(save)
if err != nil {
httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
} }
func httpError(w http.ResponseWriter, r *http.Request, err error, code int) { // daily
log.Printf("%s: %s\n", r.URL.Path, err)
http.Error(w, err.Error(), code) func handleDailySeed(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(daily.Seed()))
}
func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
var err error
var category int
if r.URL.Query().Has("category") {
category, err = strconv.Atoi(r.URL.Query().Get("category"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest)
return
}
}
page := 1
if r.URL.Query().Has("page") {
page, err = strconv.Atoi(r.URL.Query().Get("page"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to convert page: %s", err), http.StatusBadRequest)
return
}
}
rankings, err := daily.Rankings(category, page)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
err = json.NewEncoder(w).Encode(rankings)
if err != nil {
httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
}
func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) {
var category int
if r.URL.Query().Has("category") {
var err error
category, err = strconv.Atoi(r.URL.Query().Get("category"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest)
return
}
}
count, err := daily.RankingPageCount(category)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
}
w.Write([]byte(strconv.Itoa(count)))
} }

View File

@ -7,9 +7,6 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"os/signal"
"sync"
"syscall"
"github.com/pagefaultgames/pokerogue-server/api" "github.com/pagefaultgames/pokerogue-server/api"
"github.com/pagefaultgames/pokerogue-server/db" "github.com/pagefaultgames/pokerogue-server/db"
@ -17,8 +14,6 @@ import (
func main() { func main() {
// flag stuff // flag stuff
debug := flag.Bool("debug", false, "debug mode")
proto := flag.String("proto", "tcp", "protocol for api to use (tcp, unix)") proto := flag.String("proto", "tcp", "protocol for api to use (tcp, unix)")
addr := flag.String("addr", "0.0.0.0", "network address for api to listen on") addr := flag.String("addr", "0.0.0.0", "network address for api to listen on")
@ -46,15 +41,13 @@ func main() {
log.Fatalf("failed to create net listener: %s", err) log.Fatalf("failed to create net listener: %s", err)
} }
// create exit handler mux := http.NewServeMux()
var exit sync.RWMutex
createExitHandler(&exit)
// init api // init api
api.Init() api.Init(mux)
// start web server // start web server
err = http.Serve(listener, &api.Server{Debug: *debug, Exit: &exit}) err = http.Serve(listener, mux)
if err != nil { if err != nil {
log.Fatalf("failed to create http server or server errored: %s", err) log.Fatalf("failed to create http server or server errored: %s", err)
} }
@ -75,20 +68,4 @@ func createListener(proto, addr string) (net.Listener, error) {
} }
return listener, nil return listener, nil
} }
func createExitHandler(mtx *sync.RWMutex) {
s := make(chan os.Signal, 1)
signal.Notify(s, syscall.SIGINT, syscall.SIGTERM)
go func() {
// wait for exit signal of some kind
<-s
// block new requests and wait for existing ones to finish
mtx.Lock()
// bail
os.Exit(0)
}()
}