mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-03-25 23:30:41 +01:00
blocks direct Ip.
This commit is contained in:
179
cmd/ctrld_library/netstack/ip_tracker.go
Normal file
179
cmd/ctrld_library/netstack/ip_tracker.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IPTracker tracks IP addresses that have been resolved through DNS.
|
||||
// This allows blocking direct IP connections that bypass DNS filtering.
|
||||
type IPTracker struct {
|
||||
// Map of IP address string -> expiration time
|
||||
resolvedIPs map[string]time.Time
|
||||
mu sync.RWMutex
|
||||
|
||||
// TTL for tracked IPs (how long to remember them)
|
||||
ttl time.Duration
|
||||
|
||||
// Enable IP blocking (only in firewall mode)
|
||||
enabled bool
|
||||
|
||||
// Running state
|
||||
running bool
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewIPTracker creates a new IP tracker with the specified TTL and enabled flag
|
||||
func NewIPTracker(ttl time.Duration, enabled bool) *IPTracker {
|
||||
if ttl == 0 {
|
||||
ttl = 5 * time.Minute // Default 5 minutes
|
||||
}
|
||||
|
||||
return &IPTracker{
|
||||
resolvedIPs: make(map[string]time.Time),
|
||||
ttl: ttl,
|
||||
enabled: enabled,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// IsEnabled returns whether IP blocking is enabled
|
||||
func (t *IPTracker) IsEnabled() bool {
|
||||
if t == nil {
|
||||
return false
|
||||
}
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.enabled
|
||||
}
|
||||
|
||||
// SetEnabled sets whether IP blocking is enabled
|
||||
func (t *IPTracker) SetEnabled(enabled bool) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
t.mu.Lock()
|
||||
t.enabled = enabled
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// Start starts the IP tracker cleanup routine
|
||||
func (t *IPTracker) Start() {
|
||||
t.mu.Lock()
|
||||
if t.running {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
t.running = true
|
||||
t.mu.Unlock()
|
||||
|
||||
// Start cleanup goroutine to remove expired IPs
|
||||
t.wg.Add(1)
|
||||
go t.cleanupExpiredIPs()
|
||||
}
|
||||
|
||||
// Stop stops the IP tracker
|
||||
func (t *IPTracker) Stop() {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
if !t.running {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
t.running = false
|
||||
t.mu.Unlock()
|
||||
|
||||
// Close stop channel (protected against double close)
|
||||
select {
|
||||
case <-t.stopCh:
|
||||
// Already closed
|
||||
default:
|
||||
close(t.stopCh)
|
||||
}
|
||||
|
||||
t.wg.Wait()
|
||||
|
||||
// Clear all tracked IPs
|
||||
t.mu.Lock()
|
||||
t.resolvedIPs = make(map[string]time.Time)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// TrackIP adds an IP address to the tracking list
|
||||
func (t *IPTracker) TrackIP(ip net.IP) {
|
||||
if ip == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize to string format
|
||||
ipStr := ip.String()
|
||||
|
||||
t.mu.Lock()
|
||||
t.resolvedIPs[ipStr] = time.Now().Add(t.ttl)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// IsTracked checks if an IP address is in the tracking list
|
||||
func (t *IPTracker) IsTracked(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ipStr := ip.String()
|
||||
|
||||
t.mu.RLock()
|
||||
expiration, exists := t.resolvedIPs[ipStr]
|
||||
t.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if expired
|
||||
if time.Now().After(expiration) {
|
||||
// Clean up expired entry
|
||||
t.mu.Lock()
|
||||
delete(t.resolvedIPs, ipStr)
|
||||
t.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetTrackedCount returns the number of currently tracked IPs
|
||||
func (t *IPTracker) GetTrackedCount() int {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return len(t.resolvedIPs)
|
||||
}
|
||||
|
||||
// cleanupExpiredIPs periodically removes expired IP entries
|
||||
func (t *IPTracker) cleanupExpiredIPs() {
|
||||
defer t.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.stopCh:
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
t.mu.Lock()
|
||||
for ip, expiration := range t.resolvedIPs {
|
||||
if now.After(expiration) {
|
||||
delete(t.resolvedIPs, ip)
|
||||
}
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user