Protect against cross-session overwrites
parent
75cf6f3ab1
commit
6acbb6448a
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/pagefaultgames/pokerogue-server/api/account"
|
||||
"github.com/pagefaultgames/pokerogue-server/api/daily"
|
||||
"github.com/pagefaultgames/pokerogue-server/api/savedata"
|
||||
"github.com/pagefaultgames/pokerogue-server/db"
|
||||
"github.com/pagefaultgames/pokerogue-server/defs"
|
||||
)
|
||||
|
||||
|
@ -187,14 +188,64 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
var token []byte
|
||||
token, err = base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to decode token: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.URL.Path {
|
||||
case "/savedata/get":
|
||||
err = db.UpdateActiveSession(uuid, token)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
save, err = savedata.Get(uuid, datatype, slot)
|
||||
case "/savedata/update":
|
||||
var token []byte
|
||||
token, err = base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to decode token: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var active bool
|
||||
active, err = db.IsActiveSession(token)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !active {
|
||||
httpError(w, r, fmt.Errorf("session out of date"), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err = savedata.Update(uuid, slot, save)
|
||||
case "/savedata/delete":
|
||||
var active bool
|
||||
active, err = db.IsActiveSession(token)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !active {
|
||||
httpError(w, r, fmt.Errorf("session out of date"), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err = savedata.Delete(uuid, datatype, slot)
|
||||
case "/savedata/clear":
|
||||
var active bool
|
||||
active, err = db.IsActiveSession(token)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if active {
|
||||
s, ok := save.(defs.SessionSaveData)
|
||||
if !ok {
|
||||
httpError(w, r, fmt.Errorf("save data is not type SessionSaveData"), http.StatusBadRequest)
|
||||
|
@ -203,6 +254,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// doesn't return a save, but it works
|
||||
save, err = savedata.Clear(uuid, slot, daily.Seed(), s)
|
||||
} else {
|
||||
var response savedata.ClearResponse
|
||||
response.Error = "session out of date"
|
||||
save = response
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
|
||||
type ClearResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// /savedata/clear - mark session save data as cleared and delete
|
||||
|
|
|
@ -175,6 +175,25 @@ func FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
|
|||
return key, salt, nil
|
||||
}
|
||||
|
||||
func IsActiveSession(token []byte) (bool, error) {
|
||||
var active int
|
||||
err := handle.QueryRow("SELECT `active` FROM sessions WHERE token = ?", token).Scan(&active)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return active == 1, nil
|
||||
}
|
||||
|
||||
func UpdateActiveSession(uuid []byte, token []byte) error {
|
||||
_, err := handle.Exec("UPDATE sessions SET `active` = CASE WHEN token = ? THEN 1 ELSE 0 END WHERE uuid = ?", token, uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchUUIDFromToken(token []byte) ([]byte, error) {
|
||||
var uuid []byte
|
||||
err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ? AND expire > UTC_TIMESTAMP()", token).Scan(&uuid)
|
||||
|
|
Loading…
Reference in New Issue