From 4600b7b024859852f035030237d238a6bd7bbb2d Mon Sep 17 00:00:00 2001 From: maru Date: Thu, 28 Dec 2023 19:53:59 -0500 Subject: [PATCH] More changes --- api/account.go | 126 +++++++++++++++++++++++++++++++++++++++++++++++-- api/generic.go | 6 --- db/account.go | 60 +++++++++++++++++++++-- db/db.go | 1 - go.mod | 4 ++ go.sum | 4 ++ 6 files changed, 186 insertions(+), 15 deletions(-) create mode 100644 go.sum diff --git a/api/account.go b/api/account.go index e8ccba4..c9c18e0 100644 --- a/api/account.go +++ b/api/account.go @@ -1,14 +1,28 @@ package api import ( + "bytes" + "crypto/rand" + "database/sql" "encoding/base64" "encoding/json" "fmt" "net/http" + "regexp" "github.com/Flashfyre/pokerogue-server/db" + "golang.org/x/crypto/argon2" ) +const ( + argonTime = 1 + argonMemory = 256*1024 + argonThreads = 4 + argonKeyLength = 32 +) + +var isValidUsername = regexp.MustCompile(`^\w{6,16}$`).MatchString + // /api/account/info - get account info type AccountInfoResponse struct{ @@ -22,7 +36,7 @@ func HandleAccountInfo(w http.ResponseWriter, r *http.Request) { return } - username, err := db.GetAccountInfoFromToken(token) + username, err := db.GetUsernameFromToken(token) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -40,7 +54,6 @@ func HandleAccountInfo(w http.ResponseWriter, r *http.Request) { // /api/account/register - register account type AccountRegisterRequest GenericAuthRequest -type AccountRegisterResponse GenericAuthResponse func HandleAccountRegister(w http.ResponseWriter, r *http.Request) { var request AccountRegisterRequest @@ -50,7 +63,39 @@ func HandleAccountRegister(w http.ResponseWriter, r *http.Request) { return } - + if isValidUsername(request.Username) { + http.Error(w, "invalid username", http.StatusBadRequest) + return + } + + if len(request.Password) < 6 { + http.Error(w, "invalid password", http.StatusBadRequest) + return + } + + uuid := make([]byte, 16) + + _, err = rand.Read(uuid) + if err != nil { + http.Error(w, fmt.Sprintf("failed to generate uuid: %s", err), http.StatusInternalServerError) + return + } + + salt := make([]byte, 16) + + _, err = rand.Read(salt) + if err != nil { + http.Error(w, fmt.Sprintf("failed to generate salt: %s", err), http.StatusInternalServerError) + return + } + + err = db.AddAccountRecord(uuid, request.Username, argon2.IDKey([]byte(request.Password), salt, argonTime, argonMemory, argonThreads, argonKeyLength), salt) + if err != nil { + http.Error(w, "failed to add account record", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) } // /api/account/login - log into account @@ -59,11 +104,86 @@ type AccountLoginRequest GenericAuthRequest type AccountLoginResponse GenericAuthResponse func HandleAccountLogin(w http.ResponseWriter, r *http.Request) { + var request AccountLoginRequest + err := json.NewDecoder(r.Body).Decode(&request) + if err != nil { + http.Error(w, fmt.Sprintf("failed to decode request body: %s", err), http.StatusBadRequest) + return + } + if isValidUsername(request.Username) { + http.Error(w, "invalid username", http.StatusBadRequest) + return + } + + if len(request.Password) < 6 { + http.Error(w, "invalid password", http.StatusBadRequest) + return + } + + key, salt, err := db.GetAccountKeySaltFromUsername(request.Username) + if err != nil { + if err == sql.ErrNoRows { + http.Error(w, "account doesn't exist", http.StatusBadRequest) + return + } + + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if !bytes.Equal(key, argon2.IDKey([]byte(request.Password), salt, argonTime, argonMemory, argonThreads, argonKeyLength)) { + http.Error(w, "password doesn't match", http.StatusBadRequest) + return + } + + token := make([]byte, 16) + + _, err = rand.Read(token) + if err != nil { + http.Error(w, fmt.Sprintf("failed to generate token: %s", err), http.StatusInternalServerError) + return + } + + err = db.AddAccountSession(request.Username, token) + if err != nil { + http.Error(w, "failed to add account session", http.StatusInternalServerError) + return + } + + response, err := json.Marshal(AccountLoginResponse{Token: base64.StdEncoding.EncodeToString(token)}) + if err != nil { + http.Error(w, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError) + return + } + + w.Write(response) } // /api/account/logout - log out of account func HandleAccountLogout(w http.ResponseWriter, r *http.Request) { + token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization")) + if err != nil { + http.Error(w, fmt.Sprintf("failed to decode token: %s", err), http.StatusBadRequest) + return + } + if len(token) != 32 { + http.Error(w, "invalid token", http.StatusBadRequest) + return + } + + err = db.RemoveSessionFromToken(token) + if err != nil { + if err == sql.ErrNoRows { + http.Error(w, "token not found", http.StatusBadRequest) + return + } + + http.Error(w, "failed to remove account session", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) } diff --git a/api/generic.go b/api/generic.go index d09a4b0..bcbbe01 100644 --- a/api/generic.go +++ b/api/generic.go @@ -1,11 +1,5 @@ package api -// error - -type ErrorResponse struct{ - Message string `json:"message"` -} - // auth type GenericAuthRequest struct { diff --git a/db/account.go b/db/account.go index e21098f..3cc8cdc 100644 --- a/db/account.go +++ b/db/account.go @@ -2,19 +2,69 @@ package db import ( "database/sql" - "fmt" ) -func GetAccountInfoFromToken(token []byte) (string, error) { +func AddAccountRecord(uuid []byte, username string, key, salt []byte) error { + _, err := handle.Exec("INSERT INTO accounts (uuid, username, key, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt) + if err != nil { + return err + } + + return nil +} + +func AddAccountSession(username string, token []byte) error { + _, err := handle.Exec("INSERT INTO sessions (token, uuid, expire) SELECT a.uuid, ?, DATE_ADD(UTC_TIMESTAMP(), INTERVAL 1 WEEK) FROM accounts a WHERE a.username = ?", token, username) + if err != nil { + return err + } + + return nil +} + +func GetUsernameFromToken(token []byte) (string, error) { var username string - err := handle.QueryRow("SELECT username FROM accounts WHERE uuid IN (SELECT uuid FROM sessions WHERE token = ? AND expire > UTC_TIMESTAMP())").Scan(&username) + err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON s.uuid = a.uuid WHERE s.token = ? AND s.expire > UTC_TIMESTAMP()").Scan(&username) if err != nil { if err == sql.ErrNoRows { - return "", fmt.Errorf("invalid token") + return "", err } - return "", fmt.Errorf("query failed: %s", err) + return "", err } return username, nil } + +func GetAccountKeySaltFromUsername(username string) ([]byte, []byte, error) { + var key, salt []byte + err := handle.QueryRow("SELECT key, salt FROM accounts WHERE username = ?", username).Scan(&key, &salt) + if err != nil { + return nil, nil, err + } + + return key, salt, nil +} + +func GetUUIDFromToken(token []byte) ([]byte, error) { + var uuid []byte + err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ? AND expire > UTC_TIMESTAMP()", token).Scan(&uuid) + if err != nil { + if err == sql.ErrNoRows { + return nil, err + } + + return nil, err + } + + return uuid, nil +} + +func RemoveSessionFromToken(token []byte) error { + _, err := handle.Exec("DELETE FROM sessions WHERE token = ?", token) + if err != nil { + return err + } + + return nil +} diff --git a/db/db.go b/db/db.go index 4d88574..135f1d7 100644 --- a/db/db.go +++ b/db/db.go @@ -17,4 +17,3 @@ func Init(username, password, protocol, address, database string) error { return nil } - diff --git a/go.mod b/go.mod index d505a52..8aded22 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,7 @@ module github.com/Flashfyre/pokerogue-server go 1.21.4 + +require golang.org/x/crypto v0.16.0 + +require golang.org/x/sys v0.15.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..1bc566e --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= +golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=