diff --git a/handlers/bootstrap.go b/handlers/bootstrap.go index 9e8e90b..3ce0438 100644 --- a/handlers/bootstrap.go +++ b/handlers/bootstrap.go @@ -32,7 +32,7 @@ func Bootstrap() { s, err := storage.NewFileStorage(config.SessionsFile) if err != nil && !os.IsNotExist(err) { - log.Fatal("Error decoding sessions from disk ", err) + log.Fatal("Error initializing StorageAPI: ", err) } core = pwd.NewPWD(d, t, Broadcast, s) diff --git a/handlers/ws.go b/handlers/ws.go index 37e8b8f..e0fea5d 100644 --- a/handlers/ws.go +++ b/handlers/ws.go @@ -34,8 +34,7 @@ func WS(so socketio.Socket) { so.On("terminal in", func(name, data string) { // User wrote something on the terminal. Need to write it to the instance terminal - instance := core.InstanceGet(session, name) - core.InstanceWriteToTerminal(instance, data) + core.InstanceWriteToTerminal(session.Id, name, data) }) so.On("viewport resize", func(cols, rows uint) { diff --git a/pwd/client.go b/pwd/client.go index 6fc1021..ed3917c 100644 --- a/pwd/client.go +++ b/pwd/client.go @@ -2,6 +2,7 @@ package pwd import ( "log" + "sync/atomic" "time" "github.com/play-with-docker/play-with-docker/pwd/types" @@ -11,6 +12,7 @@ func (p *pwd) ClientNew(id string, session *types.Session) *types.Client { defer observeAction("ClientNew", time.Now()) c := &types.Client{Id: id, Session: session} session.Clients = append(session.Clients, c) + p.clientCount = atomic.AddInt32(&p.clientCount, 1) return c } @@ -29,6 +31,7 @@ func (p *pwd) ClientClose(client *types.Client) { for i, cl := range session.Clients { if cl.Id == client.Id { session.Clients = append(session.Clients[:i], session.Clients[i+1:]...) + p.clientCount = atomic.AddInt32(&p.clientCount, -1) break } } @@ -38,6 +41,10 @@ func (p *pwd) ClientClose(client *types.Client) { p.setGauges() } +func (p *pwd) ClientCount() int { + return int(atomic.LoadInt32(&p.clientCount)) +} + func (p *pwd) notifyClientSmallestViewPort(session *types.Session) { vp := p.SessionGetSmallestViewPort(session) // Resize all terminals in the session diff --git a/pwd/client_test.go b/pwd/client_test.go index 7ab9243..85d0605 100644 --- a/pwd/client_test.go +++ b/pwd/client_test.go @@ -24,6 +24,21 @@ func TestClientNew(t *testing.T) { assert.Equal(t, types.Client{Id: "foobar", Session: session, ViewPort: types.ViewPort{Cols: 0, Rows: 0}}, *client) assert.Contains(t, session.Clients, client) } +func TestClientCount(t *testing.T) { + docker := &mockDocker{} + tasks := &mockTasks{} + broadcast := &mockBroadcast{} + storage := &mockStorage{} + + p := NewPWD(docker, tasks, broadcast, storage) + + session, err := p.SessionNew(time.Hour, "", "", "") + assert.Nil(t, err) + + p.ClientNew("foobar", session) + + assert.Equal(t, 1, p.ClientCount()) +} func TestClientResizeViewPort(t *testing.T) { docker := &mockDocker{} diff --git a/pwd/instance.go b/pwd/instance.go index 4482728..a024c0b 100644 --- a/pwd/instance.go +++ b/pwd/instance.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "log" + "net" "net/http" "path/filepath" "strings" @@ -23,6 +24,8 @@ type sessionWriter struct { broadcast BroadcastApi } +var terms = make(map[string]map[string]net.Conn) + func (s *sessionWriter) Write(p []byte) (n int, err error) { s.broadcast.BroadcastTo(s.sessionId, "terminal out", s.instanceName, string(p)) return len(p), nil @@ -46,6 +49,10 @@ func (p *pwd) InstanceResizeTerminal(instance *types.Instance, rows, cols uint) } func (p *pwd) InstanceAttachTerminal(instance *types.Instance) error { + // already have a connection for this instance + if getInstanceTermConn(instance.SessionId, instance.Name) != nil { + return nil + } conn, err := p.docker.CreateAttachConnection(instance.Name) if err != nil { @@ -54,7 +61,11 @@ func (p *pwd) InstanceAttachTerminal(instance *types.Instance) error { encoder := encoding.Replacement.NewEncoder() sw := &sessionWriter{sessionId: instance.Session.Id, instanceName: instance.Name, broadcast: p.broadcast} - instance.Terminal = conn + if terms[instance.SessionId] == nil { + terms[instance.SessionId] = map[string]net.Conn{instance.Name: conn} + } else { + terms[instance.SessionId][instance.Name] = conn + } io.Copy(encoder.Writer(sw), conn) return nil @@ -155,8 +166,10 @@ func (p *pwd) InstanceFindByAlias(sessionPrefix, alias string) *types.Instance { func (p *pwd) InstanceDelete(session *types.Session, instance *types.Instance) error { defer observeAction("InstanceDelete", time.Now()) - if instance.Terminal != nil { - instance.Terminal.Close() + conn := getInstanceTermConn(session.Id, instance.Name) + if conn != nil { + conn.Close() + delete(terms[instance.SessionId], instance.Name) } err := p.docker.DeleteContainer(instance.Name) if err != nil && !strings.Contains(err.Error(), "No such container") { @@ -167,7 +180,7 @@ func (p *pwd) InstanceDelete(session *types.Session, instance *types.Instance) e p.broadcast.BroadcastTo(session.Id, "delete instance", instance.Name) delete(session.Instances, instance.Name) - if err := p.storage.SessionPut(session); err != nil { + if err := p.storage.InstanceDelete(session.Id, instance.Name); err != nil { return err } @@ -239,6 +252,7 @@ func (p *pwd) InstanceNew(session *types.Session, conf InstanceConfig) (*types.I instance := &types.Instance{} instance.Image = opts.Image instance.IP = ip + instance.SessionId = session.Id instance.Name = containerName instance.Hostname = conf.Hostname instance.Alias = conf.Alias @@ -258,7 +272,7 @@ func (p *pwd) InstanceNew(session *types.Session, conf InstanceConfig) (*types.I go p.InstanceAttachTerminal(instance) - err = p.storage.SessionPut(session) + err = p.storage.InstanceCreate(session.Id, instance) if err != nil { return nil, err } @@ -270,10 +284,11 @@ func (p *pwd) InstanceNew(session *types.Session, conf InstanceConfig) (*types.I return instance, nil } -func (p *pwd) InstanceWriteToTerminal(instance *types.Instance, data string) { +func (p *pwd) InstanceWriteToTerminal(sessionId, instanceName string, data string) { defer observeAction("InstanceWriteToTerminal", time.Now()) - if instance != nil && instance.Terminal != nil && len(data) > 0 { - instance.Terminal.Write([]byte(data)) + conn := getInstanceTermConn(sessionId, instanceName) + if conn != nil && len(data) > 0 { + conn.Write([]byte(data)) } } @@ -292,3 +307,7 @@ func (p *pwd) InstanceExec(instance *types.Instance, cmd []string) (int, error) defer observeAction("InstanceExec", time.Now()) return p.docker.Exec(instance.Name, cmd) } + +func getInstanceTermConn(sessionId, instanceName string) net.Conn { + return terms[sessionId][instanceName] +} diff --git a/pwd/instance_test.go b/pwd/instance_test.go index cb82ac1..2bfb9e5 100644 --- a/pwd/instance_test.go +++ b/pwd/instance_test.go @@ -1,7 +1,9 @@ package pwd import ( + "errors" "fmt" + "net" "sync" "testing" "time" @@ -69,6 +71,7 @@ func TestInstanceNew(t *testing.T) { Alias: "", Image: config.GetDindImageName(), IsDockerHost: true, + SessionId: session.Id, Session: session, } @@ -159,6 +162,7 @@ func TestInstanceNew_WithNotAllowedImage(t *testing.T) { IP: "10.0.0.1", Alias: "", Image: "redis", + SessionId: session.Id, IsDockerHost: false, Session: session, } @@ -209,6 +213,7 @@ func TestInstanceNew_WithCustomHostname(t *testing.T) { Image: "redis", IsDockerHost: false, Session: session, + SessionId: session.Id, } assert.Equal(t, expectedInstance, *instance) @@ -239,3 +244,39 @@ func TestInstanceAllowedImages(t *testing.T) { assert.Equal(t, expectedImages, p.InstanceAllowedImages()) } + +type errConn struct { + *mockConn +} + +func (ec errConn) Read(b []byte) (int, error) { + return 0, errors.New("Error") +} + +func TestTermConnAssignment(t *testing.T) { + dock := &mockDocker{} + tasks := &mockTasks{} + broadcast := &mockBroadcast{} + storage := &mockStorage{} + + dock.createAttachConnection = func(name string) (net.Conn, error) { + // return error connection to unlock the goroutine + return errConn{}, nil + } + + p := NewPWD(dock, tasks, broadcast, storage) + session, _ := p.SessionNew(time.Hour, "", "", "") + mockInstance := &types.Instance{ + Name: fmt.Sprintf("%s_redis-master", session.Id[:8]), + Hostname: "redis-master", + IP: "10.0.0.1", + Alias: "", + SessionId: session.Id, + Image: "redis", + IsDockerHost: false, + Session: session, + } + p.InstanceAttachTerminal(mockInstance) + assert.NotNil(t, getInstanceTermConn(session.Id, mockInstance.Name)) + +} diff --git a/pwd/pwd.go b/pwd/pwd.go index 4c397fb..1d5782f 100644 --- a/pwd/pwd.go +++ b/pwd/pwd.go @@ -43,10 +43,11 @@ func init() { } type pwd struct { - docker docker.DockerApi - tasks SchedulerApi - broadcast BroadcastApi - storage storage.StorageApi + docker docker.DockerApi + tasks SchedulerApi + broadcast BroadcastApi + storage storage.StorageApi + clientCount int32 } type PWDApi interface { @@ -63,17 +64,19 @@ type PWDApi interface { InstanceUploadFromUrl(instance *types.Instance, fileName, dest, url string) error InstanceUploadFromReader(instance *types.Instance, fileName, dest string, reader io.Reader) error InstanceGet(session *types.Session, name string) *types.Instance + // TODO remove this function when we add the session prefix to the PWD url InstanceFindByIP(ip string) *types.Instance InstanceFindByAlias(sessionPrefix, alias string) *types.Instance InstanceFindByIPAndSession(sessionPrefix, ip string) *types.Instance InstanceDelete(session *types.Session, instance *types.Instance) error - InstanceWriteToTerminal(instance *types.Instance, data string) + InstanceWriteToTerminal(sessionId, instanceName string, data string) InstanceAllowedImages() []string InstanceExec(instance *types.Instance, cmd []string) (int, error) ClientNew(id string, session *types.Session) *types.Client ClientResizeViewPort(client *types.Client, cols, rows uint) ClientClose(client *types.Client) + ClientCount() int } func NewPWD(d docker.DockerApi, t SchedulerApi, b BroadcastApi, s storage.StorageApi) *pwd { @@ -85,7 +88,7 @@ func (p *pwd) setGauges() { ses := float64(s) i, _ := p.storage.InstanceCount() ins := float64(i) - c, _ := p.storage.ClientCount() + c := p.ClientCount() cli := float64(c) clientsGauge.Set(cli) diff --git a/pwd/session.go b/pwd/session.go index fedf2f1..9dd26a2 100644 --- a/pwd/session.go +++ b/pwd/session.go @@ -16,6 +16,8 @@ import ( "github.com/twinj/uuid" ) +var preparedSessions = map[string]bool{} + type sessionBuilderWriter struct { sessionId string broadcast BroadcastApi @@ -65,7 +67,7 @@ func (p *pwd) SessionNew(duration time.Duration, stack, stackName, imageName str } log.Printf("Network [%s] created for session [%s]\n", s.Id, s.Id) - if err := p.prepareSession(s); err != nil { + if _, err := p.prepareSession(s); err != nil { log.Println(err) return nil, err } @@ -186,7 +188,7 @@ func (p *pwd) SessionGet(sessionId string) *types.Session { s, _ := p.storage.SessionGet(sessionId) - if err := p.prepareSession(s); err != nil { + if _, err := p.prepareSession(s); err != nil { log.Println(err) return nil } @@ -283,22 +285,27 @@ func (p *pwd) SessionSetup(session *types.Session, conf SessionSetupConf) error return nil } +func isSessionPrepared(sessionId string) bool { + _, ok := preparedSessions[sessionId] + return ok +} + // This function should be called any time a session needs to be prepared: // 1. Like when it is created // 2. When it was loaded from storage -func (p *pwd) prepareSession(session *types.Session) error { +func (p *pwd) prepareSession(session *types.Session) (bool, error) { session.Lock() defer session.Unlock() - if session.IsPrepared() { - return nil + if isSessionPrepared(session.Id) { + return false, nil } p.scheduleSessionClose(session) // Connect PWD daemon to the new network if err := p.connectToNetwork(session); err != nil { - return err + return false, err } // Schedule periodic tasks @@ -309,9 +316,9 @@ func (p *pwd) prepareSession(session *types.Session) error { i.Session = session go p.InstanceAttachTerminal(i) } - session.SetPrepared() + preparedSessions[session.Id] = true - return nil + return true, nil } func (p *pwd) scheduleSessionClose(s *types.Session) { diff --git a/pwd/session_test.go b/pwd/session_test.go index bef5a1d..c31d3cb 100644 --- a/pwd/session_test.go +++ b/pwd/session_test.go @@ -209,10 +209,10 @@ func TestSessionSetup(t *testing.T) { Image: "franela/dind", Hostname: "manager1", IP: "10.0.0.1", + SessionId: s.Id, Alias: "", IsDockerHost: true, Session: s, - Terminal: manager1Received.Terminal, Docker: manager1Received.Docker, }, manager1Received) @@ -225,8 +225,8 @@ func TestSessionSetup(t *testing.T) { IP: "10.0.0.2", Alias: "", IsDockerHost: true, + SessionId: s.Id, Session: s, - Terminal: manager2Received.Terminal, Docker: manager2Received.Docker, }, manager2Received) @@ -238,9 +238,9 @@ func TestSessionSetup(t *testing.T) { Hostname: "manager3", IP: "10.0.0.3", Alias: "", + SessionId: s.Id, IsDockerHost: true, Session: s, - Terminal: manager3Received.Terminal, Docker: manager3Received.Docker, }, manager3Received) @@ -252,9 +252,9 @@ func TestSessionSetup(t *testing.T) { Hostname: "worker1", IP: "10.0.0.4", Alias: "", + SessionId: s.Id, IsDockerHost: true, Session: s, - Terminal: worker1Received.Terminal, Docker: worker1Received.Docker, }, worker1Received) @@ -266,9 +266,9 @@ func TestSessionSetup(t *testing.T) { Hostname: "other", IP: "10.0.0.5", Alias: "", + SessionId: s.Id, IsDockerHost: true, Session: s, - Terminal: otherReceived.Terminal, Docker: otherReceived.Docker, }, otherReceived) @@ -277,3 +277,20 @@ func TestSessionSetup(t *testing.T) { assert.True(t, manager3JoinedHasManager) assert.True(t, worker1JoinedHasWorker) } + +func TestSessionPrepareOnce(t *testing.T) { + dock := &mockDocker{} + tasks := &mockTasks{} + broadcast := &mockBroadcast{} + storage := &mockStorage{} + + p := NewPWD(dock, tasks, broadcast, storage) + session := &types.Session{Id: "1234"} + prepared, err := p.prepareSession(session) + assert.True(t, preparedSessions[session.Id]) + assert.True(t, prepared) + + prepared, err = p.prepareSession(session) + assert.Nil(t, err) + assert.False(t, prepared) +} diff --git a/pwd/storage_mock_test.go b/pwd/storage_mock_test.go index 7074c21..196520f 100644 --- a/pwd/storage_mock_test.go +++ b/pwd/storage_mock_test.go @@ -10,6 +10,8 @@ type mockStorage struct { instanceFindByAlias func(sessionPrefix, alias string) (*types.Instance, error) instanceFindByIP func(ip string) (*types.Instance, error) instanceFindByIPAndSession func(sessionPrefix, ip string) (*types.Instance, error) + instanceCreate func(string, *types.Instance) error + instanceDelete func(sessionId string, instanceName string) error instanceCount func() (int, error) clientCount func() (int, error) } @@ -56,6 +58,18 @@ func (m *mockStorage) InstanceFindByIPAndSession(sessionPrefix, ip string) (*typ } return nil, nil } +func (m *mockStorage) InstanceCreate(sessionId string, instance *types.Instance) error { + if m.instanceCreate != nil { + return m.instanceCreate(sessionId, instance) + } + return nil +} +func (m *mockStorage) InstanceDelete(sessionId, instanceName string) error { + if m.instanceDelete != nil { + return m.instanceDelete(sessionId, instanceName) + } + return nil +} func (m *mockStorage) InstanceCount() (int, error) { if m.instanceCount != nil { return m.instanceCount() diff --git a/pwd/tasks.go b/pwd/tasks.go index adc52ce..6597217 100644 --- a/pwd/tasks.go +++ b/pwd/tasks.go @@ -33,7 +33,7 @@ type scheduler struct { } func (sch *scheduler) Schedule(s *types.Session) { - if s.IsPrepared() { + if isSessionPrepared(s.Id) { return } diff --git a/pwd/types/instance.go b/pwd/types/instance.go index b39f60e..e9f40ed 100644 --- a/pwd/types/instance.go +++ b/pwd/types/instance.go @@ -2,7 +2,6 @@ package types import ( "context" - "net" "sync" "github.com/play-with-docker/play-with-docker/docker" @@ -15,27 +14,28 @@ func (p UInt16Slice) Less(i, j int) bool { return p[i] < p[j] } func (p UInt16Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } type Instance struct { - Image string `json:"image"` - Name string `json:"name"` - Hostname string `json:"hostname"` - IP string `json:"ip"` - IsManager *bool `json:"is_manager"` - Mem string `json:"mem"` - Cpu string `json:"cpu"` - Alias string `json:"alias"` - ServerCert []byte `json:"server_cert"` - ServerKey []byte `json:"server_key"` - CACert []byte `json:"ca_cert"` - Cert []byte `json:"cert"` - Key []byte `json:"key"` - IsDockerHost bool `json:"is_docker_host"` - Docker docker.DockerApi `json:"-"` - Session *Session `json:"-"` - Terminal net.Conn `json:"-"` - ctx context.Context `json:"-"` - tempPorts []uint16 `json:"-"` - Ports UInt16Slice - rw sync.Mutex + Image string `json:"image" bson:"image"` + Name string `json:"name" bson:"name"` + Hostname string `json:"hostname" bson:"hostname"` + IP string `json:"ip" bson:"ip"` + IsManager *bool `json:"is_manager" bson:"is_manager"` + Mem string `json:"mem" bson:"mem"` + Cpu string `json:"cpu" bson:"cpu"` + Alias string `json:"alias" bson:"alias"` + ServerCert []byte `json:"server_cert" bson:"server_cert"` + ServerKey []byte `json:"server_key" bson:"server_key"` + CACert []byte `json:"ca_cert" bson:"ca_cert"` + Cert []byte `json:"cert" bson:"cert"` + Key []byte `json:"key" bson:"key"` + IsDockerHost bool `json:"is_docker_host" bson:"is_docker_host"` + SessionId string `json:"session_id" bson:"session_id"` + SessionPrefix string `json:"session_prefix" bson:"session_prefix"` + Docker docker.DockerApi `json:"-"` + Session *Session `json:"-" bson:"-"` + ctx context.Context `json:"-" bson:"-"` + tempPorts []uint16 `json:"-" bson:"-"` + Ports UInt16Slice + rw sync.Mutex } func (i *Instance) SetUsedPort(port uint16) { diff --git a/pwd/types/session.go b/pwd/types/session.go index 0b4e831..48c9f89 100644 --- a/pwd/types/session.go +++ b/pwd/types/session.go @@ -7,7 +7,7 @@ import ( type Session struct { Id string `json:"id"` - Instances map[string]*Instance `json:"instances"` + Instances map[string]*Instance `json:"instances" bson:"-"` CreatedAt time.Time `json:"created_at"` ExpiresAt time.Time `json:"expires_at"` PwdIpAddress string `json:"pwd_ip_address"` @@ -16,12 +16,11 @@ type Session struct { StackName string `json:"stack_name"` ImageName string `json:"image_name"` Host string `json:"host"` - Clients []*Client `json:"-"` + Clients []*Client `json:"-" bson:"-"` closingTimer *time.Timer `json:"-"` scheduled bool `json:"-"` ticker *time.Ticker `json:"-"` rw sync.Mutex `json:"-"` - prepared bool `json:"-"` } func (s *Session) Lock() { @@ -46,11 +45,3 @@ func (s *Session) SetClosingTimer(t *time.Timer) { func (s *Session) ClosingTimer() *time.Timer { return s.closingTimer } - -func (s *Session) IsPrepared() bool { - return s.prepared -} - -func (s *Session) SetPrepared() { - s.prepared = true -} diff --git a/storage/file.go b/storage/file.go index 0697ec7..ca75c91 100644 --- a/storage/file.go +++ b/storage/file.go @@ -31,6 +31,10 @@ func (store *storage) SessionPut(s *types.Session) error { store.rw.Lock() defer store.rw.Unlock() + // Initialize instances map if nil + if s.Instances == nil { + s.Instances = map[string]*types.Instance{} + } store.db[s.Id] = s return store.save() @@ -85,6 +89,24 @@ func (store *storage) InstanceFindByAlias(sessionPrefix, alias string) (*types.I return nil, fmt.Errorf("%s", notFound) } +func (store *storage) InstanceCreate(sessionId string, instance *types.Instance) error { + store.rw.Lock() + defer store.rw.Unlock() + + s, found := store.db[sessionId] + if !found { + return fmt.Errorf("Session %s", notFound) + } + + s.Instances[instance.Name] = instance + + return store.save() +} + +func (store *storage) InstanceDelete(sessionId, name string) error { + panic("not implemented") +} + func (store *storage) SessionCount() (int, error) { store.rw.Lock() defer store.rw.Unlock() diff --git a/storage/file_test.go b/storage/file_test.go index 993f041..f3bea7b 100644 --- a/storage/file_test.go +++ b/storage/file_test.go @@ -191,6 +191,33 @@ func TestInstanceFindByAlias(t *testing.T) { assert.Nil(t, foundInstance) } +func TestInstanceCreate(t *testing.T) { + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + tmpfile.Close() + os.Remove(tmpfile.Name()) + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + i1 := &types.Instance{Name: "i1", Alias: "foo", IP: "10.0.0.1"} + s1 := &types.Session{Id: "session1"} + err = storage.SessionPut(s1) + assert.Nil(t, err) + err = storage.InstanceCreate(s1.Id, i1) + assert.Nil(t, err) + + loadedSession, err := storage.SessionGet("session1") + assert.Nil(t, err) + + assert.Equal(t, i1, loadedSession.Instances["i1"]) + +} + func TestCounts(t *testing.T) { tmpfile, err := ioutil.TempFile("", "pwd") if err != nil { @@ -226,9 +253,6 @@ func TestCounts(t *testing.T) { assert.Nil(t, err) assert.Equal(t, 2, num) - num, err = storage.ClientCount() - assert.Nil(t, err) - assert.Equal(t, 1, num) } func TestSessionDelete(t *testing.T) { diff --git a/storage/storage.go b/storage/storage.go index 4ae72c6..d9a1538 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -9,16 +9,17 @@ func NotFound(e error) bool { } type StorageApi interface { - SessionGet(sessionId string) (*types.Session, error) + SessionGet(string) (*types.Session, error) SessionPut(*types.Session) error SessionCount() (int, error) - SessionDelete(sessionId string) error + SessionDelete(string) error InstanceFindByAlias(sessionPrefix, alias string) (*types.Instance, error) // Should have the session id too, soon InstanceFindByIP(ip string) (*types.Instance, error) InstanceFindByIPAndSession(sessionPrefix, ip string) (*types.Instance, error) - InstanceCount() (int, error) + InstanceCreate(sessionId string, instance *types.Instance) error + InstanceDelete(sessionId, instanceName string) error - ClientCount() (int, error) + InstanceCount() (int, error) }