diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 91ae1fa..0aa22df 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -50,9 +50,10 @@ var ( ) var ( - v = viper.NewWithOptions(viper.KeyDelimiter("::")) - defaultConfigFile = "ctrld.toml" - rootCertPool *x509.CertPool + v = viper.NewWithOptions(viper.KeyDelimiter("::")) + defaultConfigFile = "ctrld.toml" + rootCertPool *x509.CertPool + errSelfCheckNoAnswer = errors.New("no answer from ctrld listener") ) var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"} @@ -280,17 +281,40 @@ func initCLI() { return } - status := selfCheckStatus(s) - switch status { - case service.StatusRunning: + ok, status, err := selfCheckStatus(s) + switch { + case ok && status == service.StatusRunning: mainLog.Load().Notice().Msg("Service started") default: marker := bytes.Repeat([]byte("="), 32) - mainLog.Load().Error().Msg("ctrld service may not have started due to an error or misconfiguration, service log:") - _, _ = mainLog.Load().Write(marker) - for msg := range runCmdLogCh { - _, _ = mainLog.Load().Write([]byte(msg)) + // If ctrld service is not running, emitting log obtained from ctrld process. + if status != service.StatusRunning { + mainLog.Load().Error().Msg("ctrld service may not have started due to an error or misconfiguration, service log:") + _, _ = mainLog.Load().Write(marker) + haveLog := false + for msg := range runCmdLogCh { + _, _ = mainLog.Load().Write([]byte(msg)) + haveLog = true + } + // If we're unable to get log from "ctrld run", notice users about it. + if !haveLog { + mainLog.Load().Write([]byte(`"`)) + } } + // Report any error if occurred. + if err != nil { + _, _ = mainLog.Load().Write(marker) + msg := fmt.Sprintf("An error happened when performing test query: %s", err) + mainLog.Load().Write([]byte(msg)) + } + // If ctrld service is running but selfCheckStatus failed, it could be related + // to user's system firewall configuration, notice users about it. + if status == service.StatusRunning { + _, _ = mainLog.Load().Write(marker) + mainLog.Load().Write([]byte(`ctrld service was running, but somehow DNS query could not be sent to its listener`)) + mainLog.Load().Write([]byte(`Please check your system firewall if it is configured to block/intercept/redirect DNS queries`)) + } + _, _ = mainLog.Load().Write(marker) uninstall(p, s) os.Exit(1) @@ -1346,41 +1370,44 @@ func defaultIfaceName() string { return dri } -func selfCheckStatus(s service.Service) service.Status { +// selfCheckStatus performs the end-to-end DNS test by sending query to ctrld listener. +// It returns a boolean to indicate whether the check is succeeded, the actual status +// of ctrld service, and an additional error if any. +func selfCheckStatus(s service.Service) (bool, service.Status, error) { status, err := s.Status() if err != nil { mainLog.Load().Warn().Err(err).Msg("could not get service status") - return status + return false, service.StatusUnknown, err } // If ctrld is not running, do nothing, just return the status as-is. if status != service.StatusRunning { - return status + return false, status, nil } dir, err := socketDir() if err != nil { mainLog.Load().Error().Err(err).Msg("failed to check ctrld listener status: could not get home directory") - return service.StatusUnknown + return false, status, err } mainLog.Load().Debug().Msg("waiting for ctrld listener to be ready") cc := newSocketControlClient(s, dir) if cc == nil { - return service.StatusUnknown + return false, status, errors.New("could not connect to control server") } resp, err := cc.post(startedPath, nil) if err != nil { mainLog.Load().Error().Err(err).Msg("failed to connect to control server") - return service.StatusUnknown + return false, status, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { mainLog.Load().Error().Msg("ctrld listener is not ready") - return service.StatusUnknown + return false, status, errors.New("ctrld listener is not ready") } // Not a ctrld upstream, return status as-is. if cfg.FirstUpstream().VerifyDomain() == "" { - return status + return true, status, nil } mainLog.Load().Debug().Msg("ctrld listener is ready") @@ -1405,12 +1432,12 @@ func selfCheckStatus(s service.Service) service.Status { domain := cfg.FirstUpstream().VerifyDomain() if domain == "" { // Nothing to do, return the status as-is. - return status + return true, status, nil } watcher, err := fsnotify.NewWatcher() if err != nil { mainLog.Load().Error().Err(err).Msg("could not watch config change") - return service.StatusUnknown + return false, status, err } defer watcher.Close() @@ -1449,14 +1476,18 @@ func selfCheckStatus(s service.Service) service.Status { m := new(dns.Msg) m.SetQuestion(domain+".", dns.TypeA) m.RecursionDesired = true - r, _, err := c.ExchangeContext(ctx, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))) + r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))) if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 { mainLog.Load().Debug().Msgf("self-check against %q succeeded", domain) - return status + return true, status, nil + } + // Return early if this is a connection refused. + if errConnectionRefused(exErr) { + return false, status, exErr } lastAnswer = r - lastErr = err - bo.BackOff(ctx, fmt.Errorf("ExchangeContext: %w", err)) + lastErr = exErr + bo.BackOff(ctx, fmt.Errorf("ExchangeContext: %w", exErr)) } mainLog.Load().Debug().Msgf("self-check against %q failed", domain) lc := cfg.FirstListener() @@ -1471,9 +1502,9 @@ func selfCheckStatus(s service.Service) service.Status { for _, s := range strings.Split(lastAnswer.String(), "\n") { mainLog.Load().Debug().Msgf("%s", s) } - mainLog.Load().Debug().Msg(marker) + return false, status, errSelfCheckNoAnswer } - return service.StatusUnknown + return false, status, lastErr } func userHomeDir() (string, error) { @@ -2130,3 +2161,10 @@ func ensureUninstall(s service.Service) error { } return errors.Join(err, errors.New("uninstall failed")) } + +// exchangeContextWithTimeout wraps c.ExchangeContext with the given timeout. +func exchangeContextWithTimeout(c *dns.Client, timeout time.Duration, msg *dns.Msg, addr string) (*dns.Msg, time.Duration, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return c.ExchangeContext(ctx, msg, addr) +} diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 8d0cf3e..92eadc8 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -586,6 +586,15 @@ func errNetworkError(err error) bool { return false } +// errConnectionRefused reports whether err is connection refused. +func errConnectionRefused(err error) bool { + var opErr *net.OpError + if !errors.As(err, &opErr) { + return false + } + return errors.Is(opErr.Err, syscall.ECONNREFUSED) || errors.Is(opErr.Err, windowsECONNREFUSED) +} + func ifaceFirstPrivateIP(iface *net.Interface) string { if iface == nil { return ""