diff --git a/pwd/client_test.go b/pwd/client_test.go index af30718..accc758 100644 --- a/pwd/client_test.go +++ b/pwd/client_test.go @@ -33,6 +33,7 @@ func TestClientNew(t *testing.T) { _d.On("DaemonHost").Return("localhost") _d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) + _s.On("UserGet", mock.Anything).Return(&types.User{}, nil) _s.On("SessionCount").Return(1, nil) _s.On("InstanceCount").Return(0, nil) _s.On("ClientCount").Return(1, nil) @@ -76,6 +77,7 @@ func TestClientCount(t *testing.T) { _d.On("DaemonHost").Return("localhost") _d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) + _s.On("UserGet", mock.Anything).Return(&types.User{}, nil) _s.On("ClientPut", mock.AnythingOfType("*types.Client")).Return(nil) _s.On("ClientCount").Return(1, nil) _s.On("SessionCount").Return(1, nil) @@ -118,6 +120,7 @@ func TestClientResizeViewPort(t *testing.T) { _d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) _s.On("SessionCount").Return(1, nil) + _s.On("UserGet", mock.Anything).Return(&types.User{}, nil) _s.On("InstanceCount").Return(0, nil) _s.On("InstanceFindBySessionId", "aaaabbbbcccc").Return([]*types.Instance{}, nil) _s.On("ClientPut", mock.AnythingOfType("*types.Client")).Return(nil) diff --git a/pwd/instance_test.go b/pwd/instance_test.go index 8f40250..129bc8f 100644 --- a/pwd/instance_test.go +++ b/pwd/instance_test.go @@ -61,6 +61,7 @@ func TestInstanceNew(t *testing.T) { _d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) _s.On("SessionCount").Return(1, nil) + _s.On("UserGet", mock.Anything).Return(&types.User{}, nil) _s.On("ClientCount").Return(0, nil) _s.On("InstanceCount").Return(0, nil) _s.On("InstanceFindBySessionId", "aaaabbbbcccc").Return([]*types.Instance{}, nil) @@ -134,6 +135,7 @@ func TestInstanceNew_WithNotAllowedImage(t *testing.T) { _d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) _s.On("SessionCount").Return(1, nil) + _s.On("UserGet", mock.Anything).Return(&types.User{}, nil) _s.On("ClientCount").Return(0, nil) _s.On("InstanceCount").Return(0, nil) _s.On("InstanceFindBySessionId", "aaaabbbbcccc").Return([]*types.Instance{}, nil) @@ -204,6 +206,7 @@ func TestInstanceNew_WithCustomHostname(t *testing.T) { _d.On("DaemonHost").Return("localhost") _d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) + _s.On("UserGet", mock.Anything).Return(&types.User{}, nil) _s.On("SessionCount").Return(1, nil) _s.On("ClientCount").Return(0, nil) _s.On("InstanceCount").Return(0, nil) diff --git a/pwd/session.go b/pwd/session.go index d0a1fc7..e3c1a0f 100644 --- a/pwd/session.go +++ b/pwd/session.go @@ -47,7 +47,7 @@ type SessionSetupInstanceConf struct { func (p *pwd) SessionNew(ctx context.Context, config types.SessionConfig) (*types.Session, error) { defer observeAction("SessionNew", time.Now()) - if u, _ := p.storage.UserGet(config.UserId); u.IsBanned { + if u, err := p.storage.UserGet(config.UserId); err == nil && u.IsBanned { return nil, fmt.Errorf("User %s is banned\n", config.UserId) } diff --git a/pwd/session_test.go b/pwd/session_test.go index 6accba5..9563af7 100644 --- a/pwd/session_test.go +++ b/pwd/session_test.go @@ -35,6 +35,7 @@ func TestSessionNew(t *testing.T) { _d.On("DaemonHost").Return("localhost") _d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) + _s.On("UserGet", mock.Anything).Return(&types.User{}, nil) _s.On("SessionCount").Return(1, nil) _s.On("InstanceCount").Return(0, nil) _s.On("ClientCount").Return(0, nil) @@ -89,19 +90,7 @@ func TestSessionFailWhenUserIsBanned(t *testing.T) { ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_g, _f, _s)) sp := provisioner.NewOverlaySessionProvisioner(_f) - _g.On("NewId").Return("aaaabbbbcccc") - _f.On("GetForSession", mock.AnythingOfType("*types.Session")).Return(_d, nil) - _d.On("NetworkCreate", "aaaabbbbcccc", dtypes.NetworkCreate{Attachable: true, Driver: "overlay"}).Return(nil) - _d.On("DaemonHost").Return("localhost") - _d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) - _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) _s.On("UserGet", mock.Anything).Return(&types.User{IsBanned: true}, nil) - _s.On("SessionCount").Return(1, nil) - _s.On("InstanceCount").Return(0, nil) - _s.On("ClientCount").Return(0, nil) - - var nilArgs []interface{} - _e.M.On("Emit", event.SESSION_NEW, "aaaabbbbcccc", nilArgs).Return() p := NewPWD(_f, _e, _s, sp, ipf) p.generator = _g @@ -112,6 +101,12 @@ func TestSessionFailWhenUserIsBanned(t *testing.T) { assert.NotNil(t, e) assert.Nil(t, s) assert.Contains(t, e.Error(), "banned") + + _d.AssertExpectations(t) + _f.AssertExpectations(t) + _s.AssertExpectations(t) + _g.AssertExpectations(t) + _e.M.AssertExpectations(t) } /*