Changes to L1 router to also pass the protocol it is routing.
This commit is contained in:
@@ -224,14 +224,16 @@ func TestProxy_TLS(t *testing.T) {
|
||||
const msg = "It works!"
|
||||
|
||||
var receivedHost string
|
||||
var receivedProtocol Protocol
|
||||
|
||||
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, msg)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
receivedProtocol = protocol
|
||||
u, _ := url.Parse(ts.URL)
|
||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||
return a, nil
|
||||
@@ -249,6 +251,7 @@ func TestProxy_TLS(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, msg, string(body))
|
||||
assert.Equal(t, "localhost", receivedHost)
|
||||
assert.Equal(t, ProtocolHTTPS, receivedProtocol)
|
||||
}
|
||||
|
||||
func TestProxy_Http(t *testing.T) {
|
||||
@@ -258,14 +261,16 @@ func TestProxy_Http(t *testing.T) {
|
||||
const msg = "It works!"
|
||||
|
||||
var receivedHost string
|
||||
var receivedProtocol Protocol
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, msg)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
receivedProtocol = protocol
|
||||
u, _ := url.Parse(ts.URL)
|
||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||
return a, nil
|
||||
@@ -285,6 +290,7 @@ func TestProxy_Http(t *testing.T) {
|
||||
|
||||
u, _ := url.Parse(getRouterUrl("http", r))
|
||||
assert.Equal(t, u.Host, receivedHost)
|
||||
assert.Equal(t, ProtocolHTTP, receivedProtocol)
|
||||
}
|
||||
|
||||
func TestProxy_WS(t *testing.T) {
|
||||
@@ -317,7 +323,7 @@ func TestProxy_WS(t *testing.T) {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
u, _ := url.Parse(ts.URL)
|
||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||
return a, nil
|
||||
@@ -375,7 +381,7 @@ func TestProxy_WSS(t *testing.T) {
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
u, _ := url.Parse(ts.URL)
|
||||
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
||||
return a, nil
|
||||
@@ -411,9 +417,11 @@ func TestProxy_DNS_UDP(t *testing.T) {
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
var receivedHost string
|
||||
var receivedProtocol Protocol
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
receivedProtocol = protocol
|
||||
if host == "10_0_0_1.foo.bar" {
|
||||
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
|
||||
return a, nil
|
||||
@@ -427,11 +435,13 @@ func TestProxy_DNS_UDP(t *testing.T) {
|
||||
ips, err := routerLookup("udp", "10_0_0_1.foo.bar", r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10_0_0_1.foo.bar", receivedHost)
|
||||
assert.Equal(t, ProtocolDNS, receivedProtocol)
|
||||
assert.Equal(t, []string{"10.0.0.1"}, ips)
|
||||
|
||||
ips, err = routerLookup("udp", "www.google.com", r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "www.google.com", receivedHost)
|
||||
assert.Equal(t, ProtocolDNS, receivedProtocol)
|
||||
|
||||
expectedIps, err := net.LookupHost("www.google.com")
|
||||
assert.Nil(t, err)
|
||||
@@ -443,6 +453,7 @@ func TestProxy_DNS_UDP(t *testing.T) {
|
||||
ips, err = routerLookup("udp", "localhost", r)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, "localhost", receivedHost)
|
||||
assert.Equal(t, ProtocolDNS, receivedProtocol)
|
||||
assert.Equal(t, []string{"127.0.0.1"}, ips)
|
||||
}
|
||||
|
||||
@@ -452,7 +463,7 @@ func TestProxy_DNS_TCP(t *testing.T) {
|
||||
|
||||
var receivedHost string
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
if host == "10_0_0_1.foo.bar" {
|
||||
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
|
||||
@@ -494,6 +505,7 @@ func TestProxy_SSH(t *testing.T) {
|
||||
var receivedPass string
|
||||
var receivedChannelType string
|
||||
var receivedHost string
|
||||
var receivedProtocol Protocol
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
@@ -505,8 +517,9 @@ func TestProxy_SSH(t *testing.T) {
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
r := NewRouter(func(host string) (*net.TCPAddr, error) {
|
||||
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
||||
receivedHost = host
|
||||
receivedProtocol = protocol
|
||||
if host == "10-0-0-1-aaaabbbb" {
|
||||
chunks := strings.Split(laddr, ":")
|
||||
a, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]))
|
||||
@@ -527,4 +540,5 @@ func TestProxy_SSH(t *testing.T) {
|
||||
assert.Equal(t, "root", receivedPass)
|
||||
assert.Equal(t, "session", receivedChannelType)
|
||||
assert.Equal(t, "10-0-0-1-aaaabbbb", receivedHost)
|
||||
assert.Equal(t, ProtocolSSH, receivedProtocol)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user