545 lines
12 KiB
Go
545 lines
12 KiB
Go
package router
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/asn1"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/miekg/dns"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func testSshClient(user string, r *proxyRouter) error {
|
|
reader := rand.Reader
|
|
bitSize := 2048
|
|
|
|
key, err := rsa.GenerateKey(reader, bitSize)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
signer, err := ssh.NewSignerFromKey(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
config := &ssh.ClientConfig{
|
|
User: user,
|
|
Auth: []ssh.AuthMethod{
|
|
ssh.PublicKeys(signer),
|
|
},
|
|
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
|
return nil
|
|
},
|
|
}
|
|
chunks := strings.Split(r.ListenSshAddress(), ":")
|
|
l := fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1])
|
|
client, err := ssh.Dial("tcp", l, config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer session.Close()
|
|
|
|
return nil
|
|
}
|
|
|
|
func testSshServer(f func(user, pass, ctype string)) (string, error) {
|
|
reader := rand.Reader
|
|
bitSize := 2048
|
|
|
|
key, err := rsa.GenerateKey(reader, bitSize)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
signer, err := ssh.NewSignerFromKey(key)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
var receivedUser string
|
|
var receivedPass string
|
|
var receivedChannelType string
|
|
|
|
config := &ssh.ServerConfig{
|
|
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
receivedUser = conn.User()
|
|
receivedPass = string(password)
|
|
|
|
return nil, nil
|
|
},
|
|
}
|
|
config.AddHostKey(signer)
|
|
listener, err := net.Listen("tcp", "0.0.0.0:0")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
go func() {
|
|
defer listener.Close()
|
|
nConn, err := listener.Accept()
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
conn, chans, reqs, err := ssh.NewServerConn(nConn, config)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
go ssh.DiscardRequests(reqs)
|
|
defer nConn.Close()
|
|
defer conn.Close()
|
|
|
|
ch := <-chans
|
|
|
|
receivedChannelType = ch.ChannelType()
|
|
|
|
f(receivedUser, receivedPass, receivedChannelType)
|
|
}()
|
|
|
|
return listener.Addr().String(), nil
|
|
}
|
|
|
|
func generateKeys() (string, string, string, error) {
|
|
dir, err := ioutil.TempDir("", "pwd")
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
|
|
reader := rand.Reader
|
|
bitSize := 2048
|
|
|
|
key, err := rsa.GenerateKey(reader, bitSize)
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
|
|
privateFile, err := os.OpenFile(fmt.Sprintf("%s/id_rsa", dir), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
defer privateFile.Close()
|
|
|
|
var privateKey = &pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
|
}
|
|
|
|
err = pem.Encode(privateFile, privateKey)
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
|
|
publicFile, err := os.OpenFile(fmt.Sprintf("%s/id_rsa.pub", dir), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
defer publicFile.Close()
|
|
|
|
asn1Bytes, err := asn1.Marshal(key.PublicKey)
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
|
|
var publicKey = &pem.Block{
|
|
Type: "PUBLIC KEY",
|
|
Bytes: asn1Bytes,
|
|
}
|
|
|
|
err = pem.Encode(publicFile, publicKey)
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
return dir, privateFile.Name(), publicFile.Name(), nil
|
|
}
|
|
|
|
func getRouterUrl(scheme string, r *proxyRouter) string {
|
|
chunks := strings.Split(r.ListenHttpAddress(), ":")
|
|
return fmt.Sprintf("%s://localhost:%s", scheme, chunks[len(chunks)-1])
|
|
}
|
|
|
|
func routerLookup(protocol, domain string, r *proxyRouter) ([]string, error) {
|
|
c := dns.Client{Net: protocol}
|
|
m := dns.Msg{}
|
|
|
|
m.SetQuestion(fmt.Sprintf("%s.", domain), dns.TypeA)
|
|
var l string
|
|
if protocol == "udp" {
|
|
chunks := strings.Split(r.ListenDnsUdpAddress(), ":")
|
|
l = fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1])
|
|
} else if protocol == "tcp" {
|
|
chunks := strings.Split(r.ListenDnsTcpAddress(), ":")
|
|
l = fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1])
|
|
}
|
|
res, _, err := c.Exchange(&m, l)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(res.Answer) == 0 {
|
|
return nil, fmt.Errorf("Didn't receive any answer")
|
|
}
|
|
addrs := []string{}
|
|
for _, a := range res.Answer {
|
|
if b, ok := a.(*dns.A); ok {
|
|
addrs = append(addrs, b.A.String())
|
|
} else if b, ok := a.(*dns.AAAA); ok {
|
|
addrs = append(addrs, b.AAAA.String())
|
|
}
|
|
}
|
|
|
|
return addrs, nil
|
|
}
|
|
|
|
func TestProxy_TLS(t *testing.T) {
|
|
dir, private, _, _ := generateKeys()
|
|
defer os.RemoveAll(dir)
|
|
|
|
tr := &http.Transport{
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
}
|
|
client := &http.Client{Transport: tr}
|
|
|
|
const msg = "It works!"
|
|
|
|
var receivedHost string
|
|
var receivedProtocol Protocol
|
|
|
|
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprint(w, msg)
|
|
}))
|
|
defer ts.Close()
|
|
|
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
|
receivedHost = host
|
|
receivedProtocol = protocol
|
|
u, _ := url.Parse(ts.URL)
|
|
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
|
return a, nil
|
|
}, private)
|
|
r.Listen(":0", ":0", ":0")
|
|
defer r.Close()
|
|
|
|
req, err := http.NewRequest("GET", getRouterUrl("https", r), nil)
|
|
assert.Nil(t, err)
|
|
|
|
resp, err := client.Do(req)
|
|
assert.Nil(t, err)
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, msg, string(body))
|
|
assert.Equal(t, "localhost", receivedHost)
|
|
assert.Equal(t, ProtocolHTTPS, receivedProtocol)
|
|
}
|
|
|
|
func TestProxy_Http(t *testing.T) {
|
|
dir, private, _, _ := generateKeys()
|
|
defer os.RemoveAll(dir)
|
|
|
|
const msg = "It works!"
|
|
|
|
var receivedHost string
|
|
var receivedProtocol Protocol
|
|
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprint(w, msg)
|
|
}))
|
|
defer ts.Close()
|
|
|
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
|
receivedHost = host
|
|
receivedProtocol = protocol
|
|
u, _ := url.Parse(ts.URL)
|
|
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
|
return a, nil
|
|
}, private)
|
|
r.Listen(":0", ":0", ":0")
|
|
defer r.Close()
|
|
|
|
req, err := http.NewRequest("GET", getRouterUrl("http", r), nil)
|
|
assert.Nil(t, err)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
assert.Nil(t, err)
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, msg, string(body))
|
|
|
|
u, _ := url.Parse(getRouterUrl("http", r))
|
|
assert.Equal(t, u.Host, receivedHost)
|
|
assert.Equal(t, ProtocolHTTP, receivedProtocol)
|
|
}
|
|
|
|
func TestProxy_WS(t *testing.T) {
|
|
dir, private, _, _ := generateKeys()
|
|
defer os.RemoveAll(dir)
|
|
|
|
const msg = "It works!"
|
|
|
|
var serverReceivedMessage string
|
|
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var upgrader = websocket.Upgrader{}
|
|
c, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Print("upgrade:", err)
|
|
return
|
|
}
|
|
defer c.Close()
|
|
mt, message, err := c.ReadMessage()
|
|
if err != nil {
|
|
log.Println("read:", err)
|
|
return
|
|
}
|
|
serverReceivedMessage = string(message)
|
|
err = c.WriteMessage(mt, message)
|
|
if err != nil {
|
|
log.Println("write:", err)
|
|
return
|
|
}
|
|
}))
|
|
defer ts.Close()
|
|
|
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
|
u, _ := url.Parse(ts.URL)
|
|
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
|
return a, nil
|
|
}, private)
|
|
r.Listen(":0", ":0", ":0")
|
|
defer r.Close()
|
|
|
|
c, _, err := websocket.DefaultDialer.Dial(getRouterUrl("ws", r), nil)
|
|
assert.Nil(t, err)
|
|
defer c.Close()
|
|
|
|
var clientReceivedMessage []byte
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
_, clientReceivedMessage, _ = c.ReadMessage()
|
|
wg.Done()
|
|
}()
|
|
err = c.WriteMessage(websocket.TextMessage, []byte(msg))
|
|
assert.Nil(t, err)
|
|
|
|
wg.Wait()
|
|
|
|
assert.Equal(t, msg, string(clientReceivedMessage))
|
|
assert.Equal(t, msg, serverReceivedMessage)
|
|
}
|
|
|
|
func TestProxy_WSS(t *testing.T) {
|
|
dir, private, _, _ := generateKeys()
|
|
defer os.RemoveAll(dir)
|
|
|
|
const msg = "It works!"
|
|
|
|
var serverReceivedMessage string
|
|
|
|
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var upgrader = websocket.Upgrader{}
|
|
c, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Print("upgrade:", err)
|
|
return
|
|
}
|
|
defer c.Close()
|
|
mt, message, err := c.ReadMessage()
|
|
if err != nil {
|
|
log.Println("read:", err)
|
|
return
|
|
}
|
|
serverReceivedMessage = string(message)
|
|
err = c.WriteMessage(mt, message)
|
|
if err != nil {
|
|
log.Println("write:", err)
|
|
return
|
|
}
|
|
}))
|
|
defer ts.Close()
|
|
|
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
|
u, _ := url.Parse(ts.URL)
|
|
a, _ := net.ResolveTCPAddr("tcp", u.Host)
|
|
return a, nil
|
|
}, private)
|
|
r.Listen(":0", ":0", ":0")
|
|
defer r.Close()
|
|
|
|
d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
|
c, _, err := d.Dial(getRouterUrl("wss", r), nil)
|
|
assert.Nil(t, err)
|
|
defer c.Close()
|
|
|
|
var clientReceivedMessage []byte
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
_, clientReceivedMessage, _ = c.ReadMessage()
|
|
wg.Done()
|
|
}()
|
|
|
|
err = c.WriteMessage(websocket.TextMessage, []byte(msg))
|
|
assert.Nil(t, err)
|
|
|
|
wg.Wait()
|
|
|
|
assert.Equal(t, msg, string(clientReceivedMessage))
|
|
assert.Equal(t, msg, serverReceivedMessage)
|
|
}
|
|
|
|
func TestProxy_DNS_UDP(t *testing.T) {
|
|
dir, private, _, _ := generateKeys()
|
|
defer os.RemoveAll(dir)
|
|
|
|
var receivedHost string
|
|
var receivedProtocol Protocol
|
|
|
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
|
receivedHost = host
|
|
receivedProtocol = protocol
|
|
if host == "10_0_0_1.foo.bar" {
|
|
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
|
|
return a, nil
|
|
} else {
|
|
return nil, fmt.Errorf("Not recognized")
|
|
}
|
|
}, private)
|
|
r.Listen(":0", ":0", ":0")
|
|
defer r.Close()
|
|
|
|
ips, err := routerLookup("udp", "10_0_0_1.foo.bar", r)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, "10_0_0_1.foo.bar", receivedHost)
|
|
assert.Equal(t, ProtocolDNS, receivedProtocol)
|
|
assert.Equal(t, []string{"10.0.0.1"}, ips)
|
|
|
|
ips, err = routerLookup("udp", "www.google.com", r)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, "www.google.com", receivedHost)
|
|
assert.Equal(t, ProtocolDNS, receivedProtocol)
|
|
|
|
expectedIps, err := net.LookupHost("www.google.com")
|
|
assert.Nil(t, err)
|
|
|
|
sort.Strings(expectedIps)
|
|
sort.Strings(ips)
|
|
assert.Equal(t, expectedIps, ips)
|
|
|
|
ips, err = routerLookup("udp", "localhost", r)
|
|
assert.Nil(t, err)
|
|
assert.NotEqual(t, "localhost", receivedHost)
|
|
assert.Equal(t, ProtocolDNS, receivedProtocol)
|
|
assert.Equal(t, []string{"127.0.0.1"}, ips)
|
|
}
|
|
|
|
func TestProxy_DNS_TCP(t *testing.T) {
|
|
dir, private, _, _ := generateKeys()
|
|
defer os.RemoveAll(dir)
|
|
|
|
var receivedHost string
|
|
|
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
|
receivedHost = host
|
|
if host == "10_0_0_1.foo.bar" {
|
|
a, _ := net.ResolveTCPAddr("tcp", "10.0.0.1:0")
|
|
return a, nil
|
|
} else {
|
|
return nil, fmt.Errorf("Not recognized")
|
|
}
|
|
}, private)
|
|
r.Listen(":0", ":0", ":0")
|
|
defer r.Close()
|
|
|
|
ips, err := routerLookup("tcp", "10_0_0_1.foo.bar", r)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, "10_0_0_1.foo.bar", receivedHost)
|
|
assert.Equal(t, []string{"10.0.0.1"}, ips)
|
|
|
|
ips, err = routerLookup("tcp", "www.google.com", r)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, "www.google.com", receivedHost)
|
|
|
|
expectedIps, err := net.LookupHost("www.google.com")
|
|
assert.Nil(t, err)
|
|
|
|
sort.Strings(expectedIps)
|
|
sort.Strings(ips)
|
|
assert.Equal(t, expectedIps, ips)
|
|
|
|
ips, err = routerLookup("tcp", "localhost", r)
|
|
assert.Nil(t, err)
|
|
assert.NotEqual(t, "localhost", receivedHost)
|
|
assert.Equal(t, []string{"127.0.0.1"}, ips)
|
|
}
|
|
|
|
func TestProxy_SSH(t *testing.T) {
|
|
dir, private, _, _ := generateKeys()
|
|
defer os.RemoveAll(dir)
|
|
|
|
var receivedUser string
|
|
var receivedPass string
|
|
var receivedChannelType string
|
|
var receivedHost string
|
|
var receivedProtocol Protocol
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
laddr, err := testSshServer(func(user, pass, ctype string) {
|
|
receivedUser = user
|
|
receivedPass = pass
|
|
receivedChannelType = ctype
|
|
wg.Done()
|
|
})
|
|
assert.Nil(t, err)
|
|
|
|
r := NewRouter(func(protocol Protocol, host string) (*net.TCPAddr, error) {
|
|
receivedHost = host
|
|
receivedProtocol = protocol
|
|
if host == "10-0-0-1-aaaabbbb" {
|
|
chunks := strings.Split(laddr, ":")
|
|
a, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%s", chunks[len(chunks)-1]))
|
|
return a, nil
|
|
} else {
|
|
return nil, fmt.Errorf("Not recognized")
|
|
}
|
|
}, private)
|
|
r.Listen(":0", ":0", ":0")
|
|
defer r.Close()
|
|
|
|
err = testSshClient("10-0-0-1-aaaabbbb", r)
|
|
assert.Nil(t, err)
|
|
|
|
wg.Wait()
|
|
|
|
assert.Equal(t, "root", receivedUser)
|
|
assert.Equal(t, "root", receivedPass)
|
|
assert.Equal(t, "session", receivedChannelType)
|
|
assert.Equal(t, "10-0-0-1-aaaabbbb", receivedHost)
|
|
assert.Equal(t, ProtocolSSH, receivedProtocol)
|
|
}
|