package netstack import ( "net" "sync" "time" ) // IPTracker tracks IP addresses that have been resolved through DNS. // This allows blocking direct IP connections that bypass DNS filtering. type IPTracker struct { // Map of IP address string -> expiration time resolvedIPs map[string]time.Time mu sync.RWMutex // TTL for tracked IPs (how long to remember them) ttl time.Duration // Enable IP blocking (only in firewall mode) enabled bool // Running state running bool stopCh chan struct{} wg sync.WaitGroup } // NewIPTracker creates a new IP tracker with the specified TTL and enabled flag func NewIPTracker(ttl time.Duration, enabled bool) *IPTracker { if ttl == 0 { ttl = 5 * time.Minute // Default 5 minutes } return &IPTracker{ resolvedIPs: make(map[string]time.Time), ttl: ttl, enabled: enabled, stopCh: make(chan struct{}), } } // IsEnabled returns whether IP blocking is enabled func (t *IPTracker) IsEnabled() bool { if t == nil { return false } t.mu.RLock() defer t.mu.RUnlock() return t.enabled } // SetEnabled sets whether IP blocking is enabled func (t *IPTracker) SetEnabled(enabled bool) { if t == nil { return } t.mu.Lock() t.enabled = enabled t.mu.Unlock() } // Start starts the IP tracker cleanup routine func (t *IPTracker) Start() { t.mu.Lock() if t.running { t.mu.Unlock() return } t.running = true t.mu.Unlock() // Start cleanup goroutine to remove expired IPs t.wg.Add(1) go t.cleanupExpiredIPs() } // Stop stops the IP tracker func (t *IPTracker) Stop() { if t == nil { return } t.mu.Lock() if !t.running { t.mu.Unlock() return } t.running = false t.mu.Unlock() // Close stop channel (protected against double close) select { case <-t.stopCh: // Already closed default: close(t.stopCh) } t.wg.Wait() // Clear all tracked IPs t.mu.Lock() t.resolvedIPs = make(map[string]time.Time) t.mu.Unlock() } // TrackIP adds an IP address to the tracking list func (t *IPTracker) TrackIP(ip net.IP) { if ip == nil { return } // Normalize to string format ipStr := ip.String() t.mu.Lock() t.resolvedIPs[ipStr] = time.Now().Add(t.ttl) t.mu.Unlock() } // IsTracked checks if an IP address is in the tracking list func (t *IPTracker) IsTracked(ip net.IP) bool { if ip == nil { return false } ipStr := ip.String() t.mu.RLock() expiration, exists := t.resolvedIPs[ipStr] t.mu.RUnlock() if !exists { return false } // Check if expired if time.Now().After(expiration) { // Clean up expired entry t.mu.Lock() delete(t.resolvedIPs, ipStr) t.mu.Unlock() return false } return true } // GetTrackedCount returns the number of currently tracked IPs func (t *IPTracker) GetTrackedCount() int { t.mu.RLock() defer t.mu.RUnlock() return len(t.resolvedIPs) } // cleanupExpiredIPs periodically removes expired IP entries func (t *IPTracker) cleanupExpiredIPs() { defer t.wg.Done() ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-t.stopCh: return case <-ticker.C: now := time.Now() t.mu.Lock() for ip, expiration := range t.resolvedIPs { if now.After(expiration) { delete(t.resolvedIPs, ip) } } t.mu.Unlock() } } }