From 511c4e696ffb384dbbc2a2b7108cd7645145586f Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Thu, 21 Sep 2023 06:06:09 +0000 Subject: [PATCH] cmd/cli: add upstream monitor Some users mentioned that when there is an Internet outage, ctrld fails to recover, crashing or locks up the router. When requests start failing, this results in the clients emitting more queries, creating a resource spiral of death that can brick the device entirely. To guard against this case, this commit implement an upstream monitor approach: - Marking upstream as down after 100 consecutive failed queries. - Start a goroutine to check when the upstream is back again. - When upstream is down, answer all queries with SERVFAIL. - The checking process uses backoff retry to reduce high requests rate. - As long as the query succeeded, marking the upstream as alive then start operate normally. --- cmd/cli/dns_proxy.go | 18 +++++-- cmd/cli/prog.go | 38 ++++++++------ cmd/cli/upstream_monitor.go | 98 +++++++++++++++++++++++++++++++++++++ doh.go | 6 ++- 4 files changed, 141 insertions(+), 19 deletions(-) create mode 100644 cmd/cli/upstream_monitor.go diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 13d967a..4cc4f29 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -145,7 +145,7 @@ func (p *prog) serveDNS(listenerNum string) error { // processed later, because policy logging want to know whether a network rule // is disregarded in favor of the domain level rule. func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) { - upstreams := []string{"upstream." + defaultUpstreamNum} + upstreams := []string{upstreamPrefix + defaultUpstreamNum} matchedPolicy := "no policy" matchedNetwork := "no network" matchedRule := "no rule" @@ -229,7 +229,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) if len(upstreamConfigs) == 0 { upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} - upstreams = []string{"upstream.os"} + upstreams = []string{upstreamOS} } // Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4 if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { @@ -273,6 +273,12 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i answer, err := resolve1(n, upstreamConfig, msg) if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query") + if errNetworkError(err) { + p.um.increaseFailureCount(upstreams[n]) + if p.um.isDown(upstreams[n]) { + go p.um.checkUpstream(upstreams[n], upstreamConfig) + } + } return nil } return answer @@ -281,6 +287,10 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i if upstreamConfig == nil { continue } + if p.um.isDown(upstreams[n]) { + ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n]) + continue + } answer := resolve(n, upstreamConfig, msg) if answer == nil { if serveStaleCache && staleAnswer != nil { @@ -312,7 +322,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i } return answer } - ctrld.Log(ctx, mainLog.Load().Error(), "all upstreams failed") + ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) answer := new(dns.Msg) answer.SetRcode(msg, dns.RcodeServerFailure) return answer @@ -321,7 +331,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig { upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams)) for _, upstream := range upstreams { - upstreamNum := strings.TrimPrefix(upstream, "upstream.") + upstreamNum := strings.TrimPrefix(upstream, upstreamPrefix) upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum]) } return upstreamConfigs diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 2ba24b3..47e2304 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -28,6 +28,8 @@ const ( defaultSemaphoreCap = 256 ctrldLogUnixSock = "ctrld_start.sock" ctrldControlUnixSock = "ctrld_control.sock" + upstreamPrefix = "upstream." + upstreamOS = upstreamPrefix + "os" ) var logf = func(format string, args ...any) { @@ -54,6 +56,7 @@ type prog struct { cache dnscache.Cacher sema semaphore ciTable *clientinfo.Table + um *upstreamMonitor router router.Router started chan struct{} @@ -118,6 +121,8 @@ func (p *prog) run() { nc.IPNets = append(nc.IPNets, ipNet) } } + + p.um = newUpstreamMonitor(p.cfg) for n := range p.cfg.Upstream { uc := p.cfg.Upstream[n] uc.Init() @@ -351,20 +356,25 @@ var ( func errUrlNetworkError(err error) bool { var urlErr *url.Error if errors.As(err, &urlErr) { - var opErr *net.OpError - if errors.As(urlErr.Err, &opErr) { - if opErr.Temporary() { - return true - } - switch { - case errors.Is(opErr.Err, syscall.ECONNREFUSED), - errors.Is(opErr.Err, syscall.EINVAL), - errors.Is(opErr.Err, syscall.ENETUNREACH), - errors.Is(opErr.Err, windowsENETUNREACH), - errors.Is(opErr.Err, windowsEINVAL), - errors.Is(opErr.Err, windowsECONNREFUSED): - return true - } + return errNetworkError(urlErr.Err) + } + return false +} + +func errNetworkError(err error) bool { + var opErr *net.OpError + if errors.As(err, &opErr) { + if opErr.Temporary() { + return true + } + switch { + case errors.Is(opErr.Err, syscall.ECONNREFUSED), + errors.Is(opErr.Err, syscall.EINVAL), + errors.Is(opErr.Err, syscall.ENETUNREACH), + errors.Is(opErr.Err, windowsENETUNREACH), + errors.Is(opErr.Err, windowsEINVAL), + errors.Is(opErr.Err, windowsECONNREFUSED): + return true } } return false diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go new file mode 100644 index 0000000..4b3ee69 --- /dev/null +++ b/cmd/cli/upstream_monitor.go @@ -0,0 +1,98 @@ +package cli + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + "tailscale.com/logtail/backoff" + + "github.com/Control-D-Inc/ctrld" +) + +const ( + // maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down. + maxFailureRequest = 100 + // checkUpstreamMaxBackoff is the max backoff time when checking upstream status. + checkUpstreamMaxBackoff = 2 * time.Minute +) + +// upstreamMonitor performs monitoring upstreams health. +type upstreamMonitor struct { + cfg *ctrld.Config + + down map[string]*atomic.Bool + failureReq map[string]*atomic.Uint64 + + mu sync.Mutex + checking map[string]bool +} + +func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { + um := &upstreamMonitor{ + cfg: cfg, + down: make(map[string]*atomic.Bool), + failureReq: make(map[string]*atomic.Uint64), + checking: make(map[string]bool), + } + for n := range cfg.Upstream { + upstream := upstreamPrefix + n + um.down[upstream] = new(atomic.Bool) + um.failureReq[upstream] = new(atomic.Uint64) + } + um.down[upstreamOS] = new(atomic.Bool) + um.failureReq[upstreamOS] = new(atomic.Uint64) + return um +} + +// increaseFailureCount increase failed queries count for an upstream by 1. +func (um *upstreamMonitor) increaseFailureCount(upstream string) { + failedCount := um.failureReq[upstream].Add(1) + um.down[upstream].Store(failedCount >= maxFailureRequest) +} + +// isDown reports whether the given upstream is being marked as down. +func (um *upstreamMonitor) isDown(upstream string) bool { + return um.down[upstream].Load() +} + +// reset marks an upstream as up and set failed queries counter to zero. +func (um *upstreamMonitor) reset(upstream string) { + um.failureReq[upstream].Store(0) + um.down[upstream].Store(false) +} + +// checkUpstream checks the given upstream status, periodically sending query to upstream +// until successfully. An upstream status/counter will be reset once it becomes reachable. +func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { + um.mu.Lock() + isChecking := um.checking[upstream] + if isChecking { + um.mu.Unlock() + return + } + um.checking[upstream] = true + um.mu.Unlock() + + bo := backoff.NewBackoff("checkUpstream", logf, checkUpstreamMaxBackoff) + resolver, err := ctrld.NewResolver(uc) + if err != nil { + mainLog.Load().Warn().Err(err).Msg("could not check upstream") + return + } + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + ctx := context.Background() + + for { + _, err := resolver.Resolve(ctx, msg) + if err == nil { + mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint) + um.reset(upstream) + return + } + bo.BackOff(ctx, err) + } +} diff --git a/doh.go b/doh.go index 5886881..d0525d4 100644 --- a/doh.go +++ b/doh.go @@ -97,8 +97,10 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { req.Header.Set("Content-Type", headerApplicationDNS) req.Header.Set("Accept", headerApplicationDNS) + printed := false if sendClientInfo { if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil { + printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != "" if ci.Mac != "" { req.Header.Set(dohMacHeader, ci.Mac) } @@ -110,5 +112,7 @@ func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { } } } - Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header") + if printed { + Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header") + } }