diff --git a/api.go b/api.go index eff0941..fdcb8bc 100644 --- a/api.go +++ b/api.go @@ -26,12 +26,6 @@ func main() { r := mux.NewRouter() corsRouter := mux.NewRouter() - // Reverse proxy (needs to be the first route, to make sure it is the first thing we check) - //proxyHandler := handlers.NewMultipleHostReverseProxy() - //websocketProxyHandler := handlers.NewMultipleHostWebsocketReverseProxy() - - tcpHandler := handlers.NewTCPProxy() - corsHandler := gh.CORS(gh.AllowCredentials(), gh.AllowedHeaders([]string{"x-requested-with", "content-type"}), gh.AllowedOrigins([]string{"*"})) // Specific routes @@ -84,9 +78,6 @@ func main() { go handlers.ListenSSHProxy("0.0.0.0:1022") - // Now listen for TLS connections that need to be proxied - handlers.StartTLSProxy(config.SSLPortNumber) - log.Println("Listening on port " + config.PortNumber) log.Fatal(httpServer.ListenAndServe()) } diff --git a/router/dns/dns.go b/router/dns/dns.go deleted file mode 100644 index 2160e92..0000000 --- a/router/dns/dns.go +++ /dev/null @@ -1,111 +0,0 @@ -package dns - -import ( - "fmt" - "log" - "net" - "strings" - - "github.com/miekg/dns" - "github.com/play-with-docker/play-with-docker/config" -) - -func DnsRequest(w dns.ResponseWriter, r *dns.Msg) { - if len(r.Question) > 0 && config.NameFilter.MatchString(r.Question[0].Name) { - // this is something we know about and we should try to handle - question := r.Question[0].Name - - match := config.NameFilter.FindStringSubmatch(question) - - ip := strings.Replace(match[1], "-", ".", -1) - - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - m.RecursionAvailable = true - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, ip)) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - w.WriteMsg(m) - return - } else if len(r.Question) > 0 && config.AliasFilter.MatchString(r.Question[0].Name) { - // this is something we know about and we should try to handle - question := r.Question[0].Name - - match := config.AliasFilter.FindStringSubmatch(question) - - i := core.InstanceFindByAlias(match[2], match[1]) - - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - m.RecursionAvailable = true - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, i.IP)) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - w.WriteMsg(m) - return - } else { - if len(r.Question) > 0 { - question := r.Question[0].Name - - if question == "localhost." { - log.Printf("Not a PWD host. Asked for [localhost.] returning automatically [127.0.0.1]\n") - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - m.RecursionAvailable = true - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A 127.0.0.1", question)) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - w.WriteMsg(m) - return - } - - log.Printf("Not a PWD host. Looking up [%s]\n", question) - ips, err := net.LookupIP(question) - if err != nil { - // we have no information about this and we are not a recursive dns server, so we just fail so the client can fallback to the next dns server it has configured - w.Close() - // dns.HandleFailed(w, r) - return - } - log.Printf("Not a PWD host. Looking up [%s] got [%s]\n", question, ips) - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - m.RecursionAvailable = true - for _, ip := range ips { - ipv4 := ip.To4() - if ipv4 == nil { - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN AAAA %s", question, ip.String())) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - } else { - a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, ipv4.String())) - if err != nil { - log.Fatal(err) - } - m.Answer = append(m.Answer, a) - } - } - w.WriteMsg(m) - return - - } else { - log.Printf("Not a PWD host. Got DNS without any question\n") - // we have no information about this and we are not a recursive dns server, so we just fail so the client can fallback to the next dns server it has configured - w.Close() - // dns.HandleFailed(w, r) - return - } - } -} diff --git a/router/router.go b/router/router.go index b62895a..b46a986 100644 --- a/router/router.go +++ b/router/router.go @@ -2,13 +2,19 @@ package router import ( "bufio" + "fmt" "io" + "io/ioutil" "log" "net" "net/http" + "strings" "sync" + "golang.org/x/crypto/ssh" + vhost "github.com/inconshreveable/go-vhost" + "github.com/miekg/dns" ) type Director func(host string) (*net.TCPAddr, error) @@ -16,45 +22,279 @@ type Director func(host string) (*net.TCPAddr, error) type proxyRouter struct { sync.Mutex - director Director - listener net.Listener - closed bool + keyPath string + director Director + closed bool + httpListener net.Listener + udpDnsServer *dns.Server + tcpDnsServer *dns.Server + sshListener net.Listener + sshConfig *ssh.ServerConfig } -func (r *proxyRouter) Listen(laddr string) { - l, err := net.Listen("tcp", laddr) +func (r *proxyRouter) Listen(httpAddr, dnsAddr, sshAddr string) { + l, err := net.Listen("tcp", httpAddr) if err != nil { log.Fatal(err) } - r.listener = l + r.httpListener = l go func() { for !r.closed { - conn, err := r.listener.Accept() + conn, err := r.httpListener.Accept() if err != nil { continue } go r.handleConnection(conn) } }() + + dnsMux := dns.NewServeMux() + dnsMux.HandleFunc(".", r.dnsRequest) + r.udpDnsServer = &dns.Server{Addr: dnsAddr, Net: "udp", Handler: dnsMux} + r.tcpDnsServer = &dns.Server{Addr: dnsAddr, Net: "tcp", Handler: dnsMux} + + wg := sync.WaitGroup{} + wg.Add(2) + + r.udpDnsServer.NotifyStartedFunc = func() { + wg.Done() + } + r.tcpDnsServer.NotifyStartedFunc = func() { + wg.Done() + } + go r.udpDnsServer.ListenAndServe() + go r.tcpDnsServer.ListenAndServe() + wg.Wait() + + lssh, err := net.Listen("tcp", sshAddr) + if err != nil { + log.Fatal("failed to listen for connection: ", err) + } + r.sshListener = lssh + go func() { + for { + nConn, err := lssh.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection: ", err) + } + + go r.sshHandle(nConn) + } + }() +} + +func (r *proxyRouter) sshHandle(nConn net.Conn) { + sshCon, chans, reqs, err := ssh.NewServerConn(nConn, r.sshConfig) + if err != nil { + nConn.Close() + return + } + + dstHost, err := r.director(sshCon.User()) + if err != nil { + nConn.Close() + return + } + + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) + + newChannel := <-chans + if newChannel == nil { + sshCon.Close() + return + } + + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + return + } + + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatalf("Could not accept channel: %v", err) + } + + stderr := channel.Stderr() + + fmt.Fprintf(stderr, "Connecting to %s\r\n", dstHost.String()) + + clientConfig := &ssh.ClientConfig{ + User: "root", + Auth: []ssh.AuthMethod{ + ssh.Password("root"), + }, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + + client, err := ssh.Dial("tcp", dstHost.String(), clientConfig) + if err != nil { + fmt.Fprintf(stderr, "Connect failed: %v\r\n", err) + channel.Close() + return + } + + go func() { + for newChannel = range chans { + if newChannel == nil { + return + } + + channel2, reqs2, err := client.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData()) + if err != nil { + x, ok := err.(*ssh.OpenChannelError) + if ok { + newChannel.Reject(x.Reason, x.Message) + } else { + newChannel.Reject(ssh.Prohibited, "remote server denied channel request") + } + continue + } + + channel, reqs, err := newChannel.Accept() + if err != nil { + channel2.Close() + continue + } + go proxySsh(reqs, reqs2, channel, channel2) + } + }() + + // Forward the session channel + channel2, reqs2, err := client.OpenChannel("session", []byte{}) + if err != nil { + fmt.Fprintf(stderr, "Remote session setup failed: %v\r\n", err) + channel.Close() + return + } + + maskedReqs := make(chan *ssh.Request, 1) + go func() { + for req := range requests { + if req.Type == "auth-agent-req@openssh.com" { + continue + } + maskedReqs <- req + } + }() + proxySsh(maskedReqs, reqs2, channel, channel2) +} + +func (r *proxyRouter) dnsRequest(w dns.ResponseWriter, req *dns.Msg) { + if len(req.Question) > 0 { + question := req.Question[0].Name + + if question == "localhost." { + log.Printf("Asked for [localhost.] returning automatically [127.0.0.1]\n") + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + m.RecursionAvailable = true + a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A 127.0.0.1", question)) + if err != nil { + log.Fatal(err) + } + m.Answer = append(m.Answer, a) + w.WriteMsg(m) + return + } + + dstHost, err := r.director(strings.TrimSuffix(question, ".")) + if err != nil { + // Director couldn't resolve it, try to lookup in the system's DNS + ips, err := net.LookupIP(question) + if err != nil { + // we have no information about this and we are not a recursive dns server, so we just fail so the client can fallback to the next dns server it has configured + w.Close() + // dns.HandleFailed(w, r) + return + } + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + m.RecursionAvailable = true + for _, ip := range ips { + ipv4 := ip.To4() + if ipv4 == nil { + a, err := dns.NewRR(fmt.Sprintf("%s 60 IN AAAA %s", question, ip.String())) + if err != nil { + log.Fatal(err) + } + m.Answer = append(m.Answer, a) + } else { + a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, ipv4.String())) + if err != nil { + log.Fatal(err) + } + m.Answer = append(m.Answer, a) + } + } + w.WriteMsg(m) + return + } + + m := new(dns.Msg) + m.SetReply(req) + m.Authoritative = true + m.RecursionAvailable = true + a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, dstHost.IP)) + if err != nil { + log.Println(err) + w.Close() + // dns.HandleFailed(w, r) + return + } + m.Answer = append(m.Answer, a) + w.WriteMsg(m) + return + } } func (r *proxyRouter) Close() { r.Lock() defer r.Unlock() - if r.listener != nil { - r.listener.Close() + if r.httpListener != nil { + r.httpListener.Close() + } + if r.udpDnsServer != nil { + r.udpDnsServer.Shutdown() } r.closed = true } -func (r *proxyRouter) ListenAddress() string { - if r.listener != nil { - return r.listener.Addr().String() +func (r *proxyRouter) ListenHttpAddress() string { + if r.httpListener != nil { + return r.httpListener.Addr().String() } return "" } +func (r *proxyRouter) ListenDnsUdpAddress() string { + if r.udpDnsServer != nil && r.udpDnsServer.PacketConn != nil { + return r.udpDnsServer.PacketConn.LocalAddr().String() + } + + return "" +} +func (r *proxyRouter) ListenDnsTcpAddress() string { + if r.tcpDnsServer != nil && r.tcpDnsServer.Listener != nil { + return r.tcpDnsServer.Listener.Addr().String() + } + + return "" +} + +func (r *proxyRouter) ListenSshAddress() string { + if r.sshListener != nil { + return r.sshListener.Addr().String() + } + + return "" +} + func (r *proxyRouter) handleConnection(c net.Conn) { defer c.Close() // first try tls @@ -75,7 +315,7 @@ func (r *proxyRouter) handleConnection(c net.Conn) { return } - proxy(vhostConn, d) + proxyConn(vhostConn, d) } else { // it is not TLS // treat it as an http connection @@ -100,11 +340,59 @@ func (r *proxyRouter) handleConnection(c net.Conn) { log.Printf("Error requesting backend %s: %v\n", dstHost.String(), err) return } - proxy(c, d) + proxyConn(c, d) } } -func proxy(src, dst net.Conn) { +func proxySsh(reqs1, reqs2 <-chan *ssh.Request, channel1, channel2 ssh.Channel) { + var closer sync.Once + closeFunc := func() { + channel1.Close() + channel2.Close() + } + + defer closer.Do(closeFunc) + + closerChan := make(chan bool, 1) + + go func() { + io.Copy(channel1, channel2) + closerChan <- true + }() + + go func() { + io.Copy(channel2, channel1) + closerChan <- true + }() + + for { + select { + case req := <-reqs1: + if req == nil { + return + } + b, err := channel2.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + return + } + req.Reply(b, nil) + + case req := <-reqs2: + if req == nil { + return + } + b, err := channel1.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + return + } + req.Reply(b, nil) + case <-closerChan: + return + } + } +} + +func proxyConn(src, dst net.Conn) { errc := make(chan error, 2) cp := func(dst io.Writer, src io.Reader) { _, err := io.Copy(dst, src) @@ -115,44 +403,23 @@ func proxy(src, dst net.Conn) { <-errc } -func NewRouter(director Director) *proxyRouter { - return &proxyRouter{director: director} -} - -/* - // Start the DNS server - dns.HandleFunc(".", routerDns.DnsRequest) - udpDnsServer := &dns.Server{Addr: ":53", Net: "udp"} - go func() { - err := udpDnsServer.ListenAndServe() - if err != nil { - log.Fatal(err) - } - }() - tcpDnsServer := &dns.Server{Addr: ":53", Net: "tcp"} - go func() { - err := tcpDnsServer.ListenAndServe() - if err != nil { - log.Fatal(err) - } - }() - r := mux.NewRouter() - tcpHandler := handlers.NewTCPProxy() - r.Host(fmt.Sprintf("{subdomain:.*}pwd{node:%s}-{port:%s}.{tld:.*}", config.PWDHostnameRegex, config.PortRegex)).Handler(tcpHandler) - r.Host(fmt.Sprintf("{subdomain:.*}pwd{node:%s}.{tld:.*}", config.PWDHostnameRegex)).Handler(tcpHandler) - r.Host(fmt.Sprintf("pwd{alias:%s}-{session:%s}-{port:%s}.{tld:.*}", config.AliasnameRegex, config.AliasSessionRegex, config.PortRegex)).Handler(tcpHandler) - r.Host(fmt.Sprintf("pwd{alias:%s}-{session:%s}.{tld:.*}", config.AliasnameRegex, config.AliasSessionRegex)).Handler(tcpHandler) - r.HandleFunc("/ping", handlers.Ping).Methods("GET") - n := negroni.Classic() - n.UseHandler(r) - - httpServer := http.Server{ - Addr: "0.0.0.0:" + config.PortNumber, - Handler: n, - IdleTimeout: 30 * time.Second, - ReadHeaderTimeout: 5 * time.Second, +func NewRouter(director Director, keyPath string) *proxyRouter { + var sshConfig = &ssh.ServerConfig{ + PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + return nil, nil + }, } - // Now listen for TLS connections that need to be proxied - tls.StartTLSProxy(config.SSLPortNumber) - http.ListenAndServe() -*/ + privateBytes, err := ioutil.ReadFile(keyPath) + if err != nil { + log.Fatal("Failed to load private key: ", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key: ", err) + } + + sshConfig.AddHostKey(private) + + return &proxyRouter{director: director, sshConfig: sshConfig} +} diff --git a/router/router_test.go b/router/router_test.go index fcdba89..12f63d0 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -1,7 +1,12 @@ package router import ( + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "encoding/asn1" + "encoding/pem" "fmt" "io/ioutil" "log" @@ -9,20 +14,207 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "strings" "sync" "testing" + "golang.org/x/crypto/ssh" + "github.com/gorilla/websocket" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) +func testSshClient(user string, r *proxyRouter) error { + reader := rand.Reader + bitSize := 2048 + + key, err := rsa.GenerateKey(reader, bitSize) + if err != nil { + return err + } + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + return err + } + config := &ssh.ClientConfig{ + User: user, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + chunks := strings.Split(r.ListenSshAddress(), ":") + l := fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]) + client, err := ssh.Dial("tcp", l, config) + if err != nil { + return err + } + session, err := client.NewSession() + if err != nil { + return err + } + defer session.Close() + + return nil +} + +func testSshServer(f func(user, pass, ctype string)) (string, error) { + reader := rand.Reader + bitSize := 2048 + + key, err := rsa.GenerateKey(reader, bitSize) + if err != nil { + return "", err + } + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + return "", err + } + + var receivedUser string + var receivedPass string + var receivedChannelType string + + config := &ssh.ServerConfig{ + PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + receivedUser = conn.User() + receivedPass = string(password) + + return nil, nil + }, + } + config.AddHostKey(signer) + listener, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + return "", err + } + + go func() { + defer listener.Close() + nConn, err := listener.Accept() + if err != nil { + log.Println(err) + return + } + conn, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Println(err) + return + } + go ssh.DiscardRequests(reqs) + defer nConn.Close() + defer conn.Close() + + ch := <-chans + + receivedChannelType = ch.ChannelType() + + f(receivedUser, receivedPass, receivedChannelType) + }() + + return listener.Addr().String(), nil +} + +func generateKeys() (string, string, string, error) { + dir, err := ioutil.TempDir("", "pwd") + if err != nil { + return "", "", "", err + } + + reader := rand.Reader + bitSize := 2048 + + key, err := rsa.GenerateKey(reader, bitSize) + if err != nil { + return "", "", "", err + } + + privateFile, err := os.OpenFile(fmt.Sprintf("%s/id_rsa", dir), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return "", "", "", err + } + defer privateFile.Close() + + var privateKey = &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + } + + err = pem.Encode(privateFile, privateKey) + if err != nil { + return "", "", "", err + } + + publicFile, err := os.OpenFile(fmt.Sprintf("%s/id_rsa.pub", dir), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return "", "", "", err + } + defer publicFile.Close() + + asn1Bytes, err := asn1.Marshal(key.PublicKey) + if err != nil { + return "", "", "", err + } + + var publicKey = &pem.Block{ + Type: "PUBLIC KEY", + Bytes: asn1Bytes, + } + + err = pem.Encode(publicFile, publicKey) + if err != nil { + return "", "", "", err + } + return dir, privateFile.Name(), publicFile.Name(), nil +} + func getRouterUrl(scheme string, r *proxyRouter) string { - chunks := strings.Split(r.ListenAddress(), ":") + chunks := strings.Split(r.ListenHttpAddress(), ":") return fmt.Sprintf("%s://localhost:%s", scheme, chunks[len(chunks)-1]) } +func routerLookup(protocol, domain string, r *proxyRouter) ([]string, error) { + c := dns.Client{Net: protocol} + m := dns.Msg{} + + m.SetQuestion(fmt.Sprintf("%s.", domain), dns.TypeA) + var l string + if protocol == "udp" { + chunks := strings.Split(r.ListenDnsUdpAddress(), ":") + l = fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]) + } else if protocol == "tcp" { + chunks := strings.Split(r.ListenDnsTcpAddress(), ":") + l = fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]) + } + res, _, err := c.Exchange(&m, l) + + if err != nil { + return nil, err + } + + if len(res.Answer) == 0 { + return nil, fmt.Errorf("Didn't receive any answer") + } + addrs := []string{} + for _, a := range res.Answer { + if b, ok := a.(*dns.A); ok { + addrs = append(addrs, b.A.String()) + } else if b, ok := a.(*dns.AAAA); ok { + addrs = append(addrs, b.AAAA.String()) + } + } + + return addrs, nil +} + func TestProxy_TLS(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } @@ -42,8 +234,8 @@ func TestProxy_TLS(t *testing.T) { u, _ := url.Parse(ts.URL) a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil - }) - r.Listen(":0") + }, private) + r.Listen(":0", ":0", ":0") defer r.Close() req, err := http.NewRequest("GET", getRouterUrl("https", r), nil) @@ -59,6 +251,9 @@ func TestProxy_TLS(t *testing.T) { } func TestProxy_Http(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + const msg = "It works!" var receivedHost string @@ -73,8 +268,8 @@ func TestProxy_Http(t *testing.T) { u, _ := url.Parse(ts.URL) a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil - }) - r.Listen(":0") + }, private) + r.Listen(":0", ":0", ":0") defer r.Close() req, err := http.NewRequest("GET", getRouterUrl("http", r), nil) @@ -92,6 +287,9 @@ func TestProxy_Http(t *testing.T) { } func TestProxy_WS(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + const msg = "It works!" var serverReceivedMessage string @@ -122,8 +320,8 @@ func TestProxy_WS(t *testing.T) { u, _ := url.Parse(ts.URL) a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil - }) - r.Listen(":0") + }, private) + r.Listen(":0", ":0", ":0") defer r.Close() c, _, err := websocket.DefaultDialer.Dial(getRouterUrl("ws", r), nil) @@ -147,6 +345,9 @@ func TestProxy_WS(t *testing.T) { } func TestProxy_WSS(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + const msg = "It works!" var serverReceivedMessage string @@ -177,8 +378,8 @@ func TestProxy_WSS(t *testing.T) { u, _ := url.Parse(ts.URL) a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil - }) - r.Listen(":0") + }, private) + r.Listen(":0", ":0", ":0") defer r.Close() d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} @@ -203,3 +404,120 @@ func TestProxy_WSS(t *testing.T) { assert.Equal(t, msg, string(clientReceivedMessage)) assert.Equal(t, msg, serverReceivedMessage) } + +func TestProxy_DNS_UDP(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + var receivedHost string + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + if host == "10_0_0_1.foo.bar" { + a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0") + return a, nil + } else { + return nil, fmt.Errorf("Not recognized") + } + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + ips, err := routerLookup("udp", "10_0_0_1.foo.bar", r) + assert.Nil(t, err) + assert.Equal(t, "10_0_0_1.foo.bar", receivedHost) + assert.Equal(t, []string{"10.0.0.1"}, ips) + + ips, err = routerLookup("udp", "www.google.com", r) + assert.Nil(t, err) + assert.Equal(t, "www.google.com", receivedHost) + + expectedIps, err := net.LookupHost("www.google.com") + assert.Nil(t, err) + assert.Equal(t, expectedIps, ips) + + ips, err = routerLookup("udp", "localhost", r) + assert.Nil(t, err) + assert.NotEqual(t, "localhost", receivedHost) + assert.Equal(t, []string{"127.0.0.1"}, ips) +} + +func TestProxy_DNS_TCP(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + var receivedHost string + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + if host == "10_0_0_1.foo.bar" { + a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0") + return a, nil + } else { + return nil, fmt.Errorf("Not recognized") + } + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + ips, err := routerLookup("tcp", "10_0_0_1.foo.bar", r) + assert.Nil(t, err) + assert.Equal(t, "10_0_0_1.foo.bar", receivedHost) + assert.Equal(t, []string{"10.0.0.1"}, ips) + + ips, err = routerLookup("tcp", "www.google.com", r) + assert.Nil(t, err) + assert.Equal(t, "www.google.com", receivedHost) + + expectedIps, err := net.LookupHost("www.google.com") + assert.Nil(t, err) + assert.Equal(t, expectedIps, ips) + + ips, err = routerLookup("tcp", "localhost", r) + assert.Nil(t, err) + assert.NotEqual(t, "localhost", receivedHost) + assert.Equal(t, []string{"127.0.0.1"}, ips) +} + +func TestProxy_SSH(t *testing.T) { + dir, private, _, _ := generateKeys() + defer os.RemoveAll(dir) + + var receivedUser string + var receivedPass string + var receivedChannelType string + var receivedHost string + + wg := sync.WaitGroup{} + wg.Add(1) + laddr, err := testSshServer(func(user, pass, ctype string) { + receivedUser = user + receivedPass = pass + receivedChannelType = ctype + wg.Done() + }) + assert.Nil(t, err) + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + if host == "10-0-0-1-aaaabbbb" { + chunks := strings.Split(laddr, ":") + a, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1])) + return a, nil + } else { + return nil, fmt.Errorf("Not recognized") + } + }, private) + r.Listen(":0", ":0", ":0") + defer r.Close() + + err = testSshClient("10-0-0-1-aaaabbbb", r) + assert.Nil(t, err) + + wg.Wait() + + assert.Equal(t, "root", receivedUser) + assert.Equal(t, "root", receivedPass) + assert.Equal(t, "session", receivedChannelType) + assert.Equal(t, "10-0-0-1-aaaabbbb", receivedHost) +} diff --git a/router/tcp/tcp.go b/router/tcp/tcp.go deleted file mode 100644 index 80a1cca..0000000 --- a/router/tcp/tcp.go +++ /dev/null @@ -1,143 +0,0 @@ -package tcp - -import ( - "crypto/tls" - "fmt" - "io" - "log" - "net" - "net/http" - "strings" - - "github.com/franela/pwd.old/config" - "github.com/gorilla/mux" -) - -func getTargetInfo(vars map[string]string, req *http.Request) (string, string) { - node := vars["node"] - port := vars["port"] - alias := vars["alias"] - sessionPrefix := vars["session"] - hostPort := strings.Split(req.Host, ":") - - // give priority to the URL host port - if len(hostPort) > 1 && hostPort[1] != config.PortNumber { - port = hostPort[1] - } else if port == "" { - port = "80" - } - - if alias != "" { - instance := core.InstanceFindByAlias(sessionPrefix, alias) - if instance != nil { - node = instance.IP - return node, port - } - } - - // Node is actually an ip, need to convert underscores by dots. - ip := strings.Replace(node, "-", ".", -1) - - if net.ParseIP(ip) == nil { - // Not a valid IP, so treat this is a hostname. - } else { - node = ip - } - - return node, port - -} - -type tcpProxy struct { - Director func(*http.Request) - ErrorLog *log.Logger - Dial func(network, addr string) (net.Conn, error) -} - -func (p *tcpProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { - logFunc := log.Printf - if p.ErrorLog != nil { - logFunc = p.ErrorLog.Printf - } - - vars := mux.Vars(r) - instanceIP := vars["node"] - - if i := core.InstanceFindByIP(strings.Replace(instanceIP, "-", ".", -1)); i == nil { - w.WriteHeader(http.StatusServiceUnavailable) - return - } - - outreq := new(http.Request) - // shallow copying - *outreq = *r - p.Director(outreq) - host := outreq.URL.Host - - dial := p.Dial - if dial == nil { - dial = net.Dial - } - - if outreq.URL.Scheme == "wss" || outreq.URL.Scheme == "https" { - var tlsConfig *tls.Config - tlsConfig = &tls.Config{InsecureSkipVerify: true} - dial = func(network, address string) (net.Conn, error) { - return tls.Dial("tcp", host, tlsConfig) - } - } - - d, err := dial("tcp", host) - if err != nil { - http.Error(w, "Error forwarding request.", 500) - logFunc("Error dialing websocket backend %s: %v", outreq.URL, err) - return - } - // All request generated by the http package implement this interface. - hj, ok := w.(http.Hijacker) - if !ok { - http.Error(w, "Not a hijacker?", 500) - return - } - // Hijack() tells the http package not to do anything else with the connection. - // After, it bcomes this functions job to manage it. `nc` is of type *net.Conn. - nc, _, err := hj.Hijack() - if err != nil { - logFunc("Hijack error: %v", err) - return - } - defer nc.Close() // must close the underlying net connection after hijacking - defer d.Close() - - // write the modified incoming request to the dialed connection - err = outreq.Write(d) - if err != nil { - logFunc("Error copying request to target: %v", err) - return - } - errc := make(chan error, 2) - cp := func(dst io.Writer, src io.Reader) { - _, err := io.Copy(dst, src) - errc <- err - } - go cp(d, nc) - go cp(nc, d) - <-errc -} -func NewTCPProxy() http.Handler { - director := func(req *http.Request) { - v := mux.Vars(req) - - node, port := getTargetInfo(v, req) - - if port == "443" { - if strings.Contains(req.URL.Scheme, "http") { - req.URL.Scheme = "https" - } else { - req.URL.Scheme = "wss" - } - } - req.URL.Host = fmt.Sprintf("%s:%s", node, port) - } - return &tcpProxy{Director: director} -} diff --git a/router/tls/tls.go b/router/tls/tls.go deleted file mode 100644 index a3c2ada..0000000 --- a/router/tls/tls.go +++ /dev/null @@ -1,95 +0,0 @@ -package tls - -import ( - "fmt" - "io" - "log" - "net" - "strings" - - vhost "github.com/inconshreveable/go-vhost" - "github.com/play-with-docker/play-with-docker/config" -) - -func StartTLSProxy(port string) { - - tlsListener, tlsErr := net.Listen("tcp", fmt.Sprintf(":%s", port)) - log.Println("Listening on port " + port) - if tlsErr != nil { - log.Fatal(tlsErr) - } - defer tlsListener.Close() - for { - // Wait for TLS Connection - conn, err := tlsListener.Accept() - if err != nil { - log.Printf("Could not accept new TLS connection. Error: %s", err) - continue - } - // Handle connection on a new goroutine and continue accepting other new connections - go func(c net.Conn) { - defer c.Close() - vhostConn, err := vhost.TLS(conn) - if err != nil { - log.Printf("Incoming TLS connection produced an error. Error: %s", err) - return - } - defer vhostConn.Close() - - var targetIP string - targetPort := "443" - - host := vhostConn.ClientHelloMsg.ServerName - match := config.NameFilter.FindStringSubmatch(host) - if len(match) < 2 { - // Not a valid proxy host, try alias hosts - match := config.AliasFilter.FindStringSubmatch(host) - if len(match) < 4 { - // Not valid, just close the connection - return - } else { - alias := match[1] - sessionPrefix := match[2] - instance := core.InstanceFindByAlias(sessionPrefix, alias) - if instance != nil { - targetIP = instance.IP - } else { - return - } - if len(match) == 4 { - targetPort = match[3] - } - } - } else { - // Valid proxy host - ip := strings.Replace(match[1], "-", ".", -1) - if net.ParseIP(ip) == nil { - // Not a valid IP, so treat this is a hostname. - return - } else { - targetIP = ip - } - if len(match) == 3 { - targetPort = match[2] - } - } - - dest := fmt.Sprintf("%s:%s", targetIP, targetPort) - d, err := net.Dial("tcp", dest) - if err != nil { - log.Printf("Error dialing backend %s: %v\n", dest, err) - return - } - - errc := make(chan error, 2) - cp := func(dst io.Writer, src io.Reader) { - _, err := io.Copy(dst, src) - errc <- err - } - go cp(d, vhostConn) - go cp(vhostConn, d) - <-errc - }(conn) - } - -}