Refactor user authentication
This commit is contained in:
@@ -2,6 +2,7 @@ package pwd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
@@ -19,6 +20,18 @@ import (
|
||||
|
||||
var preparedSessions = map[string]bool{}
|
||||
|
||||
type AccessDeniedError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (u *AccessDeniedError) Error() string {
|
||||
return fmt.Sprintf("Acess denied error: %s", u.Err.Error())
|
||||
}
|
||||
|
||||
func (u *AccessDeniedError) Unwrap() error {
|
||||
return u.Err
|
||||
}
|
||||
|
||||
type sessionBuilderWriter struct {
|
||||
sessionId string
|
||||
event event.EventApi
|
||||
@@ -48,8 +61,10 @@ type SessionSetupInstanceConf struct {
|
||||
func (p *pwd) SessionNew(ctx context.Context, config types.SessionConfig) (*types.Session, error) {
|
||||
defer observeAction("SessionNew", time.Now())
|
||||
|
||||
if u, err := p.storage.UserGet(config.UserId); err == nil && u.IsBanned {
|
||||
return nil, fmt.Errorf("User %s is banned\n", config.UserId)
|
||||
if _, err := p.UserGet(config.UserId); errors.Is(err, userBannedError) {
|
||||
return nil, &AccessDeniedError{err}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &types.Session{}
|
||||
|
||||
@@ -2,6 +2,7 @@ package pwd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -100,6 +101,7 @@ func TestSessionFailWhenUserIsBanned(t *testing.T) {
|
||||
s, e := p.SessionNew(context.Background(), sConfig)
|
||||
assert.NotNil(t, e)
|
||||
assert.Nil(t, s)
|
||||
assert.True(t, errors.Is(e, userBannedError))
|
||||
assert.Contains(t, e.Error(), "banned")
|
||||
|
||||
_d.AssertExpectations(t)
|
||||
|
||||
10
pwd/user.go
10
pwd/user.go
@@ -1,10 +1,14 @@
|
||||
package pwd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/play-with-docker/play-with-docker/pwd/types"
|
||||
"github.com/play-with-docker/play-with-docker/storage"
|
||||
)
|
||||
|
||||
var userBannedError = errors.New("User is banned")
|
||||
|
||||
func (p *pwd) UserNewLoginRequest(providerName string) (*types.LoginRequest, error) {
|
||||
req := &types.LoginRequest{Id: p.generator.NewId(), Provider: providerName}
|
||||
if err := p.storage.LoginRequestPut(req); err != nil {
|
||||
@@ -40,9 +44,11 @@ func (p *pwd) UserLogin(loginRequest *types.LoginRequest, user *types.User) (*ty
|
||||
return u, nil
|
||||
}
|
||||
func (p *pwd) UserGet(id string) (*types.User, error) {
|
||||
var user *types.User
|
||||
if user, err := p.storage.UserGet(id); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return user, nil
|
||||
} else if user.IsBanned {
|
||||
return user, userBannedError
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user