diff --git a/cmd/cli/upstream_monitor.go b/cmd/cli/upstream_monitor.go index 83087a4..67ae13d 100644 --- a/cmd/cli/upstream_monitor.go +++ b/cmd/cli/upstream_monitor.go @@ -3,7 +3,6 @@ package cli import ( "context" "sync" - "sync/atomic" "time" "github.com/miekg/dns" @@ -22,45 +21,52 @@ const ( type upstreamMonitor struct { cfg *ctrld.Config - down map[string]*atomic.Bool - failureReq map[string]*atomic.Uint64 - - mu sync.Mutex - checking map[string]bool + mu sync.Mutex + checking map[string]bool + down map[string]bool + failureReq map[string]uint64 } func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor { um := &upstreamMonitor{ cfg: cfg, - down: make(map[string]*atomic.Bool), - failureReq: make(map[string]*atomic.Uint64), checking: make(map[string]bool), + down: make(map[string]bool), + failureReq: make(map[string]uint64), } for n := range cfg.Upstream { upstream := upstreamPrefix + n - um.down[upstream] = new(atomic.Bool) - um.failureReq[upstream] = new(atomic.Uint64) + um.reset(upstream) } - um.down[upstreamOS] = new(atomic.Bool) - um.failureReq[upstreamOS] = new(atomic.Uint64) + um.reset(upstreamOS) return um } // increaseFailureCount increase failed queries count for an upstream by 1. func (um *upstreamMonitor) increaseFailureCount(upstream string) { - failedCount := um.failureReq[upstream].Add(1) - um.down[upstream].Store(failedCount >= maxFailureRequest) + um.mu.Lock() + defer um.mu.Unlock() + + um.failureReq[upstream] += 1 + failedCount := um.failureReq[upstream] + um.down[upstream] = failedCount >= maxFailureRequest } // isDown reports whether the given upstream is being marked as down. func (um *upstreamMonitor) isDown(upstream string) bool { - return um.down[upstream].Load() + um.mu.Lock() + defer um.mu.Unlock() + + return um.down[upstream] } // reset marks an upstream as up and set failed queries counter to zero. func (um *upstreamMonitor) reset(upstream string) { - um.failureReq[upstream].Store(0) - um.down[upstream].Store(false) + um.mu.Lock() + defer um.mu.Unlock() + + um.failureReq[upstream] = 0 + um.down[upstream] = false } // checkUpstream checks the given upstream status, periodically sending query to upstream @@ -74,6 +80,11 @@ func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConf } um.checking[upstream] = true um.mu.Unlock() + defer func() { + um.mu.Lock() + um.checking[upstream] = false + um.mu.Unlock() + }() resolver, err := ctrld.NewResolver(uc) if err != nil {