Major refactor of API

fix-voucher-compensation
maru 2024-04-08 20:44:36 -04:00
parent 0edfeab3ca
commit 5778675171
No known key found for this signature in database
GPG Key ID: 37689350E9CD0F0D
6 changed files with 340 additions and 392 deletions

View File

@ -5,9 +5,7 @@ import (
"crypto/rand" "crypto/rand"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"encoding/json"
"fmt" "fmt"
"net/http"
"os" "os"
"regexp" "regexp"
"strconv" "strconv"
@ -34,19 +32,7 @@ type AccountInfoResponse struct {
} }
// /account/info - get account info // /account/info - get account info
func (s *Server) handleAccountInfo(w http.ResponseWriter, r *http.Request) { func handleAccountInfo(username string, uuid []byte) (AccountInfoResponse, error) {
username, err := getUsernameFromRequest(r)
if err != nil {
httpError(w, r, err.Error(), http.StatusBadRequest)
return
}
uuid, err := getUUIDFromRequest(r) // lazy
if err != nil {
httpError(w, r, err.Error(), http.StatusBadRequest)
return
}
var latestSave time.Time var latestSave time.Time
latestSaveID := -1 latestSaveID := -1
for id := range sessionSlotCount { for id := range sessionSlotCount {
@ -66,142 +52,95 @@ func (s *Server) handleAccountInfo(w http.ResponseWriter, r *http.Request) {
} }
} }
response, err := json.Marshal(AccountInfoResponse{Username: username, LastSessionSlot: latestSaveID}) return AccountInfoResponse{Username: username, LastSessionSlot: latestSaveID}, nil
if err != nil {
httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
w.Write(response)
} }
type AccountRegisterRequest GenericAuthRequest type AccountRegisterRequest GenericAuthRequest
// /account/register - register account // /account/register - register account
func (s *Server) handleAccountRegister(w http.ResponseWriter, r *http.Request) { func handleAccountRegister(username, password string) error {
var request AccountRegisterRequest if !isValidUsername(username) {
err := json.NewDecoder(r.Body).Decode(&request) return fmt.Errorf("invalid username")
if err != nil {
httpError(w, r, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest)
return
} }
if !isValidUsername(request.Username) { if len(password) < 6 {
httpError(w, r, "invalid username", http.StatusBadRequest) return fmt.Errorf("invalid password")
return
}
if len(request.Password) < 6 {
httpError(w, r, "invalid password", http.StatusBadRequest)
return
} }
uuid := make([]byte, UUIDSize) uuid := make([]byte, UUIDSize)
_, err = rand.Read(uuid) _, err := rand.Read(uuid)
if err != nil { if err != nil {
httpError(w, r, fmt.Sprintf("failed to generate uuid: %s", err), http.StatusInternalServerError) return fmt.Errorf("failed to generate uuid: %s", err)
return
} }
salt := make([]byte, ArgonSaltSize) salt := make([]byte, ArgonSaltSize)
_, err = rand.Read(salt) _, err = rand.Read(salt)
if err != nil { if err != nil {
httpError(w, r, fmt.Sprintf("failed to generate salt: %s", err), http.StatusInternalServerError) return fmt.Errorf(fmt.Sprintf("failed to generate salt: %s", err))
return
} }
err = db.AddAccountRecord(uuid, request.Username, argon2.IDKey([]byte(request.Password), salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize), salt) err = db.AddAccountRecord(uuid, username, argon2.IDKey([]byte(password), salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize), salt)
if err != nil { if err != nil {
httpError(w, r, fmt.Sprintf("failed to add account record: %s", err), http.StatusInternalServerError) return fmt.Errorf("failed to add account record: %s", err)
return
} }
w.WriteHeader(http.StatusOK) return nil
} }
type AccountLoginRequest GenericAuthRequest type AccountLoginRequest GenericAuthRequest
type AccountLoginResponse GenericAuthResponse type AccountLoginResponse GenericAuthResponse
// /account/login - log into account // /account/login - log into account
func (s *Server) handleAccountLogin(w http.ResponseWriter, r *http.Request) { func handleAccountLogin(username, password string) (AccountLoginResponse, error) {
var request AccountLoginRequest if !isValidUsername(username) {
err := json.NewDecoder(r.Body).Decode(&request) return AccountLoginResponse{}, fmt.Errorf("invalid username")
if err != nil {
httpError(w, r, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest)
return
} }
if !isValidUsername(request.Username) { if len(password) < 6 {
httpError(w, r, "invalid username", http.StatusBadRequest) return AccountLoginResponse{}, fmt.Errorf("invalid password")
return
} }
if len(request.Password) < 6 { key, salt, err := db.FetchAccountKeySaltFromUsername(username)
httpError(w, r, "invalid password", http.StatusBadRequest)
return
}
key, salt, err := db.FetchAccountKeySaltFromUsername(request.Username)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
httpError(w, r, "account doesn't exist", http.StatusBadRequest) return AccountLoginResponse{}, fmt.Errorf("account doesn't exist")
return
} }
httpError(w, r, err.Error(), http.StatusInternalServerError) return AccountLoginResponse{}, err
return
} }
if !bytes.Equal(key, argon2.IDKey([]byte(request.Password), salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize)) { if !bytes.Equal(key, argon2.IDKey([]byte(password), salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize)) {
httpError(w, r, "password doesn't match", http.StatusBadRequest) return AccountLoginResponse{}, fmt.Errorf("password doesn't match")
return
} }
token := make([]byte, 32) token := make([]byte, 32)
_, err = rand.Read(token) _, err = rand.Read(token)
if err != nil { if err != nil {
httpError(w, r, fmt.Sprintf("failed to generate token: %s", err), http.StatusInternalServerError) return AccountLoginResponse{}, fmt.Errorf("failed to generate token: %s", err)
return
} }
err = db.AddAccountSession(request.Username, token) err = db.AddAccountSession(username, token)
if err != nil { if err != nil {
httpError(w, r, "failed to add account session", http.StatusInternalServerError) return AccountLoginResponse{}, fmt.Errorf("failed to add account session")
return
} }
response, err := json.Marshal(AccountLoginResponse{Token: base64.StdEncoding.EncodeToString(token)}) return AccountLoginResponse{Token: base64.StdEncoding.EncodeToString(token)}, nil
if err != nil {
httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
w.Write(response)
} }
// /account/logout - log out of account // /account/logout - log out of account
func (s *Server) handleAccountLogout(w http.ResponseWriter, r *http.Request) { func handleAccountLogout(token []byte) error {
token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
if err != nil {
httpError(w, r, fmt.Sprintf("failed to decode token: %s", err), http.StatusBadRequest)
return
}
if len(token) != 32 { if len(token) != 32 {
httpError(w, r, "invalid token", http.StatusBadRequest) return fmt.Errorf("invalid token")
return
} }
err = db.RemoveSessionFromToken(token) err := db.RemoveSessionFromToken(token)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
httpError(w, r, "token not found", http.StatusBadRequest) return fmt.Errorf("token not found")
return
} }
httpError(w, r, "failed to remove account session", http.StatusInternalServerError) return fmt.Errorf("failed to remove account session")
return
} }
w.WriteHeader(http.StatusOK) return nil
} }

