fix errors in /account
parent
df92ff8b6f
commit
9c6374eedf
|
@ -25,8 +25,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func ChangePW(uuid []byte, password string) error {
|
func ChangePW(uuid []byte, password string) error {
|
||||||
if len(password) < 6 {
|
if err := validatePassword(password); err != nil {
|
||||||
return fmt.Errorf("invalid password")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
salt := make([]byte, ArgonSaltSize)
|
salt := make([]byte, ArgonSaltSize)
|
||||||
|
|
|
@ -18,9 +18,12 @@
|
||||||
package account
|
package account
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/pagefaultgames/rogueserver/errors"
|
||||||
"golang.org/x/crypto/argon2"
|
"golang.org/x/crypto/argon2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,3 +55,22 @@ func deriveArgon2IDKey(password, salt []byte) []byte {
|
||||||
|
|
||||||
return argon2.IDKey(password, salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize)
|
return argon2.IDKey(password, salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateUsernamePassword(username string, password string) error {
|
||||||
|
return cmp.Or(validateUsername(username), validatePassword(password))
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateUsername(username string) error {
|
||||||
|
if !isValidUsername(username) {
|
||||||
|
return errors.NewHttpError(http.StatusBadRequest, "invalid username")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validatePassword(password string) error {
|
||||||
|
if len(password) < 6 {
|
||||||
|
return errors.NewHttpError(http.StatusBadRequest, "invalid password")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pagefaultgames/rogueserver/errors"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateUsernamePassword(t *testing.T) {
|
||||||
|
t.Run("valid username and password", func(t *testing.T) {
|
||||||
|
err := validateUsernamePassword("validUser", "validPass")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid username", func(t *testing.T) {
|
||||||
|
err := validateUsernamePassword("", "validPass")
|
||||||
|
require.NotNil(t, err)
|
||||||
|
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid username"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid password", func(t *testing.T) {
|
||||||
|
err := validateUsernamePassword("validUser", "123")
|
||||||
|
require.NotNil(t, err)
|
||||||
|
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid password"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid username and password", func(t *testing.T) {
|
||||||
|
err := validateUsernamePassword("", "123")
|
||||||
|
require.NotNil(t, err)
|
||||||
|
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid username"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateUsername(t *testing.T) {
|
||||||
|
t.Run("valid username", func(t *testing.T) {
|
||||||
|
err := validateUsername("validUser")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid username", func(t *testing.T) {
|
||||||
|
err := validateUsername("")
|
||||||
|
require.NotNil(t, err)
|
||||||
|
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid username"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePassword(t *testing.T) {
|
||||||
|
t.Run("valid password", func(t *testing.T) {
|
||||||
|
err := validatePassword("validPass")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid password", func(t *testing.T) {
|
||||||
|
err := validatePassword("123")
|
||||||
|
require.NotNil(t, err)
|
||||||
|
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid password"))
|
||||||
|
})
|
||||||
|
}
|
|
@ -23,8 +23,10 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/pagefaultgames/rogueserver/db"
|
"github.com/pagefaultgames/rogueserver/db"
|
||||||
|
"github.com/pagefaultgames/rogueserver/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LoginResponse GenericAuthResponse
|
type LoginResponse GenericAuthResponse
|
||||||
|
@ -33,25 +35,21 @@ type LoginResponse GenericAuthResponse
|
||||||
func Login(username, password string) (LoginResponse, error) {
|
func Login(username, password string) (LoginResponse, error) {
|
||||||
var response LoginResponse
|
var response LoginResponse
|
||||||
|
|
||||||
if !isValidUsername(username) {
|
if err := validateUsernamePassword(username, password); err != nil {
|
||||||
return response, fmt.Errorf("invalid username")
|
return response, err
|
||||||
}
|
|
||||||
|
|
||||||
if len(password) < 6 {
|
|
||||||
return response, fmt.Errorf("invalid password")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
key, salt, err := db.FetchAccountKeySaltFromUsername(username)
|
key, salt, err := db.FetchAccountKeySaltFromUsername(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return response, fmt.Errorf("account doesn't exist")
|
return response, errors.NewHttpError(http.StatusNotFound, "account doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, err
|
return response, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.Equal(key, deriveArgon2IDKey([]byte(password), salt)) {
|
if !bytes.Equal(key, deriveArgon2IDKey([]byte(password), salt)) {
|
||||||
return response, fmt.Errorf("password doesn't match")
|
return response, errors.NewHttpError(http.StatusUnauthorized, "password doesn't match")
|
||||||
}
|
}
|
||||||
|
|
||||||
token := make([]byte, TokenSize)
|
token := make([]byte, TokenSize)
|
||||||
|
|
|
@ -19,18 +19,18 @@ package account
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
stderrors "errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/pagefaultgames/rogueserver/db"
|
"github.com/pagefaultgames/rogueserver/db"
|
||||||
|
"github.com/pagefaultgames/rogueserver/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// /account/register - register account
|
// /account/register - register account
|
||||||
func Register(username, password string) error {
|
func Register(username, password string) error {
|
||||||
if !isValidUsername(username) {
|
if err := validateUsernamePassword(username, password); err != nil {
|
||||||
return fmt.Errorf("invalid username")
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
if len(password) < 6 {
|
|
||||||
return fmt.Errorf("invalid password")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uuid := make([]byte, UUIDSize)
|
uuid := make([]byte, UUIDSize)
|
||||||
|
@ -47,6 +47,9 @@ func Register(username, password string) error {
|
||||||
|
|
||||||
err = db.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)
|
err = db.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if stderrors.Is(err, db.ErrAccountAlreadyExists) {
|
||||||
|
return errors.NewHttpError(http.StatusConflict, fmt.Sprintf(`username "%s" already taken`, username))
|
||||||
|
}
|
||||||
return fmt.Errorf("failed to add account record: %s", err)
|
return fmt.Errorf("failed to add account record: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,12 +20,15 @@ package api
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
stderrors "errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/pagefaultgames/rogueserver/api/account"
|
"github.com/pagefaultgames/rogueserver/api/account"
|
||||||
"github.com/pagefaultgames/rogueserver/api/daily"
|
"github.com/pagefaultgames/rogueserver/api/daily"
|
||||||
"github.com/pagefaultgames/rogueserver/db"
|
"github.com/pagefaultgames/rogueserver/db"
|
||||||
"log"
|
"github.com/pagefaultgames/rogueserver/errors"
|
||||||
"net/http"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Init(mux *http.ServeMux) error {
|
func Init(mux *http.ServeMux) error {
|
||||||
|
@ -69,16 +72,16 @@ func Init(mux *http.ServeMux) error {
|
||||||
|
|
||||||
func tokenFromRequest(r *http.Request) ([]byte, error) {
|
func tokenFromRequest(r *http.Request) ([]byte, error) {
|
||||||
if r.Header.Get("Authorization") == "" {
|
if r.Header.Get("Authorization") == "" {
|
||||||
return nil, fmt.Errorf("missing token")
|
return nil, errors.NewHttpError(http.StatusBadRequest, "missing token")
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
|
token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to decode token: %s", err)
|
return nil, errors.NewHttpError(http.StatusBadRequest, "failed to decode token")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(token) != account.TokenSize {
|
if len(token) != account.TokenSize {
|
||||||
return nil, fmt.Errorf("invalid token length: got %d, expected %d", len(token), account.TokenSize)
|
return nil, errors.NewHttpError(http.StatusBadRequest, "invalid token length")
|
||||||
}
|
}
|
||||||
|
|
||||||
return token, nil
|
return token, nil
|
||||||
|
@ -97,14 +100,17 @@ func tokenAndUuidFromRequest(r *http.Request) ([]byte, []byte, error) {
|
||||||
|
|
||||||
uuid, err := db.FetchUUIDFromToken(token)
|
uuid, err := db.FetchUUIDFromToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("failed to validate token: %s", err)
|
if stderrors.Is(err, db.ErrTokenNotFound) {
|
||||||
|
return nil, nil, errors.NewHttpError(http.StatusUnauthorized, "bad token")
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("failed to fetch uuid from db: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return token, uuid, nil
|
return token, uuid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpError(w http.ResponseWriter, r *http.Request, err error, code int) {
|
func httpError(w http.ResponseWriter, r *http.Request, err error, code int) {
|
||||||
log.Printf("%s: %s\n", r.URL.Path, err)
|
log.Printf("%s: %s\n", r.URL.Path, err.Error())
|
||||||
http.Error(w, err.Error(), code)
|
http.Error(w, err.Error(), code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,3 +122,13 @@ func jsonResponse(w http.ResponseWriter, r *http.Request, data any) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func statusCodeFromError(err error) int {
|
||||||
|
var httpErr *errors.HttpError
|
||||||
|
|
||||||
|
if stderrors.As(err, &httpErr) {
|
||||||
|
return httpErr.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
stderrors "errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pagefaultgames/rogueserver/errors"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStatusCodeFromError(t *testing.T) {
|
||||||
|
t.Run("nil", func(t *testing.T) {
|
||||||
|
code := statusCodeFromError(nil)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, code)
|
||||||
|
})
|
||||||
|
t.Run("http error", func(t *testing.T) {
|
||||||
|
err := errors.NewHttpError(http.StatusTeapot, "teapot")
|
||||||
|
code := statusCodeFromError(err)
|
||||||
|
assert.Equal(t, http.StatusTeapot, code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("standard error", func(t *testing.T) {
|
||||||
|
err := stderrors.New("standard error")
|
||||||
|
code := statusCodeFromError(err)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, code)
|
||||||
|
})
|
||||||
|
}
|
|
@ -43,7 +43,7 @@ import (
|
||||||
func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
|
func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
uuid, err := uuidFromRequest(r)
|
uuid, err := uuidFromRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusBadRequest)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,11 +71,11 @@ func handleAccountRegister(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
err = account.Register(r.Form.Get("username"), r.Form.Get("password"))
|
err = account.Register(r.Form.Get("username"), r.Form.Get("password"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusInternalServerError)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleAccountLogin(w http.ResponseWriter, r *http.Request) {
|
func handleAccountLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -87,7 +87,7 @@ func handleAccountLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
response, err := account.Login(r.Form.Get("username"), r.Form.Get("password"))
|
response, err := account.Login(r.Form.Get("username"), r.Form.Get("password"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusInternalServerError)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,17 +103,17 @@ func handleAccountChangePW(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
uuid, err := uuidFromRequest(r)
|
uuid, err := uuidFromRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusBadRequest)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = account.ChangePW(uuid, r.Form.Get("password"))
|
err = account.ChangePW(uuid, r.Form.Get("password"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusInternalServerError)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleAccountLogout(w http.ResponseWriter, r *http.Request) {
|
func handleAccountLogout(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -129,7 +129,7 @@ func handleAccountLogout(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// game
|
// game
|
||||||
|
@ -149,7 +149,7 @@ func handleGameClassicSessionCount(w http.ResponseWriter, r *http.Request) {
|
||||||
func handleGetSessionData(w http.ResponseWriter, r *http.Request) {
|
func handleGetSessionData(w http.ResponseWriter, r *http.Request) {
|
||||||
uuid, err := uuidFromRequest(r)
|
uuid, err := uuidFromRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusBadRequest)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,7 +195,7 @@ const legacyClientSessionId = "LEGACY_CLIENT"
|
||||||
func legacyHandleGetSaveData(w http.ResponseWriter, r *http.Request) {
|
func legacyHandleGetSaveData(w http.ResponseWriter, r *http.Request) {
|
||||||
uuid, err := uuidFromRequest(r)
|
uuid, err := uuidFromRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusBadRequest)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -244,7 +244,7 @@ func legacyHandleGetSaveData(w http.ResponseWriter, r *http.Request) {
|
||||||
func clearSessionData(w http.ResponseWriter, r *http.Request) {
|
func clearSessionData(w http.ResponseWriter, r *http.Request) {
|
||||||
uuid, err := uuidFromRequest(r)
|
uuid, err := uuidFromRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusBadRequest)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -332,7 +332,7 @@ func clearSessionData(w http.ResponseWriter, r *http.Request) {
|
||||||
func deleteSystemSave(w http.ResponseWriter, r *http.Request) {
|
func deleteSystemSave(w http.ResponseWriter, r *http.Request) {
|
||||||
uuid, err := uuidFromRequest(r)
|
uuid, err := uuidFromRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusBadRequest)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -412,7 +412,7 @@ func deleteSystemSave(w http.ResponseWriter, r *http.Request) {
|
||||||
func legacyHandleSaveData(w http.ResponseWriter, r *http.Request) {
|
func legacyHandleSaveData(w http.ResponseWriter, r *http.Request) {
|
||||||
uuid, err := uuidFromRequest(r)
|
uuid, err := uuidFromRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusBadRequest)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -581,7 +581,7 @@ type CombinedSaveData struct {
|
||||||
func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
|
func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
|
||||||
uuid, err := uuidFromRequest(r)
|
uuid, err := uuidFromRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusBadRequest)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -653,7 +653,7 @@ type SystemVerifyRequest struct {
|
||||||
func handleSystemVerify(w http.ResponseWriter, r *http.Request) {
|
func handleSystemVerify(w http.ResponseWriter, r *http.Request) {
|
||||||
uuid, err := uuidFromRequest(r)
|
uuid, err := uuidFromRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusBadRequest)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -705,7 +705,7 @@ func handleSystemVerify(w http.ResponseWriter, r *http.Request) {
|
||||||
func handleGetSystemData(w http.ResponseWriter, r *http.Request) {
|
func handleGetSystemData(w http.ResponseWriter, r *http.Request) {
|
||||||
uuid, err := uuidFromRequest(r)
|
uuid, err := uuidFromRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusBadRequest)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -741,7 +741,7 @@ func handleGetSystemData(w http.ResponseWriter, r *http.Request) {
|
||||||
func legacyHandleNewClear(w http.ResponseWriter, r *http.Request) {
|
func legacyHandleNewClear(w http.ResponseWriter, r *http.Request) {
|
||||||
uuid, err := uuidFromRequest(r)
|
uuid, err := uuidFromRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, r, err, http.StatusBadRequest)
|
httpError(w, r, err, statusCodeFromError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,13 +23,17 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
_ "github.com/go-sql-driver/mysql"
|
"github.com/go-sql-driver/mysql"
|
||||||
"github.com/pagefaultgames/rogueserver/defs"
|
"github.com/pagefaultgames/rogueserver/defs"
|
||||||
)
|
)
|
||||||
|
|
||||||
func AddAccountRecord(uuid []byte, username string, key, salt []byte) error {
|
func AddAccountRecord(uuid []byte, username string, key, salt []byte) error {
|
||||||
_, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt)
|
_, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
var mysqlErr *mysql.MySQLError
|
||||||
|
if errors.As(err, &mysqlErr) && mysqlErr.Number == 1062 {
|
||||||
|
return ErrAccountAlreadyExists
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -240,6 +244,9 @@ func FetchUUIDFromToken(token []byte) ([]byte, error) {
|
||||||
var uuid []byte
|
var uuid []byte
|
||||||
err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ?", token).Scan(&uuid)
|
err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ?", token).Scan(&uuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrTokenNotFound
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
2
db/db.go
2
db/db.go
|
@ -24,8 +24,6 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/go-sql-driver/mysql"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var handle *sql.DB
|
var handle *sql.DB
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrAccountAlreadyExists = errors.New("account already exists")
|
||||||
|
ErrTokenNotFound = errors.New("token not found")
|
||||||
|
)
|
|
@ -0,0 +1,14 @@
|
||||||
|
package errors
|
||||||
|
|
||||||
|
type HttpError struct {
|
||||||
|
Code int
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHttpError(code int, message string) *HttpError {
|
||||||
|
return &HttpError{Code: code, Message: message}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h HttpError) Error() string {
|
||||||
|
return h.Message
|
||||||
|
}
|
8
go.mod
8
go.mod
|
@ -6,7 +6,13 @@ require (
|
||||||
github.com/go-sql-driver/mysql v1.7.1
|
github.com/go-sql-driver/mysql v1.7.1
|
||||||
github.com/klauspost/compress v1.17.4
|
github.com/klauspost/compress v1.17.4
|
||||||
github.com/robfig/cron/v3 v3.0.1
|
github.com/robfig/cron/v3 v3.0.1
|
||||||
|
github.com/stretchr/testify v1.9.0
|
||||||
golang.org/x/crypto v0.16.0
|
golang.org/x/crypto v0.16.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require golang.org/x/sys v0.15.0 // indirect
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
golang.org/x/sys v0.15.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
||||||
|
|
10
go.sum
10
go.sum
|
@ -1,10 +1,20 @@
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
|
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
|
||||||
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||||
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
||||||
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||||
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
|
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
|
||||||
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|
Loading…
Reference in New Issue