diff --git a/Dockerfile.l2 b/Dockerfile.l2 new file mode 100644 index 0000000..372634e --- /dev/null +++ b/Dockerfile.l2 @@ -0,0 +1,30 @@ +FROM golang:1.8 + +# Copy the runtime dockerfile into the context as Dockerfile +COPY . /go/src/github.com/play-with-docker/play-with-docker + +WORKDIR /go/src/github.com/play-with-docker/play-with-docker + +RUN go get -v -d ./... + +RUN ssh-keygen -N "" -t rsa -f /etc/ssh/ssh_host_rsa_key >/dev/null + +WORKDIR /go/src/github.com/play-with-docker/play-with-docker/router/l2 + +RUN CGO_ENABLED=0 go build -a -installsuffix nocgo -o /go/bin/play-with-docker-l2 . + + +FROM alpine + +RUN apk --update add ca-certificates +RUN mkdir /app + +COPY --from=0 /go/bin/play-with-docker-l2 /app/play-with-docker-l2 +COPY --from=0 /etc/ssh/ssh_host_rsa_key /etc/ssh/ssh_host_rsa_key + +WORKDIR /app +CMD ["./play-with-docker-l2", "-ssh_key_path", "/etc/ssh/ssh_host_rsa_key"] + +EXPOSE 22 +EXPOSE 53 +EXPOSE 443 diff --git a/api.go b/api.go index 313e2bf..4c70c55 100644 --- a/api.go +++ b/api.go @@ -1,15 +1,14 @@ package main import ( - "fmt" "log" "net/http" "os" "time" + "github.com/googollee/go-socket.io" gh "github.com/gorilla/handlers" "github.com/gorilla/mux" - "github.com/miekg/dns" "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/handlers" "github.com/play-with-docker/play-with-docker/templates" @@ -18,47 +17,26 @@ import ( ) func main() { - config.ParseFlags() handlers.Bootstrap() bypassCaptcha := len(os.Getenv("GOOGLE_RECAPTCHA_DISABLED")) > 0 - // Start the DNS server - dns.HandleFunc(".", handlers.DnsRequest) - udpDnsServer := &dns.Server{Addr: ":53", Net: "udp"} - go func() { - err := udpDnsServer.ListenAndServe() - if err != nil { - log.Fatal(err) - } - }() - tcpDnsServer := &dns.Server{Addr: ":53", Net: "tcp"} - go func() { - err := tcpDnsServer.ListenAndServe() - if err != nil { - log.Fatal(err) - } - }() + server, err := socketio.NewServer(nil) + if err != nil { + log.Fatal(err) + } + server.On("connection", handlers.WS) + server.On("error", handlers.WSError) - server := handlers.Broadcast.GetHandler() + handlers.RegisterEvents(server) r := mux.NewRouter() corsRouter := mux.NewRouter() - // Reverse proxy (needs to be the first route, to make sure it is the first thing we check) - //proxyHandler := handlers.NewMultipleHostReverseProxy() - //websocketProxyHandler := handlers.NewMultipleHostWebsocketReverseProxy() - - tcpHandler := handlers.NewTCPProxy() - corsHandler := gh.CORS(gh.AllowCredentials(), gh.AllowedHeaders([]string{"x-requested-with", "content-type"}), gh.AllowedOrigins([]string{"*"})) // Specific routes - r.Host(fmt.Sprintf("{subdomain:.*}pwd{node:%s}-{port:%s}.{tld:.*}", config.PWDHostnameRegex, config.PortRegex)).Handler(tcpHandler) - r.Host(fmt.Sprintf("{subdomain:.*}pwd{node:%s}.{tld:.*}", config.PWDHostnameRegex)).Handler(tcpHandler) - r.Host(fmt.Sprintf("pwd{alias:%s}-{session:%s}-{port:%s}.{tld:.*}", config.AliasnameRegex, config.AliasSessionRegex, config.PortRegex)).Handler(tcpHandler) - r.Host(fmt.Sprintf("pwd{alias:%s}-{session:%s}.{tld:.*}", config.AliasnameRegex, config.AliasSessionRegex)).Handler(tcpHandler) r.HandleFunc("/ping", handlers.Ping).Methods("GET") corsRouter.HandleFunc("/instances/images", handlers.GetInstanceImages).Methods("GET") corsRouter.HandleFunc("/sessions/{sessionId}", handlers.GetSession).Methods("GET") @@ -106,13 +84,6 @@ func main() { ReadHeaderTimeout: 5 * time.Second, } - go func() { - log.Println("Listening on port " + config.PortNumber) - log.Fatal(httpServer.ListenAndServe()) - }() - - go handlers.ListenSSHProxy("0.0.0.0:1022") - - // Now listen for TLS connections that need to be proxied - handlers.StartTLSProxy(config.SSLPortNumber) + log.Println("Listening on port " + config.PortNumber) + log.Fatal(httpServer.ListenAndServe()) } diff --git a/config/config.go b/config/config.go index b165620..bbeecdd 100644 --- a/config/config.go +++ b/config/config.go @@ -2,6 +2,7 @@ package config import ( "flag" + "log" "os" "regexp" "time" @@ -13,14 +14,14 @@ const ( AliasnameRegex = "[0-9|a-z|A-Z|-]*" AliasSessionRegex = "[0-9|a-z|A-Z]{8}" AliasGroupRegex = "(" + AliasnameRegex + ")-(" + AliasSessionRegex + ")" - PWDHostPortGroupRegex = "^.*pwd(" + PWDHostnameRegex + ")(?:-?(" + PortRegex + "))?\\..*$" + PWDHostPortGroupRegex = "^.*ip(" + PWDHostnameRegex + ")(?:-?(" + PortRegex + "))?(?:\\..*)?$" AliasPortGroupRegex = "^.*pwd" + AliasGroupRegex + "(?:-?(" + PortRegex + "))?\\..*$" ) var NameFilter = regexp.MustCompile(PWDHostPortGroupRegex) var AliasFilter = regexp.MustCompile(AliasPortGroupRegex) -var SSLPortNumber, PortNumber, Key, Cert, SessionsFile, PWDContainerName, PWDCName, HashKey string +var SSLPortNumber, PortNumber, Key, Cert, SessionsFile, PWDContainerName, PWDCName, HashKey, SSHKeyPath string var MaxLoadAvg float64 func ParseFlags() { @@ -33,7 +34,10 @@ func ParseFlags() { flag.StringVar(&PWDCName, "cname", "host1", "CNAME given to this host") flag.StringVar(&HashKey, "hash_key", "salmonrosado", "Hash key to use for cookies") flag.Float64Var(&MaxLoadAvg, "maxload", 100, "Maximum allowed load average before failing ping requests") + flag.StringVar(&SSHKeyPath, "ssh_key_path", "", "SSH Private Key to use") flag.Parse() + + log.Println("*****************************", SSHKeyPath) } func GetDindImageName() string { dindImage := os.Getenv("DIND_IMAGE") diff --git a/event/event.go b/event/event.go new file mode 100644 index 0000000..3216a88 --- /dev/null +++ b/event/event.go @@ -0,0 +1,25 @@ +package event + +type EventType string + +func (e EventType) String() string { + return string(e) +} + +const INSTANCE_VIEWPORT_RESIZE EventType = "instance viewport resize" +const INSTANCE_DELETE EventType = "instance delete" +const INSTANCE_NEW EventType = "instance new" +const INSTANCE_STATS EventType = "instance stats" +const INSTANCE_TERMINAL_OUT EventType = "instance terminal out" +const SESSION_END EventType = "session end" +const SESSION_READY EventType = "session ready" +const SESSION_BUILDER_OUT EventType = "session builder out" + +type Handler func(sessionId string, args ...interface{}) +type AnyHandler func(eventType EventType, sessionId string, args ...interface{}) + +type EventApi interface { + Emit(name EventType, sessionId string, args ...interface{}) + On(name EventType, handler Handler) + OnAny(handler AnyHandler) +} diff --git a/event/local_broker.go b/event/local_broker.go new file mode 100644 index 0000000..2f33292 --- /dev/null +++ b/event/local_broker.go @@ -0,0 +1,47 @@ +package event + +import "sync" + +type localBroker struct { + sync.Mutex + + handlers map[EventType][]Handler + anyHandlers []AnyHandler +} + +func NewLocalBroker() *localBroker { + return &localBroker{handlers: map[EventType][]Handler{}, anyHandlers: []AnyHandler{}} +} + +func (b *localBroker) On(name EventType, handler Handler) { + b.Lock() + defer b.Unlock() + + if b.handlers[name] == nil { + b.handlers[name] = []Handler{} + } + b.handlers[name] = append(b.handlers[name], handler) +} + +func (b *localBroker) OnAny(handler AnyHandler) { + b.Lock() + defer b.Unlock() + + b.anyHandlers = append(b.anyHandlers, handler) +} + +func (b *localBroker) Emit(name EventType, sessionId string, args ...interface{}) { + go func() { + b.Lock() + defer b.Unlock() + + for _, handler := range b.anyHandlers { + handler(name, sessionId, args...) + } + if b.handlers[name] != nil { + for _, handler := range b.handlers[name] { + handler(sessionId, args...) + } + } + }() +} diff --git a/event/local_broker_test.go b/event/local_broker_test.go new file mode 100644 index 0000000..90939c8 --- /dev/null +++ b/event/local_broker_test.go @@ -0,0 +1,60 @@ +package event + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLocalBroker_On(t *testing.T) { + broker := NewLocalBroker() + + called := 0 + receivedSessionId := "" + receivedArgs := []interface{}{} + + wg := sync.WaitGroup{} + wg.Add(1) + + broker.On(INSTANCE_NEW, func(sessionId string, args ...interface{}) { + called++ + receivedSessionId = sessionId + receivedArgs = args + wg.Done() + }) + broker.Emit(SESSION_READY, "1") + broker.Emit(INSTANCE_NEW, "2", "foo", "bar") + + wg.Wait() + + assert.Equal(t, 1, called) + assert.Equal(t, "2", receivedSessionId) + assert.Equal(t, []interface{}{"foo", "bar"}, receivedArgs) +} + +func TestLocalBroker_OnAny(t *testing.T) { + broker := NewLocalBroker() + + var receivedEvent EventType + receivedSessionId := "" + receivedArgs := []interface{}{} + + wg := sync.WaitGroup{} + wg.Add(1) + + broker.OnAny(func(eventType EventType, sessionId string, args ...interface{}) { + receivedSessionId = sessionId + receivedArgs = args + receivedEvent = eventType + wg.Done() + }) + broker.Emit(SESSION_READY, "1") + + wg.Wait() + + var expectedArgs []interface{} + assert.Equal(t, SESSION_READY, receivedEvent) + assert.Equal(t, "1", receivedSessionId) + assert.Equal(t, expectedArgs, receivedArgs) +} diff --git a/handlers/bootstrap.go b/handlers/bootstrap.go index 3ce0438..1342b6c 100644 --- a/handlers/bootstrap.go +++ b/handlers/bootstrap.go @@ -5,14 +5,17 @@ import ( "os" "github.com/docker/docker/client" + "github.com/googollee/go-socket.io" "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd" "github.com/play-with-docker/play-with-docker/storage" ) var core pwd.PWDApi -var Broadcast pwd.BroadcastApi +var e event.EventApi +var ws *socketio.Server func Bootstrap() { c, err := client.NewEnvClient() @@ -22,18 +25,24 @@ func Bootstrap() { d := docker.NewDocker(c) - Broadcast, err = pwd.NewBroadcast(WS, WSError) - if err != nil { - log.Fatal(err) - } + e = event.NewLocalBroker() - t := pwd.NewScheduler(Broadcast, d) + t := pwd.NewScheduler(e, d) s, err := storage.NewFileStorage(config.SessionsFile) if err != nil && !os.IsNotExist(err) { log.Fatal("Error initializing StorageAPI: ", err) } - core = pwd.NewPWD(d, t, Broadcast, s) + core = pwd.NewPWD(d, t, e, s) } + +func RegisterEvents(s *socketio.Server) { + ws = s + e.OnAny(broadcastEvent) +} + +func broadcastEvent(eventType event.EventType, sessionId string, args ...interface{}) { + ws.BroadcastTo(sessionId, eventType.String(), args...) +} diff --git a/handlers/dns.go b/handlers/dns.go deleted file mode 100644 index e199296..0000000 --- a/handlers/dns.go +++ /dev/null @@ -1,111 +0,0 @@ -package handlers - -import ( - "fmt" - "log" - "net" - "strings" - - "github.com/miekg/dns" - "github.com/play-with-docker/play-with-docker/config" -) - -func DnsRequest(w dns.ResponseWriter, r *dns.Msg) { - if len(r.Question) > 0 && config.NameFilter.MatchString(r.Question[0].Name) { - // this is something we know about and we should try to handle - question := r.Question[0].Name - - match := config.NameFilter.FindStringSubmatch(question) - - ip := strings.Replace(match[1], "-", ".", -1) - - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - m.RecursionAvailable = true - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, ip)) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - w.WriteMsg(m) - return - } else if len(r.Question) > 0 && config.AliasFilter.MatchString(r.Question[0].Name) { - // this is something we know about and we should try to handle - question := r.Question[0].Name - - match := config.AliasFilter.FindStringSubmatch(question) - - i := core.InstanceFindByAlias(match[2], match[1]) - - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - m.RecursionAvailable = true - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, i.IP)) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - w.WriteMsg(m) - return - } else { - if len(r.Question) > 0 { - question := r.Question[0].Name - - if question == "localhost." { - log.Printf("Not a PWD host. Asked for [localhost.] returning automatically [127.0.0.1]\n") - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - m.RecursionAvailable = true - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A 127.0.0.1", question)) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - w.WriteMsg(m) - return - } - - log.Printf("Not a PWD host. Looking up [%s]\n", question) - ips, err := net.LookupIP(question) - if err != nil { - // we have no information about this and we are not a recursive dns server, so we just fail so the client can fallback to the next dns server it has configured - w.Close() - // dns.HandleFailed(w, r) - return - } - log.Printf("Not a PWD host. Looking up [%s] got [%s]\n", question, ips) - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - m.RecursionAvailable = true - for _, ip := range ips { - ipv4 := ip.To4() - if ipv4 == nil { - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN AAAA %s", question, ip.String())) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - } else { - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, ipv4.String())) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - } - } - w.WriteMsg(m) - return - - } else { - log.Printf("Not a PWD host. Got DNS without any question\n") - // we have no information about this and we are not a recursive dns server, so we just fail so the client can fallback to the next dns server it has configured - w.Close() - // dns.HandleFailed(w, r) - return - } - } -} diff --git a/handlers/reverseproxy.go b/handlers/reverseproxy.go deleted file mode 100644 index ed5a72e..0000000 --- a/handlers/reverseproxy.go +++ /dev/null @@ -1,143 +0,0 @@ -package handlers - -import ( - "crypto/tls" - "fmt" - "io" - "log" - "net" - "net/http" - "strings" - - "github.com/gorilla/mux" - "github.com/play-with-docker/play-with-docker/config" -) - -func getTargetInfo(vars map[string]string, req *http.Request) (string, string) { - node := vars["node"] - port := vars["port"] - alias := vars["alias"] - sessionPrefix := vars["session"] - hostPort := strings.Split(req.Host, ":") - - // give priority to the URL host port - if len(hostPort) > 1 && hostPort[1] != config.PortNumber { - port = hostPort[1] - } else if port == "" { - port = "80" - } - - if alias != "" { - instance := core.InstanceFindByAlias(sessionPrefix, alias) - if instance != nil { - node = instance.IP - return node, port - } - } - - // Node is actually an ip, need to convert underscores by dots. - ip := strings.Replace(node, "-", ".", -1) - - if net.ParseIP(ip) == nil { - // Not a valid IP, so treat this is a hostname. - } else { - node = ip - } - - return node, port - -} - -type tcpProxy struct { - Director func(*http.Request) - ErrorLog *log.Logger - Dial func(network, addr string) (net.Conn, error) -} - -func (p *tcpProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - logFunc := log.Printf - if p.ErrorLog != nil { - logFunc = p.ErrorLog.Printf - } - - vars := mux.Vars(r) - instanceIP := vars["node"] - - if i := core.InstanceFindByIP(strings.Replace(instanceIP, "-", ".", -1)); i == nil { - w.WriteHeader(http.StatusServiceUnavailable) - return - } - - outreq := new(http.Request) - // shallow copying - *outreq = *r - p.Director(outreq) - host := outreq.URL.Host - - dial := p.Dial - if dial == nil { - dial = net.Dial - } - - if outreq.URL.Scheme == "wss" || outreq.URL.Scheme == "https" { - var tlsConfig *tls.Config - tlsConfig = &tls.Config{InsecureSkipVerify: true} - dial = func(network, address string) (net.Conn, error) { - return tls.Dial("tcp", host, tlsConfig) - } - } - - d, err := dial("tcp", host) - if err != nil { - http.Error(w, "Error forwarding request.", 500) - logFunc("Error dialing websocket backend %s: %v", outreq.URL, err) - return - } - // All request generated by the http package implement this interface. - hj, ok := w.(http.Hijacker) - if !ok { - http.Error(w, "Not a hijacker?", 500) - return - } - // Hijack() tells the http package not to do anything else with the connection. - // After, it bcomes this functions job to manage it. `nc` is of type *net.Conn. - nc, _, err := hj.Hijack() - if err != nil { - logFunc("Hijack error: %v", err) - return - } - defer nc.Close() // must close the underlying net connection after hijacking - defer d.Close() - - // write the modified incoming request to the dialed connection - err = outreq.Write(d) - if err != nil { - logFunc("Error copying request to target: %v", err) - return - } - errc := make(chan error, 2) - cp := func(dst io.Writer, src io.Reader) { - _, err := io.Copy(dst, src) - errc <- err - } - go cp(d, nc) - go cp(nc, d) - <-errc -} -func NewTCPProxy() http.Handler { - director := func(req *http.Request) { - v := mux.Vars(req) - - node, port := getTargetInfo(v, req) - - if port == "443" { - if strings.Contains(req.URL.Scheme, "http") { - req.URL.Scheme = "https" - } else { - req.URL.Scheme = "wss" - } - } - req.URL.Host = fmt.Sprintf("%s:%s", node, port) - } - return &tcpProxy{Director: director} -} diff --git a/handlers/tlsproxy.go b/handlers/tlsproxy.go deleted file mode 100644 index 4a1ed03..0000000 --- a/handlers/tlsproxy.go +++ /dev/null @@ -1,95 +0,0 @@ -package handlers - -import ( - "fmt" - "io" - "log" - "net" - "strings" - - vhost "github.com/inconshreveable/go-vhost" - "github.com/play-with-docker/play-with-docker/config" -) - -func StartTLSProxy(port string) { - - tlsListener, tlsErr := net.Listen("tcp", fmt.Sprintf(":%s", port)) - log.Println("Listening on port " + port) - if tlsErr != nil { - log.Fatal(tlsErr) - } - defer tlsListener.Close() - for { - // Wait for TLS Connection - conn, err := tlsListener.Accept() - if err != nil { - log.Printf("Could not accept new TLS connection. Error: %s", err) - continue - } - // Handle connection on a new goroutine and continue accepting other new connections - go func(c net.Conn) { - defer c.Close() - vhostConn, err := vhost.TLS(conn) - if err != nil { - log.Printf("Incoming TLS connection produced an error. Error: %s", err) - return - } - defer vhostConn.Close() - - var targetIP string - targetPort := "443" - - host := vhostConn.ClientHelloMsg.ServerName - match := config.NameFilter.FindStringSubmatch(host) - if len(match) < 2 { - // Not a valid proxy host, try alias hosts - match := config.AliasFilter.FindStringSubmatch(host) - if len(match) < 4 { - // Not valid, just close the connection - return - } else { - alias := match[1] - sessionPrefix := match[2] - instance := core.InstanceFindByAlias(sessionPrefix, alias) - if instance != nil { - targetIP = instance.IP - } else { - return - } - if len(match) == 4 { - targetPort = match[3] - } - } - } else { - // Valid proxy host - ip := strings.Replace(match[1], "-", ".", -1) - if net.ParseIP(ip) == nil { - // Not a valid IP, so treat this is a hostname. - return - } else { - targetIP = ip - } - if len(match) == 3 { - targetPort = match[2] - } - } - - dest := fmt.Sprintf("%s:%s", targetIP, targetPort) - d, err := net.Dial("tcp", dest) - if err != nil { - log.Printf("Error dialing backend %s: %v\n", dest, err) - return - } - - errc := make(chan error, 2) - cp := func(dst io.Writer, src io.Reader) { - _, err := io.Copy(dst, src) - errc <- err - } - go cp(d, vhostConn) - go cp(vhostConn, d) - <-errc - }(conn) - } - -} diff --git a/pwd/broadcast.go b/pwd/broadcast.go deleted file mode 100644 index 1f54d08..0000000 --- a/pwd/broadcast.go +++ /dev/null @@ -1,34 +0,0 @@ -package pwd - -import ( - "net/http" - - "github.com/googollee/go-socket.io" -) - -type BroadcastApi interface { - BroadcastTo(sessionId, eventName string, args ...interface{}) - GetHandler() http.Handler -} - -type broadcast struct { - sio *socketio.Server -} - -func (b *broadcast) BroadcastTo(sessionId, eventName string, args ...interface{}) { - b.sio.BroadcastTo(sessionId, eventName, args...) -} - -func (b *broadcast) GetHandler() http.Handler { - return b.sio -} - -func NewBroadcast(connectionEvent, errorEvent interface{}) (*broadcast, error) { - server, err := socketio.NewServer(nil) - if err != nil { - return nil, err - } - server.On("connection", connectionEvent) - server.On("error", errorEvent) - return &broadcast{sio: server}, nil -} diff --git a/pwd/broadcast_mock_test.go b/pwd/broadcast_mock_test.go deleted file mode 100644 index 1d40cdc..0000000 --- a/pwd/broadcast_mock_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package pwd - -import "net/http" - -type mockBroadcast struct { - broadcastTo func(sessionId, eventName string, args ...interface{}) - getHandler func() http.Handler -} - -func (m *mockBroadcast) BroadcastTo(sessionId, eventName string, args ...interface{}) { - if m.broadcastTo != nil { - m.broadcastTo(sessionId, eventName, args...) - } -} -func (m *mockBroadcast) GetHandler() http.Handler { - if m.getHandler != nil { - return m.getHandler() - } - return nil -} diff --git a/pwd/client.go b/pwd/client.go index ed3917c..01b3516 100644 --- a/pwd/client.go +++ b/pwd/client.go @@ -5,6 +5,7 @@ import ( "sync/atomic" "time" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" ) @@ -48,7 +49,7 @@ func (p *pwd) ClientCount() int { func (p *pwd) notifyClientSmallestViewPort(session *types.Session) { vp := p.SessionGetSmallestViewPort(session) // Resize all terminals in the session - p.broadcast.BroadcastTo(session.Id, "viewport resize", vp.Cols, vp.Rows) + p.event.Emit(event.INSTANCE_VIEWPORT_RESIZE, session.Id, vp.Cols, vp.Rows) for _, instance := range session.Instances { err := p.InstanceResizeTerminal(instance, vp.Rows, vp.Cols) if err != nil { diff --git a/pwd/client_test.go b/pwd/client_test.go index 85d0605..c0cffd1 100644 --- a/pwd/client_test.go +++ b/pwd/client_test.go @@ -1,9 +1,11 @@ package pwd import ( + "sync" "testing" "time" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/stretchr/testify/assert" ) @@ -11,10 +13,10 @@ import ( func TestClientNew(t *testing.T) { docker := &mockDocker{} tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(docker, tasks, broadcast, storage) + p := NewPWD(docker, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") assert.Nil(t, err) @@ -27,10 +29,10 @@ func TestClientNew(t *testing.T) { func TestClientCount(t *testing.T) { docker := &mockDocker{} tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(docker, tasks, broadcast, storage) + p := NewPWD(docker, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") assert.Nil(t, err) @@ -41,33 +43,34 @@ func TestClientCount(t *testing.T) { } func TestClientResizeViewPort(t *testing.T) { + wg := sync.WaitGroup{} + wg.Add(1) docker := &mockDocker{} tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() broadcastedSessionId := "" - broadcastedEventName := "" broadcastedArgs := []interface{}{} - broadcast.broadcastTo = func(sessionId, eventName string, args ...interface{}) { + e.On(event.INSTANCE_VIEWPORT_RESIZE, func(sessionId string, args ...interface{}) { broadcastedSessionId = sessionId - broadcastedEventName = eventName broadcastedArgs = args - } + wg.Done() + }) storage := &mockStorage{} - p := NewPWD(docker, tasks, broadcast, storage) + p := NewPWD(docker, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") assert.Nil(t, err) client := p.ClientNew("foobar", session) p.ClientResizeViewPort(client, 80, 24) + wg.Wait() assert.Equal(t, types.ViewPort{Cols: 80, Rows: 24}, client.ViewPort) assert.Equal(t, session.Id, broadcastedSessionId) - assert.Equal(t, "viewport resize", broadcastedEventName) assert.Equal(t, uint(80), broadcastedArgs[0]) assert.Equal(t, uint(24), broadcastedArgs[1]) } diff --git a/pwd/instance.go b/pwd/instance.go index a024c0b..32f28da 100644 --- a/pwd/instance.go +++ b/pwd/instance.go @@ -13,6 +13,7 @@ import ( "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" "golang.org/x/text/encoding" @@ -21,13 +22,13 @@ import ( type sessionWriter struct { sessionId string instanceName string - broadcast BroadcastApi + event event.EventApi } var terms = make(map[string]map[string]net.Conn) func (s *sessionWriter) Write(p []byte) (n int, err error) { - s.broadcast.BroadcastTo(s.sessionId, "terminal out", s.instanceName, string(p)) + s.event.Emit(event.INSTANCE_TERMINAL_OUT, s.sessionId, s.instanceName, string(p)) return len(p), nil } @@ -60,7 +61,7 @@ func (p *pwd) InstanceAttachTerminal(instance *types.Instance) error { } encoder := encoding.Replacement.NewEncoder() - sw := &sessionWriter{sessionId: instance.Session.Id, instanceName: instance.Name, broadcast: p.broadcast} + sw := &sessionWriter{sessionId: instance.Session.Id, instanceName: instance.Name, event: p.event} if terms[instance.SessionId] == nil { terms[instance.SessionId] = map[string]net.Conn{instance.Name: conn} } else { @@ -177,7 +178,7 @@ func (p *pwd) InstanceDelete(session *types.Session, instance *types.Instance) e return err } - p.broadcast.BroadcastTo(session.Id, "delete instance", instance.Name) + p.event.Emit(event.INSTANCE_DELETE, session.Id, instance.Name) delete(session.Instances, instance.Name) if err := p.storage.InstanceDelete(session.Id, instance.Name); err != nil { @@ -277,7 +278,7 @@ func (p *pwd) InstanceNew(session *types.Session, conf InstanceConfig) (*types.I return nil, err } - p.broadcast.BroadcastTo(session.Id, "new instance", instance.Name, instance.IP, instance.Hostname) + p.event.Emit(event.INSTANCE_NEW, session.Id, instance.Name, instance.IP, instance.Hostname) p.setGauges() diff --git a/pwd/instance_test.go b/pwd/instance_test.go index 2bfb9e5..308ea5a 100644 --- a/pwd/instance_test.go +++ b/pwd/instance_test.go @@ -10,6 +10,7 @@ import ( "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/stretchr/testify/assert" ) @@ -29,10 +30,10 @@ func TestInstanceResizeTerminal(t *testing.T) { } tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(docker, tasks, broadcast, storage) + p := NewPWD(docker, tasks, e, storage) err := p.InstanceResizeTerminal(&types.Instance{Name: "foobar"}, 24, 80) @@ -51,10 +52,10 @@ func TestInstanceNew(t *testing.T) { } tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") @@ -102,10 +103,10 @@ func TestInstanceNew_Concurrency(t *testing.T) { } tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") @@ -143,10 +144,10 @@ func TestInstanceNew_WithNotAllowedImage(t *testing.T) { } tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") @@ -192,10 +193,10 @@ func TestInstanceNew_WithCustomHostname(t *testing.T) { } tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, e, storage) session, err := p.SessionNew(time.Hour, "", "", "") @@ -235,10 +236,10 @@ func TestInstanceNew_WithCustomHostname(t *testing.T) { func TestInstanceAllowedImages(t *testing.T) { dock := &mockDocker{} tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, e, storage) expectedImages := []string{config.GetDindImageName(), "franela/dind:overlay2-dev", "franela/ucp:2.4.1"} @@ -256,7 +257,7 @@ func (ec errConn) Read(b []byte) (int, error) { func TestTermConnAssignment(t *testing.T) { dock := &mockDocker{} tasks := &mockTasks{} - broadcast := &mockBroadcast{} + e := event.NewLocalBroker() storage := &mockStorage{} dock.createAttachConnection = func(name string) (net.Conn, error) { @@ -264,7 +265,7 @@ func TestTermConnAssignment(t *testing.T) { return errConn{}, nil } - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, e, storage) session, _ := p.SessionNew(time.Hour, "", "", "") mockInstance := &types.Instance{ Name: fmt.Sprintf("%s_redis-master", session.Id[:8]), diff --git a/pwd/pwd.go b/pwd/pwd.go index 1d5782f..8300c9a 100644 --- a/pwd/pwd.go +++ b/pwd/pwd.go @@ -5,6 +5,7 @@ import ( "time" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/play-with-docker/play-with-docker/storage" "github.com/prometheus/client_golang/prometheus" @@ -45,7 +46,7 @@ func init() { type pwd struct { docker docker.DockerApi tasks SchedulerApi - broadcast BroadcastApi + event event.EventApi storage storage.StorageApi clientCount int32 } @@ -79,8 +80,8 @@ type PWDApi interface { ClientCount() int } -func NewPWD(d docker.DockerApi, t SchedulerApi, b BroadcastApi, s storage.StorageApi) *pwd { - return &pwd{docker: d, tasks: t, broadcast: b, storage: s} +func NewPWD(d docker.DockerApi, t SchedulerApi, e event.EventApi, s storage.StorageApi) *pwd { + return &pwd{docker: d, tasks: t, event: e, storage: s} } func (p *pwd) setGauges() { diff --git a/pwd/session.go b/pwd/session.go index 9dd26a2..badf98e 100644 --- a/pwd/session.go +++ b/pwd/session.go @@ -12,6 +12,7 @@ import ( "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/twinj/uuid" ) @@ -20,11 +21,11 @@ var preparedSessions = map[string]bool{} type sessionBuilderWriter struct { sessionId string - broadcast BroadcastApi + event event.EventApi } func (s *sessionBuilderWriter) Write(p []byte) (n int, err error) { - s.broadcast.BroadcastTo(s.sessionId, "session builder out", string(p)) + s.event.Emit(event.SESSION_BUILDER_OUT, s.sessionId, string(p)) return len(p), nil } @@ -90,8 +91,7 @@ func (p *pwd) SessionClose(s *types.Session) error { s.StopTicker() - p.broadcast.BroadcastTo(s.Id, "session end") - p.broadcast.BroadcastTo(s.Id, "disconnect") + p.event.Emit(event.SESSION_END, s.Id) log.Printf("Starting clean up of session [%s]\n", s.Id) for _, i := range s.Instances { err := p.InstanceDelete(s, i) @@ -149,7 +149,7 @@ func (p *pwd) SessionDeployStack(s *types.Session) error { } s.Ready = false - p.broadcast.BroadcastTo(s.Id, "session ready", false) + p.event.Emit(event.SESSION_READY, s.Id, false) i, err := p.InstanceNew(s, InstanceConfig{ImageName: s.ImageName, Host: s.Host}) if err != nil { log.Printf("Error creating instance for stack [%s]: %s\n", s.Stack, err) @@ -167,7 +167,7 @@ func (p *pwd) SessionDeployStack(s *types.Session) error { file := fmt.Sprintf("/var/run/pwd/uploads/%s", fileName) cmd := fmt.Sprintf("docker swarm init --advertise-addr eth0 && docker-compose -f %s pull && docker stack deploy -c %s %s", file, file, s.StackName) - w := sessionBuilderWriter{sessionId: s.Id, broadcast: p.broadcast} + w := sessionBuilderWriter{sessionId: s.Id, event: p.event} code, err := p.docker.ExecAttach(i.Name, []string{"sh", "-c", cmd}, &w) if err != nil { log.Printf("Error executing stack [%s]: %s\n", s.Stack, err) @@ -176,7 +176,7 @@ func (p *pwd) SessionDeployStack(s *types.Session) error { log.Printf("Stack execution finished with code %d\n", code) s.Ready = true - p.broadcast.BroadcastTo(s.Id, "session ready", true) + p.event.Emit(event.SESSION_READY, s.Id, true) if err := p.storage.SessionPut(s); err != nil { return err } diff --git a/pwd/session_test.go b/pwd/session_test.go index c31d3cb..fb0f22a 100644 --- a/pwd/session_test.go +++ b/pwd/session_test.go @@ -7,6 +7,7 @@ import ( "github.com/play-with-docker/play-with-docker/config" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" "github.com/stretchr/testify/assert" ) @@ -36,14 +37,14 @@ func TestSessionNew(t *testing.T) { scheduledSession = s } - broadcast := &mockBroadcast{} + ev := event.NewLocalBroker() storage := &mockStorage{} storage.sessionPut = func(s *types.Session) error { saveCalled = true return nil } - p := NewPWD(docker, tasks, broadcast, storage) + p := NewPWD(docker, tasks, ev, storage) before := time.Now() @@ -166,10 +167,10 @@ func TestSessionSetup(t *testing.T) { return nil, nil } tasks := &mockTasks{} - broadcast := &mockBroadcast{} + ev := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, ev, storage) s, e := p.SessionNew(time.Hour, "", "", "") assert.Nil(t, e) @@ -281,10 +282,10 @@ func TestSessionSetup(t *testing.T) { func TestSessionPrepareOnce(t *testing.T) { dock := &mockDocker{} tasks := &mockTasks{} - broadcast := &mockBroadcast{} + ev := event.NewLocalBroker() storage := &mockStorage{} - p := NewPWD(dock, tasks, broadcast, storage) + p := NewPWD(dock, tasks, ev, storage) session := &types.Session{Id: "1234"} prepared, err := p.prepareSession(session) assert.True(t, preparedSessions[session.Id]) diff --git a/pwd/tasks.go b/pwd/tasks.go index 6597217..412c1f2 100644 --- a/pwd/tasks.go +++ b/pwd/tasks.go @@ -15,6 +15,7 @@ import ( "github.com/docker/docker/client" "github.com/docker/go-connections/tlsconfig" "github.com/play-with-docker/play-with-docker/docker" + "github.com/play-with-docker/play-with-docker/event" "github.com/play-with-docker/play-with-docker/pwd/types" ) @@ -28,7 +29,7 @@ type SchedulerApi interface { } type scheduler struct { - broadcast BroadcastApi + event event.EventApi periodicTasks []periodicTask } @@ -102,7 +103,7 @@ func (sch *scheduler) Schedule(s *types.Session) { sort.Sort(ins.Ports) ins.CleanUsedPorts() - sch.broadcast.BroadcastTo(ins.Session.Id, "instance stats", ins.Name, ins.Mem, ins.Cpu, ins.IsManager, ins.Ports) + sch.event.Emit(event.INSTANCE_STATS, ins.Session.Id, ins.Name, ins.Mem, ins.Cpu, ins.IsManager, ins.Ports) } } }() @@ -111,8 +112,8 @@ func (sch *scheduler) Schedule(s *types.Session) { func (sch *scheduler) Unschedule(s *types.Session) { } -func NewScheduler(b BroadcastApi, d docker.DockerApi) *scheduler { - s := &scheduler{broadcast: b} +func NewScheduler(e event.EventApi, d docker.DockerApi) *scheduler { + s := &scheduler{event: e} s.periodicTasks = []periodicTask{&collectStatsTask{docker: d}, &checkSwarmStatusTask{}, &checkUsedPortsTask{}, &checkSwarmUsedPortsTask{}} return s } diff --git a/router/l2/l2.go b/router/l2/l2.go new file mode 100644 index 0000000..1cce468 --- /dev/null +++ b/router/l2/l2.go @@ -0,0 +1,136 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net" + "os" + "strings" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/client" + "github.com/play-with-docker/play-with-docker/config" + "github.com/play-with-docker/play-with-docker/router" +) + +func director(host string) (*net.TCPAddr, error) { + chunks := strings.Split(host, ":") + matches := config.NameFilter.FindStringSubmatch(chunks[0]) + + var rawHost, port string + + if len(matches) == 3 { + rawHost = matches[1] + port = matches[2] + } else if len(matches) == 2 { + rawHost = matches[1] + } else { + return nil, fmt.Errorf("Couldn't find host in string") + } + + if port == "" { + if len(chunks) == 2 { + port = chunks[1] + } else { + port = "80" + } + } + + dstHost := strings.Replace(rawHost, "-", ".", -1) + + t, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%s", dstHost, port)) + if err != nil { + return nil, err + } + return t, nil +} + +func connectNetworks() error { + ctx := context.Background() + c, err := client.NewEnvClient() + if err != nil { + log.Fatal(err) + } + + defer c.Close() + + f, err := os.Open(config.SessionsFile) + if err != nil { + return err + } + defer f.Close() + + networks := map[string]*network.EndpointSettings{} + + err = json.NewDecoder(f).Decode(&networks) + if err != nil { + return err + } + + for netId, opts := range networks { + settings := &network.EndpointSettings{} + settings.IPAddress = opts.IPAddress + log.Printf("Connected to network [%s] with ip [%s]\n", netId, opts.IPAddress) + c.NetworkConnect(ctx, netId, config.PWDContainerName, settings) + } + + return nil +} + +func monitorNetworks() { + c, err := client.NewEnvClient() + if err != nil { + log.Fatal(err) + } + + defer c.Close() + + ctx := context.Background() + + args := filters.NewArgs() + + cmsg, _ := c.Events(ctx, types.EventsOptions{Filters: args}) + for { + select { + case m := <-cmsg: + if m.Type == "network" { + // Router has been connected to a new network. Let's get all connections and store them in case of restart. + container, err := c.ContainerInspect(ctx, config.PWDContainerName) + if err != nil { + log.Println(err) + return + } + + f, err := os.Create(config.SessionsFile) + if err != nil { + log.Println(err) + return + } + err = json.NewEncoder(f).Encode(container.NetworkSettings.Networks) + if err != nil { + log.Println(err) + return + } + log.Println("Saved networks") + } + } + } +} + +func main() { + config.ParseFlags() + + err := connectNetworks() + if err != nil && !os.IsNotExist(err) { + log.Fatal("connect networks:", err) + } + go monitorNetworks() + + r := router.NewRouter(director, config.SSHKeyPath) + r.ListenAndWait(":443", ":53", ":22") + defer r.Close() +} diff --git a/router/l2/l2_test.go b/router/l2/l2_test.go new file mode 100644 index 0000000..70f3fcf --- /dev/null +++ b/router/l2/l2_test.go @@ -0,0 +1,47 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDirector(t *testing.T) { + addr, err := director("ip10-0-0-1-8080.foo.bar") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:8080", addr.String()) + + addr, err = director("ip10-0-0-1.foo.bar") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:80", addr.String()) + + addr, err = director("ip10-0-0-1.foo.bar:9090") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:9090", addr.String()) + + addr, err = director("ip10-0-0-1-2222.foo.bar:9090") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:2222", addr.String()) + + addr, err = director("lala.ip10-0-0-1-2222.foo.bar") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:2222", addr.String()) + + addr, err = director("lala.ip10-0-0-1-2222") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:2222", addr.String()) + + addr, err = director("ip10-0-0-1-2222") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:2222", addr.String()) + + addr, err = director("ip10-0-0-1") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:80", addr.String()) + + _, err = director("lala10-0-0-1.foo.bar") + assert.NotNil(t, err) + + _, err = director("ip10-0-0-1-10-20") + assert.NotNil(t, err) +} diff --git a/router/router.go b/router/router.go new file mode 100644 index 0000000..4948370 --- /dev/null +++ b/router/router.go @@ -0,0 +1,482 @@ +package router + +import ( + "bufio" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "strings" + "sync" + + "golang.org/x/crypto/ssh" + + vhost "github.com/inconshreveable/go-vhost" + "github.com/miekg/dns" +) + +type Director func(host string) (*net.TCPAddr, error) + +type proxyRouter struct { + sync.Mutex + + keyPath string + director Director + closed bool + httpListener net.Listener + udpDnsServer *dns.Server + tcpDnsServer *dns.Server + sshListener net.Listener + sshConfig *ssh.ServerConfig +} + +func (r *proxyRouter) Listen(httpAddr, dnsAddr, sshAddr string) { + l, err := net.Listen("tcp", httpAddr) + if err != nil { + log.Fatal(err) + } + r.httpListener = l + go func() { + for !r.closed { + conn, err := r.httpListener.Accept() + if err != nil { + continue + } + go r.handleConnection(conn) + } + }() + + dnsMux := dns.NewServeMux() + dnsMux.HandleFunc(".", r.dnsRequest) + r.udpDnsServer = &dns.Server{Addr: dnsAddr, Net: "udp", Handler: dnsMux} + r.tcpDnsServer = &dns.Server{Addr: dnsAddr, Net: "tcp", Handler: dnsMux} + + wg := sync.WaitGroup{} + wg.Add(2) + + r.udpDnsServer.NotifyStartedFunc = func() { + wg.Done() + } + r.tcpDnsServer.NotifyStartedFunc = func() { + wg.Done() + } + go r.udpDnsServer.ListenAndServe() + go r.tcpDnsServer.ListenAndServe() + wg.Wait() + + lssh, err := net.Listen("tcp", sshAddr) + if err != nil { + log.Fatal("failed to listen for connection: ", err) + } + r.sshListener = lssh + go func() { + for { + nConn, err := lssh.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection: ", err) + } + + go r.sshHandle(nConn) + } + }() +} +func (r *proxyRouter) ListenAndWait(httpAddr, dnsAddr, sshAddr string) { + listenWG := sync.WaitGroup{} + + l, err := net.Listen("tcp", httpAddr) + if err != nil { + log.Fatal(err) + } + r.httpListener = l + listenWG.Add(1) + go func() { + for !r.closed { + conn, err := r.httpListener.Accept() + if err != nil { + continue + } + go r.handleConnection(conn) + } + listenWG.Done() + }() + + dnsMux := dns.NewServeMux() + dnsMux.HandleFunc(".", r.dnsRequest) + r.udpDnsServer = &dns.Server{Addr: dnsAddr, Net: "udp", Handler: dnsMux} + r.tcpDnsServer = &dns.Server{Addr: dnsAddr, Net: "tcp", Handler: dnsMux} + + wg := sync.WaitGroup{} + wg.Add(2) + + r.udpDnsServer.NotifyStartedFunc = func() { + wg.Done() + } + r.tcpDnsServer.NotifyStartedFunc = func() { + wg.Done() + } + go r.udpDnsServer.ListenAndServe() + go r.tcpDnsServer.ListenAndServe() + wg.Wait() + + lssh, err := net.Listen("tcp", sshAddr) + if err != nil { + log.Fatal("failed to listen for connection: ", err) + } + r.sshListener = lssh + listenWG.Add(1) + go func() { + for { + nConn, err := lssh.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection: ", err) + } + + go r.sshHandle(nConn) + } + listenWG.Done() + }() + listenWG.Wait() +} + +func (r *proxyRouter) sshHandle(nConn net.Conn) { + sshCon, chans, reqs, err := ssh.NewServerConn(nConn, r.sshConfig) + if err != nil { + nConn.Close() + return + } + + dstHost, err := r.director(sshCon.User()) + if err != nil { + nConn.Close() + return + } + + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) + + newChannel := <-chans + if newChannel == nil { + sshCon.Close() + return + } + + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + return + } + + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatalf("Could not accept channel: %v", err) + } + + stderr := channel.Stderr() + + fmt.Fprintf(stderr, "Connecting to %s\r\n", dstHost.String()) + + clientConfig := &ssh.ClientConfig{ + User: "root", + Auth: []ssh.AuthMethod{ + ssh.Password("root"), + }, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + + client, err := ssh.Dial("tcp", dstHost.String(), clientConfig) + if err != nil { + fmt.Fprintf(stderr, "Connect failed: %v\r\n", err) + channel.Close() + return + } + + go func() { + for newChannel = range chans { + if newChannel == nil { + return + } + + channel2, reqs2, err := client.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData()) + if err != nil { + x, ok := err.(*ssh.OpenChannelError) + if ok { + newChannel.Reject(x.Reason, x.Message) + } else { + newChannel.Reject(ssh.Prohibited, "remote server denied channel request") + } + continue + } + + channel, reqs, err := newChannel.Accept() + if err != nil { + channel2.Close() + continue + } + go proxySsh(reqs, reqs2, channel, channel2) + } + }() + + // Forward the session channel + channel2, reqs2, err := client.OpenChannel("session", []byte{}) + if err != nil { + fmt.Fprintf(stderr, "Remote session setup failed: %v\r\n", err) + channel.Close() + return + } + + maskedReqs := make(chan *ssh.Request, 1) + go func() { + for req := range requests { + if req.Type == "auth-agent-req@openssh.com" { + continue + } + maskedReqs <- req + } + }() + proxySsh(maskedReqs, reqs2, channel, channel2) +} + +func (r *proxyRouter) dnsRequest(w dns.ResponseWriter, req *dns.Msg) { + if len(req.Question) > 0 { + question := req.Question[0].Name + + if question == "localhost." { + log.Printf("Asked for [localhost.] returning automatically [127.0.0.1]\n") + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + m.RecursionAvailable = true + a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A 127.0.0.1", question)) + if err != nil { + log.Fatal(err) + } + m.Answer = append(m.Answer, a) + w.WriteMsg(m) + return + } + + dstHost, err := r.director(strings.TrimSuffix(question, ".")) + if err != nil { + // Director couldn't resolve it, try to lookup in the system's DNS + ips, err := net.LookupIP(question) + if err != nil { + // we have no information about this and we are not a recursive dns server, so we just fail so the client can fallback to the next dns server it has configured + w.Close() + // dns.HandleFailed(w, r) + return + } + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + m.RecursionAvailable = true + for _, ip := range ips { + ipv4 := ip.To4() + if ipv4 == nil { + a, err := dns.NewRR(fmt.Sprintf("%s 60 IN AAAA %s", question, ip.String())) + if err != nil { + log.Fatal(err) + } + m.Answer = append(m.Answer, a) + } else { + a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, ipv4.String())) + if err != nil { + log.Fatal(err) + } + m.Answer = append(m.Answer, a) + } + } + w.WriteMsg(m) + return + } + + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + m.RecursionAvailable = true + a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, dstHost.IP)) + if err != nil { + log.Println(err) + w.Close() + // dns.HandleFailed(w, r) + return + } + m.Answer = append(m.Answer, a) + w.WriteMsg(m) + return + } +} + +func (r *proxyRouter) Close() { + r.Lock() + defer r.Unlock() + + if r.httpListener != nil { + r.httpListener.Close() + } + if r.udpDnsServer != nil { + r.udpDnsServer.Shutdown() + } + r.closed = true +} + +func (r *proxyRouter) ListenHttpAddress() string { + if r.httpListener != nil { + return r.httpListener.Addr().String() + } + return "" +} + +func (r *proxyRouter) ListenDnsUdpAddress() string { + if r.udpDnsServer != nil && r.udpDnsServer.PacketConn != nil { + return r.udpDnsServer.PacketConn.LocalAddr().String() + } + + return "" +} +func (r *proxyRouter) ListenDnsTcpAddress() string { + if r.tcpDnsServer != nil && r.tcpDnsServer.Listener != nil { + return r.tcpDnsServer.Listener.Addr().String() + } + + return "" +} + +func (r *proxyRouter) ListenSshAddress() string { + if r.sshListener != nil { + return r.sshListener.Addr().String() + } + + return "" +} + +func (r *proxyRouter) handleConnection(c net.Conn) { + defer c.Close() + // first try tls + vhostConn, err := vhost.TLS(c) + + if err == nil { + // It is a TLS connection + defer vhostConn.Close() + host := vhostConn.ClientHelloMsg.ServerName + dstHost, err := r.director(host) + if err != nil { + log.Printf("Error directing request: %v\n", err) + return + } + d, err := net.Dial("tcp", dstHost.String()) + if err != nil { + log.Printf("Error dialing backend %s: %v\n", dstHost.String(), err) + return + } + + proxyConn(vhostConn, d) + } else { + // it is not TLS + // treat it as an http connection + + req, err := http.ReadRequest(bufio.NewReader(vhostConn)) + if err != nil { + // It is not http neither. So just close the connection. + return + } + dstHost, err := r.director(req.Host) + if err != nil { + log.Printf("Error directing request: %v\n", err) + return + } + d, err := net.Dial("tcp", dstHost.String()) + if err != nil { + log.Printf("Error dialing backend %s: %v\n", dstHost.String(), err) + return + } + err = req.Write(d) + if err != nil { + log.Printf("Error requesting backend %s: %v\n", dstHost.String(), err) + return + } + proxyConn(c, d) + } +} + +func proxySsh(reqs1, reqs2 <-chan *ssh.Request, channel1, channel2 ssh.Channel) { + var closer sync.Once + closeFunc := func() { + channel1.Close() + channel2.Close() + } + + defer closer.Do(closeFunc) + + closerChan := make(chan bool, 1) + + go func() { + io.Copy(channel1, channel2) + closerChan <- true + }() + + go func() { + io.Copy(channel2, channel1) + closerChan <- true + }() + + for { + select { + case req := <-reqs1: + if req == nil { + return + } + b, err := channel2.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + return + } + req.Reply(b, nil) + + case req := <-reqs2: + if req == nil { + return + } + b, err := channel1.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + return + } + req.Reply(b, nil) + case <-closerChan: + return + } + } +} + +func proxyConn(src, dst net.Conn) { + errc := make(chan error, 2) + cp := func(dst io.Writer, src io.Reader) { + _, err := io.Copy(dst, src) + errc <- err + } + go cp(src, dst) + go cp(dst, src) + <-errc +} + +func NewRouter(director Director, keyPath string) *proxyRouter { + var sshConfig = &ssh.ServerConfig{ + PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + return nil, nil + }, + } + privateBytes, err := ioutil.ReadFile(keyPath) + if err != nil { + log.Fatal("Failed to load private key: ", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key: ", err) + } + + sshConfig.AddHostKey(private) + + return &proxyRouter{director: director, sshConfig: sshConfig} +} diff --git a/router/router_test.go b/router/router_test.go new file mode 100644 index 0000000..12f63d0 --- /dev/null +++ b/router/router_test.go @@ -0,0 +1,523 @@ +package router + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/asn1" + "encoding/pem" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "sync" + "testing" + + "golang.org/x/crypto/ssh" + + "github.com/gorilla/websocket" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func testSshClient(user string, r *proxyRouter) error { + reader := rand.Reader + bitSize := 2048 + + key, err := rsa.GenerateKey(reader, bitSize) + if err != nil { + return err + } + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + return err + } + config := &ssh.ClientConfig{ + User: user, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + chunks := strings.Split(r.ListenSshAddress(), ":") + l := fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]) + client, err := ssh.Dial("tcp", l, config) + if err != nil { + return err + } + session, err := client.NewSession() + if err != nil { + return err + } + defer session.Close() + + return nil +} + +func testSshServer(f func(user, pass, ctype string)) (string, error) { + reader := rand.Reader + bitSize := 2048 + + key, err := rsa.GenerateKey(reader, bitSize) + if err != nil { + return "", err + } + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + return "", err + } + + var receivedUser string + var receivedPass string + var receivedChannelType string + + config := &ssh.ServerConfig{ + PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + receivedUser = conn.User() + receivedPass = string(password) + + return nil, nil + }, + } + config.AddHostKey(signer) + listener, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + return "", err + } + + go func() { + defer listener.Close() + nConn, err := listener.Accept() + if err != nil { + log.Println(err) + return + } + conn, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Println(err) + return + } + go ssh.DiscardRequests(reqs) + defer nConn.Close() + defer conn.Close() + + ch := <-chans + + receivedChannelType = ch.ChannelType() + + f(receivedUser, receivedPass, receivedChannelType) + }() + + return listener.Addr().String(), nil +} + +func generateKeys() (string, string, string, error) { + dir, err := ioutil.TempDir("", "pwd") + if err != nil { + return "", "", "", err + } + + reader := rand.Reader + bitSize := 2048 + + key, err := rsa.GenerateKey(reader, bitSize) + if err != nil { + return "", "", "", err + } + + privateFile, err := os.OpenFile(fmt.Sprintf("%s/id_rsa", dir), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return "", "", "", err + } + defer privateFile.Close() + + var privateKey = &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + } + + err = pem.Encode(privateFile, privateKey) + if err != nil { + return "", "", "", err + } + + publicFile, err := os.OpenFile(fmt.Sprintf("%s/id_rsa.pub", dir), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return "", "", "", err + } + defer publicFile.Close() + + asn1Bytes, err := asn1.Marshal(key.PublicKey) + if err != nil { + return "", "", "", err + } + + var publicKey = &pem.Block{ + Type: "PUBLIC KEY", + Bytes: asn1Bytes, + } + + err = pem.Encode(publicFile, publicKey) + if err != nil { + return "", "", "", err + } + return dir, privateFile.Name(), publicFile.Name(), nil +} + +func getRouterUrl(scheme string, r *proxyRouter) string { + chunks := strings.Split(r.ListenHttpAddress(), ":") + return fmt.Sprintf("%s://localhost:%s", scheme, chunks[len(chunks)-1]) +} + +func routerLookup(protocol, domain string, r *proxyRouter) ([]string, error) { + c := dns.Client{Net: protocol} + m := dns.Msg{} + + m.SetQuestion(fmt.Sprintf("%s.", domain), dns.TypeA) + var l string + if protocol == "udp" { + chunks := strings.Split(r.ListenDnsUdpAddress(), ":") + l = fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]) + } else if protocol == "tcp" { + chunks := strings.Split(r.ListenDnsTcpAddress(), ":") + l = fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]) + } + res, _, err := c.Exchange(&m, l) + + if err != nil { + return nil, err + } + + if len(res.Answer) == 0 { + return nil, fmt.Errorf("Didn't receive any answer") + } + addrs := []string{} + for _, a := range res.Answer { + if b, ok := a.(*dns.A); ok { + addrs = append(addrs, b.A.String()) + } else if b, ok := a.(*dns.AAAA); ok { + addrs = append(addrs, b.AAAA.String()) + } + } + + return addrs, nil +} + +func TestProxy_TLS(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{Transport: tr} + + const msg = "It works!" + + var receivedHost string + + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + })) + defer ts.Close() + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + u, _ := url.Parse(ts.URL) + a, _ := net.ResolveTCPAddr("tcp", u.Host) + return a, nil + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + req, err := http.NewRequest("GET", getRouterUrl("https", r), nil) + assert.Nil(t, err) + + resp, err := client.Do(req) + assert.Nil(t, err) + + body, err := ioutil.ReadAll(resp.Body) + assert.Nil(t, err) + assert.Equal(t, msg, string(body)) + assert.Equal(t, "localhost", receivedHost) +} + +func TestProxy_Http(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + const msg = "It works!" + + var receivedHost string + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + })) + defer ts.Close() + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + u, _ := url.Parse(ts.URL) + a, _ := net.ResolveTCPAddr("tcp", u.Host) + return a, nil + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + req, err := http.NewRequest("GET", getRouterUrl("http", r), nil) + assert.Nil(t, err) + + resp, err := http.DefaultClient.Do(req) + assert.Nil(t, err) + + body, err := ioutil.ReadAll(resp.Body) + assert.Nil(t, err) + assert.Equal(t, msg, string(body)) + + u, _ := url.Parse(getRouterUrl("http", r)) + assert.Equal(t, u.Host, receivedHost) +} + +func TestProxy_WS(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + const msg = "It works!" + + var serverReceivedMessage string + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Print("upgrade:", err) + return + } + defer c.Close() + mt, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + return + } + serverReceivedMessage = string(message) + err = c.WriteMessage(mt, message) + if err != nil { + log.Println("write:", err) + return + } + })) + defer ts.Close() + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + u, _ := url.Parse(ts.URL) + a, _ := net.ResolveTCPAddr("tcp", u.Host) + return a, nil + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + c, _, err := websocket.DefaultDialer.Dial(getRouterUrl("ws", r), nil) + assert.Nil(t, err) + defer c.Close() + + var clientReceivedMessage []byte + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + _, clientReceivedMessage, _ = c.ReadMessage() + wg.Done() + }() + err = c.WriteMessage(websocket.TextMessage, []byte(msg)) + assert.Nil(t, err) + + wg.Wait() + + assert.Equal(t, msg, string(clientReceivedMessage)) + assert.Equal(t, msg, serverReceivedMessage) +} + +func TestProxy_WSS(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + const msg = "It works!" + + var serverReceivedMessage string + + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Print("upgrade:", err) + return + } + defer c.Close() + mt, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + return + } + serverReceivedMessage = string(message) + err = c.WriteMessage(mt, message) + if err != nil { + log.Println("write:", err) + return + } + })) + defer ts.Close() + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + u, _ := url.Parse(ts.URL) + a, _ := net.ResolveTCPAddr("tcp", u.Host) + return a, nil + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + c, _, err := d.Dial(getRouterUrl("wss", r), nil) + assert.Nil(t, err) + defer c.Close() + + var clientReceivedMessage []byte + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + _, clientReceivedMessage, _ = c.ReadMessage() + wg.Done() + }() + + err = c.WriteMessage(websocket.TextMessage, []byte(msg)) + assert.Nil(t, err) + + wg.Wait() + + assert.Equal(t, msg, string(clientReceivedMessage)) + assert.Equal(t, msg, serverReceivedMessage) +} + +func TestProxy_DNS_UDP(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + var receivedHost string + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + if host == "10_0_0_1.foo.bar" { + a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0") + return a, nil + } else { + return nil, fmt.Errorf("Not recognized") + } + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + ips, err := routerLookup("udp", "10_0_0_1.foo.bar", r) + assert.Nil(t, err) + assert.Equal(t, "10_0_0_1.foo.bar", receivedHost) + assert.Equal(t, []string{"10.0.0.1"}, ips) + + ips, err = routerLookup("udp", "www.google.com", r) + assert.Nil(t, err) + assert.Equal(t, "www.google.com", receivedHost) + + expectedIps, err := net.LookupHost("www.google.com") + assert.Nil(t, err) + assert.Equal(t, expectedIps, ips) + + ips, err = routerLookup("udp", "localhost", r) + assert.Nil(t, err) + assert.NotEqual(t, "localhost", receivedHost) + assert.Equal(t, []string{"127.0.0.1"}, ips) +} + +func TestProxy_DNS_TCP(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + var receivedHost string + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + if host == "10_0_0_1.foo.bar" { + a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0") + return a, nil + } else { + return nil, fmt.Errorf("Not recognized") + } + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + ips, err := routerLookup("tcp", "10_0_0_1.foo.bar", r) + assert.Nil(t, err) + assert.Equal(t, "10_0_0_1.foo.bar", receivedHost) + assert.Equal(t, []string{"10.0.0.1"}, ips) + + ips, err = routerLookup("tcp", "www.google.com", r) + assert.Nil(t, err) + assert.Equal(t, "www.google.com", receivedHost) + + expectedIps, err := net.LookupHost("www.google.com") + assert.Nil(t, err) + assert.Equal(t, expectedIps, ips) + + ips, err = routerLookup("tcp", "localhost", r) + assert.Nil(t, err) + assert.NotEqual(t, "localhost", receivedHost) + assert.Equal(t, []string{"127.0.0.1"}, ips) +} + +func TestProxy_SSH(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + var receivedUser string + var receivedPass string + var receivedChannelType string + var receivedHost string + + wg := sync.WaitGroup{} + wg.Add(1) + laddr, err := testSshServer(func(user, pass, ctype string) { + receivedUser = user + receivedPass = pass + receivedChannelType = ctype + wg.Done() + }) + assert.Nil(t, err) + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + if host == "10-0-0-1-aaaabbbb" { + chunks := strings.Split(laddr, ":") + a, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1])) + return a, nil + } else { + return nil, fmt.Errorf("Not recognized") + } + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + err = testSshClient("10-0-0-1-aaaabbbb", r) + assert.Nil(t, err) + + wg.Wait() + + assert.Equal(t, "root", receivedUser) + assert.Equal(t, "root", receivedPass) + assert.Equal(t, "session", receivedChannelType) + assert.Equal(t, "10-0-0-1-aaaabbbb", receivedHost) +}