diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index 1d07863..0e2b642 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -165,9 +165,13 @@ func initCLI() { initLogging() if setupRouter { - if err := router.PreStart(); err != nil { + s, _ := runDNSServerForNTPD() + if err := router.PreRun(); err != nil { mainLog.Fatal().Err(err).Msg("failed to perform router pre-start check") } + if err := s.Shutdown(); err != nil { + mainLog.Fatal().Err(err).Msg("failed to shutdown dns server for ntpd") + } } processCDFlags() @@ -909,7 +913,7 @@ func unsupportedPlatformHelp(cmd *cobra.Command) { func userHomeDir() (string, error) { switch router.Name() { - case router.DDWrt, router.Merlin: + case router.DDWrt, router.Merlin, router.Tomato: exe, err := os.Executable() if err != nil { return "", err diff --git a/cmd/ctrld/dns_proxy.go b/cmd/ctrld/dns_proxy.go index a7602f0..5b2e34b 100644 --- a/cmd/ctrld/dns_proxy.go +++ b/cmd/ctrld/dns_proxy.go @@ -459,3 +459,45 @@ func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-cha waitLock.Lock() return s, errCh } + +func runDNSServerForNTPD() (*dns.Server, <-chan error) { + dnsResolver := ctrld.NewBootstrapResolver() + s := &dns.Server{ + Addr: router.ListenAddress(), + Net: "udp", + Handler: dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { + mainLog.Debug().Msg("Serving query for ntpd") + resolveCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + if osUpstreamConfig.Timeout > 0 { + timeoutCtx, cancel := context.WithTimeout(resolveCtx, time.Millisecond*time.Duration(osUpstreamConfig.Timeout)) + defer cancel() + resolveCtx = timeoutCtx + } + answer, err := dnsResolver.Resolve(resolveCtx, m) + if err != nil { + mainLog.Error().Err(err).Msgf("could not resolve: %v", m) + return + } + if err := w.WriteMsg(answer); err != nil { + mainLog.Error().Err(err).Msg("runDNSServerForNTPD: failed to send DNS response") + } + }), + } + + waitLock := sync.Mutex{} + waitLock.Lock() + s.NotifyStartedFunc = waitLock.Unlock + + errCh := make(chan error) + go func() { + defer close(errCh) + if err := s.ListenAndServe(); err != nil { + waitLock.Unlock() + mainLog.Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr) + errCh <- err + } + }() + waitLock.Lock() + return s, errCh +} diff --git a/internal/router/merlin.go b/internal/router/merlin.go index aab05e7..8e20d68 100644 --- a/internal/router/merlin.go +++ b/internal/router/merlin.go @@ -2,16 +2,11 @@ package router import ( "bytes" - "context" - "errors" "fmt" "os" "os/exec" "strings" - "time" "unicode" - - "tailscale.com/logtail/backoff" ) func setupMerlin() error { @@ -92,43 +87,3 @@ func merlinParsePostConf(buf []byte) []byte { } return buf } - -func merlinPreStart() (err error) { - pidFile := "/tmp/ctrld.pid" - - // Remove pid file and trigger dnsmasq restart, so NTP can resolve - // server name and perform time synchronization. - pid, err := os.ReadFile(pidFile) - if err != nil { - return fmt.Errorf("PreStart: os.Readfile: %w", err) - } - if err := os.Remove(pidFile); err != nil { - return fmt.Errorf("PreStart: os.Remove: %w", err) - } - defer func() { - if werr := os.WriteFile(pidFile, pid, 0600); werr != nil { - err = errors.Join(err, werr) - return - } - if rerr := restartDNSMasq(); rerr != nil { - err = errors.Join(err, rerr) - return - } - }() - if err := restartDNSMasq(); err != nil { - return fmt.Errorf("PreStart: restartDNSMasqFn: %w", err) - } - - // Wait until `ntp_ready=1` set. - b := backoff.NewBackoff("PreStart", func(format string, args ...any) {}, 10*time.Second) - for { - out, err := nvram("get", "ntp_ready") - if err != nil { - return fmt.Errorf("PreStart: nvram: %w", err) - } - if out == "1" { - return nil - } - b.BackOff(context.Background(), errors.New("ntp not ready")) - } -} diff --git a/internal/router/router.go b/internal/router/router.go index 81246e4..2a36253 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -2,14 +2,18 @@ package router import ( "bytes" + "context" "errors" + "fmt" "os" "os/exec" "sync" "sync/atomic" + "time" "github.com/fsnotify/fsnotify" "github.com/kardianos/service" + "tailscale.com/logtail/backoff" "github.com/Control-D-Inc/ctrld" ) @@ -106,14 +110,23 @@ func ConfigureService(sc *service.Config) error { return nil } -// PreStart blocks until the router is ready for running ctrld. -func PreStart() (err error) { +// PreRun blocks until the router is ready for running ctrld. +func PreRun() (err error) { // On some routers, NTP may out of sync, so waiting for it to be ready. switch Name() { - case Merlin: - return merlinPreStart() - case Tomato: - return tomatoPreStart() + case Merlin, Tomato: + // Wait until `ntp_ready=1` set. + b := backoff.NewBackoff("PreStart", func(format string, args ...any) {}, 10*time.Second) + for { + out, err := nvram("get", "ntp_ready") + if err != nil { + return fmt.Errorf("PreStart: nvram: %w", err) + } + if out == "1" { + return nil + } + b.BackOff(context.Background(), errors.New("ntp not ready")) + } default: return nil } diff --git a/internal/router/tomato.go b/internal/router/tomato.go index 50f13ba..945e992 100644 --- a/internal/router/tomato.go +++ b/internal/router/tomato.go @@ -1,13 +1,8 @@ package router import ( - "context" - "errors" "fmt" "os/exec" - "time" - - "tailscale.com/logtail/backoff" ) const ( @@ -72,27 +67,6 @@ func cleanupTomato() error { return nil } -func tomatoPreStart() (err error) { - // cleanup to trigger dnsmasq restart, so NTP can resolve - // server name and perform time synchronization. - if err = cleanupTomato(); err != nil { - return err - } - - // Wait until `ntp_ready=1` set. - b := backoff.NewBackoff("PreStart", func(format string, args ...any) {}, 10*time.Second) - for { - out, err := nvram("get", "ntp_ready") - if err != nil { - return fmt.Errorf("PreStart: nvram: %w", err) - } - if out == "1" { - return nil - } - b.BackOff(context.Background(), errors.New("ntp not ready")) - } -} - func tomatoRestartService(name string) error { return tomatoRestartServiceWithKill(name, false) } diff --git a/resolver.go b/resolver.go index 391a4e8..0180762 100644 --- a/resolver.go +++ b/resolver.go @@ -201,3 +201,17 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) } return ips } + +// NewBootstrapResolver returns an OS resolver, which use following nameservers: +// +// - ControlD bootstrap DNS server. +// - Gateway IP address (depends on OS). +// - Input servers. +func NewBootstrapResolver(servers ...string) Resolver { + resolver := &osResolver{nameservers: nameservers()} + resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...) + for _, ns := range servers { + resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...) + } + return resolver +}