fix errors in /account

pull/14/head
Krystian Chmura 2024-05-15 20:57:58 +02:00
parent df92ff8b6f
commit 9c6374eedf
No known key found for this signature in database
14 changed files with 215 additions and 44 deletions

View File

@ -25,8 +25,8 @@ import (
)
func ChangePW(uuid []byte, password string) error {
if len(password) < 6 {
return fmt.Errorf("invalid password")
if err := validatePassword(password); err != nil {
return err
}
salt := make([]byte, ArgonSaltSize)

View File

@ -18,9 +18,12 @@
package account
import (
"cmp"
"net/http"
"regexp"
"runtime"
"github.com/pagefaultgames/rogueserver/errors"
"golang.org/x/crypto/argon2"
)
@ -52,3 +55,22 @@ func deriveArgon2IDKey(password, salt []byte) []byte {
return argon2.IDKey(password, salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeySize)
}
func validateUsernamePassword(username string, password string) error {
return cmp.Or(validateUsername(username), validatePassword(password))
}
func validateUsername(username string) error {
if !isValidUsername(username) {
return errors.NewHttpError(http.StatusBadRequest, "invalid username")
}
return nil
}
func validatePassword(password string) error {
if len(password) < 6 {
return errors.NewHttpError(http.StatusBadRequest, "invalid password")
}
return nil
}

View File

@ -0,0 +1,61 @@
package account
import (
"net/http"
"testing"
"github.com/pagefaultgames/rogueserver/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateUsernamePassword(t *testing.T) {
t.Run("valid username and password", func(t *testing.T) {
err := validateUsernamePassword("validUser", "validPass")
assert.NoError(t, err)
})
t.Run("invalid username", func(t *testing.T) {
err := validateUsernamePassword("", "validPass")
require.NotNil(t, err)
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid username"))
})
t.Run("invalid password", func(t *testing.T) {
err := validateUsernamePassword("validUser", "123")
require.NotNil(t, err)
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid password"))
})
t.Run("invalid username and password", func(t *testing.T) {
err := validateUsernamePassword("", "123")
require.NotNil(t, err)
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid username"))
})
}
func TestValidateUsername(t *testing.T) {
t.Run("valid username", func(t *testing.T) {
err := validateUsername("validUser")
assert.NoError(t, err)
})
t.Run("invalid username", func(t *testing.T) {
err := validateUsername("")
require.NotNil(t, err)
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid username"))
})
}
func TestValidatePassword(t *testing.T) {
t.Run("valid password", func(t *testing.T) {
err := validatePassword("validPass")
assert.NoError(t, err)
})
t.Run("invalid password", func(t *testing.T) {
err := validatePassword("123")
require.NotNil(t, err)
assert.Equal(t, err, errors.NewHttpError(http.StatusBadRequest, "invalid password"))
})
}

View File

@ -23,8 +23,10 @@ import (
"database/sql"
"encoding/base64"
"fmt"
"net/http"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/errors"
)
type LoginResponse GenericAuthResponse
@ -33,25 +35,21 @@ type LoginResponse GenericAuthResponse
func Login(username, password string) (LoginResponse, error) {
var response LoginResponse
if !isValidUsername(username) {
return response, fmt.Errorf("invalid username")
}
if len(password) < 6 {
return response, fmt.Errorf("invalid password")
if err := validateUsernamePassword(username, password); err != nil {
return response, err
}
key, salt, err := db.FetchAccountKeySaltFromUsername(username)
if err != nil {
if err == sql.ErrNoRows {
return response, fmt.Errorf("account doesn't exist")
return response, errors.NewHttpError(http.StatusNotFound, "account doesn't exist")
}
return response, err
}
if !bytes.Equal(key, deriveArgon2IDKey([]byte(password), salt)) {
return response, fmt.Errorf("password doesn't match")
return response, errors.NewHttpError(http.StatusUnauthorized, "password doesn't match")
}
token := make([]byte, TokenSize)

View File

@ -19,18 +19,18 @@ package account
import (
"crypto/rand"
stderrors "errors"
"fmt"
"net/http"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/errors"
)
// /account/register - register account
func Register(username, password string) error {
if !isValidUsername(username) {
return fmt.Errorf("invalid username")
}
if len(password) < 6 {
return fmt.Errorf("invalid password")
if err := validateUsernamePassword(username, password); err != nil {
return err
}
uuid := make([]byte, UUIDSize)
@ -47,6 +47,9 @@ func Register(username, password string) error {
err = db.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)
if err != nil {
if stderrors.Is(err, db.ErrAccountAlreadyExists) {
return errors.NewHttpError(http.StatusConflict, fmt.Sprintf(`username "%s" already taken`, username))
}
return fmt.Errorf("failed to add account record: %s", err)
}

View File

@ -20,12 +20,15 @@ package api
import (
"encoding/base64"
"encoding/json"
stderrors "errors"
"fmt"
"log"
"net/http"
"github.com/pagefaultgames/rogueserver/api/account"
"github.com/pagefaultgames/rogueserver/api/daily"
"github.com/pagefaultgames/rogueserver/db"
"log"
"net/http"
"github.com/pagefaultgames/rogueserver/errors"
)
func Init(mux *http.ServeMux) error {
@ -69,16 +72,16 @@ func Init(mux *http.ServeMux) error {
func tokenFromRequest(r *http.Request) ([]byte, error) {
if r.Header.Get("Authorization") == "" {
return nil, fmt.Errorf("missing token")
return nil, errors.NewHttpError(http.StatusBadRequest, "missing token")
}
token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
if err != nil {
return nil, fmt.Errorf("failed to decode token: %s", err)
return nil, errors.NewHttpError(http.StatusBadRequest, "failed to decode token")
}
if len(token) != account.TokenSize {
return nil, fmt.Errorf("invalid token length: got %d, expected %d", len(token), account.TokenSize)
return nil, errors.NewHttpError(http.StatusBadRequest, "invalid token length")
}
return token, nil
@ -97,14 +100,17 @@ func tokenAndUuidFromRequest(r *http.Request) ([]byte, []byte, error) {
uuid, err := db.FetchUUIDFromToken(token)
if err != nil {
return nil, nil, fmt.Errorf("failed to validate token: %s", err)
if stderrors.Is(err, db.ErrTokenNotFound) {
return nil, nil, errors.NewHttpError(http.StatusUnauthorized, "bad token")
}
return nil, nil, fmt.Errorf("failed to fetch uuid from db: %w", err)
}
return token, uuid, nil
}
func httpError(w http.ResponseWriter, r *http.Request, err error, code int) {
log.Printf("%s: %s\n", r.URL.Path, err)
log.Printf("%s: %s\n", r.URL.Path, err.Error())
http.Error(w, err.Error(), code)
}
@ -116,3 +122,13 @@ func jsonResponse(w http.ResponseWriter, r *http.Request, data any) {
return
}
}
func statusCodeFromError(err error) int {
var httpErr *errors.HttpError
if stderrors.As(err, &httpErr) {
return httpErr.Code
}
return http.StatusInternalServerError
}

28
api/common_test.go Normal file
View File

@ -0,0 +1,28 @@
package api
import (
stderrors "errors"
"net/http"
"testing"
"github.com/pagefaultgames/rogueserver/errors"
"github.com/stretchr/testify/assert"
)
func TestStatusCodeFromError(t *testing.T) {
t.Run("nil", func(t *testing.T) {
code := statusCodeFromError(nil)
assert.Equal(t, http.StatusInternalServerError, code)
})
t.Run("http error", func(t *testing.T) {
err := errors.NewHttpError(http.StatusTeapot, "teapot")
code := statusCodeFromError(err)
assert.Equal(t, http.StatusTeapot, code)
})
t.Run("standard error", func(t *testing.T) {
err := stderrors.New("standard error")
code := statusCodeFromError(err)
assert.Equal(t, http.StatusInternalServerError, code)
})
}

View File

@ -43,7 +43,7 @@ import (
func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
httpError(w, r, err, statusCodeFromError(err))
return
}
@ -71,11 +71,11 @@ func handleAccountRegister(w http.ResponseWriter, r *http.Request) {
err = account.Register(r.Form.Get("username"), r.Form.Get("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
httpError(w, r, err, statusCodeFromError(err))
return
}
w.WriteHeader(http.StatusOK)
w.WriteHeader(http.StatusCreated)
}
func handleAccountLogin(w http.ResponseWriter, r *http.Request) {
@ -87,7 +87,7 @@ func handleAccountLogin(w http.ResponseWriter, r *http.Request) {
response, err := account.Login(r.Form.Get("username"), r.Form.Get("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
httpError(w, r, err, statusCodeFromError(err))
return
}
@ -103,17 +103,17 @@ func handleAccountChangePW(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
httpError(w, r, err, statusCodeFromError(err))
return
}
err = account.ChangePW(uuid, r.Form.Get("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
httpError(w, r, err, statusCodeFromError(err))
return
}
w.WriteHeader(http.StatusOK)
w.WriteHeader(http.StatusNoContent)
}
func handleAccountLogout(w http.ResponseWriter, r *http.Request) {
@ -129,7 +129,7 @@ func handleAccountLogout(w http.ResponseWriter, r *http.Request) {
return
}
w.WriteHeader(http.StatusOK)
w.WriteHeader(http.StatusNoContent)
}
// game
@ -149,7 +149,7 @@ func handleGameClassicSessionCount(w http.ResponseWriter, r *http.Request) {
func handleGetSessionData(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
httpError(w, r, err, statusCodeFromError(err))
return
}
@ -195,7 +195,7 @@ const legacyClientSessionId = "LEGACY_CLIENT"
func legacyHandleGetSaveData(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
httpError(w, r, err, statusCodeFromError(err))
return
}
@ -244,7 +244,7 @@ func legacyHandleGetSaveData(w http.ResponseWriter, r *http.Request) {
func clearSessionData(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
httpError(w, r, err, statusCodeFromError(err))
return
}
@ -332,7 +332,7 @@ func clearSessionData(w http.ResponseWriter, r *http.Request) {
func deleteSystemSave(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
httpError(w, r, err, statusCodeFromError(err))
return
}
@ -412,7 +412,7 @@ func deleteSystemSave(w http.ResponseWriter, r *http.Request) {
func legacyHandleSaveData(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
httpError(w, r, err, statusCodeFromError(err))
return
}
@ -581,7 +581,7 @@ type CombinedSaveData struct {
func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
httpError(w, r, err, statusCodeFromError(err))
return
}
@ -653,7 +653,7 @@ type SystemVerifyRequest struct {
func handleSystemVerify(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
httpError(w, r, err, statusCodeFromError(err))
return
}
@ -705,7 +705,7 @@ func handleSystemVerify(w http.ResponseWriter, r *http.Request) {
func handleGetSystemData(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
httpError(w, r, err, statusCodeFromError(err))
return
}
@ -741,7 +741,7 @@ func handleGetSystemData(w http.ResponseWriter, r *http.Request) {
func legacyHandleNewClear(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusBadRequest)
httpError(w, r, err, statusCodeFromError(err))
return
}

View File

@ -23,13 +23,17 @@ import (
"fmt"
"slices"
_ "github.com/go-sql-driver/mysql"
"github.com/go-sql-driver/mysql"
"github.com/pagefaultgames/rogueserver/defs"
)
func AddAccountRecord(uuid []byte, username string, key, salt []byte) error {
_, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt)
if err != nil {
var mysqlErr *mysql.MySQLError
if errors.As(err, &mysqlErr) && mysqlErr.Number == 1062 {
return ErrAccountAlreadyExists
}
return err
}
@ -240,6 +244,9 @@ func FetchUUIDFromToken(token []byte) ([]byte, error) {
var uuid []byte
err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ?", token).Scan(&uuid)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrTokenNotFound
}
return nil, err
}

View File

@ -24,8 +24,6 @@ import (
"log"
"os"
"time"
_ "github.com/go-sql-driver/mysql"
)
var handle *sql.DB

8
db/errors.go Normal file
View File

@ -0,0 +1,8 @@
package db
import "errors"
var (
ErrAccountAlreadyExists = errors.New("account already exists")
ErrTokenNotFound = errors.New("token not found")
)

14
errors/errors.go Normal file
View File

@ -0,0 +1,14 @@
package errors
type HttpError struct {
Code int
Message string
}
func NewHttpError(code int, message string) *HttpError {
return &HttpError{Code: code, Message: message}
}
func (h HttpError) Error() string {
return h.Message
}

8
go.mod
View File

@ -6,7 +6,13 @@ require (
github.com/go-sql-driver/mysql v1.7.1
github.com/klauspost/compress v1.17.4
github.com/robfig/cron/v3 v3.0.1
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.16.0
)
require golang.org/x/sys v0.15.0 // indirect
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/sys v0.15.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

10
go.sum
View File

@ -1,10 +1,20 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=