Changes to L1 router to also pass the protocol it is routing.

This commit is contained in:
Jonathan Leibiusky @xetorthio
2017-08-03 11:55:53 -03:00
parent e0626f4176
commit 3906bd3b57
5 changed files with 69 additions and 25 deletions

View File

@@ -16,7 +16,7 @@ import (
"github.com/play-with-docker/play-with-docker/router" "github.com/play-with-docker/play-with-docker/router"
) )
func director(host string) (*net.TCPAddr, error) { func director(protocol router.Protocol, host string) (*net.TCPAddr, error) {
info, err := router.DecodeHost(host) info, err := router.DecodeHost(host)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -29,8 +29,15 @@ func director(host string) (*net.TCPAddr, error) {
} }
if port == 0 { if port == 0 {
// TODO: Should default depending on the protocol if protocol == router.ProtocolHTTP {
port = 80 port = 80
} else if protocol == router.ProtocolHTTPS {
port = 443
} else if protocol == router.ProtocolSSH {
port = 22
} else if protocol == router.ProtocolDNS {
port = 53
}
} }
t, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", info.InstanceIP, port)) t, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", info.InstanceIP, port))

View File

@@ -3,42 +3,55 @@ package main
import ( import (
"testing" "testing"
"github.com/play-with-docker/play-with-docker/router"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestDirector(t *testing.T) { func TestDirector(t *testing.T) {
addr, err := director("ip10-0-0-1-aabb-8080.foo.bar") addr, err := director(router.ProtocolHTTP, "ip10-0-0-1-aabb-8080.foo.bar")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "10.0.0.1:8080", addr.String()) assert.Equal(t, "10.0.0.1:8080", addr.String())
addr, err = director("ip10-0-0-1-aabb.foo.bar") addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb.foo.bar")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "10.0.0.1:80", addr.String()) assert.Equal(t, "10.0.0.1:80", addr.String())
addr, err = director("ip10-0-0-1-aabb.foo.bar:9090") addr, err = director(router.ProtocolHTTPS, "ip10-0-0-1-aabb.foo.bar")
assert.Nil(t, err)
assert.Equal(t, "10.0.0.1:443", addr.String())
addr, err = director(router.ProtocolSSH, "ip10-0-0-1-aabb.foo.bar")
assert.Nil(t, err)
assert.Equal(t, "10.0.0.1:22", addr.String())
addr, err = director(router.ProtocolDNS, "ip10-0-0-1-aabb.foo.bar")
assert.Nil(t, err)
assert.Equal(t, "10.0.0.1:53", addr.String())
addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb.foo.bar:9090")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "10.0.0.1:9090", addr.String()) assert.Equal(t, "10.0.0.1:9090", addr.String())
addr, err = director("ip10-0-0-1-aabb-2222.foo.bar:9090") addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb-2222.foo.bar:9090")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "10.0.0.1:2222", addr.String()) assert.Equal(t, "10.0.0.1:2222", addr.String())
addr, err = director("lala.ip10-0-0-1-aabb-2222.foo.bar") addr, err = director(router.ProtocolHTTP, "lala.ip10-0-0-1-aabb-2222.foo.bar")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "10.0.0.1:2222", addr.String()) assert.Equal(t, "10.0.0.1:2222", addr.String())
addr, err = director("lala.ip10-0-0-1-aabb-2222") addr, err = director(router.ProtocolHTTP, "lala.ip10-0-0-1-aabb-2222")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "10.0.0.1:2222", addr.String()) assert.Equal(t, "10.0.0.1:2222", addr.String())
addr, err = director("ip10-0-0-1-aabb-2222") addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb-2222")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "10.0.0.1:2222", addr.String()) assert.Equal(t, "10.0.0.1:2222", addr.String())
addr, err = director("ip10-0-0-1-aabb") addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "10.0.0.1:80", addr.String()) assert.Equal(t, "10.0.0.1:80", addr.String())
_, err = director("lala10-0-0-1-aabb.foo.bar") _, err = director(router.ProtocolHTTP, "lala10-0-0-1-aabb.foo.bar")
assert.NotNil(t, err) assert.NotNil(t, err)
} }

View File

