Refactor user authentication

This commit is contained in:
Marcos Lilljedahl
2020-10-02 00:22:52 -03:00
parent 2d1515d12f
commit 02804c4b58
3 changed files with 27 additions and 4 deletions

View File

@@ -2,6 +2,7 @@ package pwd
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log" "log"
"math" "math"
@@ -19,6 +20,18 @@ import (
var preparedSessions = map[string]bool{} var preparedSessions = map[string]bool{}
type AccessDeniedError struct {
Err error
}
func (u *AccessDeniedError) Error() string {
return fmt.Sprintf("Acess denied error: %s", u.Err.Error())
}
func (u *AccessDeniedError) Unwrap() error {
return u.Err
}
type sessionBuilderWriter struct { type sessionBuilderWriter struct {
sessionId string sessionId string
event event.EventApi event event.EventApi
@@ -48,8 +61,10 @@ type SessionSetupInstanceConf struct {
func (p *pwd) SessionNew(ctx context.Context, config types.SessionConfig) (*types.Session, error) { func (p *pwd) SessionNew(ctx context.Context, config types.SessionConfig) (*types.Session, error) {
defer observeAction("SessionNew", time.Now()) defer observeAction("SessionNew", time.Now())
if u, err := p.storage.UserGet(config.UserId); err == nil && u.IsBanned { if _, err := p.UserGet(config.UserId); errors.Is(err, userBannedError) {
return nil, fmt.Errorf("User %s is banned\n", config.UserId) return nil, &AccessDeniedError{err}
} else if err != nil {
return nil, err
} }
s := &types.Session{} s := &types.Session{}

View File

@@ -2,6 +2,7 @@ package pwd
import ( import (
"context" "context"
"errors"
"testing" "testing"
"time" "time"
@@ -100,6 +101,7 @@ func TestSessionFailWhenUserIsBanned(t *testing.T) {
s, e := p.SessionNew(context.Background(), sConfig) s, e := p.SessionNew(context.Background(), sConfig)
assert.NotNil(t, e) assert.NotNil(t, e)
assert.Nil(t, s) assert.Nil(t, s)
assert.True(t, errors.Is(e, userBannedError))
assert.Contains(t, e.Error(), "banned") assert.Contains(t, e.Error(), "banned")
_d.AssertExpectations(t) _d.AssertExpectations(t)

View File

@@ -1,10 +1,14 @@
package pwd package pwd
import ( import (
"errors"
"github.com/play-with-docker/play-with-docker/pwd/types" "github.com/play-with-docker/play-with-docker/pwd/types"
"github.com/play-with-docker/play-with-docker/storage" "github.com/play-with-docker/play-with-docker/storage"
) )
var userBannedError = errors.New("User is banned")
func (p *pwd) UserNewLoginRequest(providerName string) (*types.LoginRequest, error) { func (p *pwd) UserNewLoginRequest(providerName string) (*types.LoginRequest, error) {
req := &types.LoginRequest{Id: p.generator.NewId(), Provider: providerName} req := &types.LoginRequest{Id: p.generator.NewId(), Provider: providerName}
if err := p.storage.LoginRequestPut(req); err != nil { if err := p.storage.LoginRequestPut(req); err != nil {
@@ -40,9 +44,11 @@ func (p *pwd) UserLogin(loginRequest *types.LoginRequest, user *types.User) (*ty
return u, nil return u, nil
} }
func (p *pwd) UserGet(id string) (*types.User, error) { func (p *pwd) UserGet(id string) (*types.User, error) {
var user *types.User
if user, err := p.storage.UserGet(id); err != nil { if user, err := p.storage.UserGet(id); err != nil {
return nil, err return nil, err
} else { } else if user.IsBanned {
return user, nil return user, userBannedError
} }
return user, nil
} }