From 02804c4b58c656a831e927a586e3e09eb76217e9 Mon Sep 17 00:00:00 2001 From: Marcos Lilljedahl Date: Fri, 2 Oct 2020 00:22:52 -0300 Subject: [PATCH] Refactor user authentication --- pwd/session.go | 19 +++++++++++++++++-- pwd/session_test.go | 2 ++ pwd/user.go | 10 ++++++++-- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/pwd/session.go b/pwd/session.go index 192e653..2083add 100644 --- a/pwd/session.go +++ b/pwd/session.go @@ -2,6 +2,7 @@ package pwd import ( "context" + "errors" "fmt" "log" "math" @@ -19,6 +20,18 @@ import ( var preparedSessions = map[string]bool{} +type AccessDeniedError struct { + Err error +} + +func (u *AccessDeniedError) Error() string { + return fmt.Sprintf("Acess denied error: %s", u.Err.Error()) +} + +func (u *AccessDeniedError) Unwrap() error { + return u.Err +} + type sessionBuilderWriter struct { sessionId string event event.EventApi @@ -48,8 +61,10 @@ type SessionSetupInstanceConf struct { func (p *pwd) SessionNew(ctx context.Context, config types.SessionConfig) (*types.Session, error) { defer observeAction("SessionNew", time.Now()) - if u, err := p.storage.UserGet(config.UserId); err == nil && u.IsBanned { - return nil, fmt.Errorf("User %s is banned\n", config.UserId) + if _, err := p.UserGet(config.UserId); errors.Is(err, userBannedError) { + return nil, &AccessDeniedError{err} + } else if err != nil { + return nil, err } s := &types.Session{} diff --git a/pwd/session_test.go b/pwd/session_test.go index 9563af7..31f6837 100644 --- a/pwd/session_test.go +++ b/pwd/session_test.go @@ -2,6 +2,7 @@ package pwd import ( "context" + "errors" "testing" "time" @@ -100,6 +101,7 @@ func TestSessionFailWhenUserIsBanned(t *testing.T) { s, e := p.SessionNew(context.Background(), sConfig) assert.NotNil(t, e) assert.Nil(t, s) + assert.True(t, errors.Is(e, userBannedError)) assert.Contains(t, e.Error(), "banned") _d.AssertExpectations(t) diff --git a/pwd/user.go b/pwd/user.go index 508e20f..3cb330d 100644 --- a/pwd/user.go +++ b/pwd/user.go @@ -1,10 +1,14 @@ package pwd import ( + "errors" + "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/play-with-docker/play-with-docker/storage" ) +var userBannedError = errors.New("User is banned") + func (p *pwd) UserNewLoginRequest(providerName string) (*types.LoginRequest, error) { req := &types.LoginRequest{Id: p.generator.NewId(), Provider: providerName} if err := p.storage.LoginRequestPut(req); err != nil { @@ -40,9 +44,11 @@ func (p *pwd) UserLogin(loginRequest *types.LoginRequest, user *types.User) (*ty return u, nil } func (p *pwd) UserGet(id string) (*types.User, error) { + var user *types.User if user, err := p.storage.UserGet(id); err != nil { return nil, err - } else { - return user, nil + } else if user.IsBanned { + return user, userBannedError } + return user, nil }