Changes to L1 router to also pass the protocol it is routing.
This commit is contained in:
@@ -16,7 +16,7 @@ import (
|
|||||||
"github.com/play-with-docker/play-with-docker/router"
|
"github.com/play-with-docker/play-with-docker/router"
|
||||||
)
|
)
|
||||||
|
|
||||||
func director(host string) (*net.TCPAddr, error) {
|
func director(protocol router.Protocol, host string) (*net.TCPAddr, error) {
|
||||||
info, err := router.DecodeHost(host)
|
info, err := router.DecodeHost(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -29,8 +29,15 @@ func director(host string) (*net.TCPAddr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if port == 0 {
|
if port == 0 {
|
||||||
// TODO: Should default depending on the protocol
|
if protocol == router.ProtocolHTTP {
|
||||||
port = 80
|
port = 80
|
||||||
|
} else if protocol == router.ProtocolHTTPS {
|
||||||
|
port = 443
|
||||||
|
} else if protocol == router.ProtocolSSH {
|
||||||
|
port = 22
|
||||||
|
} else if protocol == router.ProtocolDNS {
|
||||||
|
port = 53
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", info.InstanceIP, port))
|
t, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", info.InstanceIP, port))
|
||||||
|
|||||||
@@ -3,42 +3,55 @@ package main
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/play-with-docker/play-with-docker/router"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDirector(t *testing.T) {
|
func TestDirector(t *testing.T) {
|
||||||
addr, err := director("ip10-0-0-1-aabb-8080.foo.bar")
|
addr, err := director(router.ProtocolHTTP, "ip10-0-0-1-aabb-8080.foo.bar")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10.0.0.1:8080", addr.String())
|
assert.Equal(t, "10.0.0.1:8080", addr.String())
|
||||||
|
|
||||||
addr, err = director("ip10-0-0-1-aabb.foo.bar")
|
addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb.foo.bar")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10.0.0.1:80", addr.String())
|
assert.Equal(t, "10.0.0.1:80", addr.String())
|
||||||
|
|
||||||
addr, err = director("ip10-0-0-1-aabb.foo.bar:9090")
|
addr, err = director(router.ProtocolHTTPS, "ip10-0-0-1-aabb.foo.bar")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "10.0.0.1:443", addr.String())
|
||||||
|
|
||||||
|
addr, err = director(router.ProtocolSSH, "ip10-0-0-1-aabb.foo.bar")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "10.0.0.1:22", addr.String())
|
||||||
|
|
||||||
|
addr, err = director(router.ProtocolDNS, "ip10-0-0-1-aabb.foo.bar")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "10.0.0.1:53", addr.String())
|
||||||
|
|
||||||
|
addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb.foo.bar:9090")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10.0.0.1:9090", addr.String())
|
assert.Equal(t, "10.0.0.1:9090", addr.String())
|
||||||
|
|
||||||
addr, err = director("ip10-0-0-1-aabb-2222.foo.bar:9090")
|
addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb-2222.foo.bar:9090")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10.0.0.1:2222", addr.String())
|
assert.Equal(t, "10.0.0.1:2222", addr.String())
|
||||||
|
|
||||||
addr, err = director("lala.ip10-0-0-1-aabb-2222.foo.bar")
|
addr, err = director(router.ProtocolHTTP, "lala.ip10-0-0-1-aabb-2222.foo.bar")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10.0.0.1:2222", addr.String())
|
assert.Equal(t, "10.0.0.1:2222", addr.String())
|
||||||
|
|
||||||
addr, err = director("lala.ip10-0-0-1-aabb-2222")
|
addr, err = director(router.ProtocolHTTP, "lala.ip10-0-0-1-aabb-2222")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10.0.0.1:2222", addr.String())
|
assert.Equal(t, "10.0.0.1:2222", addr.String())
|
||||||
|
|
||||||
addr, err = director("ip10-0-0-1-aabb-2222")
|
addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb-2222")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10.0.0.1:2222", addr.String())
|
assert.Equal(t, "10.0.0.1:2222", addr.String())
|
||||||
|
|
||||||
addr, err = director("ip10-0-0-1-aabb")
|
addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10.0.0.1:80", addr.String())
|
assert.Equal(t, "10.0.0.1:80", addr.String())
|
||||||
|
|
||||||
_, err = director("lala10-0-0-1-aabb.foo.bar")
|
_, err = director(router.ProtocolHTTP, "lala10-0-0-1-aabb.foo.bar")
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,16 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Director func(host string) (*net.TCPAddr, error)
|
type Protocol int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProtocolHTTP Protocol = iota
|
||||||
|
ProtocolHTTPS
|
||||||
|
ProtocolSSH
|
||||||
|
ProtocolDNS
|
||||||
|
)
|
||||||
|
|
||||||
|
type Director func(protocol Protocol, host string) (*net.TCPAddr, error)
|
||||||
|
|
||||||
type proxyRouter struct {
|
type proxyRouter struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
@@ -147,7 +156,7 @@ func (r *proxyRouter) sshHandle(nConn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dstHost, err := r.director(sshCon.User())
|
dstHost, err := r.director(ProtocolSSH, sshCon.User())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
nConn.Close()
|
nConn.Close()
|
||||||
return
|
return
|
||||||
@@ -258,7 +267,7 @@ func (r *proxyRouter) dnsRequest(w dns.ResponseWriter, req *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dstHost, err := r.director(strings.TrimSuffix(question, "."))
|
dstHost, err := r.director(ProtocolDNS, strings.TrimSuffix(question, "."))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Director couldn't resolve it, try to lookup in the system's DNS
|
// Director couldn't resolve it, try to lookup in the system's DNS
|
||||||
ips, err := net.LookupIP(question)
|
ips, err := net.LookupIP(question)
|
||||||
@@ -361,7 +370,7 @@ func (r *proxyRouter) handleConnection(c net.Conn) {
|
|||||||
// It is a TLS connection
|
// It is a TLS connection
|
||||||
defer vhostConn.Close()
|
defer vhostConn.Close()
|
||||||
host := vhostConn.ClientHelloMsg.ServerName
|
host := vhostConn.ClientHelloMsg.ServerName
|
||||||
dstHost, err := r.director(host)
|
dstHost, err := r.director(ProtocolHTTPS, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error directing request: %v\n", err)
|
log.Printf("Error directing request: %v\n", err)
|
||||||
return
|
return
|
||||||
@@ -386,7 +395,7 @@ func (r *proxyRouter) handleConnection(c net.Conn) {
|
|||||||
if host == "" {
|
if host == "" {
|
||||||
host = req.Host
|
host = req.Host
|
||||||
}
|
}
|
||||||
dstHost, err := r.director(host)
|
dstHost, err := r.director(ProtocolHTTP, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error directing request: %v\n", err)
|
log.Printf("Error directing request: %v\n", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -224,14 +224,16 @@ func TestProxy_TLS(t *testing.T) {
|
|||||||
const msg = "It works!"
|
const msg = "It works!"
|
||||||
|
|
||||||
var receivedHost string
|
var receivedHost string
|
||||||
|
var receivedProtocol Protocol
|
||||||
|
|
||||||
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprint(w, msg)
|
fmt.Fprint(w, msg)
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||||
receivedHost = host
|
receivedHost = host
|
||||||
|
receivedProtocol = protocol
|
||||||
u, _ := url.Parse(ts.URL)
|
u, _ := url.Parse(ts.URL)
|
||||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||||
return a, nil
|
return a, nil
|
||||||
@@ -249,6 +251,7 @@ func TestProxy_TLS(t *testing.T) {
|
|||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, msg, string(body))
|
assert.Equal(t, msg, string(body))
|
||||||
assert.Equal(t, "localhost", receivedHost)
|
assert.Equal(t, "localhost", receivedHost)
|
||||||
|
assert.Equal(t, ProtocolHTTPS, receivedProtocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_Http(t *testing.T) {
|
func TestProxy_Http(t *testing.T) {
|
||||||
@@ -258,14 +261,16 @@ func TestProxy_Http(t *testing.T) {
|
|||||||
const msg = "It works!"
|
const msg = "It works!"
|
||||||
|
|
||||||
var receivedHost string
|
var receivedHost string
|
||||||
|
var receivedProtocol Protocol
|
||||||
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprint(w, msg)
|
fmt.Fprint(w, msg)
|
||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||||
receivedHost = host
|
receivedHost = host
|
||||||
|
receivedProtocol = protocol
|
||||||
u, _ := url.Parse(ts.URL)
|
u, _ := url.Parse(ts.URL)
|
||||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||||
return a, nil
|
return a, nil
|
||||||
@@ -285,6 +290,7 @@ func TestProxy_Http(t *testing.T) {
|
|||||||
|
|
||||||
u, _ := url.Parse(getRouterUrl("http", r))
|
u, _ := url.Parse(getRouterUrl("http", r))
|
||||||
assert.Equal(t, u.Host, receivedHost)
|
assert.Equal(t, u.Host, receivedHost)
|
||||||
|
assert.Equal(t, ProtocolHTTP, receivedProtocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxy_WS(t *testing.T) {
|
func TestProxy_WS(t *testing.T) {
|
||||||
@@ -317,7 +323,7 @@ func TestProxy_WS(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||||
u, _ := url.Parse(ts.URL)
|
u, _ := url.Parse(ts.URL)
|
||||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||||
return a, nil
|
return a, nil
|
||||||
@@ -375,7 +381,7 @@ func TestProxy_WSS(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||||
u, _ := url.Parse(ts.URL)
|
u, _ := url.Parse(ts.URL)
|
||||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||||
return a, nil
|
return a, nil
|
||||||
@@ -411,9 +417,11 @@ func TestProxy_DNS_UDP(t *testing.T) {
|
|||||||
defer os.RemoveAll(dir)
|
defer os.RemoveAll(dir)
|
||||||
|
|
||||||
var receivedHost string
|
var receivedHost string
|
||||||
|
var receivedProtocol Protocol
|
||||||
|
|
||||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||||
receivedHost = host
|
receivedHost = host
|
||||||
|
receivedProtocol = protocol
|
||||||
if host == "10_0_0_1.foo.bar" {
|
if host == "10_0_0_1.foo.bar" {
|
||||||
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
|
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
|
||||||
return a, nil
|
return a, nil
|
||||||
@@ -427,11 +435,13 @@ func TestProxy_DNS_UDP(t *testing.T) {
|
|||||||
ips, err := routerLookup("udp", "10_0_0_1.foo.bar", r)
|
ips, err := routerLookup("udp", "10_0_0_1.foo.bar", r)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10_0_0_1.foo.bar", receivedHost)
|
assert.Equal(t, "10_0_0_1.foo.bar", receivedHost)
|
||||||
|
assert.Equal(t, ProtocolDNS, receivedProtocol)
|
||||||
assert.Equal(t, []string{"10.0.0.1"}, ips)
|
assert.Equal(t, []string{"10.0.0.1"}, ips)
|
||||||
|
|
||||||
ips, err = routerLookup("udp", "www.google.com", r)
|
ips, err = routerLookup("udp", "www.google.com", r)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "www.google.com", receivedHost)
|
assert.Equal(t, "www.google.com", receivedHost)
|
||||||
|
assert.Equal(t, ProtocolDNS, receivedProtocol)
|
||||||
|
|
||||||
expectedIps, err := net.LookupHost("www.google.com")
|
expectedIps, err := net.LookupHost("www.google.com")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
@@ -443,6 +453,7 @@ func TestProxy_DNS_UDP(t *testing.T) {
|
|||||||
ips, err = routerLookup("udp", "localhost", r)
|
ips, err = routerLookup("udp", "localhost", r)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.NotEqual(t, "localhost", receivedHost)
|
assert.NotEqual(t, "localhost", receivedHost)
|
||||||
|
assert.Equal(t, ProtocolDNS, receivedProtocol)
|
||||||
assert.Equal(t, []string{"127.0.0.1"}, ips)
|
assert.Equal(t, []string{"127.0.0.1"}, ips)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -452,7 +463,7 @@ func TestProxy_DNS_TCP(t *testing.T) {
|
|||||||
|
|
||||||
var receivedHost string
|
var receivedHost string
|
||||||
|
|
||||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||||
receivedHost = host
|
receivedHost = host
|
||||||
if host == "10_0_0_1.foo.bar" {
|
if host == "10_0_0_1.foo.bar" {
|
||||||
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
|
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
|
||||||
@@ -494,6 +505,7 @@ func TestProxy_SSH(t *testing.T) {
|
|||||||
var receivedPass string
|
var receivedPass string
|
||||||
var receivedChannelType string
|
var receivedChannelType string
|
||||||
var receivedHost string
|
var receivedHost string
|
||||||
|
var receivedProtocol Protocol
|
||||||
|
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@@ -505,8 +517,9 @@ func TestProxy_SSH(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||||
receivedHost = host
|
receivedHost = host
|
||||||
|
receivedProtocol = protocol
|
||||||
if host == "10-0-0-1-aaaabbbb" {
|
if host == "10-0-0-1-aaaabbbb" {
|
||||||
chunks := strings.Split(laddr, ":")
|
chunks := strings.Split(laddr, ":")
|
||||||
a, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]))
|
a, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]))
|
||||||
@@ -527,4 +540,5 @@ func TestProxy_SSH(t *testing.T) {
|
|||||||
assert.Equal(t, "root", receivedPass)
|
assert.Equal(t, "root", receivedPass)
|
||||||
assert.Equal(t, "session", receivedChannelType)
|
assert.Equal(t, "session", receivedChannelType)
|
||||||
assert.Equal(t, "10-0-0-1-aaaabbbb", receivedHost)
|
assert.Equal(t, "10-0-0-1-aaaabbbb", receivedHost)
|
||||||
|
assert.Equal(t, ProtocolSSH, receivedProtocol)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ func (s *scheduler) Start() {
|
|||||||
for _, session := range s.scheduledSessions {
|
for _, session := range s.scheduledSessions {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
session.cancel = cancel
|
session.cancel = cancel
|
||||||
|
session.ticker = time.NewTicker(1 * time.Second)
|
||||||
go s.cron(ctx, session)
|
go s.cron(ctx, session)
|
||||||
}
|
}
|
||||||
s.event.On(event.SESSION_NEW, func(sessionId string, args ...interface{}) {
|
s.event.On(event.SESSION_NEW, func(sessionId string, args ...interface{}) {
|
||||||
@@ -128,7 +129,6 @@ func (s *scheduler) register(session *types.Session) *scheduledSession {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *scheduler) cron(ctx context.Context, session *scheduledSession) {
|
func (s *scheduler) cron(ctx context.Context, session *scheduledSession) {
|
||||||
session.ticker = time.NewTicker(1 * time.Second)
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-session.ticker.C:
|
case <-session.ticker.C:
|
||||||
@@ -167,6 +167,7 @@ func (s *scheduler) Schedule(session *types.Session) error {
|
|||||||
scheduledSession := s.register(session)
|
scheduledSession := s.register(session)
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
scheduledSession.cancel = cancel
|
scheduledSession.cancel = cancel
|
||||||
|
scheduledSession.ticker = time.NewTicker(1 * time.Second)
|
||||||
go s.cron(ctx, scheduledSession)
|
go s.cron(ctx, scheduledSession)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user