From 954c52471b4285ffa4d88abbfca1a8eeb20849de Mon Sep 17 00:00:00 2001 From: "Jonathan Leibiusky @xetorthio" Date: Fri, 1 Sep 2017 20:12:19 -0300 Subject: [PATCH] Refactor storage to support shallow types. Add Client to storage. Fix client resizing issues. --- api.go | 2 +- handlers/get_session.go | 21 +- handlers/new_instance.go | 10 +- handlers/session_setup.go | 13 +- handlers/ws.go | 9 +- provisioner/dind.go | 19 +- provisioner/windows.go | 25 +- pwd/client.go | 50 +-- pwd/client_test.go | 17 +- pwd/instance.go | 33 +- pwd/instance_test.go | 23 +- pwd/mock.go | 11 +- pwd/pwd.go | 17 +- pwd/session.go | 47 +-- pwd/session_test.go | 12 +- pwd/types/client.go | 10 +- pwd/types/instance.go | 7 +- pwd/types/session.go | 23 +- scheduler/scheduler.go | 9 +- scheduler/scheduler_test.go | 30 +- storage/file.go | 367 +++++++++++++------- storage/file_test.go | 663 ++++++++++++++++++++++++------------ storage/mock.go | 95 +++--- storage/storage.go | 39 ++- www/assets/app.js | 16 +- 25 files changed, 1005 insertions(+), 563 deletions(-) diff --git a/api.go b/api.go index 4de298e..be9eaa9 100644 --- a/api.go +++ b/api.go @@ -22,7 +22,7 @@ func main() { s := initStorage() f := initFactory(s) - ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(f, s), provisioner.NewDinD(f)) + ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(f, s), provisioner.NewDinD(f, s)) sp := provisioner.NewOverlaySessionProvisioner(f) core := pwd.NewPWD(f, e, s, sp, ipf) diff --git a/handlers/get_session.go b/handlers/get_session.go index 55aba26..a473f45 100644 --- a/handlers/get_session.go +++ b/handlers/get_session.go @@ -2,21 +2,38 @@ package handlers import ( "encoding/json" + "log" "net/http" "github.com/gorilla/mux" + "github.com/play-with-docker/play-with-docker/pwd/types" ) +type SessionInfo struct { + *types.Session + Instances map[string]*types.Instance `json:"instances"` +} + func GetSession(rw http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) sessionId := vars["sessionId"] session := core.SessionGet(sessionId) - if session == nil { rw.WriteHeader(http.StatusNotFound) return } - json.NewEncoder(rw).Encode(session) + instances, err := core.InstanceFindBySession(session) + if err != nil { + log.Println(err) + rw.WriteHeader(http.StatusInternalServerError) + return + } + is := map[string]*types.Instance{} + for _, i := range instances { + is[i.Name] = i + } + + json.NewEncoder(rw).Encode(SessionInfo{session, is}) } diff --git a/handlers/new_instance.go b/handlers/new_instance.go index 6aeda76..545710f 100644 --- a/handlers/new_instance.go +++ b/handlers/new_instance.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/play-with-docker/play-with-docker/pwd" "github.com/play-with-docker/play-with-docker/pwd/types" ) @@ -19,13 +20,12 @@ func NewInstance(rw http.ResponseWriter, req *http.Request) { s := core.SessionGet(sessionId) - if len(s.Instances) >= 5 { - rw.WriteHeader(http.StatusConflict) - return - } - i, err := core.InstanceNew(s, body) if err != nil { + if pwd.SessionComplete(err) { + rw.WriteHeader(http.StatusConflict) + return + } log.Println(err) rw.WriteHeader(http.StatusInternalServerError) return diff --git a/handlers/session_setup.go b/handlers/session_setup.go index fa96c88..b0cee0b 100644 --- a/handlers/session_setup.go +++ b/handlers/session_setup.go @@ -19,16 +19,15 @@ func SessionSetup(rw http.ResponseWriter, req *http.Request) { s := core.SessionGet(sessionId) - if len(s.Instances) > 0 { - log.Println("Cannot setup a session that contains instances") - rw.WriteHeader(http.StatusConflict) - rw.Write([]byte("Cannot setup a session that contains instances")) - return - } - s.Host = req.Host err := core.SessionSetup(s, body) if err != nil { + if pwd.SessionNotEmpty(err) { + log.Println("Cannot setup a session that contains instances") + rw.WriteHeader(http.StatusConflict) + rw.Write([]byte("Cannot setup a session that contains instances")) + return + } log.Println(err) rw.WriteHeader(http.StatusInternalServerError) return diff --git a/handlers/ws.go b/handlers/ws.go index 6dadf6c..e91b110 100644 --- a/handlers/ws.go +++ b/handlers/ws.go @@ -32,8 +32,13 @@ func WS(so socketio.Socket) { so.Join(session.Id) + instances, err := core.InstanceFindBySession(session) + if err != nil { + log.Printf("Couldn't find instances for session with id [%s]. Got: %v\n", sessionId, err) + return + } var rw sync.Mutex - trackedTerminals := make(map[string]net.Conn, len(session.Instances)) + trackedTerminals := make(map[string]net.Conn, len(instances)) attachTerminalToSocket := func(instance *types.Instance, ws socketio.Socket) { rw.Lock() @@ -73,7 +78,7 @@ func WS(so socketio.Socket) { }(instance.Name, conn, ws) } // since this is a new connection, get all terminals of the session and attach - for _, instance := range session.Instances { + for _, instance := range instances { attachTerminalToSocket(instance, so) } diff --git a/provisioner/dind.go b/provisioner/dind.go index 25d1bac..86001f2 100644 --- a/provisioner/dind.go +++ b/provisioner/dind.go @@ -14,20 +14,22 @@ import ( "github.com/play-with-docker/play-with-docker/docker" "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/play-with-docker/play-with-docker/router" + "github.com/play-with-docker/play-with-docker/storage" ) type DinD struct { factory docker.FactoryApi + storage storage.StorageApi } -func NewDinD(f docker.FactoryApi) *DinD { - return &DinD{factory: f} +func NewDinD(f docker.FactoryApi, s storage.StorageApi) *DinD { + return &DinD{factory: f, storage: s} } -func checkHostnameExists(session *types.Session, hostname string) bool { - containerName := fmt.Sprintf("%s_%s", session.Id[:8], hostname) +func checkHostnameExists(sessionId, hostname string, instances []*types.Instance) bool { + containerName := fmt.Sprintf("%s_%s", sessionId[:8], hostname) exists := false - for _, instance := range session.Instances { + for _, instance := range instances { if instance.Name == containerName { exists = true break @@ -42,10 +44,14 @@ func (d *DinD) InstanceNew(session *types.Session, conf types.InstanceConfig) (* } log.Printf("NewInstance - using image: [%s]\n", conf.ImageName) if conf.Hostname == "" { + instances, err := d.storage.InstanceFindBySessionId(session.Id) + if err != nil { + return nil, err + } var nodeName string for i := 1; ; i++ { nodeName = fmt.Sprintf("node%d", i) - exists := checkHostnameExists(session, nodeName) + exists := checkHostnameExists(session.Id, nodeName, instances) if !exists { break } @@ -92,7 +98,6 @@ func (d *DinD) InstanceNew(session *types.Session, conf types.InstanceConfig) (* instance.ServerCert = conf.ServerCert instance.ServerKey = conf.ServerKey instance.CACert = conf.CACert - instance.Session = session instance.ProxyHost = router.EncodeHost(session.Id, instance.RoutableIP, router.HostOpts{}) instance.SessionHost = session.Host diff --git a/provisioner/windows.go b/provisioner/windows.go index f1daa66..74ddb21 100644 --- a/provisioner/windows.go +++ b/provisioner/windows.go @@ -58,10 +58,14 @@ func (d *windows) InstanceNew(session *types.Session, conf types.InstanceConfig) } if conf.Hostname == "" { + instances, err := d.storage.InstanceFindBySessionId(session.Id) + if err != nil { + return nil, err + } var nodeName string for i := 1; ; i++ { nodeName = fmt.Sprintf("node%d", i) - exists := checkHostnameExists(session, nodeName) + exists := checkHostnameExists(session.Id, nodeName, instances) if !exists { break } @@ -87,11 +91,11 @@ func (d *windows) InstanceNew(session *types.Session, conf types.InstanceConfig) dockerClient, err := d.factory.GetForSession(session.Id) if err != nil { - d.releaseInstance(session.Id, winfo.id) + d.releaseInstance(winfo.id) return nil, err } if err = dockerClient.CreateContainer(opts); err != nil { - d.releaseInstance(session.Id, winfo.id) + d.releaseInstance(winfo.id) return nil, err } @@ -108,7 +112,6 @@ func (d *windows) InstanceNew(session *types.Session, conf types.InstanceConfig) instance.ServerCert = conf.ServerCert instance.ServerKey = conf.ServerKey instance.CACert = conf.CACert - instance.Session = session instance.ProxyHost = router.EncodeHost(session.Id, instance.RoutableIP, router.HostOpts{}) instance.SessionHost = session.Host @@ -167,11 +170,11 @@ func (d *windows) InstanceDelete(session *types.Session, instance *types.Instanc return err } - return d.releaseInstance(session.Id, instance.WindowsId) + return d.releaseInstance(instance.WindowsId) } -func (d *windows) releaseInstance(sessionId, instanceId string) error { - return d.storage.InstanceDeleteWindows(sessionId, instanceId) +func (d *windows) releaseInstance(instanceId string) error { + return d.storage.WindowsInstanceDelete(instanceId) } func (d *windows) InstanceResizeTerminal(instance *types.Instance, rows, cols uint) error { @@ -222,10 +225,10 @@ func (d *windows) getWindowsInstanceInfo(sessionId string) (*instanceInfo, error } } - assignedInstances, err := d.storage.InstanceGetAllWindows() + assignedInstances, err := d.storage.WindowsInstanceGetAll() assignedInstancesIds := []string{} for _, ai := range assignedInstances { - assignedInstancesIds = append(assignedInstancesIds, ai.ID) + assignedInstancesIds = append(assignedInstancesIds, ai.Id) } if err != nil { @@ -243,7 +246,7 @@ func (d *windows) getWindowsInstanceInfo(sessionId string) (*instanceInfo, error }) if err != nil { // TODO retry x times and free the instance that was picked? - d.releaseInstance(sessionId, avInstanceId) + d.releaseInstance(avInstanceId) return nil, err } @@ -273,7 +276,7 @@ func (d *windows) pickFreeInstance(sessionId string, availInstances, assignedIns } if !found { - err := d.storage.InstanceCreateWindows(&types.WindowsInstance{SessionId: sessionId, ID: av}) + err := d.storage.WindowsInstancePut(&types.WindowsInstance{SessionId: sessionId, Id: av}) if err != nil { // TODO either storage error or instance is already assigned (race condition) } diff --git a/pwd/client.go b/pwd/client.go index 01b3516..d01ad94 100644 --- a/pwd/client.go +++ b/pwd/client.go @@ -2,7 +2,6 @@ package pwd import ( "log" - "sync/atomic" "time" "github.com/play-with-docker/play-with-docker/event" @@ -11,9 +10,10 @@ import ( func (p *pwd) ClientNew(id string, session *types.Session) *types.Client { defer observeAction("ClientNew", time.Now()) - c := &types.Client{Id: id, Session: session} - session.Clients = append(session.Clients, c) - p.clientCount = atomic.AddInt32(&p.clientCount, 1) + c := &types.Client{Id: id, SessionId: session.Id} + if err := p.storage.ClientPut(c); err != nil { + log.Println("Error saving client", err) + } return c } @@ -22,38 +22,46 @@ func (p *pwd) ClientResizeViewPort(c *types.Client, cols, rows uint) { c.ViewPort.Rows = rows c.ViewPort.Cols = cols - p.notifyClientSmallestViewPort(c.Session) + if err := p.storage.ClientPut(c); err != nil { + log.Println("Error saving client", err) + return + } + p.notifyClientSmallestViewPort(c.SessionId) } func (p *pwd) ClientClose(client *types.Client) { defer observeAction("ClientClose", time.Now()) // Client has disconnected. Remove from session and recheck terminal sizes. - session := client.Session - for i, cl := range session.Clients { - if cl.Id == client.Id { - session.Clients = append(session.Clients[:i], session.Clients[i+1:]...) - p.clientCount = atomic.AddInt32(&p.clientCount, -1) - break - } + if err := p.storage.ClientDelete(client.Id); err != nil { + log.Println("Error deleting client", err) + return } - if len(session.Clients) > 0 { - p.notifyClientSmallestViewPort(session) - } - p.setGauges() + p.notifyClientSmallestViewPort(client.SessionId) } func (p *pwd) ClientCount() int { - return int(atomic.LoadInt32(&p.clientCount)) + count, err := p.storage.ClientCount() + if err != nil { + log.Println("Error counting clients", err) + return 0 + } + return count } -func (p *pwd) notifyClientSmallestViewPort(session *types.Session) { - vp := p.SessionGetSmallestViewPort(session) +func (p *pwd) notifyClientSmallestViewPort(sessionId string) { + instances, err := p.storage.InstanceFindBySessionId(sessionId) + if err != nil { + log.Printf("Error finding instances for session [%s]. Got: %v\n", sessionId, err) + return + } + + vp := p.SessionGetSmallestViewPort(sessionId) // Resize all terminals in the session - p.event.Emit(event.INSTANCE_VIEWPORT_RESIZE, session.Id, vp.Cols, vp.Rows) - for _, instance := range session.Instances { + for _, instance := range instances { err := p.InstanceResizeTerminal(instance, vp.Rows, vp.Cols) if err != nil { log.Println("Error resizing terminal", err) } } + p.event.Emit(event.INSTANCE_VIEWPORT_RESIZE, sessionId, vp.Cols, vp.Rows) } diff --git a/pwd/client_test.go b/pwd/client_test.go index ed69b31..27d0a79 100644 --- a/pwd/client_test.go +++ b/pwd/client_test.go @@ -22,7 +22,7 @@ func TestClientNew(t *testing.T) { _d := &docker.Mock{} _e := &event.Mock{} - ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f)) + ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f, _s)) sp := provisioner.NewOverlaySessionProvisioner(_f) _g.On("NewId").Return("aaaabbbbcccc") @@ -33,6 +33,8 @@ func TestClientNew(t *testing.T) { _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) _s.On("SessionCount").Return(1, nil) _s.On("InstanceCount").Return(0, nil) + _s.On("ClientCount").Return(1, nil) + _s.On("ClientPut", mock.AnythingOfType("*types.Client")).Return(nil) var nilArgs []interface{} _e.M.On("Emit", event.SESSION_NEW, "aaaabbbbcccc", nilArgs).Return() @@ -45,8 +47,7 @@ func TestClientNew(t *testing.T) { client := p.ClientNew("foobar", session) - assert.Equal(t, types.Client{Id: "foobar", Session: session, ViewPort: types.ViewPort{Cols: 0, Rows: 0}}, *client) - assert.Contains(t, session.Clients, client) + assert.Equal(t, types.Client{Id: "foobar", SessionId: session.Id, ViewPort: types.ViewPort{Cols: 0, Rows: 0}}, *client) _d.AssertExpectations(t) _f.AssertExpectations(t) @@ -61,7 +62,7 @@ func TestClientCount(t *testing.T) { _g := &mockGenerator{} _d := &docker.Mock{} _e := &event.Mock{} - ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f)) + ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f, _s)) sp := provisioner.NewOverlaySessionProvisioner(_f) _g.On("NewId").Return("aaaabbbbcccc") @@ -70,6 +71,8 @@ func TestClientCount(t *testing.T) { _d.On("GetDaemonHost").Return("localhost") _d.On("ConnectNetwork", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) + _s.On("ClientPut", mock.AnythingOfType("*types.Client")).Return(nil) + _s.On("ClientCount").Return(1, nil) _s.On("SessionCount").Return(1, nil) _s.On("InstanceCount").Return(-1, nil) var nilArgs []interface{} @@ -98,7 +101,7 @@ func TestClientResizeViewPort(t *testing.T) { _g := &mockGenerator{} _d := &docker.Mock{} _e := &event.Mock{} - ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f)) + ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f, _s)) sp := provisioner.NewOverlaySessionProvisioner(_f) _g.On("NewId").Return("aaaabbbbcccc") @@ -109,6 +112,9 @@ func TestClientResizeViewPort(t *testing.T) { _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) _s.On("SessionCount").Return(1, nil) _s.On("InstanceCount").Return(0, nil) + _s.On("InstanceFindBySessionId", "aaaabbbbcccc").Return([]*types.Instance{}, nil) + _s.On("ClientPut", mock.AnythingOfType("*types.Client")).Return(nil) + _s.On("ClientCount").Return(1, nil) var nilArgs []interface{} _e.M.On("Emit", event.SESSION_NEW, "aaaabbbbcccc", nilArgs).Return() @@ -119,6 +125,7 @@ func TestClientResizeViewPort(t *testing.T) { session, err := p.SessionNew(time.Hour, "", "", "") assert.Nil(t, err) client := p.ClientNew("foobar", session) + _s.On("ClientFindBySessionId", "aaaabbbbcccc").Return([]*types.Client{client}, nil) p.ClientResizeViewPort(client, 80, 24) diff --git a/pwd/instance.go b/pwd/instance.go index 2b05dde..ed405f3 100644 --- a/pwd/instance.go +++ b/pwd/instance.go @@ -51,7 +51,7 @@ func (p *pwd) InstanceUploadFromReader(instance *types.Instance, fileName, dest func (p *pwd) InstanceGet(session *types.Session, name string) *types.Instance { defer observeAction("InstanceGet", time.Now()) - instance, err := p.storage.InstanceGet(session.Id, name) + instance, err := p.storage.InstanceGet(name) if err != nil { log.Println(err) return nil @@ -59,14 +59,14 @@ func (p *pwd) InstanceGet(session *types.Session, name string) *types.Instance { return instance } -func (p *pwd) InstanceFindByIP(sessionId, ip string) *types.Instance { - defer observeAction("InstanceFindByIP", time.Now()) - i, err := p.storage.InstanceFindByIP(sessionId, ip) +func (p *pwd) InstanceFindBySession(session *types.Session) ([]*types.Instance, error) { + defer observeAction("InstanceFindBySession", time.Now()) + instances, err := p.storage.InstanceFindBySessionId(session.Id) if err != nil { - return nil + log.Println(err) + return nil, err } - - return i + return instances, nil } func (p *pwd) InstanceDelete(session *types.Session, instance *types.Instance) error { @@ -83,7 +83,7 @@ func (p *pwd) InstanceDelete(session *types.Session, instance *types.Instance) e return err } - if err := p.storage.InstanceDelete(session.Id, instance.Name); err != nil { + if err := p.storage.InstanceDelete(instance.Name); err != nil { return err } @@ -99,6 +99,16 @@ func (p *pwd) InstanceNew(session *types.Session, conf types.InstanceConfig) (*t session.Lock() defer session.Unlock() + instances, err := p.storage.InstanceFindBySessionId(session.Id) + if err != nil { + log.Println(err) + return nil, err + } + + if len(instances) >= 5 { + return nil, sessionComplete + } + prov, err := p.getProvisioner(conf.Type) if err != nil { return nil, err @@ -109,12 +119,7 @@ func (p *pwd) InstanceNew(session *types.Session, conf types.InstanceConfig) (*t return nil, err } - if session.Instances == nil { - session.Instances = make(map[string]*types.Instance) - } - session.Instances[instance.Name] = instance - - err = p.storage.InstanceCreate(session.Id, instance) + err = p.storage.InstancePut(instance) if err != nil { return nil, err } diff --git a/pwd/instance_test.go b/pwd/instance_test.go index ce6b45b..1f0aeb1 100644 --- a/pwd/instance_test.go +++ b/pwd/instance_test.go @@ -23,7 +23,7 @@ func TestInstanceResizeTerminal(t *testing.T) { _s := &storage.Mock{} _g := &mockGenerator{} _e := &event.Mock{} - ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f)) + ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f, _s)) sp := provisioner.NewOverlaySessionProvisioner(_f) _d.On("ContainerResize", "foobar", uint(24), uint(80)).Return(nil) @@ -47,7 +47,7 @@ func TestInstanceNew(t *testing.T) { _s := &storage.Mock{} _g := &mockGenerator{} _e := &event.Mock{} - ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f)) + ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f, _s)) sp := provisioner.NewOverlaySessionProvisioner(_f) _g.On("NewId").Return("aaaabbbbcccc") @@ -57,7 +57,9 @@ func TestInstanceNew(t *testing.T) { _d.On("ConnectNetwork", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) _s.On("SessionCount").Return(1, nil) + _s.On("ClientCount").Return(0, nil) _s.On("InstanceCount").Return(0, nil) + _s.On("InstanceFindBySessionId", "aaaabbbbcccc").Return([]*types.Instance{}, nil) var nilArgs []interface{} _e.M.On("Emit", event.SESSION_NEW, "aaaabbbbcccc", nilArgs).Return() @@ -75,7 +77,6 @@ func TestInstanceNew(t *testing.T) { RoutableIP: "10.0.0.1", Image: config.GetDindImageName(), SessionId: session.Id, - Session: session, SessionHost: session.Host, ProxyHost: router.EncodeHost(session.Id, "10.0.0.1", router.HostOpts{}), } @@ -94,7 +95,7 @@ func TestInstanceNew(t *testing.T) { } _d.On("CreateContainer", expectedContainerOpts).Return(nil) _d.On("GetContainerIPs", expectedInstance.Name).Return(map[string]string{session.Id: "10.0.0.1"}, nil) - _s.On("InstanceCreate", "aaaabbbbcccc", mock.AnythingOfType("*types.Instance")).Return(nil) + _s.On("InstancePut", mock.AnythingOfType("*types.Instance")).Return(nil) _e.M.On("Emit", event.INSTANCE_NEW, "aaaabbbbcccc", []interface{}{"aaaabbbb_node1", "10.0.0.1", "node1", "ip10-0-0-1-aaaabbbbcccc"}).Return() instance, err := p.InstanceNew(session, types.InstanceConfig{Host: "something.play-with-docker.com"}) @@ -115,7 +116,7 @@ func TestInstanceNew_WithNotAllowedImage(t *testing.T) { _s := &storage.Mock{} _g := &mockGenerator{} _e := &event.Mock{} - ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f)) + ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f, _s)) sp := provisioner.NewOverlaySessionProvisioner(_f) _g.On("NewId").Return("aaaabbbbcccc") @@ -125,7 +126,9 @@ func TestInstanceNew_WithNotAllowedImage(t *testing.T) { _d.On("ConnectNetwork", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) _s.On("SessionCount").Return(1, nil) + _s.On("ClientCount").Return(0, nil) _s.On("InstanceCount").Return(0, nil) + _s.On("InstanceFindBySessionId", "aaaabbbbcccc").Return([]*types.Instance{}, nil) var nilArgs []interface{} _e.M.On("Emit", event.SESSION_NEW, "aaaabbbbcccc", nilArgs).Return() @@ -144,7 +147,6 @@ func TestInstanceNew_WithNotAllowedImage(t *testing.T) { RoutableIP: "10.0.0.1", Image: "redis", SessionId: session.Id, - Session: session, SessionHost: session.Host, ProxyHost: router.EncodeHost(session.Id, "10.0.0.1", router.HostOpts{}), } @@ -162,7 +164,7 @@ func TestInstanceNew_WithNotAllowedImage(t *testing.T) { } _d.On("CreateContainer", expectedContainerOpts).Return(nil) _d.On("GetContainerIPs", expectedInstance.Name).Return(map[string]string{session.Id: "10.0.0.1"}, nil) - _s.On("InstanceCreate", "aaaabbbbcccc", mock.AnythingOfType("*types.Instance")).Return(nil) + _s.On("InstancePut", mock.AnythingOfType("*types.Instance")).Return(nil) _e.M.On("Emit", event.INSTANCE_NEW, "aaaabbbbcccc", []interface{}{"aaaabbbb_node1", "10.0.0.1", "node1", "ip10-0-0-1-aaaabbbbcccc"}).Return() instance, err := p.InstanceNew(session, types.InstanceConfig{ImageName: "redis"}) @@ -184,7 +186,7 @@ func TestInstanceNew_WithCustomHostname(t *testing.T) { _g := &mockGenerator{} _e := &event.Mock{} - ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f)) + ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f, _s)) sp := provisioner.NewOverlaySessionProvisioner(_f) _g.On("NewId").Return("aaaabbbbcccc") @@ -194,7 +196,9 @@ func TestInstanceNew_WithCustomHostname(t *testing.T) { _d.On("ConnectNetwork", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) _s.On("SessionCount").Return(1, nil) + _s.On("ClientCount").Return(0, nil) _s.On("InstanceCount").Return(0, nil) + _s.On("InstanceFindBySessionId", "aaaabbbbcccc").Return([]*types.Instance{}, nil) var nilArgs []interface{} _e.M.On("Emit", event.SESSION_NEW, "aaaabbbbcccc", nilArgs).Return() @@ -211,7 +215,6 @@ func TestInstanceNew_WithCustomHostname(t *testing.T) { IP: "10.0.0.1", RoutableIP: "10.0.0.1", Image: "redis", - Session: session, SessionHost: session.Host, SessionId: session.Id, ProxyHost: router.EncodeHost(session.Id, "10.0.0.1", router.HostOpts{}), @@ -231,7 +234,7 @@ func TestInstanceNew_WithCustomHostname(t *testing.T) { _d.On("CreateContainer", expectedContainerOpts).Return(nil) _d.On("GetContainerIPs", expectedInstance.Name).Return(map[string]string{session.Id: "10.0.0.1"}, nil) - _s.On("InstanceCreate", "aaaabbbbcccc", mock.AnythingOfType("*types.Instance")).Return(nil) + _s.On("InstancePut", mock.AnythingOfType("*types.Instance")).Return(nil) _e.M.On("Emit", event.INSTANCE_NEW, "aaaabbbbcccc", []interface{}{"aaaabbbb_redis-master", "10.0.0.1", "redis-master", "ip10-0-0-1-aaaabbbbcccc"}).Return() instance, err := p.InstanceNew(session, types.InstanceConfig{ImageName: "redis", Hostname: "redis-master"}) diff --git a/pwd/mock.go b/pwd/mock.go index a5cd34f..cb46a06 100644 --- a/pwd/mock.go +++ b/pwd/mock.go @@ -23,8 +23,8 @@ func (m *Mock) SessionClose(session *types.Session) error { return args.Error(0) } -func (m *Mock) SessionGetSmallestViewPort(session *types.Session) types.ViewPort { - args := m.Called(session) +func (m *Mock) SessionGetSmallestViewPort(sessionId string) types.ViewPort { + args := m.Called(sessionId) return args.Get(0).(types.ViewPort) } @@ -72,10 +72,9 @@ func (m *Mock) InstanceGet(session *types.Session, name string) *types.Instance args := m.Called(session, name) return args.Get(0).(*types.Instance) } - -func (m *Mock) InstanceFindByIP(session, ip string) *types.Instance { - args := m.Called(session, ip) - return args.Get(0).(*types.Instance) +func (m *Mock) InstanceFindBySession(session *types.Session) ([]*types.Instance, error) { + args := m.Called(session) + return args.Get(0).([]*types.Instance), args.Error(1) } func (m *Mock) InstanceDelete(session *types.Session, instance *types.Instance) error { diff --git a/pwd/pwd.go b/pwd/pwd.go index 814495e..c0f196a 100644 --- a/pwd/pwd.go +++ b/pwd/pwd.go @@ -1,6 +1,7 @@ package pwd import ( + "errors" "io" "net" "time" @@ -79,10 +80,22 @@ func (m *mockGenerator) NewId() string { return args.String(0) } +var sessionComplete = errors.New("Session is complete") + +func SessionComplete(e error) bool { + return e == sessionComplete +} + +var sessionNotEmpty = errors.New("Session is not empty") + +func SessionNotEmpty(e error) bool { + return e == sessionNotEmpty +} + type PWDApi interface { SessionNew(duration time.Duration, stack string, stackName, imageName string) (*types.Session, error) SessionClose(session *types.Session) error - SessionGetSmallestViewPort(session *types.Session) types.ViewPort + SessionGetSmallestViewPort(sessionId string) types.ViewPort SessionDeployStack(session *types.Session) error SessionGet(id string) *types.Session SessionSetup(session *types.Session, conf SessionSetupConf) error @@ -93,7 +106,7 @@ type PWDApi interface { InstanceUploadFromUrl(instance *types.Instance, fileName, dest, url string) error InstanceUploadFromReader(instance *types.Instance, fileName, dest string, reader io.Reader) error InstanceGet(session *types.Session, name string) *types.Instance - InstanceFindByIP(sessionId, ip string) *types.Instance + InstanceFindBySession(session *types.Session) ([]*types.Instance, error) InstanceDelete(session *types.Session, instance *types.Instance) error InstanceExec(instance *types.Instance, cmd []string) (int, error) diff --git a/pwd/session.go b/pwd/session.go index f5a8ba7..daa5990 100644 --- a/pwd/session.go +++ b/pwd/session.go @@ -15,7 +15,6 @@ import ( "github.com/play-with-docker/play-with-docker/docker" "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" - "github.com/play-with-docker/play-with-docker/storage" ) var preparedSessions = map[string]bool{} @@ -46,7 +45,6 @@ func (p *pwd) SessionNew(duration time.Duration, stack, stackName, imageName str s := &types.Session{} s.Id = p.generator.NewId() - s.Instances = map[string]*types.Instance{} s.CreatedAt = time.Now() s.ExpiresAt = s.CreatedAt.Add(duration) s.Ready = true @@ -81,21 +79,14 @@ func (p *pwd) SessionNew(duration time.Duration, stack, stackName, imageName str func (p *pwd) SessionClose(s *types.Session) error { defer observeAction("SessionClose", time.Now()) - updatedSession, err := p.storage.SessionGet(s.Id) - if err != nil { - if storage.NotFound(err) { - log.Printf("Session with id [%s] was not found in storage.\n", s.Id) - return err - } else { - log.Printf("Couldn't close session. Got: %s\n", err) - return err - } - } - s = updatedSession - log.Printf("Starting clean up of session [%s]\n", s.Id) g, _ := errgroup.WithContext(context.Background()) - for _, i := range s.Instances { + instances, err := p.storage.InstanceFindBySessionId(s.Id) + if err != nil { + log.Printf("Could not find instances in session %s. Got %v\n", s.Id, err) + return err + } + for _, i := range instances { i := i g.Go(func() error { return p.InstanceDelete(s, i) @@ -124,13 +115,22 @@ func (p *pwd) SessionClose(s *types.Session) error { } -func (p *pwd) SessionGetSmallestViewPort(s *types.Session) types.ViewPort { +func (p *pwd) SessionGetSmallestViewPort(sessionId string) types.ViewPort { defer observeAction("SessionGetSmallestViewPort", time.Now()) - minRows := s.Clients[0].ViewPort.Rows - minCols := s.Clients[0].ViewPort.Cols + clients, err := p.storage.ClientFindBySessionId(sessionId) + if err != nil { + log.Printf("Error finding clients for session [%s]. Got: %v\n", sessionId, err) + return types.ViewPort{Rows: 24, Cols: 80} + } + if len(clients) == 0 { + log.Printf("Session [%s] doesn't have clients. Returning default viewport\n", sessionId) + return types.ViewPort{Rows: 24, Cols: 80} + } + minRows := clients[0].ViewPort.Rows + minCols := clients[0].ViewPort.Cols - for _, c := range s.Clients { + for _, c := range clients { minRows = uint(math.Min(float64(minRows), float64(c.ViewPort.Rows))) minCols = uint(math.Min(float64(minCols), float64(c.ViewPort.Cols))) } @@ -206,6 +206,15 @@ func (p *pwd) SessionSetup(session *types.Session, conf SessionSetupConf) error var tokens *docker.SwarmTokens = nil var firstSwarmManager *types.Instance = nil + instances, err := p.storage.InstanceFindBySessionId(session.Id) + if err != nil { + log.Println(err) + return err + } + if len(instances) > 0 { + return sessionNotEmpty + } + // first look for a swarm manager and create it for _, conf := range conf.Instances { if conf.IsSwarmManager { diff --git a/pwd/session_test.go b/pwd/session_test.go index 1facdb9..7aa3529 100644 --- a/pwd/session_test.go +++ b/pwd/session_test.go @@ -9,6 +9,7 @@ import ( "github.com/play-with-docker/play-with-docker/docker" "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/provisioner" + "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/play-with-docker/play-with-docker/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -23,7 +24,7 @@ func TestSessionNew(t *testing.T) { _g := &mockGenerator{} _e := &event.Mock{} - ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f)) + ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f, _s)) sp := provisioner.NewOverlaySessionProvisioner(_f) _g.On("NewId").Return("aaaabbbbcccc") @@ -34,6 +35,7 @@ func TestSessionNew(t *testing.T) { _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) _s.On("SessionCount").Return(1, nil) _s.On("InstanceCount").Return(0, nil) + _s.On("ClientCount").Return(0, nil) var nilArgs []interface{} _e.M.On("Emit", event.SESSION_NEW, "aaaabbbbcccc", nilArgs).Return() @@ -77,7 +79,7 @@ func TestSessionSetup(t *testing.T) { _s := &storage.Mock{} _g := &mockGenerator{} _e := &event.Mock{} - ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f)) + ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_f, _s)) sp := provisioner.NewOverlaySessionProvisioner(_f) _g.On("NewId").Return("aaaabbbbcccc") @@ -86,9 +88,11 @@ func TestSessionSetup(t *testing.T) { _d.On("GetDaemonHost").Return("localhost") _d.On("ConnectNetwork", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil) _s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil) - _s.On("InstanceCreate", "aaaabbbbcccc", mock.AnythingOfType("*types.Instance")).Return(nil) + _s.On("InstancePut", mock.AnythingOfType("*types.Instance")).Return(nil) _s.On("SessionCount").Return(1, nil) + _s.On("ClientCount").Return(1, nil) _s.On("InstanceCount").Return(0, nil) + _s.On("InstanceFindBySessionId", "aaaabbbbcccc").Return([]*types.Instance{}, nil) _d.On("CreateContainer", docker.CreateContainerOpts{Image: "franela/dind", SessionId: "aaaabbbbcccc", PwdIpAddress: "10.0.0.1", ContainerName: "aaaabbbb_manager1", Hostname: "manager1", Privileged: true, HostFQDN: "localhost", Networks: []string{"aaaabbbbcccc"}}).Return(nil) _d.On("GetContainerIPs", "aaaabbbb_manager1").Return(map[string]string{"aaaabbbbcccc": "10.0.0.2"}, nil) @@ -153,8 +157,6 @@ func TestSessionSetup(t *testing.T) { }) assert.Nil(t, err) - assert.Equal(t, 5, len(s.Instances)) - _d.AssertExpectations(t) _f.AssertExpectations(t) _s.AssertExpectations(t) diff --git a/pwd/types/client.go b/pwd/types/client.go index 2e31e5e..e6a798c 100644 --- a/pwd/types/client.go +++ b/pwd/types/client.go @@ -1,12 +1,12 @@ package types type Client struct { - Id string - ViewPort ViewPort - Session *Session + Id string `json:"id" bson:"id"` + SessionId string `json:"session_id"` + ViewPort ViewPort `json:"viewport"` } type ViewPort struct { - Rows uint - Cols uint + Rows uint `json:"rows"` + Cols uint `json:"cols"` } diff --git a/pwd/types/instance.go b/pwd/types/instance.go index 18c5e7b..160d324 100644 --- a/pwd/types/instance.go +++ b/pwd/types/instance.go @@ -3,8 +3,8 @@ package types import "context" type Instance struct { - Image string `json:"image" bson:"image"` Name string `json:"name" bson:"name"` + Image string `json:"image" bson:"image"` Hostname string `json:"hostname" bson:"hostname"` IP string `json:"ip" bson:"ip"` RoutableIP string `json:"routable_ip" bson:"routable_id"` @@ -17,13 +17,12 @@ type Instance struct { ProxyHost string `json:"proxy_host" bson:"proxy_host"` SessionHost string `json:"session_host" bson:"session_host"` Type string `json:"type" bson:"type"` - Session *Session `json:"-" bson:"-"` - ctx context.Context `json:"-" bson:"-"` WindowsId string `json:"-" bson:"windows_id"` + ctx context.Context `json:"-" bson:"-"` } type WindowsInstance struct { - ID string `bson:"id"` + Id string `bson:"id"` SessionId string `bson:"session_id"` } diff --git a/pwd/types/session.go b/pwd/types/session.go index 944f799..1347751 100644 --- a/pwd/types/session.go +++ b/pwd/types/session.go @@ -6,19 +6,16 @@ import ( ) type Session struct { - Id string `json:"id"` - Instances map[string]*Instance `json:"instances" bson:"-"` - CreatedAt time.Time `json:"created_at"` - ExpiresAt time.Time `json:"expires_at"` - PwdIpAddress string `json:"pwd_ip_address"` - Ready bool `json:"ready"` - Stack string `json:"stack"` - StackName string `json:"stack_name"` - ImageName string `json:"image_name"` - Host string `json:"host"` - Clients []*Client `json:"-" bson:"-"` - WindowsAssigned []*WindowsInstance `json:"-" bson:"-"` - rw sync.Mutex `json:"-"` + Id string `json:"id" bson:"id"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + PwdIpAddress string `json:"pwd_ip_address"` + Ready bool `json:"ready"` + Stack string `json:"stack"` + StackName string `json:"stack_name"` + ImageName string `json:"image_name"` + Host string `json:"host"` + rw sync.Mutex `json:"-"` } func (s *Session) Lock() { diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index 47ad57b..5a7d58b 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -161,9 +161,14 @@ func (s *scheduler) processSession(ctx context.Context, session *types.Session) return } + instances, err := s.storage.InstanceFindBySessionId(updatedSession.Id) + if err != nil { + log.Printf("Couldn't find instances for session [%s]. Got: %v\n", updatedSession.Id, err) + return + } wg := sync.WaitGroup{} - wg.Add(len(updatedSession.Instances)) - for _, ins := range updatedSession.Instances { + wg.Add(len(instances)) + for _, ins := range instances { go func(ins *types.Instance) { s.processInstance(ctx, ins) wg.Done() diff --git a/scheduler/scheduler_test.go b/scheduler/scheduler_test.go index df14f7f..152aa76 100644 --- a/scheduler/scheduler_test.go +++ b/scheduler/scheduler_test.go @@ -45,16 +45,19 @@ func TestNew(t *testing.T) { s := &types.Session{ Id: "aaabbbccc", ExpiresAt: time.Now().Add(time.Hour), - Instances: map[string]*types.Instance{ - "node1": &types.Instance{ - Name: "node1", - IP: "10.0.0.1", - }, - }, + } + + i := &types.Instance{ + SessionId: s.Id, + Name: "node1", + IP: "10.0.0.1", } err := store.SessionPut(s) assert.Nil(t, err) + err = store.InstancePut(i) + assert.Nil(t, err) + sch, err := NewScheduler(store, event.NewLocalBroker(), &pwd.Mock{}) assert.Nil(t, err) assert.Len(t, sch.scheduledSessions, 1) @@ -99,16 +102,19 @@ func TestStart(t *testing.T) { s := &types.Session{ Id: "aaabbbccc", ExpiresAt: time.Now().Add(time.Hour), - Instances: map[string]*types.Instance{ - "node1": &types.Instance{ - Name: "node1", - IP: "10.0.0.1", - }, - }, + } + + i := &types.Instance{ + SessionId: s.Id, + Name: "node1", + IP: "10.0.0.1", } err := store.SessionPut(s) assert.Nil(t, err) + err = store.InstancePut(i) + assert.Nil(t, err) + sch, err := NewScheduler(store, event.NewLocalBroker(), &pwd.Mock{}) assert.Nil(t, err) diff --git a/storage/file.go b/storage/file.go index c1c75b3..68f174c 100644 --- a/storage/file.go +++ b/storage/file.go @@ -2,9 +2,7 @@ package storage import ( "encoding/json" - "fmt" "os" - "strings" "sync" "github.com/play-with-docker/play-with-docker/pwd/types" @@ -13,146 +11,76 @@ import ( type storage struct { rw sync.Mutex path string - db map[string]*types.Session + db *DB } -func (store *storage) SessionGet(sessionId string) (*types.Session, error) { +type DB struct { + Sessions map[string]*types.Session `json:"sessions"` + Instances map[string]*types.Instance `json:"instances"` + Clients map[string]*types.Client `json:"clients"` + WindowsInstances map[string]*types.WindowsInstance `json:"windows_instances"` + + WindowsInstancesBySessionId map[string][]string `json:"windows_instances_by_session_id"` + InstancesBySessionId map[string][]string `json:"instances_by_session_id"` + ClientsBySessionId map[string][]string `json:"clients_by_session_id"` +} + +func (store *storage) SessionGet(id string) (*types.Session, error) { store.rw.Lock() defer store.rw.Unlock() - s, found := store.db[sessionId] + s, found := store.db.Sessions[id] if !found { - return nil, fmt.Errorf("%s", notFound) + return nil, notFound } return s, nil } -func (store *storage) SessionGetAll() (map[string]*types.Session, error) { +func (store *storage) SessionGetAll() ([]*types.Session, error) { store.rw.Lock() defer store.rw.Unlock() - return store.db, nil + sessions := make([]*types.Session, len(store.db.Sessions)) + i := 0 + for _, s := range store.db.Sessions { + sessions[i] = s + i++ + } + + return sessions, nil } -func (store *storage) SessionPut(s *types.Session) error { +func (store *storage) SessionPut(session *types.Session) error { store.rw.Lock() defer store.rw.Unlock() - // Initialize instances map if nil - if s.Instances == nil { - s.Instances = map[string]*types.Instance{} - } - store.db[s.Id] = s + store.db.Sessions[session.Id] = session return store.save() } -func (store *storage) InstanceGetAllWindows() ([]*types.WindowsInstance, error) { +func (store *storage) SessionDelete(id string) error { store.rw.Lock() defer store.rw.Unlock() - instances := []*types.WindowsInstance{} - - for _, s := range store.db { - instances = append(instances, s.WindowsAssigned...) - } - - return instances, nil -} - -func (store *storage) InstanceGet(sessionId, name string) (*types.Instance, error) { - store.rw.Lock() - defer store.rw.Unlock() - - s := store.db[sessionId] - if s == nil { - return nil, fmt.Errorf("%s", notFound) - } - i := s.Instances[name] - if i == nil { - return nil, fmt.Errorf("%s", notFound) - } - return i, nil -} - -func (store *storage) InstanceFindByIP(sessionId, ip string) (*types.Instance, error) { - store.rw.Lock() - defer store.rw.Unlock() - - for id, s := range store.db { - if strings.HasPrefix(id, sessionId[:8]) { - for _, i := range s.Instances { - if i.IP == ip { - return i, nil - } - } - } - } - - return nil, fmt.Errorf("%s", notFound) -} - -func (store *storage) InstanceCreate(sessionId string, instance *types.Instance) error { - store.rw.Lock() - defer store.rw.Unlock() - - s, found := store.db[sessionId] + _, found := store.db.Sessions[id] if !found { - return fmt.Errorf("Session %s", notFound) - } - - s.Instances[instance.Name] = instance - - return store.save() -} - -func (store *storage) InstanceCreateWindows(instance *types.WindowsInstance) error { - store.rw.Lock() - defer store.rw.Unlock() - - s, found := store.db[instance.SessionId] - if !found { - return fmt.Errorf("Session %s", notFound) - } - - s.WindowsAssigned = append(s.WindowsAssigned, instance) - - return store.save() -} - -func (store *storage) InstanceDelete(sessionId, name string) error { - store.rw.Lock() - defer store.rw.Unlock() - - s, found := store.db[sessionId] - if !found { - return fmt.Errorf("Session %s", notFound) - } - - if _, found := s.Instances[name]; !found { return nil } - delete(s.Instances, name) - - return store.save() -} - -func (store *storage) InstanceDeleteWindows(sessionId, id string) error { - store.rw.Lock() - defer store.rw.Unlock() - - s, found := store.db[sessionId] - if !found { - return fmt.Errorf("Session %s", notFound) + for _, i := range store.db.WindowsInstancesBySessionId[id] { + delete(store.db.WindowsInstances, i) } - - for i, winst := range s.WindowsAssigned { - if winst.ID == id { - s.WindowsAssigned = append(s.WindowsAssigned[:i], s.WindowsAssigned[i+1:]...) - } - + store.db.WindowsInstancesBySessionId[id] = []string{} + for _, i := range store.db.InstancesBySessionId[id] { + delete(store.db.Instances, i) } + store.db.InstancesBySessionId[id] = []string{} + for _, i := range store.db.ClientsBySessionId[id] { + delete(store.db.Clients, i) + } + store.db.ClientsBySessionId[id] = []string{} + delete(store.db.Sessions, id) return store.save() } @@ -161,30 +89,217 @@ func (store *storage) SessionCount() (int, error) { store.rw.Lock() defer store.rw.Unlock() - return len(store.db), nil + return len(store.db.Sessions), nil +} + +func (store *storage) InstanceGet(name string) (*types.Instance, error) { + store.rw.Lock() + defer store.rw.Unlock() + + i := store.db.Instances[name] + if i == nil { + return nil, notFound + } + return i, nil +} + +func (store *storage) InstancePut(instance *types.Instance) error { + store.rw.Lock() + defer store.rw.Unlock() + + _, found := store.db.Sessions[string(instance.SessionId)] + if !found { + return notFound + } + + store.db.Instances[instance.Name] = instance + found = false + for _, i := range store.db.InstancesBySessionId[string(instance.SessionId)] { + if i == instance.Name { + found = true + break + } + } + if !found { + store.db.InstancesBySessionId[string(instance.SessionId)] = append(store.db.InstancesBySessionId[string(instance.SessionId)], instance.Name) + } + + return store.save() +} + +func (store *storage) InstanceDelete(name string) error { + store.rw.Lock() + defer store.rw.Unlock() + + instance, found := store.db.Instances[name] + if !found { + return nil + } + + instances := store.db.InstancesBySessionId[string(instance.SessionId)] + for n, i := range instances { + if i == name { + instances = append(instances[:n], instances[n+1:]...) + break + } + } + store.db.InstancesBySessionId[string(instance.SessionId)] = instances + delete(store.db.Instances, name) + + return store.save() } func (store *storage) InstanceCount() (int, error) { store.rw.Lock() defer store.rw.Unlock() - var ins int - - for _, s := range store.db { - ins += len(s.Instances) - } - - return ins, nil + return len(store.db.Instances), nil } -func (store *storage) SessionDelete(sessionId string) error { +func (store *storage) InstanceFindBySessionId(sessionId string) ([]*types.Instance, error) { store.rw.Lock() defer store.rw.Unlock() - delete(store.db, sessionId) + instanceIds := store.db.InstancesBySessionId[sessionId] + instances := make([]*types.Instance, len(instanceIds)) + for i, id := range instanceIds { + instances[i] = store.db.Instances[id] + } + + return instances, nil +} + +func (store *storage) WindowsInstanceGetAll() ([]*types.WindowsInstance, error) { + store.rw.Lock() + defer store.rw.Unlock() + + instances := []*types.WindowsInstance{} + + for _, s := range store.db.WindowsInstances { + instances = append(instances, s) + } + + return instances, nil +} + +func (store *storage) WindowsInstancePut(instance *types.WindowsInstance) error { + store.rw.Lock() + defer store.rw.Unlock() + + _, found := store.db.Sessions[string(instance.SessionId)] + if !found { + return notFound + } + store.db.WindowsInstances[instance.Id] = instance + found = false + for _, i := range store.db.WindowsInstancesBySessionId[string(instance.SessionId)] { + if i == instance.Id { + found = true + break + } + } + if !found { + store.db.WindowsInstancesBySessionId[string(instance.SessionId)] = append(store.db.WindowsInstancesBySessionId[string(instance.SessionId)], instance.Id) + } + return store.save() } +func (store *storage) WindowsInstanceDelete(id string) error { + store.rw.Lock() + defer store.rw.Unlock() + + instance, found := store.db.WindowsInstances[id] + if !found { + return nil + } + + instances := store.db.WindowsInstancesBySessionId[string(instance.SessionId)] + for n, i := range instances { + if i == id { + instances = append(instances[:n], instances[n+1:]...) + break + } + } + store.db.WindowsInstancesBySessionId[string(instance.SessionId)] = instances + delete(store.db.WindowsInstances, id) + + return store.save() +} + +func (store *storage) ClientGet(id string) (*types.Client, error) { + store.rw.Lock() + defer store.rw.Unlock() + + i := store.db.Clients[id] + if i == nil { + return nil, notFound + } + return i, nil +} +func (store *storage) ClientPut(client *types.Client) error { + store.rw.Lock() + defer store.rw.Unlock() + + _, found := store.db.Sessions[string(client.SessionId)] + if !found { + return notFound + } + + store.db.Clients[client.Id] = client + found = false + for _, i := range store.db.ClientsBySessionId[string(client.SessionId)] { + if i == client.Id { + found = true + break + } + } + if !found { + store.db.ClientsBySessionId[string(client.SessionId)] = append(store.db.ClientsBySessionId[string(client.SessionId)], client.Id) + } + + return store.save() +} +func (store *storage) ClientDelete(id string) error { + store.rw.Lock() + defer store.rw.Unlock() + + client, found := store.db.Clients[id] + if !found { + return nil + } + + clients := store.db.ClientsBySessionId[string(client.SessionId)] + for n, i := range clients { + if i == client.Id { + clients = append(clients[:n], clients[n+1:]...) + break + } + } + store.db.ClientsBySessionId[string(client.SessionId)] = clients + delete(store.db.Clients, id) + + return store.save() +} +func (store *storage) ClientCount() (int, error) { + store.rw.Lock() + defer store.rw.Unlock() + + return len(store.db.Clients), nil +} +func (store *storage) ClientFindBySessionId(sessionId string) ([]*types.Client, error) { + store.rw.Lock() + defer store.rw.Unlock() + + clientIds := store.db.ClientsBySessionId[sessionId] + clients := make([]*types.Client, len(clientIds)) + for i, id := range clientIds { + clients[i] = store.db.Clients[id] + } + + return clients, nil +} + func (store *storage) load() error { file, err := os.Open(store.path) @@ -196,7 +311,15 @@ func (store *storage) load() error { return err } } else { - store.db = map[string]*types.Session{} + store.db = &DB{ + Sessions: map[string]*types.Session{}, + Instances: map[string]*types.Instance{}, + Clients: map[string]*types.Client{}, + WindowsInstances: map[string]*types.WindowsInstance{}, + WindowsInstancesBySessionId: map[string][]string{}, + InstancesBySessionId: map[string][]string{}, + ClientsBySessionId: map[string][]string{}, + } } file.Close() diff --git a/storage/file_test.go b/storage/file_test.go index aeda53d..edb6b99 100644 --- a/storage/file_test.go +++ b/storage/file_test.go @@ -29,9 +29,16 @@ func TestSessionPut(t *testing.T) { assert.Nil(t, err) - var loadedSessions map[string]*types.Session - expectedSessions := map[string]*types.Session{} - expectedSessions[s.Id] = s + expectedDB := &DB{ + Sessions: map[string]*types.Session{s.Id: s}, + Instances: map[string]*types.Instance{}, + Clients: map[string]*types.Client{}, + WindowsInstances: map[string]*types.WindowsInstance{}, + WindowsInstancesBySessionId: map[string][]string{}, + InstancesBySessionId: map[string][]string{}, + ClientsBySessionId: map[string][]string{}, + } + var loadedDB *DB file, err := os.Open(tmpfile.Name()) @@ -39,24 +46,31 @@ func TestSessionPut(t *testing.T) { defer file.Close() decoder := json.NewDecoder(file) - err = decoder.Decode(&loadedSessions) + err = decoder.Decode(&loadedDB) assert.Nil(t, err) - assert.EqualValues(t, expectedSessions, loadedSessions) + assert.EqualValues(t, expectedDB, loadedDB) } func TestSessionGet(t *testing.T) { - expectedSession := &types.Session{Id: "session1"} - sessions := map[string]*types.Session{} - sessions[expectedSession.Id] = expectedSession + expectedSession := &types.Session{Id: "aaabbbccc"} + expectedDB := &DB{ + Sessions: map[string]*types.Session{expectedSession.Id: expectedSession}, + Instances: map[string]*types.Instance{}, + Clients: map[string]*types.Client{}, + WindowsInstances: map[string]*types.WindowsInstance{}, + WindowsInstancesBySessionId: map[string][]string{}, + InstancesBySessionId: map[string][]string{}, + ClientsBySessionId: map[string][]string{}, + } tmpfile, err := ioutil.TempFile("", "pwd") if err != nil { log.Fatal(err) } encoder := json.NewEncoder(tmpfile) - err = encoder.Encode(&sessions) + err = encoder.Encode(&expectedDB) assert.Nil(t, err) tmpfile.Close() defer os.Remove(tmpfile.Name()) @@ -65,28 +79,34 @@ func TestSessionGet(t *testing.T) { assert.Nil(t, err) - _, err = storage.SessionGet("bad id") + _, err = storage.SessionGet("foobar") assert.True(t, NotFound(err)) - loadedSession, err := storage.SessionGet("session1") + loadedSession, err := storage.SessionGet("aaabbbccc") assert.Nil(t, err) assert.Equal(t, expectedSession, loadedSession) } func TestSessionGetAll(t *testing.T) { - s1 := &types.Session{Id: "session1"} - s2 := &types.Session{Id: "session2"} - sessions := map[string]*types.Session{} - sessions[s1.Id] = s1 - sessions[s2.Id] = s2 + s1 := &types.Session{Id: "aaabbbccc"} + s2 := &types.Session{Id: "dddeeefff"} + expectedDB := &DB{ + Sessions: map[string]*types.Session{s1.Id: s1, s2.Id: s2}, + Instances: map[string]*types.Instance{}, + Clients: map[string]*types.Client{}, + WindowsInstances: map[string]*types.WindowsInstance{}, + WindowsInstancesBySessionId: map[string][]string{}, + InstancesBySessionId: map[string][]string{}, + ClientsBySessionId: map[string][]string{}, + } tmpfile, err := ioutil.TempFile("", "pwd") if err != nil { log.Fatal(err) } encoder := json.NewEncoder(tmpfile) - err = encoder.Encode(&sessions) + err = encoder.Encode(&expectedDB) assert.Nil(t, err) tmpfile.Close() defer os.Remove(tmpfile.Name()) @@ -95,216 +115,11 @@ func TestSessionGetAll(t *testing.T) { assert.Nil(t, err) - loadedSessions, err := storage.SessionGetAll() + sessions, err := storage.SessionGetAll() assert.Nil(t, err) - assert.Equal(t, s1, loadedSessions[s1.Id]) - assert.Equal(t, s2, loadedSessions[s2.Id]) -} - -func TestInstanceFindByIP(t *testing.T) { - tmpfile, err := ioutil.TempFile("", "pwd") - if err != nil { - log.Fatal(err) - } - tmpfile.Close() - os.Remove(tmpfile.Name()) - defer os.Remove(tmpfile.Name()) - - storage, err := NewFileStorage(tmpfile.Name()) - - assert.Nil(t, err) - - i1 := &types.Instance{Name: "i1", IP: "10.0.0.1"} - i2 := &types.Instance{Name: "i2", IP: "10.1.0.1"} - s1 := &types.Session{Id: "session1", Instances: map[string]*types.Instance{"i1": i1}} - s2 := &types.Session{Id: "session2", Instances: map[string]*types.Instance{"i2": i2}} - err = storage.SessionPut(s1) - assert.Nil(t, err) - err = storage.SessionPut(s2) - assert.Nil(t, err) - - foundInstance, err := storage.InstanceFindByIP("session1", "10.0.0.1") - assert.Nil(t, err) - assert.Equal(t, i1, foundInstance) - - foundInstance, err = storage.InstanceFindByIP("session2", "10.1.0.1") - assert.Nil(t, err) - assert.Equal(t, i2, foundInstance) - - foundInstance, err = storage.InstanceFindByIP("session3", "10.1.0.1") - assert.True(t, NotFound(err)) - assert.Nil(t, foundInstance) - - foundInstance, err = storage.InstanceFindByIP("session1", "10.1.0.1") - assert.True(t, NotFound(err)) - assert.Nil(t, foundInstance) - - foundInstance, err = storage.InstanceFindByIP("session1", "192.168.0.1") - assert.True(t, NotFound(err)) - assert.Nil(t, foundInstance) -} - -func TestInstanceGet(t *testing.T) { - tmpfile, err := ioutil.TempFile("", "pwd") - if err != nil { - log.Fatal(err) - } - tmpfile.Close() - os.Remove(tmpfile.Name()) - defer os.Remove(tmpfile.Name()) - - storage, err := NewFileStorage(tmpfile.Name()) - - assert.Nil(t, err) - - i1 := &types.Instance{Name: "i1", IP: "10.0.0.1"} - s1 := &types.Session{Id: "session1", Instances: map[string]*types.Instance{"i1": i1}} - err = storage.SessionPut(s1) - assert.Nil(t, err) - - foundInstance, err := storage.InstanceGet("session1", "i1") - assert.Nil(t, err) - assert.Equal(t, i1, foundInstance) -} - -func TestInstanceGetAllWindows(t *testing.T) { - tmpfile, err := ioutil.TempFile("", "pwd") - if err != nil { - log.Fatal(err) - } - tmpfile.Close() - os.Remove(tmpfile.Name()) - defer os.Remove(tmpfile.Name()) - - storage, err := NewFileStorage(tmpfile.Name()) - - assert.Nil(t, err) - w1 := []*types.WindowsInstance{{ID: "one"}, {ID: "two"}} - w2 := []*types.WindowsInstance{{ID: "three"}, {ID: "four"}} - s1 := &types.Session{Id: "session1", WindowsAssigned: w1} - s2 := &types.Session{Id: "session2", WindowsAssigned: w2} - err = storage.SessionPut(s1) - err = storage.SessionPut(s2) - assert.Nil(t, err) - - allw, err := storage.InstanceGetAllWindows() - assert.Nil(t, err) - assert.Equal(t, allw, append(w1, w2...)) -} - -func TestInstanceCreate(t *testing.T) { - tmpfile, err := ioutil.TempFile("", "pwd") - if err != nil { - log.Fatal(err) - } - tmpfile.Close() - os.Remove(tmpfile.Name()) - defer os.Remove(tmpfile.Name()) - - storage, err := NewFileStorage(tmpfile.Name()) - - assert.Nil(t, err) - - i1 := &types.Instance{Name: "i1", IP: "10.0.0.1"} - s1 := &types.Session{Id: "session1"} - err = storage.SessionPut(s1) - assert.Nil(t, err) - err = storage.InstanceCreate(s1.Id, i1) - assert.Nil(t, err) - - loadedSession, err := storage.SessionGet("session1") - assert.Nil(t, err) - - assert.Equal(t, i1, loadedSession.Instances["i1"]) - -} - -func TestInstanceCreateWindows(t *testing.T) { - tmpfile, err := ioutil.TempFile("", "pwd") - if err != nil { - log.Fatal(err) - } - tmpfile.Close() - os.Remove(tmpfile.Name()) - defer os.Remove(tmpfile.Name()) - - storage, err := NewFileStorage(tmpfile.Name()) - - assert.Nil(t, err) - - s1 := &types.Session{Id: "session1"} - i1 := &types.WindowsInstance{SessionId: s1.Id, ID: "some id"} - err = storage.SessionPut(s1) - assert.Nil(t, err) - err = storage.InstanceCreateWindows(i1) - assert.Nil(t, err) - - loadedSession, err := storage.SessionGet("session1") - assert.Nil(t, err) - - assert.Equal(t, i1, loadedSession.WindowsAssigned[0]) -} - -func TestInstanceDeleteWindows(t *testing.T) { - tmpfile, err := ioutil.TempFile("", "pwd") - if err != nil { - log.Fatal(err) - } - tmpfile.Close() - os.Remove(tmpfile.Name()) - defer os.Remove(tmpfile.Name()) - - storage, err := NewFileStorage(tmpfile.Name()) - - assert.Nil(t, err) - - s1 := &types.Session{Id: "session1", WindowsAssigned: []*types.WindowsInstance{{ID: "one"}}} - err = storage.SessionPut(s1) - assert.Nil(t, err) - - err = storage.InstanceDeleteWindows(s1.Id, "one") - assert.Nil(t, err) - - found, err := storage.SessionGet(s1.Id) - assert.Equal(t, 0, len(found.WindowsAssigned)) -} - -func TestCounts(t *testing.T) { - tmpfile, err := ioutil.TempFile("", "pwd") - if err != nil { - log.Fatal(err) - } - tmpfile.Close() - os.Remove(tmpfile.Name()) - defer os.Remove(tmpfile.Name()) - - storage, err := NewFileStorage(tmpfile.Name()) - - assert.Nil(t, err) - - c1 := &types.Client{} - i1 := &types.Instance{Name: "i1", IP: "10.0.0.1"} - i2 := &types.Instance{Name: "i2", IP: "10.1.0.1"} - s1 := &types.Session{Id: "session1", Instances: map[string]*types.Instance{"i1": i1}} - s2 := &types.Session{Id: "session2", Instances: map[string]*types.Instance{"i2": i2}} - s3 := &types.Session{Id: "session3", Clients: []*types.Client{c1}} - - err = storage.SessionPut(s1) - assert.Nil(t, err) - err = storage.SessionPut(s2) - assert.Nil(t, err) - err = storage.SessionPut(s3) - assert.Nil(t, err) - - num, err := storage.SessionCount() - assert.Nil(t, err) - assert.Equal(t, 3, num) - - num, err = storage.InstanceCount() - assert.Nil(t, err) - assert.Equal(t, 2, num) - + assert.Subset(t, sessions, []*types.Session{s1, s2}) + assert.Len(t, sessions, 2) } func TestSessionDelete(t *testing.T) { @@ -335,3 +150,401 @@ func TestSessionDelete(t *testing.T) { assert.True(t, NotFound(err)) assert.Nil(t, found) } + +func TestInstanceGet(t *testing.T) { + expectedInstance := &types.Instance{SessionId: "aaabbbccc", Name: "i1", IP: "10.0.0.1"} + expectedDB := &DB{ + Sessions: map[string]*types.Session{}, + Instances: map[string]*types.Instance{expectedInstance.Name: expectedInstance}, + Clients: map[string]*types.Client{}, + WindowsInstances: map[string]*types.WindowsInstance{}, + WindowsInstancesBySessionId: map[string][]string{}, + InstancesBySessionId: map[string][]string{expectedInstance.SessionId: []string{expectedInstance.Name}}, + ClientsBySessionId: map[string][]string{}, + } + + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + encoder := json.NewEncoder(tmpfile) + err = encoder.Encode(&expectedDB) + assert.Nil(t, err) + tmpfile.Close() + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + foundInstance, err := storage.InstanceGet("i1") + assert.Nil(t, err) + assert.Equal(t, expectedInstance, foundInstance) +} + +func TestInstancePut(t *testing.T) { + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + tmpfile.Close() + os.Remove(tmpfile.Name()) + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + s := &types.Session{Id: "aaabbbccc"} + i := &types.Instance{Name: "i1", IP: "10.0.0.1", SessionId: s.Id} + + err = storage.SessionPut(s) + assert.Nil(t, err) + + err = storage.InstancePut(i) + assert.Nil(t, err) + + expectedDB := &DB{ + Sessions: map[string]*types.Session{s.Id: s}, + Instances: map[string]*types.Instance{i.Name: i}, + Clients: map[string]*types.Client{}, + WindowsInstances: map[string]*types.WindowsInstance{}, + WindowsInstancesBySessionId: map[string][]string{}, + InstancesBySessionId: map[string][]string{i.SessionId: []string{i.Name}}, + ClientsBySessionId: map[string][]string{}, + } + var loadedDB *DB + + file, err := os.Open(tmpfile.Name()) + + assert.Nil(t, err) + defer file.Close() + + decoder := json.NewDecoder(file) + err = decoder.Decode(&loadedDB) + + assert.Nil(t, err) + + assert.EqualValues(t, expectedDB, loadedDB) +} + +func TestInstanceDelete(t *testing.T) { + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + tmpfile.Close() + os.Remove(tmpfile.Name()) + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + s := &types.Session{Id: "session1"} + err = storage.SessionPut(s) + assert.Nil(t, err) + + i := &types.Instance{Name: "i1", IP: "10.0.0.1", SessionId: s.Id} + err = storage.InstancePut(i) + assert.Nil(t, err) + + found, err := storage.InstanceGet(i.Name) + assert.Nil(t, err) + assert.Equal(t, i, found) + + err = storage.InstanceDelete(i.Name) + assert.Nil(t, err) + + found, err = storage.InstanceGet(i.Name) + assert.True(t, NotFound(err)) + assert.Nil(t, found) +} + +func TestInstanceFindBySessionId(t *testing.T) { + i1 := &types.Instance{SessionId: "aaabbbccc", Name: "c1"} + i2 := &types.Instance{SessionId: "aaabbbccc", Name: "c2"} + expectedDB := &DB{ + Sessions: map[string]*types.Session{}, + Instances: map[string]*types.Instance{i1.Name: i1, i2.Name: i2}, + Clients: map[string]*types.Client{}, + WindowsInstances: map[string]*types.WindowsInstance{}, + WindowsInstancesBySessionId: map[string][]string{}, + InstancesBySessionId: map[string][]string{i1.SessionId: []string{i1.Name, i2.Name}}, + ClientsBySessionId: map[string][]string{}, + } + + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + encoder := json.NewEncoder(tmpfile) + err = encoder.Encode(&expectedDB) + assert.Nil(t, err) + tmpfile.Close() + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + instances, err := storage.InstanceFindBySessionId("aaabbbccc") + assert.Nil(t, err) + assert.Subset(t, instances, []*types.Instance{i1, i2}) + assert.Len(t, instances, 2) +} + +func TestWindowsInstanceGetAll(t *testing.T) { + i1 := &types.WindowsInstance{SessionId: "aaabbbccc", Id: "i1"} + i2 := &types.WindowsInstance{SessionId: "aaabbbccc", Id: "i2"} + expectedDB := &DB{ + Sessions: map[string]*types.Session{}, + Instances: map[string]*types.Instance{}, + Clients: map[string]*types.Client{}, + WindowsInstances: map[string]*types.WindowsInstance{i1.Id: i1, i2.Id: i2}, + WindowsInstancesBySessionId: map[string][]string{i1.SessionId: []string{i1.Id, i2.Id}}, + InstancesBySessionId: map[string][]string{}, + ClientsBySessionId: map[string][]string{}, + } + + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + encoder := json.NewEncoder(tmpfile) + err = encoder.Encode(&expectedDB) + assert.Nil(t, err) + tmpfile.Close() + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + instances, err := storage.WindowsInstanceGetAll() + assert.Nil(t, err) + assert.Subset(t, instances, []*types.WindowsInstance{i1, i2}) + assert.Len(t, instances, 2) +} + +func TestWindowsInstancePut(t *testing.T) { + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + tmpfile.Close() + os.Remove(tmpfile.Name()) + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + s := &types.Session{Id: "aaabbbccc"} + i := &types.WindowsInstance{Id: "i1", SessionId: s.Id} + + err = storage.SessionPut(s) + assert.Nil(t, err) + + err = storage.WindowsInstancePut(i) + assert.Nil(t, err) + + expectedDB := &DB{ + Sessions: map[string]*types.Session{s.Id: s}, + Instances: map[string]*types.Instance{}, + Clients: map[string]*types.Client{}, + WindowsInstances: map[string]*types.WindowsInstance{i.Id: i}, + WindowsInstancesBySessionId: map[string][]string{i.SessionId: []string{i.Id}}, + InstancesBySessionId: map[string][]string{}, + ClientsBySessionId: map[string][]string{}, + } + var loadedDB *DB + + file, err := os.Open(tmpfile.Name()) + + assert.Nil(t, err) + defer file.Close() + + decoder := json.NewDecoder(file) + err = decoder.Decode(&loadedDB) + + assert.Nil(t, err) + + assert.EqualValues(t, expectedDB, loadedDB) +} + +func TestWindowsInstanceDelete(t *testing.T) { + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + tmpfile.Close() + os.Remove(tmpfile.Name()) + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + s := &types.Session{Id: "session1"} + err = storage.SessionPut(s) + assert.Nil(t, err) + + i := &types.WindowsInstance{Id: "i1", SessionId: s.Id} + err = storage.WindowsInstancePut(i) + assert.Nil(t, err) + + found, err := storage.WindowsInstanceGetAll() + assert.Nil(t, err) + assert.Equal(t, []*types.WindowsInstance{i}, found) + + err = storage.WindowsInstanceDelete(i.Id) + assert.Nil(t, err) + + found, err = storage.WindowsInstanceGetAll() + assert.Nil(t, err) + assert.Empty(t, found) +} + +func TestClientGet(t *testing.T) { + c := &types.Client{SessionId: "aaabbbccc", Id: "c1"} + expectedDB := &DB{ + Sessions: map[string]*types.Session{}, + Instances: map[string]*types.Instance{}, + Clients: map[string]*types.Client{c.Id: c}, + WindowsInstances: map[string]*types.WindowsInstance{}, + WindowsInstancesBySessionId: map[string][]string{}, + InstancesBySessionId: map[string][]string{}, + ClientsBySessionId: map[string][]string{c.SessionId: []string{c.Id}}, + } + + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + encoder := json.NewEncoder(tmpfile) + err = encoder.Encode(&expectedDB) + assert.Nil(t, err) + tmpfile.Close() + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + found, err := storage.ClientGet("c1") + assert.Nil(t, err) + assert.Equal(t, c, found) +} + +func TestClientPut(t *testing.T) { + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + tmpfile.Close() + os.Remove(tmpfile.Name()) + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + s := &types.Session{Id: "aaabbbccc"} + c := &types.Client{Id: "c1", SessionId: s.Id} + + err = storage.SessionPut(s) + assert.Nil(t, err) + + err = storage.ClientPut(c) + assert.Nil(t, err) + + expectedDB := &DB{ + Sessions: map[string]*types.Session{s.Id: s}, + Instances: map[string]*types.Instance{}, + Clients: map[string]*types.Client{c.Id: c}, + WindowsInstances: map[string]*types.WindowsInstance{}, + WindowsInstancesBySessionId: map[string][]string{}, + InstancesBySessionId: map[string][]string{}, + ClientsBySessionId: map[string][]string{c.SessionId: []string{c.Id}}, + } + var loadedDB *DB + + file, err := os.Open(tmpfile.Name()) + + assert.Nil(t, err) + defer file.Close() + + decoder := json.NewDecoder(file) + err = decoder.Decode(&loadedDB) + + assert.Nil(t, err) + + assert.EqualValues(t, expectedDB, loadedDB) +} + +func TestClientDelete(t *testing.T) { + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + tmpfile.Close() + os.Remove(tmpfile.Name()) + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + s := &types.Session{Id: "session1"} + err = storage.SessionPut(s) + assert.Nil(t, err) + + c := &types.Client{Id: "c1", SessionId: s.Id} + err = storage.ClientPut(c) + assert.Nil(t, err) + + found, err := storage.ClientGet(c.Id) + assert.Nil(t, err) + assert.Equal(t, c, found) + + err = storage.ClientDelete(c.Id) + assert.Nil(t, err) + + found, err = storage.ClientGet(c.Id) + assert.True(t, NotFound(err)) + assert.Nil(t, found) +} + +func TestClientFindBySessionId(t *testing.T) { + c1 := &types.Client{SessionId: "aaabbbccc", Id: "c1"} + c2 := &types.Client{SessionId: "aaabbbccc", Id: "c2"} + expectedDB := &DB{ + Sessions: map[string]*types.Session{}, + Instances: map[string]*types.Instance{}, + Clients: map[string]*types.Client{c1.Id: c1, c2.Id: c2}, + WindowsInstances: map[string]*types.WindowsInstance{}, + WindowsInstancesBySessionId: map[string][]string{}, + InstancesBySessionId: map[string][]string{}, + ClientsBySessionId: map[string][]string{c1.SessionId: []string{c1.Id, c2.Id}}, + } + + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + encoder := json.NewEncoder(tmpfile) + err = encoder.Encode(&expectedDB) + assert.Nil(t, err) + tmpfile.Close() + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + clients, err := storage.ClientFindBySessionId("aaabbbccc") + assert.Nil(t, err) + assert.Subset(t, clients, []*types.Client{c1, c2}) + assert.Len(t, clients, 2) +} diff --git a/storage/mock.go b/storage/mock.go index f8ef5d8..490ec05 100644 --- a/storage/mock.go +++ b/storage/mock.go @@ -9,67 +9,76 @@ type Mock struct { mock.Mock } -func (m *Mock) SessionGet(sessionId string) (*types.Session, error) { - args := m.Called(sessionId) +func (m *Mock) SessionGet(id string) (*types.Session, error) { + args := m.Called(id) return args.Get(0).(*types.Session), args.Error(1) } - +func (m *Mock) SessionGetAll() ([]*types.Session, error) { + args := m.Called() + return args.Get(0).([]*types.Session), args.Error(1) +} func (m *Mock) SessionPut(session *types.Session) error { args := m.Called(session) return args.Error(0) } - +func (m *Mock) SessionDelete(id string) error { + args := m.Called(id) + return args.Error(0) +} func (m *Mock) SessionCount() (int, error) { args := m.Called() return args.Int(0), args.Error(1) } - -func (m *Mock) SessionDelete(sessionId string) error { - args := m.Called(sessionId) - return args.Error(0) -} - -func (m *Mock) SessionGetAll() (map[string]*types.Session, error) { - args := m.Called() - return args.Get(0).(map[string]*types.Session), args.Error(1) -} - -func (m *Mock) InstanceGet(sessionId, name string) (*types.Instance, error) { - args := m.Called(sessionId, name) +func (m *Mock) InstanceGet(name string) (*types.Instance, error) { + args := m.Called(name) return args.Get(0).(*types.Instance), args.Error(1) } - -func (m *Mock) InstanceGetAllWindows() ([]*types.WindowsInstance, error) { - args := m.Called() - return args.Get(0).([]*types.WindowsInstance), args.Error(1) -} - -func (m *Mock) InstanceFindByIP(sessionId, ip string) (*types.Instance, error) { - args := m.Called(sessionId, ip) - return args.Get(0).(*types.Instance), args.Error(1) -} - -func (m *Mock) InstanceCreate(sessionId string, instance *types.Instance) error { - args := m.Called(sessionId, instance) - return args.Error(0) -} - -func (m *Mock) InstanceCreateWindows(instance *types.WindowsInstance) error { +func (m *Mock) InstancePut(instance *types.Instance) error { args := m.Called(instance) return args.Error(0) } - -func (m *Mock) InstanceDelete(sessionId, instanceName string) error { - args := m.Called(sessionId, instanceName) +func (m *Mock) InstanceDelete(name string) error { + args := m.Called(name) return args.Error(0) } - -func (m *Mock) InstanceDeleteWindows(sessionId, instanceId string) error { - args := m.Called(sessionId, instanceId) - return args.Error(0) -} - func (m *Mock) InstanceCount() (int, error) { args := m.Called() return args.Int(0), args.Error(1) } +func (m *Mock) InstanceFindBySessionId(sessionId string) ([]*types.Instance, error) { + args := m.Called(sessionId) + return args.Get(0).([]*types.Instance), args.Error(1) +} + +func (m *Mock) WindowsInstanceGetAll() ([]*types.WindowsInstance, error) { + args := m.Called() + return args.Get(0).([]*types.WindowsInstance), args.Error(1) +} +func (m *Mock) WindowsInstancePut(instance *types.WindowsInstance) error { + args := m.Called(instance) + return args.Error(0) +} +func (m *Mock) WindowsInstanceDelete(id string) error { + args := m.Called(id) + return args.Error(0) +} +func (m *Mock) ClientGet(id string) (*types.Client, error) { + args := m.Called(id) + return args.Get(0).(*types.Client), args.Error(1) +} +func (m *Mock) ClientPut(client *types.Client) error { + args := m.Called(client) + return args.Error(0) +} +func (m *Mock) ClientDelete(id string) error { + args := m.Called(id) + return args.Error(0) +} +func (m *Mock) ClientCount() (int, error) { + args := m.Called() + return args.Int(0), args.Error(1) +} +func (m *Mock) ClientFindBySessionId(sessionId string) ([]*types.Client, error) { + args := m.Called(sessionId) + return args.Get(0).([]*types.Client), args.Error(1) +} diff --git a/storage/storage.go b/storage/storage.go index 90d90e7..2dacbd6 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -1,34 +1,37 @@ package storage import ( - "fmt" + "errors" "github.com/play-with-docker/play-with-docker/pwd/types" ) -const notFound = "NotFound" +var notFound = errors.New("NotFound") func NotFound(e error) bool { - return e.Error() == notFound -} - -func NewNotFoundError() error { - return fmt.Errorf("%s", notFound) + return e == notFound } type StorageApi interface { - SessionGet(string) (*types.Session, error) - SessionPut(*types.Session) error + SessionGet(id string) (*types.Session, error) + SessionGetAll() ([]*types.Session, error) + SessionPut(session *types.Session) error + SessionDelete(id string) error SessionCount() (int, error) - SessionDelete(string) error - SessionGetAll() (map[string]*types.Session, error) - InstanceGet(sessionId, name string) (*types.Instance, error) - InstanceFindByIP(session, ip string) (*types.Instance, error) - InstanceCreate(sessionId string, instance *types.Instance) error - InstanceDelete(sessionId, instanceName string) error - InstanceDeleteWindows(sessionId, instanceId string) error + InstanceGet(name string) (*types.Instance, error) + InstancePut(instance *types.Instance) error + InstanceDelete(name string) error InstanceCount() (int, error) - InstanceGetAllWindows() ([]*types.WindowsInstance, error) - InstanceCreateWindows(*types.WindowsInstance) error + InstanceFindBySessionId(sessionId string) ([]*types.Instance, error) + + WindowsInstanceGetAll() ([]*types.WindowsInstance, error) + WindowsInstancePut(instance *types.WindowsInstance) error + WindowsInstanceDelete(id string) error + + ClientGet(id string) (*types.Client, error) + ClientPut(client *types.Client) error + ClientDelete(id string) error + ClientCount() (int, error) + ClientFindBySessionId(sessionId string) ([]*types.Client, error) } diff --git a/www/assets/app.js b/www/assets/app.js index 5537042..553c7b2 100644 --- a/www/assets/app.js +++ b/www/assets/app.js @@ -59,9 +59,16 @@ var selectedKeyboardShortcuts = KeyboardShortcutService.getCurrentShortcuts(); + $scope.resizeHandler = null; + angular.element($window).bind('resize', function() { if ($scope.selectedInstance) { - $scope.resize($scope.selectedInstance.term.proposeGeometry()); + if (!$scope.resizeHandler) { + $scope.resizeHandler = setTimeout(function() { + $scope.resizeHandler = null + $scope.resize($scope.selectedInstance.term.proposeGeometry()); + }, 1000); + } } }); @@ -263,7 +270,12 @@ // If instance is passed in URL, select it let inst = $scope.idx[$location.hash()]; - if (inst) $scope.showInstance(inst); + if (inst) { + $scope.showInstance(inst); + } else if($scope.instances.length > 0) { + // if no instance has been passed, select the first. + $scope.showInstance($scope.instances[0]); + } }, function(response) { if (response.status == 404) { document.write('session not found');