cmd/cli: refactoring self-check process

Make the code cleaner and easier to maintain.
This commit is contained in:
Cuong Manh Le
2024-04-12 18:12:16 +07:00
committed by Cuong Manh Le
parent 429a98b690
commit 1dee4305bc
+32 -42
View File
@@ -1647,11 +1647,6 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) {
mainLog.Load().Debug().Msg("ctrld listener is ready") mainLog.Load().Debug().Msg("ctrld listener is ready")
mainLog.Load().Debug().Msg("performing self-check") mainLog.Load().Debug().Msg("performing self-check")
bo := backoff.NewBackoff("self-check", logf, 10*time.Second)
bo.LogLongerThan = 500 * time.Millisecond
ctx := context.Background()
maxAttempts := 20
c := new(dns.Client)
var ( var (
lcChanged map[string]*ctrld.ListenerConfig lcChanged map[string]*ctrld.ListenerConfig
ucChanged map[string]*ctrld.UpstreamConfig ucChanged map[string]*ctrld.UpstreamConfig
@@ -1669,12 +1664,6 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) {
// Nothing to do, return the status as-is. // Nothing to do, return the status as-is.
return true, status, nil return true, status, nil
} }
watcher, err := fsnotify.NewWatcher()
if err != nil {
mainLog.Load().Error().Err(err).Msg("could not watch config change")
return false, status, err
}
defer watcher.Close()
v.OnConfigChange(func(in fsnotify.Event) { v.OnConfigChange(func(in fsnotify.Event) {
mu.Lock() mu.Lock()
@@ -1683,52 +1672,56 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) {
mainLog.Load().Error().Msgf("failed to unmarshal listener config: %v", err) mainLog.Load().Error().Msgf("failed to unmarshal listener config: %v", err)
return return
} }
cfg.Listener = lcChanged
if err := v.UnmarshalKey("upstream", &ucChanged); err != nil { if err := v.UnmarshalKey("upstream", &ucChanged); err != nil {
mainLog.Load().Error().Msgf("failed to unmarshal upstream config: %v", err) mainLog.Load().Error().Msgf("failed to unmarshal upstream config: %v", err)
return return
} }
cfg.Upstream = ucChanged
}) })
v.WatchConfig() v.WatchConfig()
lc := cfg.FirstListener()
addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
getInternalDomainFn := func() string { return selfCheckInternalTestDomain }
getExternalDomainFn := func() string { return cfg.FirstUpstream().VerifyDomain() }
if err := selfCheckResolveDomain(context.TODO(), addr, "internal", getInternalDomainFn); err != nil {
return false, status, err
}
if err := selfCheckResolveDomain(context.TODO(), addr, "external", getExternalDomainFn); err != nil {
return false, status, err
}
return true, status, nil
}
// selfCheckResolveDomain performs DNS test query against ctrld listener.
func selfCheckResolveDomain(ctx context.Context, addr, scope string, getDomainFn func() string) error {
bo := backoff.NewBackoff("self-check", logf, 10*time.Second)
bo.LogLongerThan = 500 * time.Millisecond
maxAttempts := 20
c := new(dns.Client)
var ( var (
lastAnswer *dns.Msg lastAnswer *dns.Msg
lastErr error lastErr error
internalTested bool
) )
domain := ""
for i := 0; i < maxAttempts; i++ { for i := 0; i < maxAttempts; i++ {
mu.Lock() domain = getDomainFn()
if lcChanged != nil {
cfg.Listener = lcChanged
}
if ucChanged != nil {
cfg.Upstream = ucChanged
}
mu.Unlock()
lc := cfg.FirstListener()
domain = cfg.FirstUpstream().VerifyDomain()
if !internalTested {
domain = selfCheckInternalTestDomain
}
if domain == "" { if domain == "" {
continue continue
} }
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion(domain+".", dns.TypeA) m.SetQuestion(domain+".", dns.TypeA)
m.RecursionDesired = true m.RecursionDesired = true
r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))) r, _, exErr := exchangeContextWithTimeout(c, time.Second, m, addr)
if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 { if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 {
internalTested = domain == selfCheckInternalTestDomain mainLog.Load().Debug().Msgf("%s self-check against %q succeeded", scope, domain)
if internalTested { return nil
mainLog.Load().Debug().Msgf("internal self-check against %q succeeded", domain)
continue // internal domain test ok, continue with external test.
} else {
mainLog.Load().Debug().Msgf("external self-check against %q succeeded", domain)
}
return true, status, nil
} }
// Return early if this is a connection refused. // Return early if this is a connection refused.
if errConnectionRefused(exErr) { if errConnectionRefused(exErr) {
return false, status, exErr return exErr
} }
lastAnswer = r lastAnswer = r
lastErr = exErr lastErr = exErr
@@ -1741,8 +1734,6 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) {
mainLog.Load().Err(err).Msgf("failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint) mainLog.Load().Err(err).Msgf("failed to connect to upstream.%s, endpoint: %s", name, uc.Endpoint)
} }
} }
lc := cfg.FirstListener()
addr := net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
marker := strings.Repeat("=", 32) marker := strings.Repeat("=", 32)
mainLog.Load().Debug().Msg(marker) mainLog.Load().Debug().Msg(marker)
mainLog.Load().Debug().Msgf("listener address : %s", addr) mainLog.Load().Debug().Msgf("listener address : %s", addr)
@@ -1753,9 +1744,8 @@ func selfCheckStatus(s service.Service) (bool, service.Status, error) {
for _, s := range strings.Split(lastAnswer.String(), "\n") { for _, s := range strings.Split(lastAnswer.String(), "\n") {
mainLog.Load().Debug().Msgf("%s", s) mainLog.Load().Debug().Msgf("%s", s)
} }
return false, status, errSelfCheckNoAnswer
} }
return false, status, lastErr return errSelfCheckNoAnswer
} }
func userHomeDir() (string, error) { func userHomeDir() (string, error) {