WIP
This commit is contained in:
@@ -1,111 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/play-with-docker/play-with-docker/config"
|
||||
)
|
||||
|
||||
func DnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) > 0 && config.NameFilter.MatchString(r.Question[0].Name) {
|
||||
// this is something we know about and we should try to handle
|
||||
question := r.Question[0].Name
|
||||
|
||||
match := config.NameFilter.FindStringSubmatch(question)
|
||||
|
||||
ip := strings.Replace(match[1], "-", ".", -1)
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Authoritative = true
|
||||
m.RecursionAvailable = true
|
||||
a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, ip))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m.Answer = append(m.Answer, a)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
} else if len(r.Question) > 0 && config.AliasFilter.MatchString(r.Question[0].Name) {
|
||||
// this is something we know about and we should try to handle
|
||||
question := r.Question[0].Name
|
||||
|
||||
match := config.AliasFilter.FindStringSubmatch(question)
|
||||
|
||||
i := core.InstanceFindByAlias(match[2], match[1])
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Authoritative = true
|
||||
m.RecursionAvailable = true
|
||||
a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, i.IP))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m.Answer = append(m.Answer, a)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
} else {
|
||||
if len(r.Question) > 0 {
|
||||
question := r.Question[0].Name
|
||||
|
||||
if question == "localhost." {
|
||||
log.Printf("Not a PWD host. Asked for [localhost.] returning automatically [127.0.0.1]\n")
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Authoritative = true
|
||||
m.RecursionAvailable = true
|
||||
a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A 127.0.0.1", question))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m.Answer = append(m.Answer, a)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Not a PWD host. Looking up [%s]\n", question)
|
||||
ips, err := net.LookupIP(question)
|
||||
if err != nil {
|
||||
// we have no information about this and we are not a recursive dns server, so we just fail so the client can fallback to the next dns server it has configured
|
||||
w.Close()
|
||||
// dns.HandleFailed(w, r)
|
||||
return
|
||||
}
|
||||
log.Printf("Not a PWD host. Looking up [%s] got [%s]\n", question, ips)
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Authoritative = true
|
||||
m.RecursionAvailable = true
|
||||
for _, ip := range ips {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
a, err := dns.NewRR(fmt.Sprintf("%s 60 IN AAAA %s", question, ip.String()))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m.Answer = append(m.Answer, a)
|
||||
} else {
|
||||
a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, ipv4.String()))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m.Answer = append(m.Answer, a)
|
||||
}
|
||||
}
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
|
||||
} else {
|
||||
log.Printf("Not a PWD host. Got DNS without any question\n")
|
||||
// we have no information about this and we are not a recursive dns server, so we just fail so the client can fallback to the next dns server it has configured
|
||||
w.Close()
|
||||
// dns.HandleFailed(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
377
router/router.go
377
router/router.go
@@ -2,13 +2,19 @@ package router
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
vhost "github.com/inconshreveable/go-vhost"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type Director func(host string) (*net.TCPAddr, error)
|
||||
@@ -16,45 +22,279 @@ type Director func(host string) (*net.TCPAddr, error)
|
||||
type proxyRouter struct {
|
||||
sync.Mutex
|
||||
|
||||
director Director
|
||||
listener net.Listener
|
||||
closed bool
|
||||
keyPath string
|
||||
director Director
|
||||
closed bool
|
||||
httpListener net.Listener
|
||||
udpDnsServer *dns.Server
|
||||
tcpDnsServer *dns.Server
|
||||
sshListener net.Listener
|
||||
sshConfig *ssh.ServerConfig
|
||||
}
|
||||
|
||||
func (r *proxyRouter) Listen(laddr string) {
|
||||
l, err := net.Listen("tcp", laddr)
|
||||
func (r *proxyRouter) Listen(httpAddr, dnsAddr, sshAddr string) {
|
||||
l, err := net.Listen("tcp", httpAddr)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
r.listener = l
|
||||
r.httpListener = l
|
||||
go func() {
|
||||
for !r.closed {
|
||||
conn, err := r.listener.Accept()
|
||||
conn, err := r.httpListener.Accept()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
go r.handleConnection(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
dnsMux := dns.NewServeMux()
|
||||
dnsMux.HandleFunc(".", r.dnsRequest)
|
||||
r.udpDnsServer = &dns.Server{Addr: dnsAddr, Net: "udp", Handler: dnsMux}
|
||||
r.tcpDnsServer = &dns.Server{Addr: dnsAddr, Net: "tcp", Handler: dnsMux}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
r.udpDnsServer.NotifyStartedFunc = func() {
|
||||
wg.Done()
|
||||
}
|
||||
r.tcpDnsServer.NotifyStartedFunc = func() {
|
||||
wg.Done()
|
||||
}
|
||||
go r.udpDnsServer.ListenAndServe()
|
||||
go r.tcpDnsServer.ListenAndServe()
|
||||
wg.Wait()
|
||||
|
||||
lssh, err := net.Listen("tcp", sshAddr)
|
||||
if err != nil {
|
||||
log.Fatal("failed to listen for connection: ", err)
|
||||
}
|
||||
r.sshListener = lssh
|
||||
go func() {
|
||||
for {
|
||||
nConn, err := lssh.Accept()
|
||||
if err != nil {
|
||||
log.Fatal("failed to accept incoming connection: ", err)
|
||||
}
|
||||
|
||||
go r.sshHandle(nConn)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *proxyRouter) sshHandle(nConn net.Conn) {
|
||||
sshCon, chans, reqs, err := ssh.NewServerConn(nConn, r.sshConfig)
|
||||
if err != nil {
|
||||
nConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
dstHost, err := r.director(sshCon.User())
|
||||
if err != nil {
|
||||
nConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// The incoming Request channel must be serviced.
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
newChannel := <-chans
|
||||
if newChannel == nil {
|
||||
sshCon.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if newChannel.ChannelType() != "session" {
|
||||
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
|
||||
return
|
||||
}
|
||||
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
log.Fatalf("Could not accept channel: %v", err)
|
||||
}
|
||||
|
||||
stderr := channel.Stderr()
|
||||
|
||||
fmt.Fprintf(stderr, "Connecting to %s\r\n", dstHost.String())
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: "root",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password("root"),
|
||||
},
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", dstHost.String(), clientConfig)
|
||||
if err != nil {
|
||||
fmt.Fprintf(stderr, "Connect failed: %v\r\n", err)
|
||||
channel.Close()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
for newChannel = range chans {
|
||||
if newChannel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
channel2, reqs2, err := client.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData())
|
||||
if err != nil {
|
||||
x, ok := err.(*ssh.OpenChannelError)
|
||||
if ok {
|
||||
newChannel.Reject(x.Reason, x.Message)
|
||||
} else {
|
||||
newChannel.Reject(ssh.Prohibited, "remote server denied channel request")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
channel, reqs, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
channel2.Close()
|
||||
continue
|
||||
}
|
||||
go proxySsh(reqs, reqs2, channel, channel2)
|
||||
}
|
||||
}()
|
||||
|
||||
// Forward the session channel
|
||||
channel2, reqs2, err := client.OpenChannel("session", []byte{})
|
||||
if err != nil {
|
||||
fmt.Fprintf(stderr, "Remote session setup failed: %v\r\n", err)
|
||||
channel.Close()
|
||||
return
|
||||
}
|
||||
|
||||
maskedReqs := make(chan *ssh.Request, 1)
|
||||
go func() {
|
||||
for req := range requests {
|
||||
if req.Type == "auth-agent-req@openssh.com" {
|
||||
continue
|
||||
}
|
||||
maskedReqs <- req
|
||||
}
|
||||
}()
|
||||
proxySsh(maskedReqs, reqs2, channel, channel2)
|
||||
}
|
||||
|
||||
func (r *proxyRouter) dnsRequest(w dns.ResponseWriter, req *dns.Msg) {
|
||||
if len(req.Question) > 0 {
|
||||
question := req.Question[0].Name
|
||||
|
||||
if question == "localhost." {
|
||||
log.Printf("Asked for [localhost.] returning automatically [127.0.0.1]\n")
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(req)
|
||||
m.Authoritative = true
|
||||
m.RecursionAvailable = true
|
||||
a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A 127.0.0.1", question))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m.Answer = append(m.Answer, a)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
|
||||
dstHost, err := r.director(strings.TrimSuffix(question, "."))
|
||||
if err != nil {
|
||||
// Director couldn't resolve it, try to lookup in the system's DNS
|
||||
ips, err := net.LookupIP(question)
|
||||
if err != nil {
|
||||
// we have no information about this and we are not a recursive dns server, so we just fail so the client can fallback to the next dns server it has configured
|
||||
w.Close()
|
||||
// dns.HandleFailed(w, r)
|
||||
return
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(req)
|
||||
m.Authoritative = true
|
||||
m.RecursionAvailable = true
|
||||
for _, ip := range ips {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
a, err := dns.NewRR(fmt.Sprintf("%s 60 IN AAAA %s", question, ip.String()))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m.Answer = append(m.Answer, a)
|
||||
} else {
|
||||
a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, ipv4.String()))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
m.Answer = append(m.Answer, a)
|
||||
}
|
||||
}
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(req)
|
||||
m.Authoritative = true
|
||||
m.RecursionAvailable = true
|
||||
a, err := dns.NewRR(fmt.Sprintf("%s 60 IN A %s", question, dstHost.IP))
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
w.Close()
|
||||
// dns.HandleFailed(w, r)
|
||||
return
|
||||
}
|
||||
m.Answer = append(m.Answer, a)
|
||||
w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (r *proxyRouter) Close() {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
|
||||
if r.listener != nil {
|
||||
r.listener.Close()
|
||||
if r.httpListener != nil {
|
||||
r.httpListener.Close()
|
||||
}
|
||||
if r.udpDnsServer != nil {
|
||||
r.udpDnsServer.Shutdown()
|
||||
}
|
||||
r.closed = true
|
||||
}
|
||||
|
||||
func (r *proxyRouter) ListenAddress() string {
|
||||
if r.listener != nil {
|
||||
return r.listener.Addr().String()
|
||||
func (r *proxyRouter) ListenHttpAddress() string {
|
||||
if r.httpListener != nil {
|
||||
return r.httpListener.Addr().String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *proxyRouter) ListenDnsUdpAddress() string {
|
||||
if r.udpDnsServer != nil && r.udpDnsServer.PacketConn != nil {
|
||||
return r.udpDnsServer.PacketConn.LocalAddr().String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
func (r *proxyRouter) ListenDnsTcpAddress() string {
|
||||
if r.tcpDnsServer != nil && r.tcpDnsServer.Listener != nil {
|
||||
return r.tcpDnsServer.Listener.Addr().String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *proxyRouter) ListenSshAddress() string {
|
||||
if r.sshListener != nil {
|
||||
return r.sshListener.Addr().String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *proxyRouter) handleConnection(c net.Conn) {
|
||||
defer c.Close()
|
||||
// first try tls
|
||||
@@ -75,7 +315,7 @@ func (r *proxyRouter) handleConnection(c net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
proxy(vhostConn, d)
|
||||
proxyConn(vhostConn, d)
|
||||
} else {
|
||||
// it is not TLS
|
||||
// treat it as an http connection
|
||||
@@ -100,11 +340,59 @@ func (r *proxyRouter) handleConnection(c net.Conn) {
|
||||
log.Printf("Error requesting backend %s: %v\n", dstHost.String(), err)
|
||||
return
|
||||
}
|
||||
proxy(c, d)
|
||||
proxyConn(c, d)
|
||||
}
|
||||
}
|
||||
|
||||
func proxy(src, dst net.Conn) {
|
||||
func proxySsh(reqs1, reqs2 <-chan *ssh.Request, channel1, channel2 ssh.Channel) {
|
||||
var closer sync.Once
|
||||
closeFunc := func() {
|
||||
channel1.Close()
|
||||
channel2.Close()
|
||||
}
|
||||
|
||||
defer closer.Do(closeFunc)
|
||||
|
||||
closerChan := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
io.Copy(channel1, channel2)
|
||||
closerChan <- true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
io.Copy(channel2, channel1)
|
||||
closerChan <- true
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case req := <-reqs1:
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
b, err := channel2.SendRequest(req.Type, req.WantReply, req.Payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Reply(b, nil)
|
||||
|
||||
case req := <-reqs2:
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
b, err := channel1.SendRequest(req.Type, req.WantReply, req.Payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Reply(b, nil)
|
||||
case <-closerChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func proxyConn(src, dst net.Conn) {
|
||||
errc := make(chan error, 2)
|
||||
cp := func(dst io.Writer, src io.Reader) {
|
||||
_, err := io.Copy(dst, src)
|
||||
@@ -115,44 +403,23 @@ func proxy(src, dst net.Conn) {
|
||||
<-errc
|
||||
}
|
||||
|
||||
func NewRouter(director Director) *proxyRouter {
|
||||
return &proxyRouter{director: director}
|
||||
}
|
||||
|
||||
/*
|
||||
// Start the DNS server
|
||||
dns.HandleFunc(".", routerDns.DnsRequest)
|
||||
udpDnsServer := &dns.Server{Addr: ":53", Net: "udp"}
|
||||
go func() {
|
||||
err := udpDnsServer.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
tcpDnsServer := &dns.Server{Addr: ":53", Net: "tcp"}
|
||||
go func() {
|
||||
err := tcpDnsServer.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}()
|
||||
r := mux.NewRouter()
|
||||
tcpHandler := handlers.NewTCPProxy()
|
||||
r.Host(fmt.Sprintf("{subdomain:.*}pwd{node:%s}-{port:%s}.{tld:.*}", config.PWDHostnameRegex, config.PortRegex)).Handler(tcpHandler)
|
||||
r.Host(fmt.Sprintf("{subdomain:.*}pwd{node:%s}.{tld:.*}", config.PWDHostnameRegex)).Handler(tcpHandler)
|
||||
r.Host(fmt.Sprintf("pwd{alias:%s}-{session:%s}-{port:%s}.{tld:.*}", config.AliasnameRegex, config.AliasSessionRegex, config.PortRegex)).Handler(tcpHandler)
|
||||
r.Host(fmt.Sprintf("pwd{alias:%s}-{session:%s}.{tld:.*}", config.AliasnameRegex, config.AliasSessionRegex)).Handler(tcpHandler)
|
||||
r.HandleFunc("/ping", handlers.Ping).Methods("GET")
|
||||
n := negroni.Classic()
|
||||
n.UseHandler(r)
|
||||
|
||||
httpServer := http.Server{
|
||||
Addr: "0.0.0.0:" + config.PortNumber,
|
||||
Handler: n,
|
||||
IdleTimeout: 30 * time.Second,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
func NewRouter(director Director, keyPath string) *proxyRouter {
|
||||
var sshConfig = &ssh.ServerConfig{
|
||||
PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
// Now listen for TLS connections that need to be proxied
|
||||
tls.StartTLSProxy(config.SSLPortNumber)
|
||||
http.ListenAndServe()
|
||||
*/
|
||||
privateBytes, err := ioutil.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to load private key: ", err)
|
||||
}
|
||||
|
||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to parse private key: ", err)
|
||||
}
|
||||
|
||||
sshConfig.AddHostKey(private)
|
||||
|
||||
return &proxyRouter{director: director, sshConfig: sshConfig}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
@@ -9,20 +14,207 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testSshClient(user string, r *proxyRouter) error {
|
||||
reader := rand.Reader
|
||||
bitSize := 2048
|
||||
|
||||
key, err := rsa.GenerateKey(reader, bitSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
signer, err := ssh.NewSignerFromKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
chunks := strings.Split(r.ListenSshAddress(), ":")
|
||||
l := fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1])
|
||||
client, err := ssh.Dial("tcp", l, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func testSshServer(f func(user, pass, ctype string)) (string, error) {
|
||||
reader := rand.Reader
|
||||
bitSize := 2048
|
||||
|
||||
key, err := rsa.GenerateKey(reader, bitSize)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
signer, err := ssh.NewSignerFromKey(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var receivedUser string
|
||||
var receivedPass string
|
||||
var receivedChannelType string
|
||||
|
||||
config := &ssh.ServerConfig{
|
||||
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
receivedUser = conn.User()
|
||||
receivedPass = string(password)
|
||||
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
config.AddHostKey(signer)
|
||||
listener, err := net.Listen("tcp", "0.0.0.0:0")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer listener.Close()
|
||||
nConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
conn, chans, reqs, err := ssh.NewServerConn(nConn, config)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
defer nConn.Close()
|
||||
defer conn.Close()
|
||||
|
||||
ch := <-chans
|
||||
|
||||
receivedChannelType = ch.ChannelType()
|
||||
|
||||
f(receivedUser, receivedPass, receivedChannelType)
|
||||
}()
|
||||
|
||||
return listener.Addr().String(), nil
|
||||
}
|
||||
|
||||
func generateKeys() (string, string, string, error) {
|
||||
dir, err := ioutil.TempDir("", "pwd")
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
reader := rand.Reader
|
||||
bitSize := 2048
|
||||
|
||||
key, err := rsa.GenerateKey(reader, bitSize)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
privateFile, err := os.OpenFile(fmt.Sprintf("%s/id_rsa", dir), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
defer privateFile.Close()
|
||||
|
||||
var privateKey = &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
}
|
||||
|
||||
err = pem.Encode(privateFile, privateKey)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
publicFile, err := os.OpenFile(fmt.Sprintf("%s/id_rsa.pub", dir), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
defer publicFile.Close()
|
||||
|
||||
asn1Bytes, err := asn1.Marshal(key.PublicKey)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
var publicKey = &pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: asn1Bytes,
|
||||
}
|
||||
|
||||
err = pem.Encode(publicFile, publicKey)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
return dir, privateFile.Name(), publicFile.Name(), nil
|
||||
}
|
||||
|
||||
func getRouterUrl(scheme string, r *proxyRouter) string {
|
||||
chunks := strings.Split(r.ListenAddress(), ":")
|
||||
chunks := strings.Split(r.ListenHttpAddress(), ":")
|
||||
return fmt.Sprintf("%s://localhost:%s", scheme, chunks[len(chunks)-1])
|
||||
}
|
||||
|
||||
func routerLookup(protocol, domain string, r *proxyRouter) ([]string, error) {
|
||||
c := dns.Client{Net: protocol}
|
||||
m := dns.Msg{}
|
||||
|
||||
m.SetQuestion(fmt.Sprintf("%s.", domain), dns.TypeA)
|
||||
var l string
|
||||
if protocol == "udp" {
|
||||
chunks := strings.Split(r.ListenDnsUdpAddress(), ":")
|
||||
l = fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1])
|
||||
} else if protocol == "tcp" {
|
||||
chunks := strings.Split(r.ListenDnsTcpAddress(), ":")
|
||||
l = fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1])
|
||||
}
|
||||
res, _, err := c.Exchange(&m, l)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(res.Answer) == 0 {
|
||||
return nil, fmt.Errorf("Didn't receive any answer")
|
||||
}
|
||||
addrs := []string{}
|
||||
for _, a := range res.Answer {
|
||||
if b, ok := a.(*dns.A); ok {
|
||||
addrs = append(addrs, b.A.String())
|
||||
} else if b, ok := a.(*dns.AAAA); ok {
|
||||
addrs = append(addrs, b.AAAA.String())
|
||||
}
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func TestProxy_TLS(t *testing.T) {
|
||||
dir, private, _, _ := generateKeys()
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
@@ -42,8 +234,8 @@ func TestProxy_TLS(t *testing.T) {
|
||||
u, _ := url.Parse(ts.URL)
|
||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||
return a, nil
|
||||
})
|
||||
r.Listen(":0")
|
||||
}, private)
|
||||
r.Listen(":0", ":0", ":0")
|
||||
defer r.Close()
|
||||
|
||||
req, err := http.NewRequest("GET", getRouterUrl("https", r), nil)
|
||||
@@ -59,6 +251,9 @@ func TestProxy_TLS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxy_Http(t *testing.T) {
|
||||
dir, private, _, _ := generateKeys()
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
const msg = "It works!"
|
||||
|
||||
var receivedHost string
|
||||
@@ -73,8 +268,8 @@ func TestProxy_Http(t *testing.T) {
|
||||
u, _ := url.Parse(ts.URL)
|
||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||
return a, nil
|
||||
})
|
||||
r.Listen(":0")
|
||||
}, private)
|
||||
r.Listen(":0", ":0", ":0")
|
||||
defer r.Close()
|
||||
|
||||
req, err := http.NewRequest("GET", getRouterUrl("http", r), nil)
|
||||
@@ -92,6 +287,9 @@ func TestProxy_Http(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxy_WS(t *testing.T) {
|
||||
dir, private, _, _ := generateKeys()
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
const msg = "It works!"
|
||||
|
||||
var serverReceivedMessage string
|
||||
@@ -122,8 +320,8 @@ func TestProxy_WS(t *testing.T) {
|
||||
u, _ := url.Parse(ts.URL)
|
||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||
return a, nil
|
||||
})
|
||||
r.Listen(":0")
|
||||
}, private)
|
||||
r.Listen(":0", ":0", ":0")
|
||||
defer r.Close()
|
||||
|
||||
c, _, err := websocket.DefaultDialer.Dial(getRouterUrl("ws", r), nil)
|
||||
@@ -147,6 +345,9 @@ func TestProxy_WS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxy_WSS(t *testing.T) {
|
||||
dir, private, _, _ := generateKeys()
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
const msg = "It works!"
|
||||
|
||||
var serverReceivedMessage string
|
||||
@@ -177,8 +378,8 @@ func TestProxy_WSS(t *testing.T) {
|
||||
u, _ := url.Parse(ts.URL)
|
||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||
return a, nil
|
||||
})
|
||||
r.Listen(":0")
|
||||
}, private)
|
||||
r.Listen(":0", ":0", ":0")
|
||||
defer r.Close()
|
||||
|
||||
d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
@@ -203,3 +404,120 @@ func TestProxy_WSS(t *testing.T) {
|
||||
assert.Equal(t, msg, string(clientReceivedMessage))
|
||||
assert.Equal(t, msg, serverReceivedMessage)
|
||||
}
|
||||
|
||||
func TestProxy_DNS_UDP(t *testing.T) {
|
||||
dir, private, _, _ := generateKeys()
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
var receivedHost string
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
if host == "10_0_0_1.foo.bar" {
|
||||
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
|
||||
return a, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("Not recognized")
|
||||
}
|
||||
}, private)
|
||||
r.Listen(":0", ":0", ":0")
|
||||
defer r.Close()
|
||||
|
||||
ips, err := routerLookup("udp", "10_0_0_1.foo.bar", r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10_0_0_1.foo.bar", receivedHost)
|
||||
assert.Equal(t, []string{"10.0.0.1"}, ips)
|
||||
|
||||
ips, err = routerLookup("udp", "www.google.com", r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "www.google.com", receivedHost)
|
||||
|
||||
expectedIps, err := net.LookupHost("www.google.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, expectedIps, ips)
|
||||
|
||||
ips, err = routerLookup("udp", "localhost", r)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, "localhost", receivedHost)
|
||||
assert.Equal(t, []string{"127.0.0.1"}, ips)
|
||||
}
|
||||
|
||||
func TestProxy_DNS_TCP(t *testing.T) {
|
||||
dir, private, _, _ := generateKeys()
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
var receivedHost string
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
if host == "10_0_0_1.foo.bar" {
|
||||
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
|
||||
return a, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("Not recognized")
|
||||
}
|
||||
}, private)
|
||||
r.Listen(":0", ":0", ":0")
|
||||
defer r.Close()
|
||||
|
||||
ips, err := routerLookup("tcp", "10_0_0_1.foo.bar", r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10_0_0_1.foo.bar", receivedHost)
|
||||
assert.Equal(t, []string{"10.0.0.1"}, ips)
|
||||
|
||||
ips, err = routerLookup("tcp", "www.google.com", r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "www.google.com", receivedHost)
|
||||
|
||||
expectedIps, err := net.LookupHost("www.google.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, expectedIps, ips)
|
||||
|
||||
ips, err = routerLookup("tcp", "localhost", r)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, "localhost", receivedHost)
|
||||
assert.Equal(t, []string{"127.0.0.1"}, ips)
|
||||
}
|
||||
|
||||
func TestProxy_SSH(t *testing.T) {
|
||||
dir, private, _, _ := generateKeys()
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
var receivedUser string
|
||||
var receivedPass string
|
||||
var receivedChannelType string
|
||||
var receivedHost string
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
laddr, err := testSshServer(func(user, pass, ctype string) {
|
||||
receivedUser = user
|
||||
receivedPass = pass
|
||||
receivedChannelType = ctype
|
||||
wg.Done()
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
if host == "10-0-0-1-aaaabbbb" {
|
||||
chunks := strings.Split(laddr, ":")
|
||||
a, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]))
|
||||
return a, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("Not recognized")
|
||||
}
|
||||
}, private)
|
||||
r.Listen(":0", ":0", ":0")
|
||||
defer r.Close()
|
||||
|
||||
err = testSshClient("10-0-0-1-aaaabbbb", r)
|
||||
assert.Nil(t, err)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, "root", receivedUser)
|
||||
assert.Equal(t, "root", receivedPass)
|
||||
assert.Equal(t, "session", receivedChannelType)
|
||||
assert.Equal(t, "10-0-0-1-aaaabbbb", receivedHost)
|
||||
}
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/franela/pwd.old/config"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func getTargetInfo(vars map[string]string, req *http.Request) (string, string) {
|
||||
node := vars["node"]
|
||||
port := vars["port"]
|
||||
alias := vars["alias"]
|
||||
sessionPrefix := vars["session"]
|
||||
hostPort := strings.Split(req.Host, ":")
|
||||
|
||||
// give priority to the URL host port
|
||||
if len(hostPort) > 1 && hostPort[1] != config.PortNumber {
|
||||
port = hostPort[1]
|
||||
} else if port == "" {
|
||||
port = "80"
|
||||
}
|
||||
|
||||
if alias != "" {
|
||||
instance := core.InstanceFindByAlias(sessionPrefix, alias)
|
||||
if instance != nil {
|
||||
node = instance.IP
|
||||
return node, port
|
||||
}
|
||||
}
|
||||
|
||||
// Node is actually an ip, need to convert underscores by dots.
|
||||
ip := strings.Replace(node, "-", ".", -1)
|
||||
|
||||
if net.ParseIP(ip) == nil {
|
||||
// Not a valid IP, so treat this is a hostname.
|
||||
} else {
|
||||
node = ip
|
||||
}
|
||||
|
||||
return node, port
|
||||
|
||||
}
|
||||
|
||||
type tcpProxy struct {
|
||||
Director func(*http.Request)
|
||||
ErrorLog *log.Logger
|
||||
Dial func(network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func (p *tcpProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
logFunc := log.Printf
|
||||
if p.ErrorLog != nil {
|
||||
logFunc = p.ErrorLog.Printf
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
instanceIP := vars["node"]
|
||||
|
||||
if i := core.InstanceFindByIP(strings.Replace(instanceIP, "-", ".", -1)); i == nil {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
outreq := new(http.Request)
|
||||
// shallow copying
|
||||
*outreq = *r
|
||||
p.Director(outreq)
|
||||
host := outreq.URL.Host
|
||||
|
||||
dial := p.Dial
|
||||
if dial == nil {
|
||||
dial = net.Dial
|
||||
}
|
||||
|
||||
if outreq.URL.Scheme == "wss" || outreq.URL.Scheme == "https" {
|
||||
var tlsConfig *tls.Config
|
||||
tlsConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
dial = func(network, address string) (net.Conn, error) {
|
||||
return tls.Dial("tcp", host, tlsConfig)
|
||||
}
|
||||
}
|
||||
|
||||
d, err := dial("tcp", host)
|
||||
if err != nil {
|
||||
http.Error(w, "Error forwarding request.", 500)
|
||||
logFunc("Error dialing websocket backend %s: %v", outreq.URL, err)
|
||||
return
|
||||
}
|
||||
// All request generated by the http package implement this interface.
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "Not a hijacker?", 500)
|
||||
return
|
||||
}
|
||||
// Hijack() tells the http package not to do anything else with the connection.
|
||||
// After, it bcomes this functions job to manage it. `nc` is of type *net.Conn.
|
||||
nc, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
logFunc("Hijack error: %v", err)
|
||||
return
|
||||
}
|
||||
defer nc.Close() // must close the underlying net connection after hijacking
|
||||
defer d.Close()
|
||||
|
||||
// write the modified incoming request to the dialed connection
|
||||
err = outreq.Write(d)
|
||||
if err != nil {
|
||||
logFunc("Error copying request to target: %v", err)
|
||||
return
|
||||
}
|
||||
errc := make(chan error, 2)
|
||||
cp := func(dst io.Writer, src io.Reader) {
|
||||
_, err := io.Copy(dst, src)
|
||||
errc <- err
|
||||
}
|
||||
go cp(d, nc)
|
||||
go cp(nc, d)
|
||||
<-errc
|
||||
}
|
||||
func NewTCPProxy() http.Handler {
|
||||
director := func(req *http.Request) {
|
||||
v := mux.Vars(req)
|
||||
|
||||
node, port := getTargetInfo(v, req)
|
||||
|
||||
if port == "443" {
|
||||
if strings.Contains(req.URL.Scheme, "http") {
|
||||
req.URL.Scheme = "https"
|
||||
} else {
|
||||
req.URL.Scheme = "wss"
|
||||
}
|
||||
}
|
||||
req.URL.Host = fmt.Sprintf("%s:%s", node, port)
|
||||
}
|
||||
return &tcpProxy{Director: director}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
vhost "github.com/inconshreveable/go-vhost"
|
||||
"github.com/play-with-docker/play-with-docker/config"
|
||||
)
|
||||
|
||||
func StartTLSProxy(port string) {
|
||||
|
||||
tlsListener, tlsErr := net.Listen("tcp", fmt.Sprintf(":%s", port))
|
||||
log.Println("Listening on port " + port)
|
||||
if tlsErr != nil {
|
||||
log.Fatal(tlsErr)
|
||||
}
|
||||
defer tlsListener.Close()
|
||||
for {
|
||||
// Wait for TLS Connection
|
||||
conn, err := tlsListener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("Could not accept new TLS connection. Error: %s", err)
|
||||
continue
|
||||
}
|
||||
// Handle connection on a new goroutine and continue accepting other new connections
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
vhostConn, err := vhost.TLS(conn)
|
||||
if err != nil {
|
||||
log.Printf("Incoming TLS connection produced an error. Error: %s", err)
|
||||
return
|
||||
}
|
||||
defer vhostConn.Close()
|
||||
|
||||
var targetIP string
|
||||
targetPort := "443"
|
||||
|
||||
host := vhostConn.ClientHelloMsg.ServerName
|
||||
match := config.NameFilter.FindStringSubmatch(host)
|
||||
if len(match) < 2 {
|
||||
// Not a valid proxy host, try alias hosts
|
||||
match := config.AliasFilter.FindStringSubmatch(host)
|
||||
if len(match) < 4 {
|
||||
// Not valid, just close the connection
|
||||
return
|
||||
} else {
|
||||
alias := match[1]
|
||||
sessionPrefix := match[2]
|
||||
instance := core.InstanceFindByAlias(sessionPrefix, alias)
|
||||
if instance != nil {
|
||||
targetIP = instance.IP
|
||||
} else {
|
||||
return
|
||||
}
|
||||
if len(match) == 4 {
|
||||
targetPort = match[3]
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Valid proxy host
|
||||
ip := strings.Replace(match[1], "-", ".", -1)
|
||||
if net.ParseIP(ip) == nil {
|
||||
// Not a valid IP, so treat this is a hostname.
|
||||
return
|
||||
} else {
|
||||
targetIP = ip
|
||||
}
|
||||
if len(match) == 3 {
|
||||
targetPort = match[2]
|
||||
}
|
||||
}
|
||||
|
||||
dest := fmt.Sprintf("%s:%s", targetIP, targetPort)
|
||||
d, err := net.Dial("tcp", dest)
|
||||
if err != nil {
|
||||
log.Printf("Error dialing backend %s: %v\n", dest, err)
|
||||
return
|
||||
}
|
||||
|
||||
errc := make(chan error, 2)
|
||||
cp := func(dst io.Writer, src io.Reader) {
|
||||
_, err := io.Copy(dst, src)
|
||||
errc <- err
|
||||
}
|
||||
go cp(d, vhostConn)
|
||||
go cp(vhostConn, d)
|
||||
<-errc
|
||||
}(conn)
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user