diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 5c7795f..b54a631 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1647,11 +1647,6 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) { mainLog.Load().Debug().Msg("ctrld listener is ready") mainLog.Load().Debug().Msg("performing self-check") - bo := backoff.NewBackoff("self-check", logf, 10*time.Second) - bo.LogLongerThan = 500 * time.Millisecond - ctx := context.Background() - maxAttempts := 20 - c := new(dns.Client) var ( lcChanged map[string]*ctrld.ListenerConfig ucChanged map[string]*ctrld.UpstreamConfig @@ -1669,12 +1664,6 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) { // Nothing to do, return the status as-is. return true, status, nil } - watcher, err := fsnotify.NewWatcher() - if err != nil { - mainLog.Load().Error().Err(err).Msg("could not watch config change") - return false, status, err - } - defer watcher.Close() v.OnConfigChange(func(in fsnotify.Event) { mu.Lock() @@ -1683,52 +1672,56 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) { mainLog.Load().Error().Msgf("failed to unmarshal listener config: %v", err) return } + cfg.Listener = lcChanged if err := v.UnmarshalKey("upstream", &ucChanged); err != nil { mainLog.Load().Error().Msgf("failed to unmarshal upstream config: %v", err) return } + cfg.Upstream = ucChanged }) v.WatchConfig() + + lc := cfg.FirstListener() + addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)) + getInternalDomainFn := func() string { return selfCheckInternalTestDomain } + getExternalDomainFn := func() string { return cfg.FirstUpstream().VerifyDomain() } + if err := selfCheckResolveDomain(context.TODO(), addr, "internal", getInternalDomainFn); err != nil { + return false, status, err + } + if err := selfCheckResolveDomain(context.TODO(), addr, "external", getExternalDomainFn); err != nil { + return false, status, err + } + return true, status, nil +} + +// selfCheckResolveDomain performs DNS test query against ctrld listener. +func selfCheckResolveDomain(ctx context.Context, addr, scope string, getDomainFn func() string) error { + bo := backoff.NewBackoff("self-check", logf, 10*time.Second) + bo.LogLongerThan = 500 * time.Millisecond + maxAttempts := 20 + c := new(dns.Client) + var ( - lastAnswer *dns.Msg - lastErr error - internalTested bool + lastAnswer *dns.Msg + lastErr error ) + domain := "" for i := 0; i < maxAttempts; i++ { - mu.Lock() - if lcChanged != nil { - cfg.Listener = lcChanged - } - if ucChanged != nil { - cfg.Upstream = ucChanged - } - mu.Unlock() - lc := cfg.FirstListener() - domain = cfg.FirstUpstream().VerifyDomain() - if !internalTested { - domain = selfCheckInternalTestDomain - } + domain = getDomainFn() if domain == "" { continue } - m := new(dns.Msg) m.SetQuestion(domain+".", dns.TypeA) m.RecursionDesired = true - r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))) + r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, addr) if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 { - internalTested = domain == selfCheckInternalTestDomain - if internalTested { - mainLog.Load().Debug().Msgf("internal self-check against %q succeeded", domain) - continue // internal domain test ok, continue with external test. - } else { - mainLog.Load().Debug().Msgf("external self-check against %q succeeded", domain) - } - return true, status, nil + mainLog.Load().Debug().Msgf("%s self-check against %q succeeded", scope, domain) + return nil } // Return early if this is a connection refused. if errConnectionRefused(exErr) { - return false, status, exErr + return exErr } lastAnswer = r lastErr = exErr @@ -1741,8 +1734,6 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) { mainLog.Load().Err(err).Msgf("failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint) } } - lc := cfg.FirstListener() - addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)) marker := strings.Repeat("=", 32) mainLog.Load().Debug().Msg(marker) mainLog.Load().Debug().Msgf("listener address : %s", addr) @@ -1753,9 +1744,8 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) { for _, s := range strings.Split(lastAnswer.String(), "\n") { mainLog.Load().Debug().Msgf("%s", s) } - return false, status, errSelfCheckNoAnswer } - return false, status, lastErr + return errSelfCheckNoAnswer } func userHomeDir() (string, error) {