diff --git a/pwd/storage_mock_test.go b/pwd/storage_mock_test.go index 196520f..abbf40e 100644 --- a/pwd/storage_mock_test.go +++ b/pwd/storage_mock_test.go @@ -4,6 +4,7 @@ import "github.com/play-with-docker/play-with-docker/pwd/types" type mockStorage struct { sessionGet func(sessionId string) (*types.Session, error) + sessionGetAll func() (map[string]*types.Session, error) sessionPut func(s *types.Session) error sessionCount func() (int, error) sessionDelete func(sessionId string) error @@ -22,6 +23,14 @@ func (m *mockStorage) SessionGet(sessionId string) (*types.Session, error) { } return nil, nil } + +func (m *mockStorage) SessionGetAll() (map[string]*types.Session, error) { + if m.sessionGetAll != nil { + return m.sessionGetAll() + } + return nil, nil +} + func (m *mockStorage) SessionPut(s *types.Session) error { if m.sessionPut != nil { return m.sessionPut(s) diff --git a/storage/file.go b/storage/file.go index ca75c91..a27e465 100644 --- a/storage/file.go +++ b/storage/file.go @@ -27,6 +27,14 @@ func (store *storage) SessionGet(sessionId string) (*types.Session, error) { return s, nil } + +func (store *storage) SessionGetAll() (map[string]*types.Session, error) { + store.rw.Lock() + defer store.rw.Unlock() + + return store.db, nil +} + func (store *storage) SessionPut(s *types.Session) error { store.rw.Lock() defer store.rw.Unlock() diff --git a/storage/file_test.go b/storage/file_test.go index f3bea7b..71d711f 100644 --- a/storage/file_test.go +++ b/storage/file_test.go @@ -74,6 +74,34 @@ func TestSessionGet(t *testing.T) { assert.Equal(t, expectedSession, loadedSession) } +func TestSessionGetAll(t *testing.T) { + s1 := &types.Session{Id: "session1"} + s2 := &types.Session{Id: "session2"} + sessions := map[string]*types.Session{} + sessions[s1.Id] = s1 + sessions[s2.Id] = s2 + + tmpfile, err := ioutil.TempFile("", "pwd") + if err != nil { + log.Fatal(err) + } + encoder := json.NewEncoder(tmpfile) + err = encoder.Encode(&sessions) + assert.Nil(t, err) + tmpfile.Close() + defer os.Remove(tmpfile.Name()) + + storage, err := NewFileStorage(tmpfile.Name()) + + assert.Nil(t, err) + + loadedSessions, err := storage.SessionGetAll() + assert.Nil(t, err) + + assert.Equal(t, s1, loadedSessions[s1.Id]) + assert.Equal(t, s2, loadedSessions[s2.Id]) +} + func TestInstanceFindByIP(t *testing.T) { tmpfile, err := ioutil.TempFile("", "pwd") if err != nil { diff --git a/storage/storage.go b/storage/storage.go index d9a1538..39f3fc6 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -13,6 +13,7 @@ type StorageApi interface { SessionPut(*types.Session) error SessionCount() (int, error) SessionDelete(string) error + SessionGetAll() (map[string]*types.Session, error) InstanceFindByAlias(sessionPrefix, alias string) (*types.Instance, error) // Should have the session id too, soon