From cd6d172cfadb9fe82b563ea66c779c3a4e57d4fc Mon Sep 17 00:00:00 2001 From: "Jonathan Leibiusky @xetorthio" Date: Mon, 19 Jun 2017 11:03:31 -0300 Subject: [PATCH 1/6] WIP --- api.go | 32 +---- {handlers => router/dns}/dns.go | 2 +- router/router.go | 117 ++++++++++++++++++ router/router_test.go | 49 ++++++++ handlers/reverseproxy.go => router/tcp/tcp.go | 4 +- handlers/tlsproxy.go => router/tls/tls.go | 2 +- 6 files changed, 173 insertions(+), 33 deletions(-) rename {handlers => router/dns}/dns.go (99%) create mode 100644 router/router.go create mode 100644 router/router_test.go rename handlers/reverseproxy.go => router/tcp/tcp.go (97%) rename handlers/tlsproxy.go => router/tls/tls.go (99%) diff --git a/api.go b/api.go index 313e2bf..eff0941 100644 --- a/api.go +++ b/api.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "log" "net/http" "os" @@ -9,7 +8,6 @@ import ( gh "github.com/gorilla/handlers" "github.com/gorilla/mux" - "github.com/miekg/dns" "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/handlers" "github.com/play-with-docker/play-with-docker/templates" @@ -18,29 +16,11 @@ import ( ) func main() { - config.ParseFlags() handlers.Bootstrap() bypassCaptcha := len(os.Getenv("GOOGLE_RECAPTCHA_DISABLED")) > 0 - // Start the DNS server - dns.HandleFunc(".", handlers.DnsRequest) - udpDnsServer := &dns.Server{Addr: ":53", Net: "udp"} - go func() { - err := udpDnsServer.ListenAndServe() - if err != nil { - log.Fatal(err) - } - }() - tcpDnsServer := &dns.Server{Addr: ":53", Net: "tcp"} - go func() { - err := tcpDnsServer.ListenAndServe() - if err != nil { - log.Fatal(err) - } - }() - server := handlers.Broadcast.GetHandler() r := mux.NewRouter() @@ -55,10 +35,6 @@ func main() { corsHandler := gh.CORS(gh.AllowCredentials(), gh.AllowedHeaders([]string{"x-requested-with", "content-type"}), gh.AllowedOrigins([]string{"*"})) // Specific routes - r.Host(fmt.Sprintf("{subdomain:.*}pwd{node:%s}-{port:%s}.{tld:.*}", config.PWDHostnameRegex, config.PortRegex)).Handler(tcpHandler) - r.Host(fmt.Sprintf("{subdomain:.*}pwd{node:%s}.{tld:.*}", config.PWDHostnameRegex)).Handler(tcpHandler) - r.Host(fmt.Sprintf("pwd{alias:%s}-{session:%s}-{port:%s}.{tld:.*}", config.AliasnameRegex, config.AliasSessionRegex, config.PortRegex)).Handler(tcpHandler) - r.Host(fmt.Sprintf("pwd{alias:%s}-{session:%s}.{tld:.*}", config.AliasnameRegex, config.AliasSessionRegex)).Handler(tcpHandler) r.HandleFunc("/ping", handlers.Ping).Methods("GET") corsRouter.HandleFunc("/instances/images", handlers.GetInstanceImages).Methods("GET") corsRouter.HandleFunc("/sessions/{sessionId}", handlers.GetSession).Methods("GET") @@ -106,13 +82,11 @@ func main() { ReadHeaderTimeout: 5 * time.Second, } - go func() { - log.Println("Listening on port " + config.PortNumber) - log.Fatal(httpServer.ListenAndServe()) - }() - go handlers.ListenSSHProxy("0.0.0.0:1022") // Now listen for TLS connections that need to be proxied handlers.StartTLSProxy(config.SSLPortNumber) + + log.Println("Listening on port " + config.PortNumber) + log.Fatal(httpServer.ListenAndServe()) } diff --git a/handlers/dns.go b/router/dns/dns.go similarity index 99% rename from handlers/dns.go rename to router/dns/dns.go index e199296..2160e92 100644 --- a/handlers/dns.go +++ b/router/dns/dns.go @@ -1,4 +1,4 @@ -package handlers +package dns import ( "fmt" diff --git a/router/router.go b/router/router.go new file mode 100644 index 0000000..0bb0d8c --- /dev/null +++ b/router/router.go @@ -0,0 +1,117 @@ +package router + +import ( + "fmt" + "io" + "log" + "net" + + vhost "github.com/inconshreveable/go-vhost" +) + +type Director func(host string) (*net.TCPAddr, error) + +type proxyRouter struct { + director Director +} + +func (r *proxyRouter) Listen(laddr string) { + l, err := net.Listen("tcp", laddr) + defer l.Close() + if err != nil { + log.Fatal(err) + } + for { + conn, err := l.Accept() + if err != nil { + log.Println(err) + continue + } + go r.handleConnection(conn) + } +} + +func (r *proxyRouter) handleConnection(c net.Conn) { + defer c.Close() + // first try tls + vhostConn, err := vhost.TLS(c) + if err != nil { + log.Printf("Incoming TLS connection produced an error. Error: %s", err) + return + } + defer vhostConn.Close() + + host := vhostConn.ClientHelloMsg.ServerName + c.LocalAddr() + dstHost, err := r.director(fmt.Sprintf("%s:%d", host, 12)) + if err != nil { + log.Printf("Error directing request: %v\n", err) + return + } + + d, err := net.Dial("tcp", dstHost.String()) + if err != nil { + log.Printf("Error dialing backend %s: %v\n", dstHost.String(), err) + return + } + + errc := make(chan error, 2) + cp := func(dst io.Writer, src io.Reader) { + _, err := io.Copy(dst, src) + errc <- err + } + go cp(d, vhostConn) + go cp(vhostConn, d) + <-errc + /* + req, err := http.ReadRequest(bufio.NewReader(c)) + if err != nil { + log.Println(err) + return + } + + log.Println(req.Header) + */ +} + +func NewRouter(director Director) *proxyRouter { + return &proxyRouter{director: director} +} + +/* + // Start the DNS server + dns.HandleFunc(".", routerDns.DnsRequest) + udpDnsServer := &dns.Server{Addr: ":53", Net: "udp"} + go func() { + err := udpDnsServer.ListenAndServe() + if err != nil { + log.Fatal(err) + } + }() + tcpDnsServer := &dns.Server{Addr: ":53", Net: "tcp"} + go func() { + err := tcpDnsServer.ListenAndServe() + if err != nil { + log.Fatal(err) + } + }() + r := mux.NewRouter() + tcpHandler := handlers.NewTCPProxy() + r.Host(fmt.Sprintf("{subdomain:.*}pwd{node:%s}-{port:%s}.{tld:.*}", config.PWDHostnameRegex, config.PortRegex)).Handler(tcpHandler) + r.Host(fmt.Sprintf("{subdomain:.*}pwd{node:%s}.{tld:.*}", config.PWDHostnameRegex)).Handler(tcpHandler) + r.Host(fmt.Sprintf("pwd{alias:%s}-{session:%s}-{port:%s}.{tld:.*}", config.AliasnameRegex, config.AliasSessionRegex, config.PortRegex)).Handler(tcpHandler) + r.Host(fmt.Sprintf("pwd{alias:%s}-{session:%s}.{tld:.*}", config.AliasnameRegex, config.AliasSessionRegex)).Handler(tcpHandler) + r.HandleFunc("/ping", handlers.Ping).Methods("GET") + n := negroni.Classic() + n.UseHandler(r) + + httpServer := http.Server{ + Addr: "0.0.0.0:" + config.PortNumber, + Handler: n, + IdleTimeout: 30 * time.Second, + ReadHeaderTimeout: 5 * time.Second, + } + // Now listen for TLS connections that need to be proxied + tls.StartTLSProxy(config.SSLPortNumber) + http.ListenAndServe() +*/ diff --git a/router/router_test.go b/router/router_test.go new file mode 100644 index 0000000..192b524 --- /dev/null +++ b/router/router_test.go @@ -0,0 +1,49 @@ +package router + +import ( + "crypto/tls" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProxy_TLS(t *testing.T) { + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{Transport: tr} + + const msg = "It works!" + + var receivedHost string + + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + })) + defer ts.Close() + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + u, _ := url.Parse(ts.URL) + a, _ := net.ResolveTCPAddr("tcp", u.Host) + return a, nil + }) + go r.Listen(":8080") + + req, err := http.NewRequest("GET", "https://localhost:8080", nil) + assert.Nil(t, err) + + resp, err := client.Do(req) + assert.Nil(t, err) + + body, err := ioutil.ReadAll(resp.Body) + assert.Nil(t, err) + assert.Equal(t, msg, string(body)) + assert.Equal(t, "localhost:8080", receivedHost) +} diff --git a/handlers/reverseproxy.go b/router/tcp/tcp.go similarity index 97% rename from handlers/reverseproxy.go rename to router/tcp/tcp.go index ed5a72e..80a1cca 100644 --- a/handlers/reverseproxy.go +++ b/router/tcp/tcp.go @@ -1,4 +1,4 @@ -package handlers +package tcp import ( "crypto/tls" @@ -9,8 +9,8 @@ import ( "net/http" "strings" + "github.com/franela/pwd.old/config" "github.com/gorilla/mux" - "github.com/play-with-docker/play-with-docker/config" ) func getTargetInfo(vars map[string]string, req *http.Request) (string, string) { diff --git a/handlers/tlsproxy.go b/router/tls/tls.go similarity index 99% rename from handlers/tlsproxy.go rename to router/tls/tls.go index 4a1ed03..a3c2ada 100644 --- a/handlers/tlsproxy.go +++ b/router/tls/tls.go @@ -1,4 +1,4 @@ -package handlers +package tls import ( "fmt" From ffaad303d6c4424a36faa6ea8971fb4a8ace9e12 Mon Sep 17 00:00:00 2001 From: "Jonathan Leibiusky @xetorthio" Date: Wed, 21 Jun 2017 16:52:04 -0300 Subject: [PATCH 2/6] WIP --- router/router.go | 113 +++++++++++++++++++---------- router/router_test.go | 162 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 236 insertions(+), 39 deletions(-) diff --git a/router/router.go b/router/router.go index 0bb0d8c..b62895a 100644 --- a/router/router.go +++ b/router/router.go @@ -1,10 +1,12 @@ package router import ( - "fmt" + "bufio" "io" "log" "net" + "net/http" + "sync" vhost "github.com/inconshreveable/go-vhost" ) @@ -12,66 +14,105 @@ import ( type Director func(host string) (*net.TCPAddr, error) type proxyRouter struct { + sync.Mutex + director Director + listener net.Listener + closed bool } func (r *proxyRouter) Listen(laddr string) { l, err := net.Listen("tcp", laddr) - defer l.Close() if err != nil { log.Fatal(err) } - for { - conn, err := l.Accept() - if err != nil { - log.Println(err) - continue + r.listener = l + go func() { + for !r.closed { + conn, err := r.listener.Accept() + if err != nil { + continue + } + go r.handleConnection(conn) } - go r.handleConnection(conn) + }() +} + +func (r *proxyRouter) Close() { + r.Lock() + defer r.Unlock() + + if r.listener != nil { + r.listener.Close() } + r.closed = true +} + +func (r *proxyRouter) ListenAddress() string { + if r.listener != nil { + return r.listener.Addr().String() + } + return "" } func (r *proxyRouter) handleConnection(c net.Conn) { defer c.Close() // first try tls vhostConn, err := vhost.TLS(c) - if err != nil { - log.Printf("Incoming TLS connection produced an error. Error: %s", err) - return - } - defer vhostConn.Close() - host := vhostConn.ClientHelloMsg.ServerName - c.LocalAddr() - dstHost, err := r.director(fmt.Sprintf("%s:%d", host, 12)) - if err != nil { - log.Printf("Error directing request: %v\n", err) - return - } + if err == nil { + // It is a TLS connection + defer vhostConn.Close() + host := vhostConn.ClientHelloMsg.ServerName + dstHost, err := r.director(host) + if err != nil { + log.Printf("Error directing request: %v\n", err) + return + } + d, err := net.Dial("tcp", dstHost.String()) + if err != nil { + log.Printf("Error dialing backend %s: %v\n", dstHost.String(), err) + return + } - d, err := net.Dial("tcp", dstHost.String()) - if err != nil { - log.Printf("Error dialing backend %s: %v\n", dstHost.String(), err) - return - } + proxy(vhostConn, d) + } else { + // it is not TLS + // treat it as an http connection + req, err := http.ReadRequest(bufio.NewReader(vhostConn)) + if err != nil { + // It is not http neither. So just close the connection. + return + } + dstHost, err := r.director(req.Host) + if err != nil { + log.Printf("Error directing request: %v\n", err) + return + } + d, err := net.Dial("tcp", dstHost.String()) + if err != nil { + log.Printf("Error dialing backend %s: %v\n", dstHost.String(), err) + return + } + err = req.Write(d) + if err != nil { + log.Printf("Error requesting backend %s: %v\n", dstHost.String(), err) + return + } + proxy(c, d) + } +} + +func proxy(src, dst net.Conn) { errc := make(chan error, 2) cp := func(dst io.Writer, src io.Reader) { _, err := io.Copy(dst, src) errc <- err } - go cp(d, vhostConn) - go cp(vhostConn, d) + go cp(src, dst) + go cp(dst, src) <-errc - /* - req, err := http.ReadRequest(bufio.NewReader(c)) - if err != nil { - log.Println(err) - return - } - - log.Println(req.Header) - */ } func NewRouter(director Director) *proxyRouter { diff --git a/router/router_test.go b/router/router_test.go index 192b524..fcdba89 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -4,15 +4,24 @@ import ( "crypto/tls" "fmt" "io/ioutil" + "log" "net" "net/http" "net/http/httptest" "net/url" + "strings" + "sync" "testing" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) +func getRouterUrl(scheme string, r *proxyRouter) string { + chunks := strings.Split(r.ListenAddress(), ":") + return fmt.Sprintf("%s://localhost:%s", scheme, chunks[len(chunks)-1]) +} + func TestProxy_TLS(t *testing.T) { tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, @@ -34,9 +43,10 @@ func TestProxy_TLS(t *testing.T) { a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil }) - go r.Listen(":8080") + r.Listen(":0") + defer r.Close() - req, err := http.NewRequest("GET", "https://localhost:8080", nil) + req, err := http.NewRequest("GET", getRouterUrl("https", r), nil) assert.Nil(t, err) resp, err := client.Do(req) @@ -45,5 +55,151 @@ func TestProxy_TLS(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) assert.Nil(t, err) assert.Equal(t, msg, string(body)) - assert.Equal(t, "localhost:8080", receivedHost) + assert.Equal(t, "localhost", receivedHost) +} + +func TestProxy_Http(t *testing.T) { + const msg = "It works!" + + var receivedHost string + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + })) + defer ts.Close() + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + u, _ := url.Parse(ts.URL) + a, _ := net.ResolveTCPAddr("tcp", u.Host) + return a, nil + }) + r.Listen(":0") + defer r.Close() + + req, err := http.NewRequest("GET", getRouterUrl("http", r), nil) + assert.Nil(t, err) + + resp, err := http.DefaultClient.Do(req) + assert.Nil(t, err) + + body, err := ioutil.ReadAll(resp.Body) + assert.Nil(t, err) + assert.Equal(t, msg, string(body)) + + u, _ := url.Parse(getRouterUrl("http", r)) + assert.Equal(t, u.Host, receivedHost) +} + +func TestProxy_WS(t *testing.T) { + const msg = "It works!" + + var serverReceivedMessage string + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Print("upgrade:", err) + return + } + defer c.Close() + mt, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + return + } + serverReceivedMessage = string(message) + err = c.WriteMessage(mt, message) + if err != nil { + log.Println("write:", err) + return + } + })) + defer ts.Close() + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + u, _ := url.Parse(ts.URL) + a, _ := net.ResolveTCPAddr("tcp", u.Host) + return a, nil + }) + r.Listen(":0") + defer r.Close() + + c, _, err := websocket.DefaultDialer.Dial(getRouterUrl("ws", r), nil) + assert.Nil(t, err) + defer c.Close() + + var clientReceivedMessage []byte + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + _, clientReceivedMessage, _ = c.ReadMessage() + wg.Done() + }() + err = c.WriteMessage(websocket.TextMessage, []byte(msg)) + assert.Nil(t, err) + + wg.Wait() + + assert.Equal(t, msg, string(clientReceivedMessage)) + assert.Equal(t, msg, serverReceivedMessage) +} + +func TestProxy_WSS(t *testing.T) { + const msg = "It works!" + + var serverReceivedMessage string + + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Print("upgrade:", err) + return + } + defer c.Close() + mt, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + return + } + serverReceivedMessage = string(message) + err = c.WriteMessage(mt, message) + if err != nil { + log.Println("write:", err) + return + } + })) + defer ts.Close() + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + u, _ := url.Parse(ts.URL) + a, _ := net.ResolveTCPAddr("tcp", u.Host) + return a, nil + }) + r.Listen(":0") + defer r.Close() + + d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + c, _, err := d.Dial(getRouterUrl("wss", r), nil) + assert.Nil(t, err) + defer c.Close() + + var clientReceivedMessage []byte + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + _, clientReceivedMessage, _ = c.ReadMessage() + wg.Done() + }() + + err = c.WriteMessage(websocket.TextMessage, []byte(msg)) + assert.Nil(t, err) + + wg.Wait() + + assert.Equal(t, msg, string(clientReceivedMessage)) + assert.Equal(t, msg, serverReceivedMessage) } From f7253a79cd0a47179b27bca06784940714332e03 Mon Sep 17 00:00:00 2001 From: "Jonathan Leibiusky @xetorthio" Date: Thu, 22 Jun 2017 17:34:24 -0300 Subject: [PATCH 3/6] WIP --- api.go | 9 - router/dns/dns.go | 111 ------------- router/router.go | 377 ++++++++++++++++++++++++++++++++++++------ router/router_test.go | 336 ++++++++++++++++++++++++++++++++++++- router/tcp/tcp.go | 143 ---------------- router/tls/tls.go | 95 ----------- 6 files changed, 649 insertions(+), 422 deletions(-) delete mode 100644 router/dns/dns.go delete mode 100644 router/tcp/tcp.go delete mode 100644 router/tls/tls.go diff --git a/api.go b/api.go index eff0941..fdcb8bc 100644 --- a/api.go +++ b/api.go @@ -26,12 +26,6 @@ func main() { r := mux.NewRouter() corsRouter := mux.NewRouter() - // Reverse proxy (needs to be the first route, to make sure it is the first thing we check) - //proxyHandler := handlers.NewMultipleHostReverseProxy() - //websocketProxyHandler := handlers.NewMultipleHostWebsocketReverseProxy() - - tcpHandler := handlers.NewTCPProxy() - corsHandler := gh.CORS(gh.AllowCredentials(), gh.AllowedHeaders([]string{"x-requested-with", "content-type"}), gh.AllowedOrigins([]string{"*"})) // Specific routes @@ -84,9 +78,6 @@ func main() { go handlers.ListenSSHProxy("0.0.0.0:1022") - // Now listen for TLS connections that need to be proxied - handlers.StartTLSProxy(config.SSLPortNumber) - log.Println("Listening on port " + config.PortNumber) log.Fatal(httpServer.ListenAndServe()) } diff --git a/router/dns/dns.go b/router/dns/dns.go deleted file mode 100644 index 2160e92..0000000 --- a/router/dns/dns.go +++ /dev/null @@ -1,111 +0,0 @@ -package dns - -import ( - "fmt" - "log" - "net" - "strings" - - "github.com/miekg/dns" - "github.com/play-with-docker/play-with-docker/config" -) - -func DnsRequest(w dns.ResponseWriter, r *dns.Msg) { - if len(r.Question) > 0 && config.NameFilter.MatchString(r.Question[0].Name) { - // this is something we know about and we should try to handle - question := r.Question[0].Name - - match := config.NameFilter.FindStringSubmatch(question) - - ip := strings.Replace(match[1], "-", ".", -1) - - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - m.RecursionAvailable = true - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, ip)) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - w.WriteMsg(m) - return - } else if len(r.Question) > 0 && config.AliasFilter.MatchString(r.Question[0].Name) { - // this is something we know about and we should try to handle - question := r.Question[0].Name - - match := config.AliasFilter.FindStringSubmatch(question) - - i := core.InstanceFindByAlias(match[2], match[1]) - - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - m.RecursionAvailable = true - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, i.IP)) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - w.WriteMsg(m) - return - } else { - if len(r.Question) > 0 { - question := r.Question[0].Name - - if question == "localhost." { - log.Printf("Not a PWD host. Asked for [localhost.] returning automatically [127.0.0.1]\n") - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - m.RecursionAvailable = true - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A 127.0.0.1", question)) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - w.WriteMsg(m) - return - } - - log.Printf("Not a PWD host. Looking up [%s]\n", question) - ips, err := net.LookupIP(question) - if err != nil { - // we have no information about this and we are not a recursive dns server, so we just fail so the client can fallback to the next dns server it has configured - w.Close() - // dns.HandleFailed(w, r) - return - } - log.Printf("Not a PWD host. Looking up [%s] got [%s]\n", question, ips) - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - m.RecursionAvailable = true - for _, ip := range ips { - ipv4 := ip.To4() - if ipv4 == nil { - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN AAAA %s", question, ip.String())) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - } else { - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, ipv4.String())) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - } - } - w.WriteMsg(m) - return - - } else { - log.Printf("Not a PWD host. Got DNS without any question\n") - // we have no information about this and we are not a recursive dns server, so we just fail so the client can fallback to the next dns server it has configured - w.Close() - // dns.HandleFailed(w, r) - return - } - } -} diff --git a/router/router.go b/router/router.go index b62895a..b46a986 100644 --- a/router/router.go +++ b/router/router.go @@ -2,13 +2,19 @@ package router import ( "bufio" + "fmt" "io" + "io/ioutil" "log" "net" "net/http" + "strings" "sync" + "golang.org/x/crypto/ssh" + vhost "github.com/inconshreveable/go-vhost" + "github.com/miekg/dns" ) type Director func(host string) (*net.TCPAddr, error) @@ -16,45 +22,279 @@ type Director func(host string) (*net.TCPAddr, error) type proxyRouter struct { sync.Mutex - director Director - listener net.Listener - closed bool + keyPath string + director Director + closed bool + httpListener net.Listener + udpDnsServer *dns.Server + tcpDnsServer *dns.Server + sshListener net.Listener + sshConfig *ssh.ServerConfig } -func (r *proxyRouter) Listen(laddr string) { - l, err := net.Listen("tcp", laddr) +func (r *proxyRouter) Listen(httpAddr, dnsAddr, sshAddr string) { + l, err := net.Listen("tcp", httpAddr) if err != nil { log.Fatal(err) } - r.listener = l + r.httpListener = l go func() { for !r.closed { - conn, err := r.listener.Accept() + conn, err := r.httpListener.Accept() if err != nil { continue } go r.handleConnection(conn) } }() + + dnsMux := dns.NewServeMux() + dnsMux.HandleFunc(".", r.dnsRequest) + r.udpDnsServer = &dns.Server{Addr: dnsAddr, Net: "udp", Handler: dnsMux} + r.tcpDnsServer = &dns.Server{Addr: dnsAddr, Net: "tcp", Handler: dnsMux} + + wg := sync.WaitGroup{} + wg.Add(2) + + r.udpDnsServer.NotifyStartedFunc = func() { + wg.Done() + } + r.tcpDnsServer.NotifyStartedFunc = func() { + wg.Done() + } + go r.udpDnsServer.ListenAndServe() + go r.tcpDnsServer.ListenAndServe() + wg.Wait() + + lssh, err := net.Listen("tcp", sshAddr) + if err != nil { + log.Fatal("failed to listen for connection: ", err) + } + r.sshListener = lssh + go func() { + for { + nConn, err := lssh.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection: ", err) + } + + go r.sshHandle(nConn) + } + }() +} + +func (r *proxyRouter) sshHandle(nConn net.Conn) { + sshCon, chans, reqs, err := ssh.NewServerConn(nConn, r.sshConfig) + if err != nil { + nConn.Close() + return + } + + dstHost, err := r.director(sshCon.User()) + if err != nil { + nConn.Close() + return + } + + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) + + newChannel := <-chans + if newChannel == nil { + sshCon.Close() + return + } + + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + return + } + + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatalf("Could not accept channel: %v", err) + } + + stderr := channel.Stderr() + + fmt.Fprintf(stderr, "Connecting to %s\r\n", dstHost.String()) + + clientConfig := &ssh.ClientConfig{ + User: "root", + Auth: []ssh.AuthMethod{ + ssh.Password("root"), + }, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + + client, err := ssh.Dial("tcp", dstHost.String(), clientConfig) + if err != nil { + fmt.Fprintf(stderr, "Connect failed: %v\r\n", err) + channel.Close() + return + } + + go func() { + for newChannel = range chans { + if newChannel == nil { + return + } + + channel2, reqs2, err := client.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData()) + if err != nil { + x, ok := err.(*ssh.OpenChannelError) + if ok { + newChannel.Reject(x.Reason, x.Message) + } else { + newChannel.Reject(ssh.Prohibited, "remote server denied channel request") + } + continue + } + + channel, reqs, err := newChannel.Accept() + if err != nil { + channel2.Close() + continue + } + go proxySsh(reqs, reqs2, channel, channel2) + } + }() + + // Forward the session channel + channel2, reqs2, err := client.OpenChannel("session", []byte{}) + if err != nil { + fmt.Fprintf(stderr, "Remote session setup failed: %v\r\n", err) + channel.Close() + return + } + + maskedReqs := make(chan *ssh.Request, 1) + go func() { + for req := range requests { + if req.Type == "auth-agent-req@openssh.com" { + continue + } + maskedReqs <- req + } + }() + proxySsh(maskedReqs, reqs2, channel, channel2) +} + +func (r *proxyRouter) dnsRequest(w dns.ResponseWriter, req *dns.Msg) { + if len(req.Question) > 0 { + question := req.Question[0].Name + + if question == "localhost." { + log.Printf("Asked for [localhost.] returning automatically [127.0.0.1]\n") + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + m.RecursionAvailable = true + a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A 127.0.0.1", question)) + if err != nil { + log.Fatal(err) + } + m.Answer = append(m.Answer, a) + w.WriteMsg(m) + return + } + + dstHost, err := r.director(strings.TrimSuffix(question, ".")) + if err != nil { + // Director couldn't resolve it, try to lookup in the system's DNS + ips, err := net.LookupIP(question) + if err != nil { + // we have no information about this and we are not a recursive dns server, so we just fail so the client can fallback to the next dns server it has configured + w.Close() + // dns.HandleFailed(w, r) + return + } + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + m.RecursionAvailable = true + for _, ip := range ips { + ipv4 := ip.To4() + if ipv4 == nil { + a, err := dns.NewRR(fmt.Sprintf("%s 60 IN AAAA %s", question, ip.String())) + if err != nil { + log.Fatal(err) + } + m.Answer = append(m.Answer, a) + } else { + a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, ipv4.String())) + if err != nil { + log.Fatal(err) + } + m.Answer = append(m.Answer, a) + } + } + w.WriteMsg(m) + return + } + + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + m.RecursionAvailable = true + a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, dstHost.IP)) + if err != nil { + log.Println(err) + w.Close() + // dns.HandleFailed(w, r) + return + } + m.Answer = append(m.Answer, a) + w.WriteMsg(m) + return + } } func (r *proxyRouter) Close() { r.Lock() defer r.Unlock() - if r.listener != nil { - r.listener.Close() + if r.httpListener != nil { + r.httpListener.Close() + } + if r.udpDnsServer != nil { + r.udpDnsServer.Shutdown() } r.closed = true } -func (r *proxyRouter) ListenAddress() string { - if r.listener != nil { - return r.listener.Addr().String() +func (r *proxyRouter) ListenHttpAddress() string { + if r.httpListener != nil { + return r.httpListener.Addr().String() } return "" } +func (r *proxyRouter) ListenDnsUdpAddress() string { + if r.udpDnsServer != nil && r.udpDnsServer.PacketConn != nil { + return r.udpDnsServer.PacketConn.LocalAddr().String() + } + + return "" +} +func (r *proxyRouter) ListenDnsTcpAddress() string { + if r.tcpDnsServer != nil && r.tcpDnsServer.Listener != nil { + return r.tcpDnsServer.Listener.Addr().String() + } + + return "" +} + +func (r *proxyRouter) ListenSshAddress() string { + if r.sshListener != nil { + return r.sshListener.Addr().String() + } + + return "" +} + func (r *proxyRouter) handleConnection(c net.Conn) { defer c.Close() // first try tls @@ -75,7 +315,7 @@ func (r *proxyRouter) handleConnection(c net.Conn) { return } - proxy(vhostConn, d) + proxyConn(vhostConn, d) } else { // it is not TLS // treat it as an http connection @@ -100,11 +340,59 @@ func (r *proxyRouter) handleConnection(c net.Conn) { log.Printf("Error requesting backend %s: %v\n", dstHost.String(), err) return } - proxy(c, d) + proxyConn(c, d) } } -func proxy(src, dst net.Conn) { +func proxySsh(reqs1, reqs2 <-chan *ssh.Request, channel1, channel2 ssh.Channel) { + var closer sync.Once + closeFunc := func() { + channel1.Close() + channel2.Close() + } + + defer closer.Do(closeFunc) + + closerChan := make(chan bool, 1) + + go func() { + io.Copy(channel1, channel2) + closerChan <- true + }() + + go func() { + io.Copy(channel2, channel1) + closerChan <- true + }() + + for { + select { + case req := <-reqs1: + if req == nil { + return + } + b, err := channel2.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + return + } + req.Reply(b, nil) + + case req := <-reqs2: + if req == nil { + return + } + b, err := channel1.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + return + } + req.Reply(b, nil) + case <-closerChan: + return + } + } +} + +func proxyConn(src, dst net.Conn) { errc := make(chan error, 2) cp := func(dst io.Writer, src io.Reader) { _, err := io.Copy(dst, src) @@ -115,44 +403,23 @@ func proxy(src, dst net.Conn) { <-errc } -func NewRouter(director Director) *proxyRouter { - return &proxyRouter{director: director} -} - -/* - // Start the DNS server - dns.HandleFunc(".", routerDns.DnsRequest) - udpDnsServer := &dns.Server{Addr: ":53", Net: "udp"} - go func() { - err := udpDnsServer.ListenAndServe() - if err != nil { - log.Fatal(err) - } - }() - tcpDnsServer := &dns.Server{Addr: ":53", Net: "tcp"} - go func() { - err := tcpDnsServer.ListenAndServe() - if err != nil { - log.Fatal(err) - } - }() - r := mux.NewRouter() - tcpHandler := handlers.NewTCPProxy() - r.Host(fmt.Sprintf("{subdomain:.*}pwd{node:%s}-{port:%s}.{tld:.*}", config.PWDHostnameRegex, config.PortRegex)).Handler(tcpHandler) - r.Host(fmt.Sprintf("{subdomain:.*}pwd{node:%s}.{tld:.*}", config.PWDHostnameRegex)).Handler(tcpHandler) - r.Host(fmt.Sprintf("pwd{alias:%s}-{session:%s}-{port:%s}.{tld:.*}", config.AliasnameRegex, config.AliasSessionRegex, config.PortRegex)).Handler(tcpHandler) - r.Host(fmt.Sprintf("pwd{alias:%s}-{session:%s}.{tld:.*}", config.AliasnameRegex, config.AliasSessionRegex)).Handler(tcpHandler) - r.HandleFunc("/ping", handlers.Ping).Methods("GET") - n := negroni.Classic() - n.UseHandler(r) - - httpServer := http.Server{ - Addr: "0.0.0.0:" + config.PortNumber, - Handler: n, - IdleTimeout: 30 * time.Second, - ReadHeaderTimeout: 5 * time.Second, +func NewRouter(director Director, keyPath string) *proxyRouter { + var sshConfig = &ssh.ServerConfig{ + PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + return nil, nil + }, } - // Now listen for TLS connections that need to be proxied - tls.StartTLSProxy(config.SSLPortNumber) - http.ListenAndServe() -*/ + privateBytes, err := ioutil.ReadFile(keyPath) + if err != nil { + log.Fatal("Failed to load private key: ", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key: ", err) + } + + sshConfig.AddHostKey(private) + + return &proxyRouter{director: director, sshConfig: sshConfig} +} diff --git a/router/router_test.go b/router/router_test.go index fcdba89..12f63d0 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -1,7 +1,12 @@ package router import ( + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "encoding/asn1" + "encoding/pem" "fmt" "io/ioutil" "log" @@ -9,20 +14,207 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "strings" "sync" "testing" + "golang.org/x/crypto/ssh" + "github.com/gorilla/websocket" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) +func testSshClient(user string, r *proxyRouter) error { + reader := rand.Reader + bitSize := 2048 + + key, err := rsa.GenerateKey(reader, bitSize) + if err != nil { + return err + } + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + return err + } + config := &ssh.ClientConfig{ + User: user, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + chunks := strings.Split(r.ListenSshAddress(), ":") + l := fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]) + client, err := ssh.Dial("tcp", l, config) + if err != nil { + return err + } + session, err := client.NewSession() + if err != nil { + return err + } + defer session.Close() + + return nil +} + +func testSshServer(f func(user, pass, ctype string)) (string, error) { + reader := rand.Reader + bitSize := 2048 + + key, err := rsa.GenerateKey(reader, bitSize) + if err != nil { + return "", err + } + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + return "", err + } + + var receivedUser string + var receivedPass string + var receivedChannelType string + + config := &ssh.ServerConfig{ + PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + receivedUser = conn.User() + receivedPass = string(password) + + return nil, nil + }, + } + config.AddHostKey(signer) + listener, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + return "", err + } + + go func() { + defer listener.Close() + nConn, err := listener.Accept() + if err != nil { + log.Println(err) + return + } + conn, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Println(err) + return + } + go ssh.DiscardRequests(reqs) + defer nConn.Close() + defer conn.Close() + + ch := <-chans + + receivedChannelType = ch.ChannelType() + + f(receivedUser, receivedPass, receivedChannelType) + }() + + return listener.Addr().String(), nil +} + +func generateKeys() (string, string, string, error) { + dir, err := ioutil.TempDir("", "pwd") + if err != nil { + return "", "", "", err + } + + reader := rand.Reader + bitSize := 2048 + + key, err := rsa.GenerateKey(reader, bitSize) + if err != nil { + return "", "", "", err + } + + privateFile, err := os.OpenFile(fmt.Sprintf("%s/id_rsa", dir), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return "", "", "", err + } + defer privateFile.Close() + + var privateKey = &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + } + + err = pem.Encode(privateFile, privateKey) + if err != nil { + return "", "", "", err + } + + publicFile, err := os.OpenFile(fmt.Sprintf("%s/id_rsa.pub", dir), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return "", "", "", err + } + defer publicFile.Close() + + asn1Bytes, err := asn1.Marshal(key.PublicKey) + if err != nil { + return "", "", "", err + } + + var publicKey = &pem.Block{ + Type: "PUBLIC KEY", + Bytes: asn1Bytes, + } + + err = pem.Encode(publicFile, publicKey) + if err != nil { + return "", "", "", err + } + return dir, privateFile.Name(), publicFile.Name(), nil +} + func getRouterUrl(scheme string, r *proxyRouter) string { - chunks := strings.Split(r.ListenAddress(), ":") + chunks := strings.Split(r.ListenHttpAddress(), ":") return fmt.Sprintf("%s://localhost:%s", scheme, chunks[len(chunks)-1]) } +func routerLookup(protocol, domain string, r *proxyRouter) ([]string, error) { + c := dns.Client{Net: protocol} + m := dns.Msg{} + + m.SetQuestion(fmt.Sprintf("%s.", domain), dns.TypeA) + var l string + if protocol == "udp" { + chunks := strings.Split(r.ListenDnsUdpAddress(), ":") + l = fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]) + } else if protocol == "tcp" { + chunks := strings.Split(r.ListenDnsTcpAddress(), ":") + l = fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]) + } + res, _, err := c.Exchange(&m, l) + + if err != nil { + return nil, err + } + + if len(res.Answer) == 0 { + return nil, fmt.Errorf("Didn't receive any answer") + } + addrs := []string{} + for _, a := range res.Answer { + if b, ok := a.(*dns.A); ok { + addrs = append(addrs, b.A.String()) + } else if b, ok := a.(*dns.AAAA); ok { + addrs = append(addrs, b.AAAA.String()) + } + } + + return addrs, nil +} + func TestProxy_TLS(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } @@ -42,8 +234,8 @@ func TestProxy_TLS(t *testing.T) { u, _ := url.Parse(ts.URL) a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil - }) - r.Listen(":0") + }, private) + r.Listen(":0", ":0", ":0") defer r.Close() req, err := http.NewRequest("GET", getRouterUrl("https", r), nil) @@ -59,6 +251,9 @@ func TestProxy_TLS(t *testing.T) { } func TestProxy_Http(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + const msg = "It works!" var receivedHost string @@ -73,8 +268,8 @@ func TestProxy_Http(t *testing.T) { u, _ := url.Parse(ts.URL) a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil - }) - r.Listen(":0") + }, private) + r.Listen(":0", ":0", ":0") defer r.Close() req, err := http.NewRequest("GET", getRouterUrl("http", r), nil) @@ -92,6 +287,9 @@ func TestProxy_Http(t *testing.T) { } func TestProxy_WS(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + const msg = "It works!" var serverReceivedMessage string @@ -122,8 +320,8 @@ func TestProxy_WS(t *testing.T) { u, _ := url.Parse(ts.URL) a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil - }) - r.Listen(":0") + }, private) + r.Listen(":0", ":0", ":0") defer r.Close() c, _, err := websocket.DefaultDialer.Dial(getRouterUrl("ws", r), nil) @@ -147,6 +345,9 @@ func TestProxy_WS(t *testing.T) { } func TestProxy_WSS(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + const msg = "It works!" var serverReceivedMessage string @@ -177,8 +378,8 @@ func TestProxy_WSS(t *testing.T) { u, _ := url.Parse(ts.URL) a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil - }) - r.Listen(":0") + }, private) + r.Listen(":0", ":0", ":0") defer r.Close() d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} @@ -203,3 +404,120 @@ func TestProxy_WSS(t *testing.T) { assert.Equal(t, msg, string(clientReceivedMessage)) assert.Equal(t, msg, serverReceivedMessage) } + +func TestProxy_DNS_UDP(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + var receivedHost string + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + if host == "10_0_0_1.foo.bar" { + a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0") + return a, nil + } else { + return nil, fmt.Errorf("Not recognized") + } + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + ips, err := routerLookup("udp", "10_0_0_1.foo.bar", r) + assert.Nil(t, err) + assert.Equal(t, "10_0_0_1.foo.bar", receivedHost) + assert.Equal(t, []string{"10.0.0.1"}, ips) + + ips, err = routerLookup("udp", "www.google.com", r) + assert.Nil(t, err) + assert.Equal(t, "www.google.com", receivedHost) + + expectedIps, err := net.LookupHost("www.google.com") + assert.Nil(t, err) + assert.Equal(t, expectedIps, ips) + + ips, err = routerLookup("udp", "localhost", r) + assert.Nil(t, err) + assert.NotEqual(t, "localhost", receivedHost) + assert.Equal(t, []string{"127.0.0.1"}, ips) +} + +func TestProxy_DNS_TCP(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + var receivedHost string + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + if host == "10_0_0_1.foo.bar" { + a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0") + return a, nil + } else { + return nil, fmt.Errorf("Not recognized") + } + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + ips, err := routerLookup("tcp", "10_0_0_1.foo.bar", r) + assert.Nil(t, err) + assert.Equal(t, "10_0_0_1.foo.bar", receivedHost) + assert.Equal(t, []string{"10.0.0.1"}, ips) + + ips, err = routerLookup("tcp", "www.google.com", r) + assert.Nil(t, err) + assert.Equal(t, "www.google.com", receivedHost) + + expectedIps, err := net.LookupHost("www.google.com") + assert.Nil(t, err) + assert.Equal(t, expectedIps, ips) + + ips, err = routerLookup("tcp", "localhost", r) + assert.Nil(t, err) + assert.NotEqual(t, "localhost", receivedHost) + assert.Equal(t, []string{"127.0.0.1"}, ips) +} + +func TestProxy_SSH(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + var receivedUser string + var receivedPass string + var receivedChannelType string + var receivedHost string + + wg := sync.WaitGroup{} + wg.Add(1) + laddr, err := testSshServer(func(user, pass, ctype string) { + receivedUser = user + receivedPass = pass + receivedChannelType = ctype + wg.Done() + }) + assert.Nil(t, err) + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + if host == "10-0-0-1-aaaabbbb" { + chunks := strings.Split(laddr, ":") + a, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1])) + return a, nil + } else { + return nil, fmt.Errorf("Not recognized") + } + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + err = testSshClient("10-0-0-1-aaaabbbb", r) + assert.Nil(t, err) + + wg.Wait() + + assert.Equal(t, "root", receivedUser) + assert.Equal(t, "root", receivedPass) + assert.Equal(t, "session", receivedChannelType) + assert.Equal(t, "10-0-0-1-aaaabbbb", receivedHost) +} diff --git a/router/tcp/tcp.go b/router/tcp/tcp.go deleted file mode 100644 index 80a1cca..0000000 --- a/router/tcp/tcp.go +++ /dev/null @@ -1,143 +0,0 @@ -package tcp - -import ( - "crypto/tls" - "fmt" - "io" - "log" - "net" - "net/http" - "strings" - - "github.com/franela/pwd.old/config" - "github.com/gorilla/mux" -) - -func getTargetInfo(vars map[string]string, req *http.Request) (string, string) { - node := vars["node"] - port := vars["port"] - alias := vars["alias"] - sessionPrefix := vars["session"] - hostPort := strings.Split(req.Host, ":") - - // give priority to the URL host port - if len(hostPort) > 1 && hostPort[1] != config.PortNumber { - port = hostPort[1] - } else if port == "" { - port = "80" - } - - if alias != "" { - instance := core.InstanceFindByAlias(sessionPrefix, alias) - if instance != nil { - node = instance.IP - return node, port - } - } - - // Node is actually an ip, need to convert underscores by dots. - ip := strings.Replace(node, "-", ".", -1) - - if net.ParseIP(ip) == nil { - // Not a valid IP, so treat this is a hostname. - } else { - node = ip - } - - return node, port - -} - -type tcpProxy struct { - Director func(*http.Request) - ErrorLog *log.Logger - Dial func(network, addr string) (net.Conn, error) -} - -func (p *tcpProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - logFunc := log.Printf - if p.ErrorLog != nil { - logFunc = p.ErrorLog.Printf - } - - vars := mux.Vars(r) - instanceIP := vars["node"] - - if i := core.InstanceFindByIP(strings.Replace(instanceIP, "-", ".", -1)); i == nil { - w.WriteHeader(http.StatusServiceUnavailable) - return - } - - outreq := new(http.Request) - // shallow copying - *outreq = *r - p.Director(outreq) - host := outreq.URL.Host - - dial := p.Dial - if dial == nil { - dial = net.Dial - } - - if outreq.URL.Scheme == "wss" || outreq.URL.Scheme == "https" { - var tlsConfig *tls.Config - tlsConfig = &tls.Config{InsecureSkipVerify: true} - dial = func(network, address string) (net.Conn, error) { - return tls.Dial("tcp", host, tlsConfig) - } - } - - d, err := dial("tcp", host) - if err != nil { - http.Error(w, "Error forwarding request.", 500) - logFunc("Error dialing websocket backend %s: %v", outreq.URL, err) - return - } - // All request generated by the http package implement this interface. - hj, ok := w.(http.Hijacker) - if !ok { - http.Error(w, "Not a hijacker?", 500) - return - } - // Hijack() tells the http package not to do anything else with the connection. - // After, it bcomes this functions job to manage it. `nc` is of type *net.Conn. - nc, _, err := hj.Hijack() - if err != nil { - logFunc("Hijack error: %v", err) - return - } - defer nc.Close() // must close the underlying net connection after hijacking - defer d.Close() - - // write the modified incoming request to the dialed connection - err = outreq.Write(d) - if err != nil { - logFunc("Error copying request to target: %v", err) - return - } - errc := make(chan error, 2) - cp := func(dst io.Writer, src io.Reader) { - _, err := io.Copy(dst, src) - errc <- err - } - go cp(d, nc) - go cp(nc, d) - <-errc -} -func NewTCPProxy() http.Handler { - director := func(req *http.Request) { - v := mux.Vars(req) - - node, port := getTargetInfo(v, req) - - if port == "443" { - if strings.Contains(req.URL.Scheme, "http") { - req.URL.Scheme = "https" - } else { - req.URL.Scheme = "wss" - } - } - req.URL.Host = fmt.Sprintf("%s:%s", node, port) - } - return &tcpProxy{Director: director} -} diff --git a/router/tls/tls.go b/router/tls/tls.go deleted file mode 100644 index a3c2ada..0000000 --- a/router/tls/tls.go +++ /dev/null @@ -1,95 +0,0 @@ -package tls - -import ( - "fmt" - "io" - "log" - "net" - "strings" - - vhost "github.com/inconshreveable/go-vhost" - "github.com/play-with-docker/play-with-docker/config" -) - -func StartTLSProxy(port string) { - - tlsListener, tlsErr := net.Listen("tcp", fmt.Sprintf(":%s", port)) - log.Println("Listening on port " + port) - if tlsErr != nil { - log.Fatal(tlsErr) - } - defer tlsListener.Close() - for { - // Wait for TLS Connection - conn, err := tlsListener.Accept() - if err != nil { - log.Printf("Could not accept new TLS connection. Error: %s", err) - continue - } - // Handle connection on a new goroutine and continue accepting other new connections - go func(c net.Conn) { - defer c.Close() - vhostConn, err := vhost.TLS(conn) - if err != nil { - log.Printf("Incoming TLS connection produced an error. Error: %s", err) - return - } - defer vhostConn.Close() - - var targetIP string - targetPort := "443" - - host := vhostConn.ClientHelloMsg.ServerName - match := config.NameFilter.FindStringSubmatch(host) - if len(match) < 2 { - // Not a valid proxy host, try alias hosts - match := config.AliasFilter.FindStringSubmatch(host) - if len(match) < 4 { - // Not valid, just close the connection - return - } else { - alias := match[1] - sessionPrefix := match[2] - instance := core.InstanceFindByAlias(sessionPrefix, alias) - if instance != nil { - targetIP = instance.IP - } else { - return - } - if len(match) == 4 { - targetPort = match[3] - } - } - } else { - // Valid proxy host - ip := strings.Replace(match[1], "-", ".", -1) - if net.ParseIP(ip) == nil { - // Not a valid IP, so treat this is a hostname. - return - } else { - targetIP = ip - } - if len(match) == 3 { - targetPort = match[2] - } - } - - dest := fmt.Sprintf("%s:%s", targetIP, targetPort) - d, err := net.Dial("tcp", dest) - if err != nil { - log.Printf("Error dialing backend %s: %v\n", dest, err) - return - } - - errc := make(chan error, 2) - cp := func(dst io.Writer, src io.Reader) { - _, err := io.Copy(dst, src) - errc <- err - } - go cp(d, vhostConn) - go cp(vhostConn, d) - <-errc - }(conn) - } - -} From d18b11ebd429c55220e803a58055d0779bedb2ac Mon Sep 17 00:00:00 2001 From: "Jonathan Leibiusky @xetorthio" Date: Tue, 11 Jul 2017 15:05:21 -0300 Subject: [PATCH 4/6] Add l2 implementation --- Dockerfile.l2 | 30 ++++++++++ api.go | 2 - config/config.go | 8 ++- router/l2/l2.go | 136 +++++++++++++++++++++++++++++++++++++++++++ router/l2/l2_test.go | 47 +++++++++++++++ router/router.go | 7 +++ 6 files changed, 226 insertions(+), 4 deletions(-) create mode 100644 Dockerfile.l2 create mode 100644 router/l2/l2.go create mode 100644 router/l2/l2_test.go diff --git a/Dockerfile.l2 b/Dockerfile.l2 new file mode 100644 index 0000000..372634e --- /dev/null +++ b/Dockerfile.l2 @@ -0,0 +1,30 @@ +FROM golang:1.8 + +# Copy the runtime dockerfile into the context as Dockerfile +COPY . /go/src/github.com/play-with-docker/play-with-docker + +WORKDIR /go/src/github.com/play-with-docker/play-with-docker + +RUN go get -v -d ./... + +RUN ssh-keygen -N "" -t rsa -f /etc/ssh/ssh_host_rsa_key >/dev/null + +WORKDIR /go/src/github.com/play-with-docker/play-with-docker/router/l2 + +RUN CGO_ENABLED=0 go build -a -installsuffix nocgo -o /go/bin/play-with-docker-l2 . + + +FROM alpine + +RUN apk --update add ca-certificates +RUN mkdir /app + +COPY --from=0 /go/bin/play-with-docker-l2 /app/play-with-docker-l2 +COPY --from=0 /etc/ssh/ssh_host_rsa_key /etc/ssh/ssh_host_rsa_key + +WORKDIR /app +CMD ["./play-with-docker-l2", "-ssh_key_path", "/etc/ssh/ssh_host_rsa_key"] + +EXPOSE 22 +EXPOSE 53 +EXPOSE 443 diff --git a/api.go b/api.go index fdcb8bc..9a69667 100644 --- a/api.go +++ b/api.go @@ -76,8 +76,6 @@ func main() { ReadHeaderTimeout: 5 * time.Second, } - go handlers.ListenSSHProxy("0.0.0.0:1022") - log.Println("Listening on port " + config.PortNumber) log.Fatal(httpServer.ListenAndServe()) } diff --git a/config/config.go b/config/config.go index b165620..bbeecdd 100644 --- a/config/config.go +++ b/config/config.go @@ -2,6 +2,7 @@ package config import ( "flag" + "log" "os" "regexp" "time" @@ -13,14 +14,14 @@ const ( AliasnameRegex = "[0-9|a-z|A-Z|-]*" AliasSessionRegex = "[0-9|a-z|A-Z]{8}" AliasGroupRegex = "(" + AliasnameRegex + ")-(" + AliasSessionRegex + ")" - PWDHostPortGroupRegex = "^.*pwd(" + PWDHostnameRegex + ")(?:-?(" + PortRegex + "))?\\..*$" + PWDHostPortGroupRegex = "^.*ip(" + PWDHostnameRegex + ")(?:-?(" + PortRegex + "))?(?:\\..*)?$" AliasPortGroupRegex = "^.*pwd" + AliasGroupRegex + "(?:-?(" + PortRegex + "))?\\..*$" ) var NameFilter = regexp.MustCompile(PWDHostPortGroupRegex) var AliasFilter = regexp.MustCompile(AliasPortGroupRegex) -var SSLPortNumber, PortNumber, Key, Cert, SessionsFile, PWDContainerName, PWDCName, HashKey string +var SSLPortNumber, PortNumber, Key, Cert, SessionsFile, PWDContainerName, PWDCName, HashKey, SSHKeyPath string var MaxLoadAvg float64 func ParseFlags() { @@ -33,7 +34,10 @@ func ParseFlags() { flag.StringVar(&PWDCName, "cname", "host1", "CNAME given to this host") flag.StringVar(&HashKey, "hash_key", "salmonrosado", "Hash key to use for cookies") flag.Float64Var(&MaxLoadAvg, "maxload", 100, "Maximum allowed load average before failing ping requests") + flag.StringVar(&SSHKeyPath, "ssh_key_path", "", "SSH Private Key to use") flag.Parse() + + log.Println("*****************************", SSHKeyPath) } func GetDindImageName() string { dindImage := os.Getenv("DIND_IMAGE") diff --git a/router/l2/l2.go b/router/l2/l2.go new file mode 100644 index 0000000..4055294 --- /dev/null +++ b/router/l2/l2.go @@ -0,0 +1,136 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net" + "os" + "strings" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/client" + "github.com/play-with-docker/play-with-docker/config" + "github.com/play-with-docker/play-with-docker/router" +) + +func director(host string) (*net.TCPAddr, error) { + chunks := strings.Split(host, ":") + matches := config.NameFilter.FindStringSubmatch(chunks[0]) + + var rawHost, port string + + if len(matches) == 3 { + rawHost = matches[1] + port = matches[2] + } else if len(matches) == 2 { + rawHost = matches[1] + } else { + return nil, fmt.Errorf("Couldn't find host in string") + } + + if port == "" { + if len(chunks) == 2 { + port = chunks[1] + } else { + port = "80" + } + } + + dstHost := strings.Replace(rawHost, "-", ".", -1) + + t, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%s", dstHost, port)) + if err != nil { + return nil, err + } + return t, nil +} + +func connectNetworks() error { + ctx := context.Background() + c, err := client.NewEnvClient() + if err != nil { + log.Fatal(err) + } + + defer c.Close() + + f, err := os.Open(config.SessionsFile) + if err != nil { + return err + } + defer f.Close() + + networks := map[string]*network.EndpointSettings{} + + err = json.NewDecoder(f).Decode(&networks) + if err != nil { + return err + } + + for netId, opts := range networks { + settings := &network.EndpointSettings{} + settings.IPAddress = opts.IPAddress + log.Printf("Connected to network [%s] with ip [%s]\n", netId, opts.IPAddress) + c.NetworkConnect(ctx, netId, config.PWDContainerName, settings) + } + + return nil +} + +func monitorNetworks() { + c, err := client.NewEnvClient() + if err != nil { + log.Fatal(err) + } + + defer c.Close() + + ctx := context.Background() + + args := filters.NewArgs() + + cmsg, _ := c.Events(ctx, types.EventsOptions{Filters: args}) + for { + select { + case m := <-cmsg: + if m.Type == "network" { + // Router has been connected to a new network. Let's get all connections and store them in case of restart. + container, err := c.ContainerInspect(ctx, config.PWDContainerName) + if err != nil { + log.Println(err) + return + } + + f, err := os.Create(config.SessionsFile) + if err != nil { + log.Println(err) + return + } + err = json.NewEncoder(f).Encode(container.NetworkSettings.Networks) + if err != nil { + log.Println(err) + return + } + log.Println("Saved networks") + } + } + } +} + +func main() { + config.ParseFlags() + + err := connectNetworks() + if err != nil && !os.IsNotExist(err) { + log.Fatal("connect networks:", err) + } + go monitorNetworks() + + r := router.NewRouter(director, config.SSHKeyPath) + r.Listen(":443", ":53", ":22") + defer r.Close() +} diff --git a/router/l2/l2_test.go b/router/l2/l2_test.go new file mode 100644 index 0000000..70f3fcf --- /dev/null +++ b/router/l2/l2_test.go @@ -0,0 +1,47 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDirector(t *testing.T) { + addr, err := director("ip10-0-0-1-8080.foo.bar") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:8080", addr.String()) + + addr, err = director("ip10-0-0-1.foo.bar") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:80", addr.String()) + + addr, err = director("ip10-0-0-1.foo.bar:9090") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:9090", addr.String()) + + addr, err = director("ip10-0-0-1-2222.foo.bar:9090") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:2222", addr.String()) + + addr, err = director("lala.ip10-0-0-1-2222.foo.bar") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:2222", addr.String()) + + addr, err = director("lala.ip10-0-0-1-2222") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:2222", addr.String()) + + addr, err = director("ip10-0-0-1-2222") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:2222", addr.String()) + + addr, err = director("ip10-0-0-1") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:80", addr.String()) + + _, err = director("lala10-0-0-1.foo.bar") + assert.NotNil(t, err) + + _, err = director("ip10-0-0-1-10-20") + assert.NotNil(t, err) +} diff --git a/router/router.go b/router/router.go index b46a986..04a3094 100644 --- a/router/router.go +++ b/router/router.go @@ -33,11 +33,14 @@ type proxyRouter struct { } func (r *proxyRouter) Listen(httpAddr, dnsAddr, sshAddr string) { + listenWG := sync.WaitGroup{} + l, err := net.Listen("tcp", httpAddr) if err != nil { log.Fatal(err) } r.httpListener = l + listenWG.Add(1) go func() { for !r.closed { conn, err := r.httpListener.Accept() @@ -46,6 +49,7 @@ func (r *proxyRouter) Listen(httpAddr, dnsAddr, sshAddr string) { } go r.handleConnection(conn) } + listenWG.Done() }() dnsMux := dns.NewServeMux() @@ -71,6 +75,7 @@ func (r *proxyRouter) Listen(httpAddr, dnsAddr, sshAddr string) { log.Fatal("failed to listen for connection: ", err) } r.sshListener = lssh + listenWG.Add(1) go func() { for { nConn, err := lssh.Accept() @@ -80,7 +85,9 @@ func (r *proxyRouter) Listen(httpAddr, dnsAddr, sshAddr string) { go r.sshHandle(nConn) } + listenWG.Done() }() + listenWG.Wait() } func (r *proxyRouter) sshHandle(nConn net.Conn) { From 6eaece99c53a3e6bc56caea8a52a9fffc9606a04 Mon Sep 17 00:00:00 2001 From: "Jonathan Leibiusky @xetorthio" Date: Wed, 12 Jul 2017 21:46:57 -0300 Subject: [PATCH 5/6] Add events --- event/event.go | 16 ++++++++++++++++ event/local_broker.go | 34 ++++++++++++++++++++++++++++++++++ event/local_broker_test.go | 31 +++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+) create mode 100644 event/event.go create mode 100644 event/local_broker.go create mode 100644 event/local_broker_test.go diff --git a/event/event.go b/event/event.go new file mode 100644 index 0000000..6a6541f --- /dev/null +++ b/event/event.go @@ -0,0 +1,16 @@ +package event + +type EventType string + +const INSTANCE_VIEWPORT_RESIZE EventType = "instance viewport resize" +const INSTANCE_DELETE EventType = "instance delete" +const INSTANCE_NEW EventType = "instance new" +const SESSION_END EventType = "session end" +const SESSION_READY EventType = "session ready" + +type Handler func(args ...interface{}) + +type EventApi interface { + Emit(name EventType, args ...interface{}) + On(name EventType, handler Handler) +} diff --git a/event/local_broker.go b/event/local_broker.go new file mode 100644 index 0000000..da68dd4 --- /dev/null +++ b/event/local_broker.go @@ -0,0 +1,34 @@ +package event + +import "sync" + +type localBroker struct { + sync.Mutex + + handlers map[EventType][]Handler +} + +func NewLocalBroker() *localBroker { + return &localBroker{handlers: map[EventType][]Handler{}} +} + +func (b *localBroker) On(name EventType, handler Handler) { + b.Lock() + defer b.Unlock() + + if b.handlers[name] == nil { + b.handlers[name] = []Handler{} + } + b.handlers[name] = append(b.handlers[name], handler) +} + +func (b *localBroker) Emit(name EventType, args ...interface{}) { + b.Lock() + defer b.Unlock() + + if b.handlers[name] != nil { + for _, handler := range b.handlers[name] { + handler(args...) + } + } +} diff --git a/event/local_broker_test.go b/event/local_broker_test.go new file mode 100644 index 0000000..777a4b8 --- /dev/null +++ b/event/local_broker_test.go @@ -0,0 +1,31 @@ +package event + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLocalBroker(t *testing.T) { + broker := NewLocalBroker() + + called := 0 + receivedArgs := []interface{}{} + + wg := sync.WaitGroup{} + wg.Add(1) + + broker.On(INSTANCE_NEW, func(args ...interface{}) { + called++ + receivedArgs = args + wg.Done() + }) + broker.Emit(SESSION_READY) + broker.Emit(INSTANCE_NEW, "foo", "bar") + + wg.Wait() + + assert.Equal(t, 1, called) + assert.Equal(t, []interface{}{"foo", "bar"}, receivedArgs) +} From 4731d8ec9846d5e92647eb05aa37072a91796d99 Mon Sep 17 00:00:00 2001 From: "Jonathan Leibiusky @xetorthio" Date: Tue, 18 Jul 2017 10:45:05 -0300 Subject: [PATCH 6/6] Event refactor --- api.go | 10 +++++++- event/event.go | 13 ++++++++-- event/local_broker.go | 29 ++++++++++++++++------ event/local_broker_test.go | 37 +++++++++++++++++++++++++--- handlers/bootstrap.go | 23 ++++++++++++------ pwd/broadcast.go | 34 -------------------------- pwd/broadcast_mock_test.go | 20 --------------- pwd/client.go | 3 ++- pwd/client_test.go | 21 +++++++++------- pwd/instance.go | 11 +++++---- pwd/instance_test.go | 25 ++++++++++--------- pwd/pwd.go | 13 +++++----- pwd/session.go | 14 +++++------ pwd/session_test.go | 9 ++++--- pwd/tasks.go | 9 ++++--- router/l2/l2.go | 2 +- router/router.go | 50 ++++++++++++++++++++++++++++++++++++++ 17 files changed, 198 insertions(+), 125 deletions(-) delete mode 100644 pwd/broadcast.go delete mode 100644 pwd/broadcast_mock_test.go diff --git a/api.go b/api.go index 9a69667..4c70c55 100644 --- a/api.go +++ b/api.go @@ -6,6 +6,7 @@ import ( "os" "time" + "github.com/googollee/go-socket.io" gh "github.com/gorilla/handlers" "github.com/gorilla/mux" "github.com/play-with-docker/play-with-docker/config" @@ -21,7 +22,14 @@ func main() { bypassCaptcha := len(os.Getenv("GOOGLE_RECAPTCHA_DISABLED")) > 0 - server := handlers.Broadcast.GetHandler() + server, err := socketio.NewServer(nil) + if err != nil { + log.Fatal(err) + } + server.On("connection", handlers.WS) + server.On("error", handlers.WSError) + + handlers.RegisterEvents(server) r := mux.NewRouter() corsRouter := mux.NewRouter() diff --git a/event/event.go b/event/event.go index 6a6541f..3216a88 100644 --- a/event/event.go +++ b/event/event.go @@ -2,15 +2,24 @@ package event type EventType string +func (e EventType) String() string { + return string(e) +} + const INSTANCE_VIEWPORT_RESIZE EventType = "instance viewport resize" const INSTANCE_DELETE EventType = "instance delete" const INSTANCE_NEW EventType = "instance new" +const INSTANCE_STATS EventType = "instance stats" +const INSTANCE_TERMINAL_OUT EventType = "instance terminal out" const SESSION_END EventType = "session end" const SESSION_READY EventType = "session ready" +const SESSION_BUILDER_OUT EventType = "session builder out" -type Handler func(args ...interface{}) +type Handler func(sessionId string, args ...interface{}) +type AnyHandler func(eventType EventType, sessionId string, args ...interface{}) type EventApi interface { - Emit(name EventType, args ...interface{}) + Emit(name EventType, sessionId string, args ...interface{}) On(name EventType, handler Handler) + OnAny(handler AnyHandler) } diff --git a/event/local_broker.go b/event/local_broker.go index da68dd4..2f33292 100644 --- a/event/local_broker.go +++ b/event/local_broker.go @@ -5,11 +5,12 @@ import "sync" type localBroker struct { sync.Mutex - handlers map[EventType][]Handler + handlers map[EventType][]Handler + anyHandlers []AnyHandler } func NewLocalBroker() *localBroker { - return &localBroker{handlers: map[EventType][]Handler{}} + return &localBroker{handlers: map[EventType][]Handler{}, anyHandlers: []AnyHandler{}} } func (b *localBroker) On(name EventType, handler Handler) { @@ -22,13 +23,25 @@ func (b *localBroker) On(name EventType, handler Handler) { b.handlers[name] = append(b.handlers[name], handler) } -func (b *localBroker) Emit(name EventType, args ...interface{}) { +func (b *localBroker) OnAny(handler AnyHandler) { b.Lock() defer b.Unlock() - if b.handlers[name] != nil { - for _, handler := range b.handlers[name] { - handler(args...) - } - } + b.anyHandlers = append(b.anyHandlers, handler) +} + +func (b *localBroker) Emit(name EventType, sessionId string, args ...interface{}) { + go func() { + b.Lock() + defer b.Unlock() + + for _, handler := range b.anyHandlers { + handler(name, sessionId, args...) + } + if b.handlers[name] != nil { + for _, handler := range b.handlers[name] { + handler(sessionId, args...) + } + } + }() } diff --git a/event/local_broker_test.go b/event/local_broker_test.go index 777a4b8..90939c8 100644 --- a/event/local_broker_test.go +++ b/event/local_broker_test.go @@ -7,25 +7,54 @@ import ( "github.com/stretchr/testify/assert" ) -func TestLocalBroker(t *testing.T) { +func TestLocalBroker_On(t *testing.T) { broker := NewLocalBroker() called := 0 + receivedSessionId := "" receivedArgs := []interface{}{} wg := sync.WaitGroup{} wg.Add(1) - broker.On(INSTANCE_NEW, func(args ...interface{}) { + broker.On(INSTANCE_NEW, func(sessionId string, args ...interface{}) { called++ + receivedSessionId = sessionId receivedArgs = args wg.Done() }) - broker.Emit(SESSION_READY) - broker.Emit(INSTANCE_NEW, "foo", "bar") + broker.Emit(SESSION_READY, "1") + broker.Emit(INSTANCE_NEW, "2", "foo", "bar") wg.Wait() assert.Equal(t, 1, called) + assert.Equal(t, "2", receivedSessionId) assert.Equal(t, []interface{}{"foo", "bar"}, receivedArgs) } + +func TestLocalBroker_OnAny(t *testing.T) { + broker := NewLocalBroker() + + var receivedEvent EventType + receivedSessionId := "" + receivedArgs := []interface{}{} + + wg := sync.WaitGroup{} + wg.Add(1) + + broker.OnAny(func(eventType EventType, sessionId string, args ...interface{}) { + receivedSessionId = sessionId + receivedArgs = args + receivedEvent = eventType + wg.Done() + }) + broker.Emit(SESSION_READY, "1") + + wg.Wait() + + var expectedArgs []interface{} + assert.Equal(t, SESSION_READY, receivedEvent) + assert.Equal(t, "1", receivedSessionId) + assert.Equal(t, expectedArgs, receivedArgs) +} diff --git a/handlers/bootstrap.go b/handlers/bootstrap.go index 9e8e90b..ffd17f0 100644 --- a/handlers/bootstrap.go +++ b/handlers/bootstrap.go @@ -5,14 +5,17 @@ import ( "os" "github.com/docker/docker/client" + "github.com/googollee/go-socket.io" "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd" "github.com/play-with-docker/play-with-docker/storage" ) var core pwd.PWDApi -var Broadcast pwd.BroadcastApi +var e event.EventApi +var ws *socketio.Server func Bootstrap() { c, err := client.NewEnvClient() @@ -22,18 +25,24 @@ func Bootstrap() { d := docker.NewDocker(c) - Broadcast, err = pwd.NewBroadcast(WS, WSError) - if err != nil { - log.Fatal(err) - } + e = event.NewLocalBroker() - t := pwd.NewScheduler(Broadcast, d) + t := pwd.NewScheduler(e, d) s, err := storage.NewFileStorage(config.SessionsFile) if err != nil && !os.IsNotExist(err) { log.Fatal("Error decoding sessions from disk ", err) } - core = pwd.NewPWD(d, t, Broadcast, s) + core = pwd.NewPWD(d, t, e, s) } + +func RegisterEvents(s *socketio.Server) { + ws = s + e.OnAny(broadcastEvent) +} + +func broadcastEvent(eventType event.EventType, sessionId string, args ...interface{}) { + ws.BroadcastTo(sessionId, eventType.String(), args...) +} diff --git a/pwd/broadcast.go b/pwd/broadcast.go deleted file mode 100644 index 1f54d08..0000000 --- a/pwd/broadcast.go +++ /dev/null @@ -1,34 +0,0 @@ -package pwd - -import ( - "net/http" - - "github.com/googollee/go-socket.io" -) - -type BroadcastApi interface { - BroadcastTo(sessionId, eventName string, args ...interface{}) - GetHandler() http.Handler -} - -type broadcast struct { - sio *socketio.Server -} - -func (b *broadcast) BroadcastTo(sessionId, eventName string, args ...interface{}) { - b.sio.BroadcastTo(sessionId, eventName, args...) -} - -func (b *broadcast) GetHandler() http.Handler { - return b.sio -} - -func NewBroadcast(connectionEvent, errorEvent interface{}) (*broadcast, error) { - server, err := socketio.NewServer(nil) - if err != nil { - return nil, err - } - server.On("connection", connectionEvent) - server.On("error", errorEvent) - return &broadcast{sio: server}, nil -} diff --git a/pwd/broadcast_mock_test.go b/pwd/broadcast_mock_test.go deleted file mode 100644 index 1d40cdc..0000000 --- a/pwd/broadcast_mock_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package pwd - -import "net/http" - -type mockBroadcast struct { - broadcastTo func(sessionId, eventName string, args ...interface{}) - getHandler func() http.Handler -} - -func (m *mockBroadcast) BroadcastTo(sessionId, eventName string, args ...interface{}) { - if m.broadcastTo != nil { - m.broadcastTo(sessionId, eventName, args...) - } -} -func (m *mockBroadcast) GetHandler() http.Handler { - if m.getHandler != nil { - return m.getHandler() - } - return nil -} diff --git a/pwd/client.go b/pwd/client.go index 6fc1021..10660bd 100644 --- a/pwd/client.go +++ b/pwd/client.go @@ -4,6 +4,7 @@ import ( "log" "time" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" ) @@ -41,7 +42,7 @@ func (p *pwd) ClientClose(client *types.Client) { func (p *pwd) notifyClientSmallestViewPort(session *types.Session) { vp := p.SessionGetSmallestViewPort(session) // Resize all terminals in the session - p.broadcast.BroadcastTo(session.Id, "viewport resize", vp.Cols, vp.Rows) + p.event.Emit(event.INSTANCE_VIEWPORT_RESIZE, session.Id, vp.Cols, vp.Rows) for _, instance := range session.Instances { err := p.InstanceResizeTerminal(instance, vp.Rows, vp.Cols) if err != nil { diff --git a/pwd/client_test.go b/pwd/client_test.go index 7ab9243..e682fdb 100644 --- a/pwd/client_test.go +++ b/pwd/client_test.go @@ -1,9 +1,11 @@ package pwd import ( + "sync" "testing" "time" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/stretchr/testify/assert" ) @@ -11,10 +13,10 @@ import ( func TestClientNew(t *testing.T) { docker := &mockDocker{} tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(docker, tasks, broadcast, storage) + p := NewPWD(docker, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") assert.Nil(t, err) @@ -26,33 +28,34 @@ func TestClientNew(t *testing.T) { } func TestClientResizeViewPort(t *testing.T) { + wg := sync.WaitGroup{} + wg.Add(1) docker := &mockDocker{} tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() broadcastedSessionId := "" - broadcastedEventName := "" broadcastedArgs := []interface{}{} - broadcast.broadcastTo = func(sessionId, eventName string, args ...interface{}) { + e.On(event.INSTANCE_VIEWPORT_RESIZE, func(sessionId string, args ...interface{}) { broadcastedSessionId = sessionId - broadcastedEventName = eventName broadcastedArgs = args - } + wg.Done() + }) storage := &mockStorage{} - p := NewPWD(docker, tasks, broadcast, storage) + p := NewPWD(docker, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") assert.Nil(t, err) client := p.ClientNew("foobar", session) p.ClientResizeViewPort(client, 80, 24) + wg.Wait() assert.Equal(t, types.ViewPort{Cols: 80, Rows: 24}, client.ViewPort) assert.Equal(t, session.Id, broadcastedSessionId) - assert.Equal(t, "viewport resize", broadcastedEventName) assert.Equal(t, uint(80), broadcastedArgs[0]) assert.Equal(t, uint(24), broadcastedArgs[1]) } diff --git a/pwd/instance.go b/pwd/instance.go index 5af7544..4c45c3d 100644 --- a/pwd/instance.go +++ b/pwd/instance.go @@ -11,6 +11,7 @@ import ( "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" "golang.org/x/text/encoding" @@ -19,11 +20,11 @@ import ( type sessionWriter struct { sessionId string instanceName string - broadcast BroadcastApi + event event.EventApi } func (s *sessionWriter) Write(p []byte) (n int, err error) { - s.broadcast.BroadcastTo(s.sessionId, "terminal out", s.instanceName, string(p)) + s.event.Emit(event.INSTANCE_TERMINAL_OUT, s.sessionId, s.instanceName, string(p)) return len(p), nil } @@ -52,7 +53,7 @@ func (p *pwd) InstanceAttachTerminal(instance *types.Instance) error { } encoder := encoding.Replacement.NewEncoder() - sw := &sessionWriter{sessionId: instance.Session.Id, instanceName: instance.Name, broadcast: p.broadcast} + sw := &sessionWriter{sessionId: instance.Session.Id, instanceName: instance.Name, event: p.event} instance.Terminal = conn io.Copy(encoder.Writer(sw), conn) @@ -139,7 +140,7 @@ func (p *pwd) InstanceDelete(session *types.Session, instance *types.Instance) e return err } - p.broadcast.BroadcastTo(session.Id, "delete instance", instance.Name) + p.event.Emit(event.INSTANCE_DELETE, session.Id, instance.Name) delete(session.Instances, instance.Name) if err := p.storage.SessionPut(session); err != nil { @@ -238,7 +239,7 @@ func (p *pwd) InstanceNew(session *types.Session, conf InstanceConfig) (*types.I return nil, err } - p.broadcast.BroadcastTo(session.Id, "new instance", instance.Name, instance.IP, instance.Hostname) + p.event.Emit(event.INSTANCE_NEW, session.Id, instance.Name, instance.IP, instance.Hostname) p.setGauges() diff --git a/pwd/instance_test.go b/pwd/instance_test.go index cb82ac1..334ed80 100644 --- a/pwd/instance_test.go +++ b/pwd/instance_test.go @@ -8,6 +8,7 @@ import ( "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/stretchr/testify/assert" ) @@ -27,10 +28,10 @@ func TestInstanceResizeTerminal(t *testing.T) { } tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(docker, tasks, broadcast, storage) + p := NewPWD(docker, tasks, e, storage) err := p.InstanceResizeTerminal(&types.Instance{Name: "foobar"}, 24, 80) @@ -49,10 +50,10 @@ func TestInstanceNew(t *testing.T) { } tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") @@ -99,10 +100,10 @@ func TestInstanceNew_Concurrency(t *testing.T) { } tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") @@ -140,10 +141,10 @@ func TestInstanceNew_WithNotAllowedImage(t *testing.T) { } tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") @@ -188,10 +189,10 @@ func TestInstanceNew_WithCustomHostname(t *testing.T) { } tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") @@ -230,10 +231,10 @@ func TestInstanceNew_WithCustomHostname(t *testing.T) { func TestInstanceAllowedImages(t *testing.T) { dock := &mockDocker{} tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, e, storage) expectedImages := []string{config.GetDindImageName(), "franela/dind:overlay2-dev", "franela/ucp:2.4.1"} diff --git a/pwd/pwd.go b/pwd/pwd.go index 1dcc8fa..a4c0496 100644 --- a/pwd/pwd.go +++ b/pwd/pwd.go @@ -5,6 +5,7 @@ import ( "time" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/play-with-docker/play-with-docker/storage" "github.com/prometheus/client_golang/prometheus" @@ -43,10 +44,10 @@ func init() { } type pwd struct { - docker docker.DockerApi - tasks SchedulerApi - broadcast BroadcastApi - storage storage.StorageApi + docker docker.DockerApi + tasks SchedulerApi + event event.EventApi + storage storage.StorageApi } type PWDApi interface { @@ -76,8 +77,8 @@ type PWDApi interface { ClientClose(client *types.Client) } -func NewPWD(d docker.DockerApi, t SchedulerApi, b BroadcastApi, s storage.StorageApi) *pwd { - return &pwd{docker: d, tasks: t, broadcast: b, storage: s} +func NewPWD(d docker.DockerApi, t SchedulerApi, e event.EventApi, s storage.StorageApi) *pwd { + return &pwd{docker: d, tasks: t, event: e, storage: s} } func (p *pwd) setGauges() { diff --git a/pwd/session.go b/pwd/session.go index ad36812..dc0ab7b 100644 --- a/pwd/session.go +++ b/pwd/session.go @@ -11,17 +11,18 @@ import ( "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/twinj/uuid" ) type sessionBuilderWriter struct { sessionId string - broadcast BroadcastApi + event event.EventApi } func (s *sessionBuilderWriter) Write(p []byte) (n int, err error) { - s.broadcast.BroadcastTo(s.sessionId, "session builder out", string(p)) + s.event.Emit(event.SESSION_BUILDER_OUT, s.sessionId, string(p)) return len(p), nil } @@ -87,8 +88,7 @@ func (p *pwd) SessionClose(s *types.Session) error { s.StopTicker() - p.broadcast.BroadcastTo(s.Id, "session end") - p.broadcast.BroadcastTo(s.Id, "disconnect") + p.event.Emit(event.SESSION_END, s.Id) log.Printf("Starting clean up of session [%s]\n", s.Id) for _, i := range s.Instances { err := p.InstanceDelete(s, i) @@ -146,7 +146,7 @@ func (p *pwd) SessionDeployStack(s *types.Session) error { } s.Ready = false - p.broadcast.BroadcastTo(s.Id, "session ready", false) + p.event.Emit(event.SESSION_READY, s.Id, false) i, err := p.InstanceNew(s, InstanceConfig{ImageName: s.ImageName, Host: s.Host}) if err != nil { log.Printf("Error creating instance for stack [%s]: %s\n", s.Stack, err) @@ -162,7 +162,7 @@ func (p *pwd) SessionDeployStack(s *types.Session) error { file := fmt.Sprintf("/var/run/pwd/uploads/%s", fileName) cmd := fmt.Sprintf("docker swarm init --advertise-addr eth0 && docker-compose -f %s pull && docker stack deploy -c %s %s", file, file, s.StackName) - w := sessionBuilderWriter{sessionId: s.Id, broadcast: p.broadcast} + w := sessionBuilderWriter{sessionId: s.Id, event: p.event} code, err := p.docker.ExecAttach(i.Name, []string{"sh", "-c", cmd}, &w) if err != nil { log.Printf("Error executing stack [%s]: %s\n", s.Stack, err) @@ -171,7 +171,7 @@ func (p *pwd) SessionDeployStack(s *types.Session) error { log.Printf("Stack execution finished with code %d\n", code) s.Ready = true - p.broadcast.BroadcastTo(s.Id, "session ready", true) + p.event.Emit(event.SESSION_READY, s.Id, true) if err := p.storage.SessionPut(s); err != nil { return err } diff --git a/pwd/session_test.go b/pwd/session_test.go index bef5a1d..4e9ac53 100644 --- a/pwd/session_test.go +++ b/pwd/session_test.go @@ -7,6 +7,7 @@ import ( "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/stretchr/testify/assert" ) @@ -36,14 +37,14 @@ func TestSessionNew(t *testing.T) { scheduledSession = s } - broadcast := &mockBroadcast{} + ev := event.NewLocalBroker() storage := &mockStorage{} storage.sessionPut = func(s *types.Session) error { saveCalled = true return nil } - p := NewPWD(docker, tasks, broadcast, storage) + p := NewPWD(docker, tasks, ev, storage) before := time.Now() @@ -166,10 +167,10 @@ func TestSessionSetup(t *testing.T) { return nil, nil } tasks := &mockTasks{} - broadcast := &mockBroadcast{} + ev := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, ev, storage) s, e := p.SessionNew(time.Hour, "", "", "") assert.Nil(t, e) diff --git a/pwd/tasks.go b/pwd/tasks.go index adc52ce..27d61c2 100644 --- a/pwd/tasks.go +++ b/pwd/tasks.go @@ -15,6 +15,7 @@ import ( "github.com/docker/docker/client" "github.com/docker/go-connections/tlsconfig" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" ) @@ -28,7 +29,7 @@ type SchedulerApi interface { } type scheduler struct { - broadcast BroadcastApi + event event.EventApi periodicTasks []periodicTask } @@ -102,7 +103,7 @@ func (sch *scheduler) Schedule(s *types.Session) { sort.Sort(ins.Ports) ins.CleanUsedPorts() - sch.broadcast.BroadcastTo(ins.Session.Id, "instance stats", ins.Name, ins.Mem, ins.Cpu, ins.IsManager, ins.Ports) + sch.event.Emit(event.INSTANCE_STATS, ins.Session.Id, ins.Name, ins.Mem, ins.Cpu, ins.IsManager, ins.Ports) } } }() @@ -111,8 +112,8 @@ func (sch *scheduler) Schedule(s *types.Session) { func (sch *scheduler) Unschedule(s *types.Session) { } -func NewScheduler(b BroadcastApi, d docker.DockerApi) *scheduler { - s := &scheduler{broadcast: b} +func NewScheduler(e event.EventApi, d docker.DockerApi) *scheduler { + s := &scheduler{event: e} s.periodicTasks = []periodicTask{&collectStatsTask{docker: d}, &checkSwarmStatusTask{}, &checkUsedPortsTask{}, &checkSwarmUsedPortsTask{}} return s } diff --git a/router/l2/l2.go b/router/l2/l2.go index 4055294..1cce468 100644 --- a/router/l2/l2.go +++ b/router/l2/l2.go @@ -131,6 +131,6 @@ func main() { go monitorNetworks() r := router.NewRouter(director, config.SSHKeyPath) - r.Listen(":443", ":53", ":22") + r.ListenAndWait(":443", ":53", ":22") defer r.Close() } diff --git a/router/router.go b/router/router.go index 04a3094..4948370 100644 --- a/router/router.go +++ b/router/router.go @@ -33,6 +33,56 @@ type proxyRouter struct { } func (r *proxyRouter) Listen(httpAddr, dnsAddr, sshAddr string) { + l, err := net.Listen("tcp", httpAddr) + if err != nil { + log.Fatal(err) + } + r.httpListener = l + go func() { + for !r.closed { + conn, err := r.httpListener.Accept() + if err != nil { + continue + } + go r.handleConnection(conn) + } + }() + + dnsMux := dns.NewServeMux() + dnsMux.HandleFunc(".", r.dnsRequest) + r.udpDnsServer = &dns.Server{Addr: dnsAddr, Net: "udp", Handler: dnsMux} + r.tcpDnsServer = &dns.Server{Addr: dnsAddr, Net: "tcp", Handler: dnsMux} + + wg := sync.WaitGroup{} + wg.Add(2) + + r.udpDnsServer.NotifyStartedFunc = func() { + wg.Done() + } + r.tcpDnsServer.NotifyStartedFunc = func() { + wg.Done() + } + go r.udpDnsServer.ListenAndServe() + go r.tcpDnsServer.ListenAndServe() + wg.Wait() + + lssh, err := net.Listen("tcp", sshAddr) + if err != nil { + log.Fatal("failed to listen for connection: ", err) + } + r.sshListener = lssh + go func() { + for { + nConn, err := lssh.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection: ", err) + } + + go r.sshHandle(nConn) + } + }() +} +func (r *proxyRouter) ListenAndWait(httpAddr, dnsAddr, sshAddr string) { listenWG := sync.WaitGroup{} l, err := net.Listen("tcp", httpAddr)