diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 13d967a..4cc4f29 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -145,7 +145,7 @@ func (p *prog) serveDNS(listenerNum string) error { // processed later, because policy logging want to know whether a network rule // is disregarded in favor of the domain level rule. func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) { - upstreams := []string{"upstream." + defaultUpstreamNum} + upstreams := []string{upstreamPrefix + defaultUpstreamNum} matchedPolicy := "no policy" matchedNetwork := "no network" matchedRule := "no rule" @@ -229,7 +229,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams) if len(upstreamConfigs) == 0 { upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig} - upstreams = []string{"upstream.os"} + upstreams = []string{upstreamOS} } // Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4 if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR { @@ -273,6 +273,12 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i answer, err := resolve1(n, upstreamConfig, msg) if err != nil { ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query") + if errNetworkError(err) { + p.um.increaseFailureCount(upstreams[n]) + if p.um.isDown(upstreams[n]) { + go p.um.checkUpstream(upstreams[n], upstreamConfig) + } + } return nil } return answer @@ -281,6 +287,10 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i if upstreamConfig == nil { continue } + if p.um.isDown(upstreams[n]) { + ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n]) + continue + } answer := resolve(n, upstreamConfig, msg) if answer == nil { if serveStaleCache && staleAnswer != nil { @@ -312,7 +322,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i } return answer } - ctrld.Log(ctx, mainLog.Load().Error(), "all upstreams failed") + ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams) answer := new(dns.Msg) answer.SetRcode(msg, dns.RcodeServerFailure) return answer @@ -321,7 +331,7 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig { upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams)) for _, upstream := range upstreams { - upstreamNum := strings.TrimPrefix(upstream, "upstream.") + upstreamNum := strings.TrimPrefix(upstream, upstreamPrefix) upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum]) } return upstreamConfigs diff --git a/cmd/cli/prog.go b/cmd/cli/prog.go index 2ba24b3..47e2304 100644 --- a/cmd/cli/prog.go +++ b/cmd/cli/prog.go @@ -28,6 +28,8 @@ const ( defaultSemaphoreCap = 256 ctrldLogUnixSock = "ctrld_start.sock" ctrldControlUnixSock = "ctrld_control.sock" + upstreamPrefix = "upstream." + upstreamOS = upstreamPrefix + "os" ) var logf = func(format string, args ...any) { @@ -54,6 +56,7 @@ type prog struct { cache dnscache.Cacher sema semaphore ciTable *clientinfo.Table + um *upstreamMonitor router router.Router started chan struct{} @@ -118,6 +121,8 @@ func (p *prog) run() { nc.IPNets = append(nc.IPNets, ipNet) } } + + p.um = newUpstreamMonitor(p.cfg) for n := range p.cfg.Upstream { uc := p.cfg.Upstream[n] uc.Init() @@ -351,20 +356,25 @@ var ( func errUrlNetworkError(err error) bool { var urlErr *url.Error if errors.As(err, &urlErr) { - var opErr *net.OpError - if errors.As(urlErr.Err, &opErr) { - if opErr.Temporary() { - return true - } - switch { - case errors.Is(opErr.Err, syscall.ECONNREFUSED), - errors.Is(opErr.Err, syscall.EINVAL), - errors.Is(opErr.Err, syscall.ENETUNREACH), - errors.Is(opErr.Err, windowsENETUNREACH), - errors.Is(opErr.Err, windowsEINVAL), - errors.Is(opErr.Err, windowsECONNREFUSED): - return true - } + return errNetworkError(urlErr.Err) + } + return false +} + +func errNetworkError(err error) bool { + var opErr *net.OpError + if errors.As(err, &opErr) { + if opErr.Temporary() { + return true + } + switch { + case errors.Is(opErr.Err, syscall.ECONNREFUSED), + errors.Is(opErr.Err, syscall.EINVAL), + errors.Is(opErr.Err, syscall.ENETUNREACH), + errors.Is(opErr.Err, windowsENETUNREACH), + errors.Is(opErr.Err, windowsEINVAL), + errors.Is(opErr.Err, windowsECONNREFUSED): + return true } } return false diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go new file mode 100644 index 0000000..4b3ee69 --- /dev/null +++ b/cmd/cli/upstream_monitor.go @@ -0,0 +1,98 @@ +package cli + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + "tailscale.com/logtail/backoff" + + "github.com/Control-D-Inc/ctrld" +) + +const ( + // maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down. + maxFailureRequest = 100 + // checkUpstreamMaxBackoff is the max backoff time when checking upstream status. + checkUpstreamMaxBackoff = 2 * time.Minute +) + +// upstreamMonitor performs monitoring upstreams health. +type upstreamMonitor struct { + cfg *ctrld.Config + + down map[string]*atomic.Bool + failureReq map[string]*atomic.Uint64 + + mu sync.Mutex + checking map[string]bool +} + +func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { + um := &upstreamMonitor{ + cfg: cfg, + down: make(map[string]*atomic.Bool), + failureReq: make(map[string]*atomic.Uint64), + checking: make(map[string]bool), + } + for n := range cfg.Upstream { + upstream := upstreamPrefix + n + um.down[upstream] = new(atomic.Bool) + um.failureReq[upstream] = new(atomic.Uint64) + } + um.down[upstreamOS] = new(atomic.Bool) + um.failureReq[upstreamOS] = new(atomic.Uint64) + return um +} + +// increaseFailureCount increase failed queries count for an upstream by 1. +func (um *upstreamMonitor) increaseFailureCount(upstream string) { + failedCount := um.failureReq[upstream].Add(1) + um.down[upstream].Store(failedCount >= maxFailureRequest) +} + +// isDown reports whether the given upstream is being marked as down. +func (um *upstreamMonitor) isDown(upstream string) bool { + return um.down[upstream].Load() +} + +// reset marks an upstream as up and set failed queries counter to zero. +func (um *upstreamMonitor) reset(upstream string) { + um.failureReq[upstream].Store(0) + um.down[upstream].Store(false) +} + +// checkUpstream checks the given upstream status, periodically sending query to upstream +// until successfully. An upstream status/counter will be reset once it becomes reachable. +func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) { + um.mu.Lock() + isChecking := um.checking[upstream] + if isChecking { + um.mu.Unlock() + return + } + um.checking[upstream] = true + um.mu.Unlock() + + bo := backoff.NewBackoff("checkUpstream", logf, checkUpstreamMaxBackoff) + resolver, err := ctrld.NewResolver(uc) + if err != nil { + mainLog.Load().Warn().Err(err).Msg("could not check upstream") + return + } + msg := new(dns.Msg) + msg.SetQuestion(".", dns.TypeNS) + ctx := context.Background() + + for { + _, err := resolver.Resolve(ctx, msg) + if err == nil { + mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint) + um.reset(upstream) + return + } + bo.BackOff(ctx, err) + } +} diff --git a/doh.go b/doh.go index 5886881..d0525d4 100644 --- a/doh.go +++ b/doh.go @@ -97,8 +97,10 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { req.Header.Set("Content-Type", headerApplicationDNS) req.Header.Set("Accept", headerApplicationDNS) + printed := false if sendClientInfo { if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil { + printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != "" if ci.Mac != "" { req.Header.Set(dohMacHeader, ci.Mac) } @@ -110,5 +112,7 @@ func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { } } } - Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header") + if printed { + Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header") + } }