diff --git a/cmd/ctrld/dns_proxy.go b/cmd/ctrld/dns_proxy.go index bc52332..3098cc1 100644 --- a/cmd/ctrld/dns_proxy.go +++ b/cmd/ctrld/dns_proxy.go @@ -12,6 +12,7 @@ import ( "time" "github.com/miekg/dns" + "golang.org/x/sync/errgroup" "github.com/Control-D-Inc/ctrld" "github.com/Control-D-Inc/ctrld/internal/dnscache" @@ -20,7 +21,7 @@ import ( const staleTTL = 60 * time.Second -func (p *prog) serveUDP(listenerNum string) error { +func (p *prog) serveDNS(listenerNum string) error { listenerConfig := p.cfg.Listener[listenerNum] // make sure ip is allocated if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil { @@ -55,27 +56,38 @@ func (p *prog) serveUDP(listenerNum string) error { } }) - // On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can - // listen on ::1, then spawn a listener for receiving DNS requests. - if runtime.GOOS == "windows" && ctrldnet.SupportsIPv6ListenLocal() { - go func() { + g := new(errgroup.Group) + for _, proto := range []string{"udp", "tcp"} { + proto := proto + // On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can + // listen on ::1, then spawn a listener for receiving DNS requests. + if runtime.GOOS == "windows" && ctrldnet.SupportsIPv6ListenLocal() { + g.Go(func() error { + s := &dns.Server{ + Addr: net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)), + Net: proto, + Handler: handler, + } + if err := s.ListenAndServe(); err != nil { + mainLog.Error().Err(err).Msg("could not serving on ::1") + } + return nil + }) + } + g.Go(func() error { s := &dns.Server{ - Addr: net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)), - Net: "udp", + Addr: net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)), + Net: proto, Handler: handler, } if err := s.ListenAndServe(); err != nil { - mainLog.Error().Err(err).Msg("could not serving on ::1") + mainLog.Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr) + return err } - }() + return nil + }) } - - s := &dns.Server{ - Addr: net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)), - Net: "udp", - Handler: handler, - } - return s.ListenAndServe() + return g.Wait() } func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) { diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index 6b58116..52b9c85 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -85,7 +85,7 @@ func (p *prog) run() { } addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)) mainLog.Info().Msgf("Starting DNS server on listener.%s: %s", listenerNum, addr) - err := p.serveUDP(listenerNum) + err := p.serveDNS(listenerNum) if err != nil && !defaultConfigWritten && cdUID == "" { mainLog.Fatal().Err(err).Msgf("Unable to start dns proxy on listener.%s", listenerNum) return @@ -109,7 +109,7 @@ func (p *prog) run() { p.cfg.Service.AllocateIP = true p.preRun() mainLog.Info().Msgf("Starting DNS server on listener.%s: %s", listenerNum, net.JoinHostPort(ip, strconv.Itoa(port))) - if err := p.serveUDP(listenerNum); err != nil { + if err := p.serveDNS(listenerNum); err != nil { mainLog.Fatal().Err(err).Msgf("Unable to start dns proxy on listener.%s", listenerNum) return }