mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-03-25 23:30:41 +01:00
180 lines
3.2 KiB
Go
180 lines
3.2 KiB
Go
package netstack
|
|
|
|
import (
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// IPTracker tracks IP addresses that have been resolved through DNS.
|
|
// This allows blocking direct IP connections that bypass DNS filtering.
|
|
type IPTracker struct {
|
|
// Map of IP address string -> expiration time
|
|
resolvedIPs map[string]time.Time
|
|
mu sync.RWMutex
|
|
|
|
// TTL for tracked IPs (how long to remember them)
|
|
ttl time.Duration
|
|
|
|
// Enable IP blocking (only in firewall mode)
|
|
enabled bool
|
|
|
|
// Running state
|
|
running bool
|
|
stopCh chan struct{}
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
// NewIPTracker creates a new IP tracker with the specified TTL and enabled flag
|
|
func NewIPTracker(ttl time.Duration, enabled bool) *IPTracker {
|
|
if ttl == 0 {
|
|
ttl = 5 * time.Minute // Default 5 minutes
|
|
}
|
|
|
|
return &IPTracker{
|
|
resolvedIPs: make(map[string]time.Time),
|
|
ttl: ttl,
|
|
enabled: enabled,
|
|
stopCh: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
// IsEnabled returns whether IP blocking is enabled
|
|
func (t *IPTracker) IsEnabled() bool {
|
|
if t == nil {
|
|
return false
|
|
}
|
|
t.mu.RLock()
|
|
defer t.mu.RUnlock()
|
|
return t.enabled
|
|
}
|
|
|
|
// SetEnabled sets whether IP blocking is enabled
|
|
func (t *IPTracker) SetEnabled(enabled bool) {
|
|
if t == nil {
|
|
return
|
|
}
|
|
t.mu.Lock()
|
|
t.enabled = enabled
|
|
t.mu.Unlock()
|
|
}
|
|
|
|
// Start starts the IP tracker cleanup routine
|
|
func (t *IPTracker) Start() {
|
|
t.mu.Lock()
|
|
if t.running {
|
|
t.mu.Unlock()
|
|
return
|
|
}
|
|
t.running = true
|
|
t.mu.Unlock()
|
|
|
|
// Start cleanup goroutine to remove expired IPs
|
|
t.wg.Add(1)
|
|
go t.cleanupExpiredIPs()
|
|
}
|
|
|
|
// Stop stops the IP tracker
|
|
func (t *IPTracker) Stop() {
|
|
if t == nil {
|
|
return
|
|
}
|
|
|
|
t.mu.Lock()
|
|
if !t.running {
|
|
t.mu.Unlock()
|
|
return
|
|
}
|
|
t.running = false
|
|
t.mu.Unlock()
|
|
|
|
// Close stop channel (protected against double close)
|
|
select {
|
|
case <-t.stopCh:
|
|
// Already closed
|
|
default:
|
|
close(t.stopCh)
|
|
}
|
|
|
|
t.wg.Wait()
|
|
|
|
// Clear all tracked IPs
|
|
t.mu.Lock()
|
|
t.resolvedIPs = make(map[string]time.Time)
|
|
t.mu.Unlock()
|
|
}
|
|
|
|
// TrackIP adds an IP address to the tracking list
|
|
func (t *IPTracker) TrackIP(ip net.IP) {
|
|
if ip == nil {
|
|
return
|
|
}
|
|
|
|
// Normalize to string format
|
|
ipStr := ip.String()
|
|
|
|
t.mu.Lock()
|
|
t.resolvedIPs[ipStr] = time.Now().Add(t.ttl)
|
|
t.mu.Unlock()
|
|
}
|
|
|
|
// IsTracked checks if an IP address is in the tracking list
|
|
func (t *IPTracker) IsTracked(ip net.IP) bool {
|
|
if ip == nil {
|
|
return false
|
|
}
|
|
|
|
ipStr := ip.String()
|
|
|
|
t.mu.RLock()
|
|
expiration, exists := t.resolvedIPs[ipStr]
|
|
t.mu.RUnlock()
|
|
|
|
if !exists {
|
|
return false
|
|
}
|
|
|
|
// Check if expired
|
|
if time.Now().After(expiration) {
|
|
// Clean up expired entry
|
|
t.mu.Lock()
|
|
delete(t.resolvedIPs, ipStr)
|
|
t.mu.Unlock()
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// GetTrackedCount returns the number of currently tracked IPs
|
|
func (t *IPTracker) GetTrackedCount() int {
|
|
t.mu.RLock()
|
|
defer t.mu.RUnlock()
|
|
return len(t.resolvedIPs)
|
|
}
|
|
|
|
// cleanupExpiredIPs periodically removes expired IP entries
|
|
func (t *IPTracker) cleanupExpiredIPs() {
|
|
defer t.wg.Done()
|
|
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-t.stopCh:
|
|
return
|
|
|
|
case <-ticker.C:
|
|
now := time.Now()
|
|
t.mu.Lock()
|
|
for ip, expiration := range t.resolvedIPs {
|
|
if now.After(expiration) {
|
|
delete(t.resolvedIPs, ip)
|
|
}
|
|
}
|
|
t.mu.Unlock()
|
|
}
|
|
}
|
|
}
|