diff --git a/config/config.go b/config/config.go index 32892ca..56bbb33 100644 --- a/config/config.go +++ b/config/config.go @@ -38,6 +38,7 @@ var SecureCookie *securecookie.SecureCookie var GithubClientID, GithubClientSecret string var FacebookClientID, FacebookClientSecret string +var DockerClientID, DockerClientSecret string type stringslice []string @@ -78,6 +79,9 @@ func ParseFlags() { flag.StringVar(&FacebookClientID, "oauth-facebook-client-id", "", "Facebook OAuth Client ID") flag.StringVar(&FacebookClientSecret, "oauth-facebook-client-secret", "", "Facebook OAuth Client Secret") + flag.StringVar(&DockerClientID, "oauth-docker-client-id", "", "Docker OAuth Client ID") + flag.StringVar(&DockerClientSecret, "oauth-docker-client-secret", "", "Docker OAuth Client Secret") + flag.Parse() SecureCookie = securecookie.New([]byte(CookieHashKey), []byte(CookieBlockKey)) @@ -107,6 +111,20 @@ func registerOAuthProviders() { Providers["facebook"] = conf } + if DockerClientID != "" && DockerClientSecret != "" { + oauth2.RegisterBrokenAuthHeaderProvider(".id.docker.com") + conf := &oauth2.Config{ + ClientID: DockerClientID, + ClientSecret: DockerClientSecret, + Scopes: []string{"openid"}, + Endpoint: oauth2.Endpoint{ + AuthURL: "https://id.docker.com/id/oauth/authorize/", + TokenURL: "https://id.docker.com/id/oauth/token", + }, + } + + Providers["docker"] = conf + } } func GetDindImageName() string { diff --git a/handlers/bootstrap.go b/handlers/bootstrap.go index a61227c..9fbd42f 100644 --- a/handlers/bootstrap.go +++ b/handlers/bootstrap.go @@ -79,6 +79,7 @@ func Register(extend HandlerExtender) { http.ServeFile(rw, r, "./www/landing.html") }).Methods("GET") + r.HandleFunc("/users/me", LoggedInUser).Methods("GET") r.HandleFunc("/oauth/providers", ListProviders).Methods("GET") r.HandleFunc("/oauth/providers/{provider}/login", Login).Methods("GET") r.HandleFunc("/oauth/providers/{provider}/callback", LoginCallback).Methods("GET") diff --git a/handlers/cookie_id.go b/handlers/cookie_id.go index c19b25a..e18023f 100644 --- a/handlers/cookie_id.go +++ b/handlers/cookie_id.go @@ -15,10 +15,11 @@ type CookieID struct { func (c *CookieID) SetCookie(rw http.ResponseWriter) error { if encoded, err := config.SecureCookie.Encode("id", c); err == nil { cookie := &http.Cookie{ - Name: "id", - Value: encoded, - Path: "/", - Secure: config.UseLetsEncrypt, + Name: "id", + Value: encoded, + Path: "/", + Secure: config.UseLetsEncrypt, + HttpOnly: true, } http.SetCookie(rw, cookie) } else { diff --git a/handlers/login.go b/handlers/login.go index 070b484..637adc3 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -14,8 +14,25 @@ import ( fb "github.com/huandu/facebook" "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/pwd/types" + "github.com/twinj/uuid" ) +func LoggedInUser(rw http.ResponseWriter, req *http.Request) { + cookie, err := ReadCookie(req) + if err != nil { + log.Println("Cannot read cookie") + rw.WriteHeader(http.StatusUnauthorized) + return + } + user, err := core.UserGet(cookie.Id) + if err != nil { + log.Printf("Couldn't get user with id %s. Got: %v\n", cookie.Id, err) + rw.WriteHeader(http.StatusUnauthorized) + return + } + json.NewEncoder(rw).Encode(user) +} + func ListProviders(rw http.ResponseWriter, req *http.Request) { providers := []string{} for name, _ := range config.Providers { @@ -51,7 +68,7 @@ func Login(rw http.ResponseWriter, req *http.Request) { host = req.Host } provider.RedirectURL = fmt.Sprintf("%s://%s/oauth/providers/%s/callback", scheme, host, providerName) - url := provider.AuthCodeURL(loginRequest.Id) + url := provider.AuthCodeURL(loginRequest.Id, oauth2.SetAuthURLParam("nonce", uuid.NewV4().String())) http.Redirect(rw, req, url, http.StatusFound) } @@ -125,6 +142,28 @@ func LoginCallback(rw http.ResponseWriter, req *http.Request) { user.Name = res.Get("name").(string) user.Avatar = res.Get("picture.data.url").(string) user.Email = res.Get("email").(string) + } else if providerName == "docker" { + ts := oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: tok.AccessToken}, + ) + tc := oauth2.NewClient(ctx, ts) + resp, err := tc.Get("https://id.docker.com/api/id/v1/openid/userinfo") + if err != nil { + log.Printf("Could not get user from docker. Got: %v\n", err) + rw.WriteHeader(http.StatusInternalServerError) + return + } + + userInfo := map[string]string{} + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + log.Printf("Could not decode user info. Got: %v\n", err) + rw.WriteHeader(http.StatusInternalServerError) + return + } + + user.ProviderUserId = userInfo["sub"] + user.Name = userInfo["preferred_username"] + user.Email = userInfo["email"] } user, err = core.UserLogin(loginRequest, user) @@ -142,5 +181,5 @@ func LoginCallback(rw http.ResponseWriter, req *http.Request) { return } - http.Redirect(rw, req, "/", http.StatusFound) + fmt.Fprintf(rw, `
`) } diff --git a/pwd/mock.go b/pwd/mock.go index c417b4c..459d997 100644 --- a/pwd/mock.go +++ b/pwd/mock.go @@ -119,3 +119,7 @@ func (m *Mock) UserLogin(loginRequest *types.LoginRequest, user *types.User) (*t args := m.Called(loginRequest, user) return args.Get(0).(*types.User), args.Error(1) } +func (m *Mock) UserGet(id string) (*types.User, error) { + args := m.Called(id) + return args.Get(0).(*types.User), args.Error(1) +} diff --git a/pwd/pwd.go b/pwd/pwd.go index 5aaf2e5..630c64f 100644 --- a/pwd/pwd.go +++ b/pwd/pwd.go @@ -97,6 +97,7 @@ type PWDApi interface { UserNewLoginRequest(providerName string) (*types.LoginRequest, error) UserGetLoginRequest(id string) (*types.LoginRequest, error) UserLogin(loginRequest *types.LoginRequest, user *types.User) (*types.User, error) + UserGet(id string) (*types.User, error) } func NewPWD(f docker.FactoryApi, e event.EventApi, s storage.StorageApi, sp provisioner.SessionProvisionerApi, ipf provisioner.InstanceProvisionerFactoryApi) *pwd { diff --git a/pwd/user.go b/pwd/user.go index 985e1e8..0d6bc3f 100644 --- a/pwd/user.go +++ b/pwd/user.go @@ -41,3 +41,10 @@ func (p *pwd) UserLogin(loginRequest *types.LoginRequest, user *types.User) (*ty return user, nil } +func (p *pwd) UserGet(id string) (*types.User, error) { + if user, err := p.storage.UserGet(id); err != nil { + return nil, err + } else { + return user, nil + } +} diff --git a/storage/file.go b/storage/file.go index 72d651c..6baaade 100644 --- a/storage/file.go +++ b/storage/file.go @@ -353,6 +353,16 @@ func (store *storage) UserPut(user *types.User) error { return nil } +func (store *storage) UserGet(id string) (*types.User, error) { + store.rw.Lock() + defer store.rw.Unlock() + + if user, found := store.db.Users[id]; !found { + return nil, NotFoundError + } else { + return user, nil + } +} func (store *storage) load() error { file, err := os.Open(store.path) diff --git a/storage/mock.go b/storage/mock.go index 3d18278..1a2c062 100644 --- a/storage/mock.go +++ b/storage/mock.go @@ -102,3 +102,7 @@ func (m *Mock) UserPut(user *types.User) error { args := m.Called(user) return args.Error(0) } +func (m *Mock) UserGet(id string) (*types.User, error) { + args := m.Called(id) + return args.Get(0).(*types.User), args.Error(1) +} diff --git a/storage/storage.go b/storage/storage.go index d199ef6..62f649a 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -41,4 +41,5 @@ type StorageApi interface { UserFindByProvider(providerName, providerUserId string) (*types.User, error) UserPut(user *types.User) error + UserGet(id string) (*types.User, error) } diff --git a/www/landing.html b/www/landing.html index 116f9fa..b900f4e 100644 --- a/www/landing.html +++ b/www/landing.html @@ -41,7 +41,7 @@ Login