View File

@ -5,15 +5,13 @@ import (
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"encoding/json"
"fmt" "fmt"
"log" "log"
"net/http"
"os" "os"
"strconv"
"time" "time"
"github.com/Flashfyre/pokerogue-server/db" "github.com/Flashfyre/pokerogue-server/db"
"github.com/Flashfyre/pokerogue-server/defs"
"github.com/go-co-op/gocron" "github.com/go-co-op/gocron"
) )
@ -81,79 +79,27 @@ func deriveDailyRunSeed(seedTime time.Time) []byte {
return hashedSeed[:] return hashedSeed[:]
} }
// /daily/seed - fetch daily run seed
func (s *Server) handleSeed(w http.ResponseWriter) {
w.Write([]byte(dailyRunSeed))
}
// /daily/rankings - fetch daily rankings // /daily/rankings - fetch daily rankings
func (s *Server) handleRankings(w http.ResponseWriter, r *http.Request) { func handleRankings(uuid []byte, category, page int) ([]defs.DailyRanking, error) {
uuid, err := getUUIDFromRequest(r) err := db.UpdateAccountLastActivity(uuid)
if err != nil {
httpError(w, r, err.Error(), http.StatusBadRequest)
return
}
err = db.UpdateAccountLastActivity(uuid)
if err != nil { if err != nil {
log.Print("failed to update account last activity") log.Print("failed to update account last activity")
} }
var category int
if r.URL.Query().Has("category") {
category, err = strconv.Atoi(r.URL.Query().Get("category"))
if err != nil {
httpError(w, r, fmt.Sprintf("failed to convert category: %s", err), http.StatusBadRequest)
return
}
}
page := 1
if r.URL.Query().Has("page") {
page, err = strconv.Atoi(r.URL.Query().Get("page"))
if err != nil {
httpError(w, r, fmt.Sprintf("failed to convert page: %s", err), http.StatusBadRequest)
return
}
}
rankings, err := db.FetchRankings(category, page) rankings, err := db.FetchRankings(category, page)
if err != nil { if err != nil {
log.Print("failed to retrieve rankings") log.Print("failed to retrieve rankings")
} }
response, err := json.Marshal(rankings) return rankings, nil
if err != nil {
httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
w.Write(response)
} }
// /daily/rankingpagecount - fetch daily ranking page count // /daily/rankingpagecount - fetch daily ranking page count
func (s *Server) handleRankingPageCount(w http.ResponseWriter, r *http.Request) { func handleRankingPageCount(category int) (int, error) {
var err error
var category int
if r.URL.Query().Has("category") {
category, err = strconv.Atoi(r.URL.Query().Get("category"))
if err != nil {
httpError(w, r, fmt.Sprintf("failed to convert category: %s", err), http.StatusBadRequest)
return
}
}
pageCount, err := db.FetchRankingPageCount(category) pageCount, err := db.FetchRankingPageCount(category)
if err != nil { if err != nil {
log.Print("failed to retrieve ranking page count") log.Print("failed to retrieve ranking page count")
} }
response, err := json.Marshal(pageCount) return pageCount, nil
if err != nil {
httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
w.Write(response)
} }

View File

@ -5,7 +5,7 @@ import (
"net/http" "net/http"
) )
func httpError(w http.ResponseWriter, r *http.Request, error string, code int) { func httpError(w http.ResponseWriter, r *http.Request, err error, code int) {
log.Printf("%s: %s\n", r.URL.Path, error) log.Printf("%s: %s\n", r.URL.Path, err)
http.Error(w, error, code) http.Error(w, err.Error(), code)
} }

View File

@ -1,14 +1,10 @@
package api package api
import ( import (
"encoding/json"
"fmt"
"log" "log"
"net/http"
"time" "time"
"github.com/Flashfyre/pokerogue-server/db" "github.com/Flashfyre/pokerogue-server/db"
"github.com/Flashfyre/pokerogue-server/defs"
"github.com/go-co-op/gocron" "github.com/go-co-op/gocron"
) )
@ -30,49 +26,14 @@ func updateStats() {
if err != nil { if err != nil {
log.Print(err) log.Print(err)
} }
battleCount, err = db.FetchBattleCount() battleCount, err = db.FetchBattleCount()
if err != nil { if err != nil {
log.Print(err) log.Print(err)
} }
classicSessionCount, err = db.FetchClassicSessionCount() classicSessionCount, err = db.FetchClassicSessionCount()
if err != nil { if err != nil {
log.Print(err) log.Print(err)
} }
} }
// /game/playercount - get player count
func (s *Server) handlePlayerCountGet(w http.ResponseWriter, r *http.Request) {
response, err := json.Marshal(playerCount)
if err != nil {
httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
w.Write(response)
}
// /game/titlestats - get title stats
func (s *Server) handleTitleStatsGet(w http.ResponseWriter, r *http.Request) {
titleStats := &defs.TitleStats{
PlayerCount: playerCount,
BattleCount: battleCount,
}
response, err := json.Marshal(titleStats)
if err != nil {
httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
w.Write(response)
}
// /game/classicsessioncount - get classic session count
func (s *Server) handleClassicSessionCountGet(w http.ResponseWriter, r *http.Request) {
response, err := json.Marshal(classicSessionCount)
if err != nil {
httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
w.Write(response)
}

View File

@ -1,14 +1,26 @@
package api package api
import ( import (
"encoding/base64"
"encoding/gob" "encoding/gob"
"encoding/json"
"fmt"
"net/http" "net/http"
"strconv"
"github.com/Flashfyre/pokerogue-server/defs"
) )
type Server struct { type Server struct {
Debug bool Debug bool
} }
/*
The caller of endpoint handler functions are responsible for extracting the necessary data from the request.
Handler functions are responsible for checking the validity of this data and returning a result or error.
Handlers should not return serialized JSON, instead return the struct itself.
*/
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
gob.Register([]interface{}{}) gob.Register([]interface{}{})
gob.Register(map[string]interface{}{}) gob.Register(map[string]interface{}{})
@ -25,37 +37,238 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
switch r.URL.Path { switch r.URL.Path {
// /account
case "/account/info": case "/account/info":
s.handleAccountInfo(w, r) username, err := getUsernameFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
uuid, err := getUUIDFromRequest(r) // lazy
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
info, err := handleAccountInfo(username, uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
response, err := json.Marshal(info)
if err != nil {
httpError(w, r, fmt.Errorf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
w.Write(response)
case "/account/register": case "/account/register":
s.handleAccountRegister(w, r) var request AccountRegisterRequest
err := json.NewDecoder(r.Body).Decode(&request)
if err != nil {
httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest)
return
}
err = handleAccountRegister(request.Username, request.Password)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
case "/account/login": case "/account/login":
s.handleAccountLogin(w, r) var request AccountLoginRequest
err := json.NewDecoder(r.Body).Decode(&request)
if err != nil {
httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest)
return
}
token, err := handleAccountLogin(request.Username, request.Password)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
response, err := json.Marshal(token)
if err != nil {
httpError(w, r, fmt.Errorf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
w.Write(response)
case "/account/logout": case "/account/logout":
s.handleAccountLogout(w, r) token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to decode token: %s", err), http.StatusBadRequest)
return
}
err = handleAccountLogout(token)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
// /game
case "/game/playercount": case "/game/playercount":
s.handlePlayerCountGet(w, r) w.Write([]byte(strconv.Itoa(playerCount)))
case "/game/titlestats": case "/game/titlestats":
s.handleTitleStatsGet(w, r) response, err := json.Marshal(&defs.TitleStats{
PlayerCount: playerCount,
BattleCount: battleCount,
})
if err != nil {
httpError(w, r, fmt.Errorf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
w.Write(response)
case "/game/classicsessioncount": case "/game/classicsessioncount":
s.handleClassicSessionCountGet(w, r) w.Write([]byte(strconv.Itoa(classicSessionCount)))
case "/savedata/get": // /savedata
s.handleSavedataGet(w, r) case "/savedata/get", "/savedata/update", "/savedata/delete", "/savedata/clear":
case "/savedata/update": uuid, err := getUUIDFromRequest(r)
s.handleSavedataUpdate(w, r) if err != nil {
case "/savedata/delete": httpError(w, r, err, http.StatusBadRequest)
s.handleSavedataDelete(w, r) return
case "/savedata/clear": }
s.handleSavedataClear(w, r)
datatype := -1
if r.URL.Query().Has("datatype") {
datatype, err = strconv.Atoi(r.URL.Query().Get("datatype"))
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
}
var slot int
if r.URL.Query().Has("slot") {
slot, err = strconv.Atoi(r.URL.Query().Get("slot"))
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
}
var save any
// /savedata/delete specifies datatype, but doesn't expect data in body
if r.URL.Path != "/savedata/delete" {
if datatype == 0 {
var system defs.SystemSaveData
err = json.NewDecoder(r.Body).Decode(&system)
if err != nil {
httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest)
return
}
save = system
// /savedata/clear doesn't specify datatype, it is assumed to be 1 (session)
} else if datatype == 1 || r.URL.Path == "/savedata/clear" {
var session defs.SessionSaveData
err = json.NewDecoder(r.Body).Decode(&session)
if err != nil {
httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest)
return
}
save = session
}
}
switch r.URL.Path {
case "/savedata/get":
save, err = handleSavedataGet(uuid, datatype, slot)
case "/savedata/update":
err = handleSavedataUpdate(uuid, slot, save)
case "/savedata/delete":
err = handleSavedataDelete(uuid, datatype, slot)
case "/savedata/clear":
// doesn't return a save, but it works
save, err = handleSavedataClear(uuid, slot, save.(defs.SessionSaveData))
}
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
if save == nil {
w.WriteHeader(http.StatusOK)
}
response, err := json.Marshal(save)
if err != nil {
httpError(w, r, fmt.Errorf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
w.Write(response)
// /daily
case "/daily/seed": case "/daily/seed":
s.handleSeed(w) w.Write([]byte(dailyRunSeed))
case "/daily/rankings": case "/daily/rankings":
s.handleRankings(w, r) uuid, err := getUUIDFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
return
}
var category int
if r.URL.Query().Has("category") {
category, err = strconv.Atoi(r.URL.Query().Get("category"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest)
return
}
}
page := 1
if r.URL.Query().Has("page") {
page, err = strconv.Atoi(r.URL.Query().Get("page"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to convert page: %s", err), http.StatusBadRequest)
return
}
}
rankings, err := handleRankings(uuid, category, page)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
response, err := json.Marshal(rankings)
if err != nil {
httpError(w, r, fmt.Errorf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
w.Write(response)
case "/daily/rankingpagecount": case "/daily/rankingpagecount":
s.handleRankingPageCount(w, r) var category int
if r.URL.Query().Has("category") {
var err error
category, err = strconv.Atoi(r.URL.Query().Get("category"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest)
return
}
}
count, err := handleRankingPageCount(category)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
}
w.Write([]byte(strconv.Itoa(count)))
} }
} }

View File

@ -4,10 +4,8 @@ import (
"bytes" "bytes"
"encoding/gob" "encoding/gob"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"log" "log"
"net/http"
"os" "os"
"strconv" "strconv"
@ -19,228 +17,146 @@ import (
const sessionSlotCount = 3 const sessionSlotCount = 3
// /savedata/get - get save data // /savedata/get - get save data
func (s *Server) handleSavedataGet(w http.ResponseWriter, r *http.Request) { func handleSavedataGet(uuid []byte, datatype, slot int) (any, error) {
uuid, err := getUUIDFromRequest(r) switch datatype {
if err != nil { case 0: // System
httpError(w, r, err.Error(), http.StatusBadRequest)
return
}
switch r.URL.Query().Get("datatype") {
case "0": // System
system, err := readSystemSaveData(uuid) system, err := readSystemSaveData(uuid)
if err != nil { if err != nil {
httpError(w, r, err.Error(), http.StatusInternalServerError) return nil, err
return
} }
saveJson, err := json.Marshal(system) return system, nil
case 1: // Session
if slot < 0 || slot >= sessionSlotCount {
return nil, fmt.Errorf("slot id %d out of range", slot)
}
session, err := readSessionSaveData(uuid, slot)
if err != nil { if err != nil {
httpError(w, r, fmt.Sprintf("failed to marshal save to json: %s", err), http.StatusInternalServerError) return nil, err
return
} }
w.Write(saveJson) return session, nil
case "1": // Session
slotID, err := strconv.Atoi(r.URL.Query().Get("slot"))
if err != nil {
httpError(w, r, fmt.Sprintf("failed to convert slot id: %s", err), http.StatusBadRequest)
return
}
if slotID < 0 || slotID >= sessionSlotCount {
httpError(w, r, fmt.Sprintf("slot id %d out of range", slotID), http.StatusBadRequest)
return
}
session, err := readSessionSaveData(uuid, slotID)
if err != nil {
httpError(w, r, err.Error(), http.StatusInternalServerError)
return
}
saveJson, err := json.Marshal(session)
if err != nil {
httpError(w, r, fmt.Sprintf("failed to marshal save to json: %s", err), http.StatusInternalServerError)
return
}
w.Write(saveJson)
default: default:
httpError(w, r, "invalid data type", http.StatusBadRequest) return nil, fmt.Errorf("invalid data type")
return
} }
} }
// /savedata/update - update save data // /savedata/update - update save data
func (s *Server) handleSavedataUpdate(w http.ResponseWriter, r *http.Request) { func handleSavedataUpdate(uuid []byte, slot int, save any) error {
uuid, err := getUUIDFromRequest(r) err := db.UpdateAccountLastActivity(uuid)
if err != nil {
httpError(w, r, err.Error(), http.StatusBadRequest)
return
}
err = db.UpdateAccountLastActivity(uuid)
if err != nil { if err != nil {
log.Print("failed to update account last activity") log.Print("failed to update account last activity")
} }
hexUUID := hex.EncodeToString(uuid) hexUUID := hex.EncodeToString(uuid)
switch r.URL.Query().Get("datatype") { switch save := save.(type) {
case "0": // System case defs.SystemSaveData: // System
var system defs.SystemSaveData if save.TrainerID == 0 && save.SecretID == 0 {
err = json.NewDecoder(r.Body).Decode(&system) return fmt.Errorf("invalid system data")
if err != nil {
httpError(w, r, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest)
return
} }
if system.TrainerID == 0 && system.SecretID == 0 { err = db.UpdateAccountStats(uuid, save.GameStats)
httpError(w, r, "invalid system data", http.StatusInternalServerError)
return
}
err = db.UpdateAccountStats(uuid, system.GameStats)
if err != nil { if err != nil {
httpError(w, r, fmt.Sprintf("failed to update account stats: %s", err), http.StatusBadRequest) return fmt.Errorf("failed to update account stats: %s", err)
return
} }
var gobBuffer bytes.Buffer var gobBuffer bytes.Buffer
err = gob.NewEncoder(&gobBuffer).Encode(system) err = gob.NewEncoder(&gobBuffer).Encode(save)
if err != nil { if err != nil {
httpError(w, r, fmt.Sprintf("failed to serialize save: %s", err), http.StatusInternalServerError) return fmt.Errorf("failed to serialize save: %s", err)
return
} }
zstdWriter, err := zstd.NewWriter(nil) zstdWriter, err := zstd.NewWriter(nil)
if err != nil { if err != nil {
httpError(w, r, fmt.Sprintf("failed to create zstd writer, %s", err), http.StatusInternalServerError) return fmt.Errorf("failed to create zstd writer, %s", err)
return
} }
compressed := zstdWriter.EncodeAll(gobBuffer.Bytes(), nil) compressed := zstdWriter.EncodeAll(gobBuffer.Bytes(), nil)
err = os.MkdirAll("userdata/"+hexUUID, 0755) err = os.MkdirAll("userdata/"+hexUUID, 0755)
if err != nil && !os.IsExist(err) { if err != nil && !os.IsExist(err) {
httpError(w, r, fmt.Sprintf("failed to create userdata folder: %s", err), http.StatusInternalServerError) return fmt.Errorf("failed to create userdata folder: %s", err)
return
} }
err = os.WriteFile("userdata/"+hexUUID+"/system.pzs", compressed, 0644) err = os.WriteFile("userdata/"+hexUUID+"/system.pzs", compressed, 0644)
if err != nil { if err != nil {
httpError(w, r, fmt.Sprintf("failed to write save file: %s", err), http.StatusInternalServerError) return fmt.Errorf("failed to write save file: %s", err)
return
} }
case "1": // Session case defs.SessionSaveData: // Session
slotID, err := strconv.Atoi(r.URL.Query().Get("slot")) if slot < 0 || slot >= sessionSlotCount {
if err != nil { return fmt.Errorf("slot id %d out of range", slot)
httpError(w, r, fmt.Sprintf("failed to convert slot id: %s", err), http.StatusBadRequest)
return
}
if slotID < 0 || slotID >= sessionSlotCount {
httpError(w, r, fmt.Sprintf("slot id %d out of range", slotID), http.StatusBadRequest)
return
} }
fileName := "session" fileName := "session"
if slotID != 0 { if slot != 0 {
fileName += strconv.Itoa(slotID) fileName += strconv.Itoa(slot)
}
var session defs.SessionSaveData
err = json.NewDecoder(r.Body).Decode(&session)
if err != nil {
httpError(w, r, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest)
return
} }
var gobBuffer bytes.Buffer var gobBuffer bytes.Buffer
err = gob.NewEncoder(&gobBuffer).Encode(session) err = gob.NewEncoder(&gobBuffer).Encode(save)
if err != nil { if err != nil {
httpError(w, r, fmt.Sprintf("failed to serialize save: %s", err), http.StatusInternalServerError) return fmt.Errorf("failed to serialize save: %s", err)
return
} }
zstdWriter, err := zstd.NewWriter(nil) zstdWriter, err := zstd.NewWriter(nil)
if err != nil { if err != nil {
httpError(w, r, fmt.Sprintf("failed to create zstd writer, %s", err), http.StatusInternalServerError) return fmt.Errorf("failed to create zstd writer, %s", err)
return
} }
compressed := zstdWriter.EncodeAll(gobBuffer.Bytes(), nil) compressed := zstdWriter.EncodeAll(gobBuffer.Bytes(), nil)
err = os.MkdirAll("userdata/"+hexUUID, 0755) err = os.MkdirAll("userdata/"+hexUUID, 0755)
if err != nil && !os.IsExist(err) { if err != nil && !os.IsExist(err) {
httpError(w, r, fmt.Sprintf("failed to create userdata folder: %s", err), http.StatusInternalServerError) return fmt.Errorf(fmt.Sprintf("failed to create userdata folder: %s", err))
return
} }
err = os.WriteFile(fmt.Sprintf("userdata/%s/%s.pzs", hexUUID, fileName), compressed, 0644) err = os.WriteFile(fmt.Sprintf("userdata/%s/%s.pzs", hexUUID, fileName), compressed, 0644)
if err != nil { if err != nil {
httpError(w, r, fmt.Sprintf("failed to write save file: %s", err), http.StatusInternalServerError) return fmt.Errorf("failed to write save file: %s", err)
return
} }
default: default:
httpError(w, r, "invalid data type", http.StatusBadRequest) return fmt.Errorf("invalid data type")
return
} }
w.WriteHeader(http.StatusOK) return nil
} }
// /savedata/delete - delete save data // /savedata/delete - delete save data
func (s *Server) handleSavedataDelete(w http.ResponseWriter, r *http.Request) { func handleSavedataDelete(uuid []byte, datatype, slot int) error {
uuid, err := getUUIDFromRequest(r) err := db.UpdateAccountLastActivity(uuid)
if err != nil {
httpError(w, r, err.Error(), http.StatusBadRequest)
return
}
err = db.UpdateAccountLastActivity(uuid)
if err != nil { if err != nil {
log.Print("failed to update account last activity") log.Print("failed to update account last activity")
} }
hexUUID := hex.EncodeToString(uuid) hexUUID := hex.EncodeToString(uuid)
switch r.URL.Query().Get("datatype") { switch datatype {
case "0": // System case 0: // System
err := os.Remove("userdata/" + hexUUID + "/system.pzs") err := os.Remove("userdata/" + hexUUID + "/system.pzs")
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
httpError(w, r, fmt.Sprintf("failed to delete save file: %s", err), http.StatusInternalServerError) return fmt.Errorf("failed to delete save file: %s", err)
return
} }
case "1": // Session case 1: // Session
slotID, err := strconv.Atoi(r.URL.Query().Get("slot")) if slot < 0 || slot >= sessionSlotCount {
if err != nil { return fmt.Errorf("slot id %d out of range", slot)
httpError(w, r, fmt.Sprintf("failed to convert slot id: %s", err), http.StatusBadRequest)
return
}
if slotID < 0 || slotID >= sessionSlotCount {
httpError(w, r, fmt.Sprintf("slot id %d out of range", slotID), http.StatusBadRequest)
return
} }
fileName := "session" fileName := "session"
if slotID != 0 { if slot != 0 {
fileName += strconv.Itoa(slotID) fileName += strconv.Itoa(slot)
} }
err = os.Remove(fmt.Sprintf("userdata/%s/%s.pzs", hexUUID, fileName)) err = os.Remove(fmt.Sprintf("userdata/%s/%s.pzs", hexUUID, fileName))
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
httpError(w, r, fmt.Sprintf("failed to delete save file: %s", err), http.StatusInternalServerError) return fmt.Errorf("failed to delete save file: %s", err)
return
} }
default: default:
httpError(w, r, "invalid data type", http.StatusBadRequest) return fmt.Errorf("invalid data type")
return
} }
w.WriteHeader(http.StatusOK) return nil
} }
type SavedataClearResponse struct { type SavedataClearResponse struct {
@ -248,73 +164,46 @@ type SavedataClearResponse struct {
} }
// /savedata/clear - mark session save data as cleared and delete // /savedata/clear - mark session save data as cleared and delete
func (s *Server) handleSavedataClear(w http.ResponseWriter, r *http.Request) { func handleSavedataClear(uuid []byte, slot int, save defs.SessionSaveData) (SavedataClearResponse, error) {
uuid, err := getUUIDFromRequest(r) err := db.UpdateAccountLastActivity(uuid)
if err != nil {
httpError(w, r, err.Error(), http.StatusBadRequest)
return
}
err = db.UpdateAccountLastActivity(uuid)
if err != nil { if err != nil {
log.Print("failed to update account last activity") log.Print("failed to update account last activity")
} }
slotID, err := strconv.Atoi(r.URL.Query().Get("slot")) if slot < 0 || slot >= sessionSlotCount {
if err != nil { return SavedataClearResponse{}, fmt.Errorf("slot id %d out of range", slot)
httpError(w, r, fmt.Sprintf("failed to convert slot id: %s", err), http.StatusBadRequest)
return
} }
if slotID < 0 || slotID >= sessionSlotCount { sessionCompleted := validateSessionCompleted(save)
httpError(w, r, fmt.Sprintf("slot id %d out of range", slotID), http.StatusBadRequest)
return
}
var session defs.SessionSaveData
err = json.NewDecoder(r.Body).Decode(&session)
if err != nil {
httpError(w, r, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest)
return
}
sessionCompleted := validateSessionCompleted(session)
newCompletion := false newCompletion := false
if session.GameMode == 3 && session.Seed == dailyRunSeed { if save.GameMode == 3 && save.Seed == dailyRunSeed {
waveCompleted := session.WaveIndex waveCompleted := save.WaveIndex
if !sessionCompleted { if !sessionCompleted {
waveCompleted-- waveCompleted--
} }
err = db.AddOrUpdateAccountDailyRun(uuid, session.Score, waveCompleted) err = db.AddOrUpdateAccountDailyRun(uuid, save.Score, waveCompleted)
if err != nil { if err != nil {
log.Printf("failed to add or update daily run record: %s", err) log.Printf("failed to add or update daily run record: %s", err)
} }
} }
if sessionCompleted { if sessionCompleted {
newCompletion, err = db.TryAddSeedCompletion(uuid, session.Seed, int(session.GameMode)) newCompletion, err = db.TryAddSeedCompletion(uuid, save.Seed, int(save.GameMode))
if err != nil { if err != nil {
log.Printf("failed to mark seed as completed: %s", err) log.Printf("failed to mark seed as completed: %s", err)
} }
} }
response, err := json.Marshal(SavedataClearResponse{Success: newCompletion})
if err != nil {
httpError(w, r, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError)
return
}
fileName := "session" fileName := "session"
if slotID != 0 { if slot != 0 {
fileName += strconv.Itoa(slotID) fileName += strconv.Itoa(slot)
} }
err = os.Remove(fmt.Sprintf("userdata/%s/%s.pzs", hex.EncodeToString(uuid), fileName)) err = os.Remove(fmt.Sprintf("userdata/%s/%s.pzs", hex.EncodeToString(uuid), fileName))
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
httpError(w, r, fmt.Sprintf("failed to delete save file: %s", err), http.StatusInternalServerError) return SavedataClearResponse{}, fmt.Errorf("failed to delete save file: %s", err)
return
} }
w.Write(response) return SavedataClearResponse{Success: newCompletion}, nil
} }