From ffaad303d6c4424a36faa6ea8971fb4a8ace9e12 Mon Sep 17 00:00:00 2001 From: "Jonathan Leibiusky @xetorthio" Date: Wed, 21 Jun 2017 16:52:04 -0300 Subject: [PATCH] WIP --- router/router.go | 113 +++++++++++++++++++---------- router/router_test.go | 162 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 236 insertions(+), 39 deletions(-) diff --git a/router/router.go b/router/router.go index 0bb0d8c..b62895a 100644 --- a/router/router.go +++ b/router/router.go @@ -1,10 +1,12 @@ package router import ( - "fmt" + "bufio" "io" "log" "net" + "net/http" + "sync" vhost "github.com/inconshreveable/go-vhost" ) @@ -12,66 +14,105 @@ import ( type Director func(host string) (*net.TCPAddr, error) type proxyRouter struct { + sync.Mutex + director Director + listener net.Listener + closed bool } func (r *proxyRouter) Listen(laddr string) { l, err := net.Listen("tcp", laddr) - defer l.Close() if err != nil { log.Fatal(err) } - for { - conn, err := l.Accept() - if err != nil { - log.Println(err) - continue + r.listener = l + go func() { + for !r.closed { + conn, err := r.listener.Accept() + if err != nil { + continue + } + go r.handleConnection(conn) } - go r.handleConnection(conn) + }() +} + +func (r *proxyRouter) Close() { + r.Lock() + defer r.Unlock() + + if r.listener != nil { + r.listener.Close() } + r.closed = true +} + +func (r *proxyRouter) ListenAddress() string { + if r.listener != nil { + return r.listener.Addr().String() + } + return "" } func (r *proxyRouter) handleConnection(c net.Conn) { defer c.Close() // first try tls vhostConn, err := vhost.TLS(c) - if err != nil { - log.Printf("Incoming TLS connection produced an error. Error: %s", err) - return - } - defer vhostConn.Close() - host := vhostConn.ClientHelloMsg.ServerName - c.LocalAddr() - dstHost, err := r.director(fmt.Sprintf("%s:%d", host, 12)) - if err != nil { - log.Printf("Error directing request: %v\n", err) - return - } + if err == nil { + // It is a TLS connection + defer vhostConn.Close() + host := vhostConn.ClientHelloMsg.ServerName + dstHost, err := r.director(host) + if err != nil { + log.Printf("Error directing request: %v\n", err) + return + } + d, err := net.Dial("tcp", dstHost.String()) + if err != nil { + log.Printf("Error dialing backend %s: %v\n", dstHost.String(), err) + return + } - d, err := net.Dial("tcp", dstHost.String()) - if err != nil { - log.Printf("Error dialing backend %s: %v\n", dstHost.String(), err) - return - } + proxy(vhostConn, d) + } else { + // it is not TLS + // treat it as an http connection + req, err := http.ReadRequest(bufio.NewReader(vhostConn)) + if err != nil { + // It is not http neither. So just close the connection. + return + } + dstHost, err := r.director(req.Host) + if err != nil { + log.Printf("Error directing request: %v\n", err) + return + } + d, err := net.Dial("tcp", dstHost.String()) + if err != nil { + log.Printf("Error dialing backend %s: %v\n", dstHost.String(), err) + return + } + err = req.Write(d) + if err != nil { + log.Printf("Error requesting backend %s: %v\n", dstHost.String(), err) + return + } + proxy(c, d) + } +} + +func proxy(src, dst net.Conn) { errc := make(chan error, 2) cp := func(dst io.Writer, src io.Reader) { _, err := io.Copy(dst, src) errc <- err } - go cp(d, vhostConn) - go cp(vhostConn, d) + go cp(src, dst) + go cp(dst, src) <-errc - /* - req, err := http.ReadRequest(bufio.NewReader(c)) - if err != nil { - log.Println(err) - return - } - - log.Println(req.Header) - */ } func NewRouter(director Director) *proxyRouter { diff --git a/router/router_test.go b/router/router_test.go index 192b524..fcdba89 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -4,15 +4,24 @@ import ( "crypto/tls" "fmt" "io/ioutil" + "log" "net" "net/http" "net/http/httptest" "net/url" + "strings" + "sync" "testing" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) +func getRouterUrl(scheme string, r *proxyRouter) string { + chunks := strings.Split(r.ListenAddress(), ":") + return fmt.Sprintf("%s://localhost:%s", scheme, chunks[len(chunks)-1]) +} + func TestProxy_TLS(t *testing.T) { tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, @@ -34,9 +43,10 @@ func TestProxy_TLS(t *testing.T) { a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil }) - go r.Listen(":8080") + r.Listen(":0") + defer r.Close() - req, err := http.NewRequest("GET", "https://localhost:8080", nil) + req, err := http.NewRequest("GET", getRouterUrl("https", r), nil) assert.Nil(t, err) resp, err := client.Do(req) @@ -45,5 +55,151 @@ func TestProxy_TLS(t *testing.T) { body, err := ioutil.ReadAll(resp.Body) assert.Nil(t, err) assert.Equal(t, msg, string(body)) - assert.Equal(t, "localhost:8080", receivedHost) + assert.Equal(t, "localhost", receivedHost) +} + +func TestProxy_Http(t *testing.T) { + const msg = "It works!" + + var receivedHost string + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, msg) + })) + defer ts.Close() + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + receivedHost = host + u, _ := url.Parse(ts.URL) + a, _ := net.ResolveTCPAddr("tcp", u.Host) + return a, nil + }) + r.Listen(":0") + defer r.Close() + + req, err := http.NewRequest("GET", getRouterUrl("http", r), nil) + assert.Nil(t, err) + + resp, err := http.DefaultClient.Do(req) + assert.Nil(t, err) + + body, err := ioutil.ReadAll(resp.Body) + assert.Nil(t, err) + assert.Equal(t, msg, string(body)) + + u, _ := url.Parse(getRouterUrl("http", r)) + assert.Equal(t, u.Host, receivedHost) +} + +func TestProxy_WS(t *testing.T) { + const msg = "It works!" + + var serverReceivedMessage string + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Print("upgrade:", err) + return + } + defer c.Close() + mt, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + return + } + serverReceivedMessage = string(message) + err = c.WriteMessage(mt, message) + if err != nil { + log.Println("write:", err) + return + } + })) + defer ts.Close() + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + u, _ := url.Parse(ts.URL) + a, _ := net.ResolveTCPAddr("tcp", u.Host) + return a, nil + }) + r.Listen(":0") + defer r.Close() + + c, _, err := websocket.DefaultDialer.Dial(getRouterUrl("ws", r), nil) + assert.Nil(t, err) + defer c.Close() + + var clientReceivedMessage []byte + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + _, clientReceivedMessage, _ = c.ReadMessage() + wg.Done() + }() + err = c.WriteMessage(websocket.TextMessage, []byte(msg)) + assert.Nil(t, err) + + wg.Wait() + + assert.Equal(t, msg, string(clientReceivedMessage)) + assert.Equal(t, msg, serverReceivedMessage) +} + +func TestProxy_WSS(t *testing.T) { + const msg = "It works!" + + var serverReceivedMessage string + + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Print("upgrade:", err) + return + } + defer c.Close() + mt, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + return + } + serverReceivedMessage = string(message) + err = c.WriteMessage(mt, message) + if err != nil { + log.Println("write:", err) + return + } + })) + defer ts.Close() + + r := NewRouter(func(host string) (*net.TCPAddr, error) { + u, _ := url.Parse(ts.URL) + a, _ := net.ResolveTCPAddr("tcp", u.Host) + return a, nil + }) + r.Listen(":0") + defer r.Close() + + d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + c, _, err := d.Dial(getRouterUrl("wss", r), nil) + assert.Nil(t, err) + defer c.Close() + + var clientReceivedMessage []byte + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + _, clientReceivedMessage, _ = c.ReadMessage() + wg.Done() + }() + + err = c.WriteMessage(websocket.TextMessage, []byte(msg)) + assert.Nil(t, err) + + wg.Wait() + + assert.Equal(t, msg, string(clientReceivedMessage)) + assert.Equal(t, msg, serverReceivedMessage) }