@@ -17,7 +17,16 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
type Director func(host string) (*net.TCPAddr, error) type Protocol int
const (
ProtocolHTTP Protocol = iota
ProtocolHTTPS
ProtocolSSH
ProtocolDNS
)
type Director func(protocol Protocol, host string) (*net.TCPAddr, error)
type proxyRouter struct { type proxyRouter struct {
sync.Mutex sync.Mutex
@@ -147,7 +156,7 @@ func (r *proxyRouter) sshHandle(nConn net.Conn) {
return return
} }
dstHost, err := r.director(sshCon.User()) dstHost, err := r.director(ProtocolSSH, sshCon.User())
if err != nil { if err != nil {
nConn.Close() nConn.Close()
return return
@@ -258,7 +267,7 @@ func (r *proxyRouter) dnsRequest(w dns.ResponseWriter, req *dns.Msg) {
return return
} }
dstHost, err := r.director(strings.TrimSuffix(question, ".")) dstHost, err := r.director(ProtocolDNS, strings.TrimSuffix(question, "."))
if err != nil { if err != nil {
// Director couldn't resolve it, try to lookup in the system's DNS // Director couldn't resolve it, try to lookup in the system's DNS
ips, err := net.LookupIP(question) ips, err := net.LookupIP(question)
@@ -361,7 +370,7 @@ func (r *proxyRouter) handleConnection(c net.Conn) {
// It is a TLS connection // It is a TLS connection
defer vhostConn.Close() defer vhostConn.Close()
host := vhostConn.ClientHelloMsg.ServerName host := vhostConn.ClientHelloMsg.ServerName
dstHost, err := r.director(host) dstHost, err := r.director(ProtocolHTTPS, host)
if err != nil { if err != nil {
log.Printf("Error directing request: %v\n", err) log.Printf("Error directing request: %v\n", err)
return return
@@ -386,7 +395,7 @@ func (r *proxyRouter) handleConnection(c net.Conn) {
if host == "" { if host == "" {
host = req.Host host = req.Host
} }
dstHost, err := r.director(host) dstHost, err := r.director(ProtocolHTTP, host)
if err != nil { if err != nil {
log.Printf("Error directing request: %v\n", err) log.Printf("Error directing request: %v\n", err)
return return

View File

@@ -224,14 +224,16 @@ func TestProxy_TLS(t *testing.T) {
const msg = "It works!" const msg = "It works!"
var receivedHost string var receivedHost string
var receivedProtocol Protocol
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, msg) fmt.Fprint(w, msg)
})) }))
defer ts.Close() defer ts.Close()
r := NewRouter(func(host string) (*net.TCPAddr, error) { r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
receivedHost = host receivedHost = host
receivedProtocol = protocol
u, _ := url.Parse(ts.URL) u, _ := url.Parse(ts.URL)
a, _ := net.ResolveTCPAddr("tcp", u.Host) a, _ := net.ResolveTCPAddr("tcp", u.Host)
return a, nil return a, nil
@@ -249,6 +251,7 @@ func TestProxy_TLS(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, msg, string(body)) assert.Equal(t, msg, string(body))
assert.Equal(t, "localhost", receivedHost) assert.Equal(t, "localhost", receivedHost)
assert.Equal(t, ProtocolHTTPS, receivedProtocol)
} }
func TestProxy_Http(t *testing.T) { func TestProxy_Http(t *testing.T) {
@@ -258,14 +261,16 @@ func TestProxy_Http(t *testing.T) {
const msg = "It works!" const msg = "It works!"
var receivedHost string var receivedHost string
var receivedProtocol Protocol
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, msg) fmt.Fprint(w, msg)
})) }))
defer ts.Close() defer ts.Close()
r := NewRouter(func(host string) (*net.TCPAddr, error) { r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
receivedHost = host receivedHost = host
receivedProtocol = protocol
u, _ := url.Parse(ts.URL) u, _ := url.Parse(ts.URL)
a, _ := net.ResolveTCPAddr("tcp", u.Host) a, _ := net.ResolveTCPAddr("tcp", u.Host)
return a, nil return a, nil
@@ -285,6 +290,7 @@ func TestProxy_Http(t *testing.T) {
u, _ := url.Parse(getRouterUrl("http", r)) u, _ := url.Parse(getRouterUrl("http", r))
assert.Equal(t, u.Host, receivedHost) assert.Equal(t, u.Host, receivedHost)
assert.Equal(t, ProtocolHTTP, receivedProtocol)
} }
func TestProxy_WS(t *testing.T) { func TestProxy_WS(t *testing.T) {
@@ -317,7 +323,7 @@ func TestProxy_WS(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
r := NewRouter(func(host string) (*net.TCPAddr, error) { r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
u, _ := url.Parse(ts.URL) u, _ := url.Parse(ts.URL)
a, _ := net.ResolveTCPAddr("tcp", u.Host) a, _ := net.ResolveTCPAddr("tcp", u.Host)
return a, nil return a, nil
@@ -375,7 +381,7 @@ func TestProxy_WSS(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
r := NewRouter(func(host string) (*net.TCPAddr, error) { r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
u, _ := url.Parse(ts.URL) u, _ := url.Parse(ts.URL)
a, _ := net.ResolveTCPAddr("tcp", u.Host) a, _ := net.ResolveTCPAddr("tcp", u.Host)
return a, nil return a, nil
@@ -411,9 +417,11 @@ func TestProxy_DNS_UDP(t *testing.T) {
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
var receivedHost string var receivedHost string
var receivedProtocol Protocol
r := NewRouter(func(host string) (*net.TCPAddr, error) { r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
receivedHost = host receivedHost = host
receivedProtocol = protocol
if host == "10_0_0_1.foo.bar" { if host == "10_0_0_1.foo.bar" {
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0") a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
return a, nil return a, nil
@@ -427,11 +435,13 @@ func TestProxy_DNS_UDP(t *testing.T) {
ips, err := routerLookup("udp", "10_0_0_1.foo.bar", r) ips, err := routerLookup("udp", "10_0_0_1.foo.bar", r)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "10_0_0_1.foo.bar", receivedHost) assert.Equal(t, "10_0_0_1.foo.bar", receivedHost)
assert.Equal(t, ProtocolDNS, receivedProtocol)
assert.Equal(t, []string{"10.0.0.1"}, ips) assert.Equal(t, []string{"10.0.0.1"}, ips)
ips, err = routerLookup("udp", "www.google.com", r) ips, err = routerLookup("udp", "www.google.com", r)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "www.google.com", receivedHost) assert.Equal(t, "www.google.com", receivedHost)
assert.Equal(t, ProtocolDNS, receivedProtocol)
expectedIps, err := net.LookupHost("www.google.com") expectedIps, err := net.LookupHost("www.google.com")
assert.Nil(t, err) assert.Nil(t, err)
@@ -443,6 +453,7 @@ func TestProxy_DNS_UDP(t *testing.T) {
ips, err = routerLookup("udp", "localhost", r) ips, err = routerLookup("udp", "localhost", r)
assert.Nil(t, err) assert.Nil(t, err)
assert.NotEqual(t, "localhost", receivedHost) assert.NotEqual(t, "localhost", receivedHost)
assert.Equal(t, ProtocolDNS, receivedProtocol)
assert.Equal(t, []string{"127.0.0.1"}, ips) assert.Equal(t, []string{"127.0.0.1"}, ips)
} }
@@ -452,7 +463,7 @@ func TestProxy_DNS_TCP(t *testing.T) {
var receivedHost string var receivedHost string
r := NewRouter(func(host string) (*net.TCPAddr, error) { r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
receivedHost = host receivedHost = host
if host == "10_0_0_1.foo.bar" { if host == "10_0_0_1.foo.bar" {
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0") a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
@@ -494,6 +505,7 @@ func TestProxy_SSH(t *testing.T) {
var receivedPass string var receivedPass string
var receivedChannelType string var receivedChannelType string
var receivedHost string var receivedHost string
var receivedProtocol Protocol
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
@@ -505,8 +517,9 @@ func TestProxy_SSH(t *testing.T) {
}) })
assert.Nil(t, err) assert.Nil(t, err)
r := NewRouter(func(host string) (*net.TCPAddr, error) { r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
receivedHost = host receivedHost = host
receivedProtocol = protocol
if host == "10-0-0-1-aaaabbbb" { if host == "10-0-0-1-aaaabbbb" {
chunks := strings.Split(laddr, ":") chunks := strings.Split(laddr, ":")
a, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1])) a, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]))
@@ -527,4 +540,5 @@ func TestProxy_SSH(t *testing.T) {
assert.Equal(t, "root", receivedPass) assert.Equal(t, "root", receivedPass)
assert.Equal(t, "session", receivedChannelType) assert.Equal(t, "session", receivedChannelType)
assert.Equal(t, "10-0-0-1-aaaabbbb", receivedHost) assert.Equal(t, "10-0-0-1-aaaabbbb", receivedHost)
assert.Equal(t, ProtocolSSH, receivedProtocol)
} }

View File

@@ -96,6 +96,7 @@ func (s *scheduler) Start() {
for _, session := range s.scheduledSessions { for _, session := range s.scheduledSessions {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
session.cancel = cancel session.cancel = cancel
session.ticker = time.NewTicker(1 * time.Second)
go s.cron(ctx, session) go s.cron(ctx, session)
} }
s.event.On(event.SESSION_NEW, func(sessionId string, args ...interface{}) { s.event.On(event.SESSION_NEW, func(sessionId string, args ...interface{}) {
@@ -128,7 +129,6 @@ func (s *scheduler) register(session *types.Session) *scheduledSession {
} }
func (s *scheduler) cron(ctx context.Context, session *scheduledSession) { func (s *scheduler) cron(ctx context.Context, session *scheduledSession) {
session.ticker = time.NewTicker(1 * time.Second)
for { for {
select { select {
case <-session.ticker.C: case <-session.ticker.C:
@@ -167,6 +167,7 @@ func (s *scheduler) Schedule(session *types.Session) error {
scheduledSession := s.register(session) scheduledSession := s.register(session)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
scheduledSession.cancel = cancel scheduledSession.cancel = cancel
scheduledSession.ticker = time.NewTicker(1 * time.Second)
go s.cron(ctx, scheduledSession) go s.cron(ctx, scheduledSession)
return nil return nil
} }