Changes to L1 router to also pass the protocol it is routing.
This commit is contained in:
@@ -16,7 +16,7 @@ import (
|
||||
"github.com/play-with-docker/play-with-docker/router"
|
||||
)
|
||||
|
||||
func director(host string) (*net.TCPAddr, error) {
|
||||
func director(protocol router.Protocol, host string) (*net.TCPAddr, error) {
|
||||
info, err := router.DecodeHost(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -29,8 +29,15 @@ func director(host string) (*net.TCPAddr, error) {
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
// TODO: Should default depending on the protocol
|
||||
port = 80
|
||||
if protocol == router.ProtocolHTTP {
|
||||
port = 80
|
||||
} else if protocol == router.ProtocolHTTPS {
|
||||
port = 443
|
||||
} else if protocol == router.ProtocolSSH {
|
||||
port = 22
|
||||
} else if protocol == router.ProtocolDNS {
|
||||
port = 53
|
||||
}
|
||||
}
|
||||
|
||||
t, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", info.InstanceIP, port))
|
||||
|
||||
@@ -3,42 +3,55 @@ package main
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/play-with-docker/play-with-docker/router"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDirector(t *testing.T) {
|
||||
addr, err := director("ip10-0-0-1-aabb-8080.foo.bar")
|
||||
addr, err := director(router.ProtocolHTTP, "ip10-0-0-1-aabb-8080.foo.bar")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.0.0.1:8080", addr.String())
|
||||
|
||||
addr, err = director("ip10-0-0-1-aabb.foo.bar")
|
||||
addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb.foo.bar")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.0.0.1:80", addr.String())
|
||||
|
||||
addr, err = director("ip10-0-0-1-aabb.foo.bar:9090")
|
||||
addr, err = director(router.ProtocolHTTPS, "ip10-0-0-1-aabb.foo.bar")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.0.0.1:443", addr.String())
|
||||
|
||||
addr, err = director(router.ProtocolSSH, "ip10-0-0-1-aabb.foo.bar")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.0.0.1:22", addr.String())
|
||||
|
||||
addr, err = director(router.ProtocolDNS, "ip10-0-0-1-aabb.foo.bar")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.0.0.1:53", addr.String())
|
||||
|
||||
addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb.foo.bar:9090")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.0.0.1:9090", addr.String())
|
||||
|
||||
addr, err = director("ip10-0-0-1-aabb-2222.foo.bar:9090")
|
||||
addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb-2222.foo.bar:9090")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.0.0.1:2222", addr.String())
|
||||
|
||||
addr, err = director("lala.ip10-0-0-1-aabb-2222.foo.bar")
|
||||
addr, err = director(router.ProtocolHTTP, "lala.ip10-0-0-1-aabb-2222.foo.bar")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.0.0.1:2222", addr.String())
|
||||
|
||||
addr, err = director("lala.ip10-0-0-1-aabb-2222")
|
||||
addr, err = director(router.ProtocolHTTP, "lala.ip10-0-0-1-aabb-2222")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.0.0.1:2222", addr.String())
|
||||
|
||||
addr, err = director("ip10-0-0-1-aabb-2222")
|
||||
addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb-2222")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.0.0.1:2222", addr.String())
|
||||
|
||||
addr, err = director("ip10-0-0-1-aabb")
|
||||
addr, err = director(router.ProtocolHTTP, "ip10-0-0-1-aabb")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.0.0.1:80", addr.String())
|
||||
|
||||
_, err = director("lala10-0-0-1-aabb.foo.bar")
|
||||
_, err = director(router.ProtocolHTTP, "lala10-0-0-1-aabb.foo.bar")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
@@ -17,7 +17,16 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type Director func(host string) (*net.TCPAddr, error)
|
||||
type Protocol int
|
||||
|
||||
const (
|
||||
ProtocolHTTP Protocol = iota
|
||||
ProtocolHTTPS
|
||||
ProtocolSSH
|
||||
ProtocolDNS
|
||||
)
|
||||
|
||||
type Director func(protocol Protocol, host string) (*net.TCPAddr, error)
|
||||
|
||||
type proxyRouter struct {
|
||||
sync.Mutex
|
||||
@@ -147,7 +156,7 @@ func (r *proxyRouter) sshHandle(nConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
dstHost, err := r.director(sshCon.User())
|
||||
dstHost, err := r.director(ProtocolSSH, sshCon.User())
|
||||
if err != nil {
|
||||
nConn.Close()
|
||||
return
|
||||
@@ -258,7 +267,7 @@ func (r *proxyRouter) dnsRequest(w dns.ResponseWriter, req *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
dstHost, err := r.director(strings.TrimSuffix(question, "."))
|
||||
dstHost, err := r.director(ProtocolDNS, strings.TrimSuffix(question, "."))
|
||||
if err != nil {
|
||||
// Director couldn't resolve it, try to lookup in the system's DNS
|
||||
ips, err := net.LookupIP(question)
|
||||
@@ -361,7 +370,7 @@ func (r *proxyRouter) handleConnection(c net.Conn) {
|
||||
// It is a TLS connection
|
||||
defer vhostConn.Close()
|
||||
host := vhostConn.ClientHelloMsg.ServerName
|
||||
dstHost, err := r.director(host)
|
||||
dstHost, err := r.director(ProtocolHTTPS, host)
|
||||
if err != nil {
|
||||
log.Printf("Error directing request: %v\n", err)
|
||||
return
|
||||
@@ -386,7 +395,7 @@ func (r *proxyRouter) handleConnection(c net.Conn) {
|
||||
if host == "" {
|
||||
host = req.Host
|
||||
}
|
||||
dstHost, err := r.director(host)
|
||||
dstHost, err := r.director(ProtocolHTTP, host)
|
||||
if err != nil {
|
||||
log.Printf("Error directing request: %v\n", err)
|
||||
return
|
||||
|
||||
@@ -224,14 +224,16 @@ func TestProxy_TLS(t *testing.T) {
|
||||
const msg = "It works!"
|
||||
|
||||
var receivedHost string
|
||||
var receivedProtocol Protocol
|
||||
|
||||
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, msg)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
receivedProtocol = protocol
|
||||
u, _ := url.Parse(ts.URL)
|
||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||
return a, nil
|
||||
@@ -249,6 +251,7 @@ func TestProxy_TLS(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, msg, string(body))
|
||||
assert.Equal(t, "localhost", receivedHost)
|
||||
assert.Equal(t, ProtocolHTTPS, receivedProtocol)
|
||||
}
|
||||
|
||||
func TestProxy_Http(t *testing.T) {
|
||||
@@ -258,14 +261,16 @@ func TestProxy_Http(t *testing.T) {
|
||||
const msg = "It works!"
|
||||
|
||||
var receivedHost string
|
||||
var receivedProtocol Protocol
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, msg)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
receivedProtocol = protocol
|
||||
u, _ := url.Parse(ts.URL)
|
||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||
return a, nil
|
||||
@@ -285,6 +290,7 @@ func TestProxy_Http(t *testing.T) {
|
||||
|
||||
u, _ := url.Parse(getRouterUrl("http", r))
|
||||
assert.Equal(t, u.Host, receivedHost)
|
||||
assert.Equal(t, ProtocolHTTP, receivedProtocol)
|
||||
}
|
||||
|
||||
func TestProxy_WS(t *testing.T) {
|
||||
@@ -317,7 +323,7 @@ func TestProxy_WS(t *testing.T) {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
u, _ := url.Parse(ts.URL)
|
||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||
return a, nil
|
||||
@@ -375,7 +381,7 @@ func TestProxy_WSS(t *testing.T) {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
u, _ := url.Parse(ts.URL)
|
||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||
return a, nil
|
||||
@@ -411,9 +417,11 @@ func TestProxy_DNS_UDP(t *testing.T) {
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
var receivedHost string
|
||||
var receivedProtocol Protocol
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
receivedProtocol = protocol
|
||||
if host == "10_0_0_1.foo.bar" {
|
||||
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
|
||||
return a, nil
|
||||
@@ -427,11 +435,13 @@ func TestProxy_DNS_UDP(t *testing.T) {
|
||||
ips, err := routerLookup("udp", "10_0_0_1.foo.bar", r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10_0_0_1.foo.bar", receivedHost)
|
||||
assert.Equal(t, ProtocolDNS, receivedProtocol)
|
||||
assert.Equal(t, []string{"10.0.0.1"}, ips)
|
||||
|
||||
ips, err = routerLookup("udp", "www.google.com", r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "www.google.com", receivedHost)
|
||||
assert.Equal(t, ProtocolDNS, receivedProtocol)
|
||||
|
||||
expectedIps, err := net.LookupHost("www.google.com")
|
||||
assert.Nil(t, err)
|
||||
@@ -443,6 +453,7 @@ func TestProxy_DNS_UDP(t *testing.T) {
|
||||
ips, err = routerLookup("udp", "localhost", r)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, "localhost", receivedHost)
|
||||
assert.Equal(t, ProtocolDNS, receivedProtocol)
|
||||
assert.Equal(t, []string{"127.0.0.1"}, ips)
|
||||
}
|
||||
|
||||
@@ -452,7 +463,7 @@ func TestProxy_DNS_TCP(t *testing.T) {
|
||||
|
||||
var receivedHost string
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
if host == "10_0_0_1.foo.bar" {
|
||||
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
|
||||
@@ -494,6 +505,7 @@ func TestProxy_SSH(t *testing.T) {
|
||||
var receivedPass string
|
||||
var receivedChannelType string
|
||||
var receivedHost string
|
||||
var receivedProtocol Protocol
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
@@ -505,8 +517,9 @@ func TestProxy_SSH(t *testing.T) {
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
receivedProtocol = protocol
|
||||
if host == "10-0-0-1-aaaabbbb" {
|
||||
chunks := strings.Split(laddr, ":")
|
||||
a, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]))
|
||||
@@ -527,4 +540,5 @@ func TestProxy_SSH(t *testing.T) {
|
||||
assert.Equal(t, "root", receivedPass)
|
||||
assert.Equal(t, "session", receivedChannelType)
|
||||
assert.Equal(t, "10-0-0-1-aaaabbbb", receivedHost)
|
||||
assert.Equal(t, ProtocolSSH, receivedProtocol)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user