diff --git a/router/router.go b/router/router.go index 3421186..b601470 100644 --- a/router/router.go +++ b/router/router.go @@ -42,64 +42,23 @@ type proxyRouter struct { } func (r *proxyRouter) Listen(httpAddr, dnsAddr, sshAddr string) { - l, err := net.Listen("tcp", httpAddr) - if err != nil { - log.Fatal(err) - } - r.httpListener = l - go func() { - for !r.closed { - conn, err := r.httpListener.Accept() - if err != nil { - continue - } - go r.handleConnection(conn) - } - }() - - dnsMux := dns.NewServeMux() - dnsMux.HandleFunc(".", r.dnsRequest) - r.udpDnsServer = &dns.Server{Addr: dnsAddr, Net: "udp", Handler: dnsMux} - r.tcpDnsServer = &dns.Server{Addr: dnsAddr, Net: "tcp", Handler: dnsMux} - - wg := sync.WaitGroup{} - wg.Add(2) - - r.udpDnsServer.NotifyStartedFunc = func() { - wg.Done() - } - r.tcpDnsServer.NotifyStartedFunc = func() { - wg.Done() - } - go r.udpDnsServer.ListenAndServe() - go r.tcpDnsServer.ListenAndServe() - wg.Wait() - - lssh, err := net.Listen("tcp", sshAddr) - if err != nil { - log.Fatal("failed to listen for connection: ", err) - } - r.sshListener = lssh - go func() { - for { - nConn, err := lssh.Accept() - if err != nil { - log.Fatal("failed to accept incoming connection: ", err) - } - - go r.sshHandle(nConn) - } - }() + r.listen(&sync.WaitGroup{}, httpAddr, dnsAddr, sshAddr) } + func (r *proxyRouter) ListenAndWait(httpAddr, dnsAddr, sshAddr string) { - listenWG := sync.WaitGroup{} + wg := sync.WaitGroup{} + r.listen(&wg, httpAddr, dnsAddr, sshAddr) + wg.Wait() +} + +func (r *proxyRouter) listen(wg *sync.WaitGroup, httpAddr, dnsAddr, sshAddr string) { l, err := net.Listen("tcp", httpAddr) if err != nil { log.Fatal(err) } r.httpListener = l - listenWG.Add(1) + wg.Add(1) go func() { for !r.closed { conn, err := r.httpListener.Accept() @@ -108,7 +67,7 @@ func (r *proxyRouter) ListenAndWait(httpAddr, dnsAddr, sshAddr string) { } go r.handleConnection(conn) } - listenWG.Done() + wg.Done() }() dnsMux := dns.NewServeMux() @@ -116,25 +75,25 @@ func (r *proxyRouter) ListenAndWait(httpAddr, dnsAddr, sshAddr string) { r.udpDnsServer = &dns.Server{Addr: dnsAddr, Net: "udp", Handler: dnsMux} r.tcpDnsServer = &dns.Server{Addr: dnsAddr, Net: "tcp", Handler: dnsMux} - wg := sync.WaitGroup{} - wg.Add(2) + wgStarted := sync.WaitGroup{} + wgStarted.Add(2) r.udpDnsServer.NotifyStartedFunc = func() { - wg.Done() + wgStarted.Done() } r.tcpDnsServer.NotifyStartedFunc = func() { - wg.Done() + wgStarted.Done() } go r.udpDnsServer.ListenAndServe() go r.tcpDnsServer.ListenAndServe() - wg.Wait() + wgStarted.Wait() lssh, err := net.Listen("tcp", sshAddr) if err != nil { log.Fatal("failed to listen for connection: ", err) } r.sshListener = lssh - listenWG.Add(1) + wg.Add(1) go func() { for { nConn, err := lssh.Accept() @@ -144,9 +103,8 @@ func (r *proxyRouter) ListenAndWait(httpAddr, dnsAddr, sshAddr string) { go r.sshHandle(nConn) } - listenWG.Done() + wg.Done() }() - listenWG.Wait() } func (r *proxyRouter) sshHandle(nConn net.Conn) {