From 3906bd3b57d858258db4fa845df480213807d0e3 Mon Sep 17 00:00:00 2001 From: "Jonathan Leibiusky @xetorthio" Date: Thu, 3 Aug 2017 11:55:53 -0300 Subject: [PATCH] Changes to L1 router to also pass the protocol it is routing. --- router/l2/l2.go | 13 ++++++++++--- router/l2/l2_test.go | 31 ++++++++++++++++++++++--------- router/router.go | 19 ++++++++++++++----- router/router_test.go | 28 +++++++++++++++++++++------- scheduler/scheduler.go | 3 ++- 5 files changed, 69 insertions(+), 25 deletions(-) diff --git a/router/l2/l2.go b/router/l2/l2.go index 416c095..38c0e1d 100644 --- a/router/l2/l2.go +++ b/router/l2/l2.go @@ -16,7 +16,7 @@ import ( "github.com/play-with-docker/play-with-docker/router" ) -func director(host string) (*net.TCPAddr, error) { +func director(protocol router.Protocol, host string) (*net.TCPAddr, error) { info, err := router.DecodeHost(host) if err != nil { return nil, err @@ -29,8 +29,15 @@ func director(host string) (*net.TCPAddr, error) { } if port == 0 { - // TODO: Should default depending on the protocol - port = 80 + if protocol == router.ProtocolHTTP { + port = 80 + } else if protocol == router.ProtocolHTTPS { + port = 443 + } else if protocol == router.ProtocolSSH { + port = 22 + } else if protocol == router.ProtocolDNS { + port = 53 + } } t, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", info.InstanceIP, port)) diff --git a/router/l2/l2_test.go b/router/l2/l2_test.go index 5b31040..dc0ba63 100644 --- a/router/l2/l2_test.go +++ b/router/l2/l2_test.go @@ -3,42 +3,55 @@ package main import ( "testing" + "github.com/play-with-docker/play-with-docker/router" "github.com/stretchr/testify/assert" ) func TestDirector(t *testing.T) { - addr, err := director("ip10-0-0-1-aabb-8080.foo.bar") + addr, err := director(router.ProtocolHTTP, "ip10-0-0-1-aabb-8080.foo.bar") assert.Nil(t, err) assert.Equal(t, "10.0.0.1:8080", addr.String()) - addr, err = director("ip10-0-0-1-aabb.foo.bar") + addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb.foo.bar") assert.Nil(t, err) assert.Equal(t, "10.0.0.1:80", addr.String()) - addr, err = director("ip10-0-0-1-aabb.foo.bar:9090") + addr, err = director(router.ProtocolHTTPS, "ip10-0-0-1-aabb.foo.bar") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:443", addr.String()) + + addr, err = director(router.ProtocolSSH, "ip10-0-0-1-aabb.foo.bar") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:22", addr.String()) + + addr, err = director(router.ProtocolDNS, "ip10-0-0-1-aabb.foo.bar") + assert.Nil(t, err) + assert.Equal(t, "10.0.0.1:53", addr.String()) + + addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb.foo.bar:9090") assert.Nil(t, err) assert.Equal(t, "10.0.0.1:9090", addr.String()) - addr, err = director("ip10-0-0-1-aabb-2222.foo.bar:9090") + addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb-2222.foo.bar:9090") assert.Nil(t, err) assert.Equal(t, "10.0.0.1:2222", addr.String()) - addr, err = director("lala.ip10-0-0-1-aabb-2222.foo.bar") + addr, err = director(router.ProtocolHTTP, "lala.ip10-0-0-1-aabb-2222.foo.bar") assert.Nil(t, err) assert.Equal(t, "10.0.0.1:2222", addr.String()) - addr, err = director("lala.ip10-0-0-1-aabb-2222") + addr, err = director(router.ProtocolHTTP, "lala.ip10-0-0-1-aabb-2222") assert.Nil(t, err) assert.Equal(t, "10.0.0.1:2222", addr.String()) - addr, err = director("ip10-0-0-1-aabb-2222") + addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb-2222") assert.Nil(t, err) assert.Equal(t, "10.0.0.1:2222", addr.String()) - addr, err = director("ip10-0-0-1-aabb") + addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb") assert.Nil(t, err) assert.Equal(t, "10.0.0.1:80", addr.String()) - _, err = director("lala10-0-0-1-aabb.foo.bar") + _, err = director(router.ProtocolHTTP, "lala10-0-0-1-aabb.foo.bar") assert.NotNil(t, err) } diff --git a/router/router.go b/router/router.go index b9c0e9b..7a61cdf 100644 --- a/router/router.go +++ b/router/router.go @@ -17,7 +17,16 @@ import ( "github.com/miekg/dns" ) -type Director func(host string) (*net.TCPAddr, error) +type Protocol int + +const ( + ProtocolHTTP Protocol = iota + ProtocolHTTPS + ProtocolSSH + ProtocolDNS +) + +type Director func(protocol Protocol, host string) (*net.TCPAddr, error) type proxyRouter struct { sync.Mutex @@ -147,7 +156,7 @@ func (r *proxyRouter) sshHandle(nConn net.Conn) { return } - dstHost, err := r.director(sshCon.User()) + dstHost, err := r.director(ProtocolSSH, sshCon.User()) if err != nil { nConn.Close() return @@ -258,7 +267,7 @@ func (r *proxyRouter) dnsRequest(w dns.ResponseWriter, req *dns.Msg) { return } - dstHost, err := r.director(strings.TrimSuffix(question, ".")) + dstHost, err := r.director(ProtocolDNS, strings.TrimSuffix(question, ".")) if err != nil { // Director couldn't resolve it, try to lookup in the system's DNS ips, err := net.LookupIP(question) @@ -361,7 +370,7 @@ func (r *proxyRouter) handleConnection(c net.Conn) { // It is a TLS connection defer vhostConn.Close() host := vhostConn.ClientHelloMsg.ServerName - dstHost, err := r.director(host) + dstHost, err := r.director(ProtocolHTTPS, host) if err != nil { log.Printf("Error directing request: %v\n", err) return @@ -386,7 +395,7 @@ func (r *proxyRouter) handleConnection(c net.Conn) { if host == "" { host = req.Host } - dstHost, err := r.director(host) + dstHost, err := r.director(ProtocolHTTP, host) if err != nil { log.Printf("Error directing request: %v\n", err) return diff --git a/router/router_test.go b/router/router_test.go index 3505e23..0ef2bc4 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -224,14 +224,16 @@ func TestProxy_TLS(t *testing.T) { const msg = "It works!" var receivedHost string + var receivedProtocol Protocol ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, msg) })) defer ts.Close() - r := NewRouter(func(host string) (*net.TCPAddr, error) { + r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) { receivedHost = host + receivedProtocol = protocol u, _ := url.Parse(ts.URL) a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil @@ -249,6 +251,7 @@ func TestProxy_TLS(t *testing.T) { assert.Nil(t, err) assert.Equal(t, msg, string(body)) assert.Equal(t, "localhost", receivedHost) + assert.Equal(t, ProtocolHTTPS, receivedProtocol) } func TestProxy_Http(t *testing.T) { @@ -258,14 +261,16 @@ func TestProxy_Http(t *testing.T) { const msg = "It works!" var receivedHost string + var receivedProtocol Protocol ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, msg) })) defer ts.Close() - r := NewRouter(func(host string) (*net.TCPAddr, error) { + r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) { receivedHost = host + receivedProtocol = protocol u, _ := url.Parse(ts.URL) a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil @@ -285,6 +290,7 @@ func TestProxy_Http(t *testing.T) { u, _ := url.Parse(getRouterUrl("http", r)) assert.Equal(t, u.Host, receivedHost) + assert.Equal(t, ProtocolHTTP, receivedProtocol) } func TestProxy_WS(t *testing.T) { @@ -317,7 +323,7 @@ func TestProxy_WS(t *testing.T) { })) defer ts.Close() - r := NewRouter(func(host string) (*net.TCPAddr, error) { + r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) { u, _ := url.Parse(ts.URL) a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil @@ -375,7 +381,7 @@ func TestProxy_WSS(t *testing.T) { })) defer ts.Close() - r := NewRouter(func(host string) (*net.TCPAddr, error) { + r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) { u, _ := url.Parse(ts.URL) a, _ := net.ResolveTCPAddr("tcp", u.Host) return a, nil @@ -411,9 +417,11 @@ func TestProxy_DNS_UDP(t *testing.T) { defer os.RemoveAll(dir) var receivedHost string + var receivedProtocol Protocol - r := NewRouter(func(host string) (*net.TCPAddr, error) { + r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) { receivedHost = host + receivedProtocol = protocol if host == "10_0_0_1.foo.bar" { a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0") return a, nil @@ -427,11 +435,13 @@ func TestProxy_DNS_UDP(t *testing.T) { ips, err := routerLookup("udp", "10_0_0_1.foo.bar", r) assert.Nil(t, err) assert.Equal(t, "10_0_0_1.foo.bar", receivedHost) + assert.Equal(t, ProtocolDNS, receivedProtocol) assert.Equal(t, []string{"10.0.0.1"}, ips) ips, err = routerLookup("udp", "www.google.com", r) assert.Nil(t, err) assert.Equal(t, "www.google.com", receivedHost) + assert.Equal(t, ProtocolDNS, receivedProtocol) expectedIps, err := net.LookupHost("www.google.com") assert.Nil(t, err) @@ -443,6 +453,7 @@ func TestProxy_DNS_UDP(t *testing.T) { ips, err = routerLookup("udp", "localhost", r) assert.Nil(t, err) assert.NotEqual(t, "localhost", receivedHost) + assert.Equal(t, ProtocolDNS, receivedProtocol) assert.Equal(t, []string{"127.0.0.1"}, ips) } @@ -452,7 +463,7 @@ func TestProxy_DNS_TCP(t *testing.T) { var receivedHost string - r := NewRouter(func(host string) (*net.TCPAddr, error) { + r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) { receivedHost = host if host == "10_0_0_1.foo.bar" { a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0") @@ -494,6 +505,7 @@ func TestProxy_SSH(t *testing.T) { var receivedPass string var receivedChannelType string var receivedHost string + var receivedProtocol Protocol wg := sync.WaitGroup{} wg.Add(1) @@ -505,8 +517,9 @@ func TestProxy_SSH(t *testing.T) { }) assert.Nil(t, err) - r := NewRouter(func(host string) (*net.TCPAddr, error) { + r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) { receivedHost = host + receivedProtocol = protocol if host == "10-0-0-1-aaaabbbb" { chunks := strings.Split(laddr, ":") a, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1])) @@ -527,4 +540,5 @@ func TestProxy_SSH(t *testing.T) { assert.Equal(t, "root", receivedPass) assert.Equal(t, "session", receivedChannelType) assert.Equal(t, "10-0-0-1-aaaabbbb", receivedHost) + assert.Equal(t, ProtocolSSH, receivedProtocol) } diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go index 1156ef3..94051b1 100644 --- a/scheduler/scheduler.go +++ b/scheduler/scheduler.go @@ -96,6 +96,7 @@ func (s *scheduler) Start() { for _, session := range s.scheduledSessions { ctx, cancel := context.WithCancel(context.Background()) session.cancel = cancel + session.ticker = time.NewTicker(1 * time.Second) go s.cron(ctx, session) } s.event.On(event.SESSION_NEW, func(sessionId string, args ...interface{}) { @@ -128,7 +129,6 @@ func (s *scheduler) register(session *types.Session) *scheduledSession { } func (s *scheduler) cron(ctx context.Context, session *scheduledSession) { - session.ticker = time.NewTicker(1 * time.Second) for { select { case <-session.ticker.C: @@ -167,6 +167,7 @@ func (s *scheduler) Schedule(session *types.Session) error { scheduledSession := s.register(session) ctx, cancel := context.WithCancel(context.Background()) scheduledSession.cancel = cancel + scheduledSession.ticker = time.NewTicker(1 * time.Second) go s.cron(ctx, scheduledSession) return nil }