mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-03-25 23:30:41 +01:00
blocks direct Ip.
This commit is contained in:
@@ -6,8 +6,9 @@ Complete TCP/UDP/DNS packet capture implementation using gVisor netstack for And
|
||||
|
||||
Provides full packet capture for mobile VPN applications:
|
||||
- **DNS filtering** through ControlD proxy
|
||||
- **TCP forwarding** for all TCP traffic
|
||||
- **UDP forwarding** with session tracking
|
||||
- **IP whitelisting** - only allows connections to DNS-resolved IPs
|
||||
- **TCP forwarding** for all TCP traffic (with whitelist enforcement)
|
||||
- **UDP forwarding** with session tracking (with whitelist enforcement)
|
||||
- **Socket protection** to prevent routing loops
|
||||
- **QUIC blocking** for better content filtering
|
||||
|
||||
@@ -29,16 +30,19 @@ Real Network (Protected Sockets)
|
||||
## Components
|
||||
|
||||
### DNS Filter (`dns_filter.go`)
|
||||
Detects DNS packets on port 53 and routes to ControlD proxy.
|
||||
Detects DNS packets on port 53, routes to ControlD proxy, and extracts resolved IPs.
|
||||
|
||||
### DNS Bridge (`dns_bridge.go`)
|
||||
Tracks DNS queries by transaction ID with 5-second timeout.
|
||||
|
||||
### IP Tracker (`ip_tracker.go`)
|
||||
Maintains whitelist of DNS-resolved IPs with 5-minute TTL.
|
||||
|
||||
### TCP Forwarder (`tcp_forwarder.go`)
|
||||
Forwards TCP connections using gVisor's `tcp.NewForwarder()`.
|
||||
Forwards TCP connections using gVisor's `tcp.NewForwarder()`. Blocks non-whitelisted IPs.
|
||||
|
||||
### UDP Forwarder (`udp_forwarder.go`)
|
||||
Forwards UDP packets with session tracking and 60-second idle timeout.
|
||||
Forwards UDP packets with 30-second read deadline. Blocks non-whitelisted IPs.
|
||||
|
||||
### Packet Handler (`packet_handler.go`)
|
||||
Interface for TUN I/O and socket protection.
|
||||
@@ -125,6 +129,34 @@ Drops UDP packets on ports 443 and 80 to force TCP fallback:
|
||||
- No user-visible errors
|
||||
- Slightly slower initial connection, then normal
|
||||
|
||||
## IP Blocking (DNS Bypass Prevention)
|
||||
|
||||
Enforces whitelist approach: ONLY allows connections to IPs resolved through ControlD DNS.
|
||||
|
||||
**How it works:**
|
||||
1. DNS responses are parsed to extract A and AAAA records
|
||||
2. Resolved IPs are tracked in memory whitelist for 5 minutes
|
||||
3. TCP/UDP connections to **non-whitelisted** IPs are **BLOCKED**
|
||||
4. Only IPs that went through DNS resolution are allowed
|
||||
|
||||
**Why:**
|
||||
- Prevents DNS bypass via hardcoded/cached IPs
|
||||
- Ensures ALL traffic must go through ControlD DNS first
|
||||
- Blocks apps that try to skip DNS filtering
|
||||
- Enforces strict ControlD policy compliance
|
||||
|
||||
**Example:**
|
||||
```
|
||||
✅ ALLOWED: App queries "example.com" → 93.184.216.34 → connects to 93.184.216.34
|
||||
❌ BLOCKED: App connects directly to 1.2.3.4 (not resolved via DNS)
|
||||
```
|
||||
|
||||
**Components:**
|
||||
- `ip_tracker.go` - Manages whitelist of DNS-resolved IPs with TTL
|
||||
- `dns_filter.go` - Extracts IPs from DNS responses for whitelist
|
||||
- `tcp_forwarder.go` - Allows only whitelisted IPs, blocks others
|
||||
- `udp_forwarder.go` - Allows only whitelisted IPs, blocks others
|
||||
|
||||
## Usage (Android)
|
||||
|
||||
```kotlin
|
||||
@@ -196,10 +228,11 @@ proxy.startFirewall(
|
||||
|
||||
- `packet_handler.go` - TUN I/O interface
|
||||
- `netstack.go` - gVisor controller
|
||||
- `dns_filter.go` - DNS packet detection
|
||||
- `dns_filter.go` - DNS packet detection and IP extraction
|
||||
- `dns_bridge.go` - Transaction tracking
|
||||
- `tcp_forwarder.go` - TCP forwarding
|
||||
- `udp_forwarder.go` - UDP forwarding
|
||||
- `ip_tracker.go` - DNS-resolved IP whitelist with TTL
|
||||
- `tcp_forwarder.go` - TCP forwarding with whitelist enforcement
|
||||
- `udp_forwarder.go` - UDP forwarding with whitelist enforcement
|
||||
|
||||
## License
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
)
|
||||
@@ -12,12 +13,14 @@ import (
|
||||
// DNSFilter intercepts and processes DNS packets.
|
||||
type DNSFilter struct {
|
||||
dnsHandler func([]byte) ([]byte, error)
|
||||
ipTracker *IPTracker
|
||||
}
|
||||
|
||||
// NewDNSFilter creates a new DNS filter with the given handler.
|
||||
func NewDNSFilter(handler func([]byte) ([]byte, error)) *DNSFilter {
|
||||
func NewDNSFilter(handler func([]byte) ([]byte, error), ipTracker *IPTracker) *DNSFilter {
|
||||
return &DNSFilter{
|
||||
dnsHandler: handler,
|
||||
ipTracker: ipTracker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,6 +104,11 @@ func (df *DNSFilter) processIPv4(packet []byte) (bool, []byte, error) {
|
||||
return true, nil, fmt.Errorf("DNS handler error: %v", err)
|
||||
}
|
||||
|
||||
// Track IPs from DNS response
|
||||
if df.ipTracker != nil {
|
||||
df.extractAndTrackIPs(dnsResponse)
|
||||
}
|
||||
|
||||
// Build response packet
|
||||
responsePacket := df.buildIPv4UDPPacket(
|
||||
dstIP.As4(), // Swap src/dst
|
||||
@@ -166,6 +174,11 @@ func (df *DNSFilter) processIPv6(packet []byte) (bool, []byte, error) {
|
||||
return true, nil, fmt.Errorf("DNS handler error: %v", err)
|
||||
}
|
||||
|
||||
// Track IPs from DNS response
|
||||
if df.ipTracker != nil {
|
||||
df.extractAndTrackIPs(dnsResponse)
|
||||
}
|
||||
|
||||
// Build response packet
|
||||
srcIP := ipHdr.SourceAddress()
|
||||
dstIP := ipHdr.DestinationAddress()
|
||||
@@ -322,3 +335,31 @@ func parseUDP(udpHeader []byte) (srcPort, dstPort uint16, ok bool) {
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
// extractAndTrackIPs parses DNS response and tracks resolved IP addresses
|
||||
func (df *DNSFilter) extractAndTrackIPs(dnsResponse []byte) {
|
||||
if len(dnsResponse) < 12 {
|
||||
return // Invalid DNS response
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(dnsResponse); err != nil {
|
||||
return // Failed to parse DNS response
|
||||
}
|
||||
|
||||
// Extract IPs from answer section
|
||||
for _, answer := range msg.Answer {
|
||||
switch rr := answer.(type) {
|
||||
case *dns.A:
|
||||
// IPv4 address
|
||||
if rr.A != nil {
|
||||
df.ipTracker.TrackIP(rr.A)
|
||||
}
|
||||
case *dns.AAAA:
|
||||
// IPv6 address
|
||||
if rr.AAAA != nil {
|
||||
df.ipTracker.TrackIP(rr.AAAA)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
179
cmd/ctrld_library/netstack/ip_tracker.go
Normal file
179
cmd/ctrld_library/netstack/ip_tracker.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IPTracker tracks IP addresses that have been resolved through DNS.
|
||||
// This allows blocking direct IP connections that bypass DNS filtering.
|
||||
type IPTracker struct {
|
||||
// Map of IP address string -> expiration time
|
||||
resolvedIPs map[string]time.Time
|
||||
mu sync.RWMutex
|
||||
|
||||
// TTL for tracked IPs (how long to remember them)
|
||||
ttl time.Duration
|
||||
|
||||
// Enable IP blocking (only in firewall mode)
|
||||
enabled bool
|
||||
|
||||
// Running state
|
||||
running bool
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewIPTracker creates a new IP tracker with the specified TTL and enabled flag
|
||||
func NewIPTracker(ttl time.Duration, enabled bool) *IPTracker {
|
||||
if ttl == 0 {
|
||||
ttl = 5 * time.Minute // Default 5 minutes
|
||||
}
|
||||
|
||||
return &IPTracker{
|
||||
resolvedIPs: make(map[string]time.Time),
|
||||
ttl: ttl,
|
||||
enabled: enabled,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// IsEnabled returns whether IP blocking is enabled
|
||||
func (t *IPTracker) IsEnabled() bool {
|
||||
if t == nil {
|
||||
return false
|
||||
}
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.enabled
|
||||
}
|
||||
|
||||
// SetEnabled sets whether IP blocking is enabled
|
||||
func (t *IPTracker) SetEnabled(enabled bool) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
t.mu.Lock()
|
||||
t.enabled = enabled
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// Start starts the IP tracker cleanup routine
|
||||
func (t *IPTracker) Start() {
|
||||
t.mu.Lock()
|
||||
if t.running {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
t.running = true
|
||||
t.mu.Unlock()
|
||||
|
||||
// Start cleanup goroutine to remove expired IPs
|
||||
t.wg.Add(1)
|
||||
go t.cleanupExpiredIPs()
|
||||
}
|
||||
|
||||
// Stop stops the IP tracker
|
||||
func (t *IPTracker) Stop() {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
if !t.running {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
t.running = false
|
||||
t.mu.Unlock()
|
||||
|
||||
// Close stop channel (protected against double close)
|
||||
select {
|
||||
case <-t.stopCh:
|
||||
// Already closed
|
||||
default:
|
||||
close(t.stopCh)
|
||||
}
|
||||
|
||||
t.wg.Wait()
|
||||
|
||||
// Clear all tracked IPs
|
||||
t.mu.Lock()
|
||||
t.resolvedIPs = make(map[string]time.Time)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// TrackIP adds an IP address to the tracking list
|
||||
func (t *IPTracker) TrackIP(ip net.IP) {
|
||||
if ip == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize to string format
|
||||
ipStr := ip.String()
|
||||
|
||||
t.mu.Lock()
|
||||
t.resolvedIPs[ipStr] = time.Now().Add(t.ttl)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// IsTracked checks if an IP address is in the tracking list
|
||||
func (t *IPTracker) IsTracked(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ipStr := ip.String()
|
||||
|
||||
t.mu.RLock()
|
||||
expiration, exists := t.resolvedIPs[ipStr]
|
||||
t.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if expired
|
||||
if time.Now().After(expiration) {
|
||||
// Clean up expired entry
|
||||
t.mu.Lock()
|
||||
delete(t.resolvedIPs, ipStr)
|
||||
t.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetTrackedCount returns the number of currently tracked IPs
|
||||
func (t *IPTracker) GetTrackedCount() int {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return len(t.resolvedIPs)
|
||||
}
|
||||
|
||||
// cleanupExpiredIPs periodically removes expired IP entries
|
||||
func (t *IPTracker) cleanupExpiredIPs() {
|
||||
defer t.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.stopCh:
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
t.mu.Lock()
|
||||
for ip, expiration := range t.resolvedIPs {
|
||||
if now.After(expiration) {
|
||||
delete(t.resolvedIPs, ip)
|
||||
}
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,7 @@ type NetstackController struct {
|
||||
linkEP *channel.Endpoint
|
||||
packetHandler PacketHandler
|
||||
dnsFilter *DNSFilter
|
||||
ipTracker *IPTracker
|
||||
tcpForwarder *TCPForwarder
|
||||
udpForwarder *UDPForwarder
|
||||
|
||||
@@ -64,6 +65,9 @@ type Config struct {
|
||||
|
||||
// UpstreamInterface is the real network interface for routing non-DNS traffic
|
||||
UpstreamInterface *net.Interface
|
||||
|
||||
// EnableIPBlocking enables IP whitelisting (firewall mode only)
|
||||
EnableIPBlocking bool
|
||||
}
|
||||
|
||||
// NewNetstackController creates a new netstack controller.
|
||||
@@ -100,14 +104,17 @@ func NewNetstackController(handler PacketHandler, cfg *Config) (*NetstackControl
|
||||
// Create link endpoint
|
||||
linkEP := channel.New(channelCapacity, cfg.MTU, "")
|
||||
|
||||
// Create DNS filter
|
||||
dnsFilter := NewDNSFilter(cfg.DNSHandler)
|
||||
// Create IP tracker (5 minute TTL for tracked IPs, enabled based on config)
|
||||
ipTracker := NewIPTracker(5*time.Minute, cfg.EnableIPBlocking)
|
||||
|
||||
// Create TCP forwarder
|
||||
tcpForwarder := NewTCPForwarder(s, handler.ProtectSocket, ctx)
|
||||
// Create DNS filter with IP tracker
|
||||
dnsFilter := NewDNSFilter(cfg.DNSHandler, ipTracker)
|
||||
|
||||
// Create UDP forwarder
|
||||
udpForwarder := NewUDPForwarder(s, handler.ProtectSocket, ctx)
|
||||
// Create TCP forwarder with IP tracker
|
||||
tcpForwarder := NewTCPForwarder(s, handler.ProtectSocket, ctx, ipTracker)
|
||||
|
||||
// Create UDP forwarder with IP tracker
|
||||
udpForwarder := NewUDPForwarder(s, handler.ProtectSocket, ctx, ipTracker)
|
||||
|
||||
// Create NIC
|
||||
if err := s.CreateNIC(NICID, linkEP); err != nil {
|
||||
@@ -176,6 +183,7 @@ func NewNetstackController(handler PacketHandler, cfg *Config) (*NetstackControl
|
||||
linkEP: linkEP,
|
||||
packetHandler: handler,
|
||||
dnsFilter: dnsFilter,
|
||||
ipTracker: ipTracker,
|
||||
tcpForwarder: tcpForwarder,
|
||||
udpForwarder: udpForwarder,
|
||||
ctx: ctx,
|
||||
@@ -199,6 +207,9 @@ func (nc *NetstackController) Start() error {
|
||||
|
||||
nc.started = true
|
||||
|
||||
// Start IP tracker
|
||||
nc.ipTracker.Start()
|
||||
|
||||
// Start packet reader goroutine (TUN -> netstack)
|
||||
nc.wg.Add(1)
|
||||
go nc.readPackets()
|
||||
@@ -207,36 +218,82 @@ func (nc *NetstackController) Start() error {
|
||||
nc.wg.Add(1)
|
||||
go nc.writePackets()
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Packet processing started (read/write goroutines)")
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Packet processing started (read/write goroutines + IP tracker)")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetFirewallMode enables or disables IP whitelisting at runtime
|
||||
func (nc *NetstackController) SetFirewallMode(enabled bool) {
|
||||
if nc == nil {
|
||||
return
|
||||
}
|
||||
nc.mu.Lock()
|
||||
defer nc.mu.Unlock()
|
||||
|
||||
if nc.ipTracker != nil {
|
||||
nc.ipTracker.SetEnabled(enabled)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the netstack controller and waits for all goroutines to finish.
|
||||
func (nc *NetstackController) Stop() error {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() called - starting shutdown")
|
||||
|
||||
nc.mu.Lock()
|
||||
if !nc.started {
|
||||
nc.mu.Unlock()
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - already stopped, returning")
|
||||
return nil
|
||||
}
|
||||
nc.mu.Unlock()
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - canceling context")
|
||||
nc.cancel()
|
||||
nc.wg.Wait()
|
||||
|
||||
// Close packet handler FIRST to unblock all pending reads
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - closing packet handler to unblock goroutines")
|
||||
if err := nc.packetHandler.Close(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Msgf("[Netstack] Stop() - failed to close packet handler: %v", err)
|
||||
// Continue shutdown even if close fails
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - packet handler closed")
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - waiting for goroutines (max 2 seconds)")
|
||||
|
||||
// Wait for goroutines with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
nc.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - all goroutines finished")
|
||||
case <-time.After(2 * time.Second):
|
||||
ctrld.ProxyLogger.Load().Warn().Msg("[Netstack] Stop() - timeout waiting for goroutines, proceeding anyway")
|
||||
}
|
||||
|
||||
// Stop IP tracker
|
||||
if nc.ipTracker != nil {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - stopping IP tracker")
|
||||
nc.ipTracker.Stop()
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - IP tracker stopped")
|
||||
}
|
||||
|
||||
// Close UDP forwarder
|
||||
if nc.udpForwarder != nil {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - closing UDP forwarder")
|
||||
nc.udpForwarder.Close()
|
||||
}
|
||||
|
||||
if err := nc.packetHandler.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close packet handler: %v", err)
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - UDP forwarder closed")
|
||||
}
|
||||
|
||||
nc.mu.Lock()
|
||||
nc.started = false
|
||||
nc.mu.Unlock()
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - shutdown complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,13 +19,15 @@ type TCPForwarder struct {
|
||||
protectSocket func(fd int) error
|
||||
ctx context.Context
|
||||
forwarder *tcp.Forwarder
|
||||
ipTracker *IPTracker
|
||||
}
|
||||
|
||||
// NewTCPForwarder creates a new TCP forwarder
|
||||
func NewTCPForwarder(s *stack.Stack, protectSocket func(fd int) error, ctx context.Context) *TCPForwarder {
|
||||
func NewTCPForwarder(s *stack.Stack, protectSocket func(fd int) error, ctx context.Context, ipTracker *IPTracker) *TCPForwarder {
|
||||
f := &TCPForwarder{
|
||||
protectSocket: protectSocket,
|
||||
ctx: ctx,
|
||||
ipTracker: ipTracker,
|
||||
}
|
||||
|
||||
// Create gVisor TCP forwarder with handler callback
|
||||
@@ -78,11 +80,28 @@ func (f *TCPForwarder) handleConnection(ep *tcp.Endpoint, wq *waiter.Queue, id s
|
||||
// - LocalAddress/LocalPort = the destination (where packet is going TO)
|
||||
// - RemoteAddress/RemotePort = the source (where packet is coming FROM)
|
||||
// We want to dial the DESTINATION (LocalAddress/LocalPort)
|
||||
dstIP := net.IP(id.LocalAddress.AsSlice())
|
||||
dstAddr := net.TCPAddr{
|
||||
IP: net.IP(id.LocalAddress.AsSlice()),
|
||||
IP: dstIP,
|
||||
Port: int(id.LocalPort),
|
||||
}
|
||||
|
||||
// Check if IP blocking is enabled (firewall mode only)
|
||||
// Skip blocking for internal VPN subnet (10.0.0.0/24)
|
||||
if f.ipTracker != nil && f.ipTracker.IsEnabled() {
|
||||
// Allow internal VPN traffic (10.0.0.0/24)
|
||||
if !(dstIP[0] == 10 && dstIP[1] == 0 && dstIP[2] == 0) {
|
||||
// Check if destination IP was resolved through ControlD DNS
|
||||
// ONLY allow connections to IPs that went through DNS (whitelist approach)
|
||||
if !f.ipTracker.IsTracked(dstIP) {
|
||||
srcAddr := net.IP(id.RemoteAddress.AsSlice())
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[TCP] BLOCKED hardcoded IP: %s:%d -> %s:%d (not resolved via DNS)",
|
||||
srcAddr, id.RemotePort, dstIP, id.LocalPort)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create outbound connection with socket protection DURING dial
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
|
||||
@@ -20,6 +20,7 @@ type UDPForwarder struct {
|
||||
protectSocket func(fd int) error
|
||||
ctx context.Context
|
||||
forwarder *udp.Forwarder
|
||||
ipTracker *IPTracker
|
||||
|
||||
// Track UDP "connections" (address pairs)
|
||||
connections map[string]*udpConn
|
||||
@@ -33,10 +34,11 @@ type udpConn struct {
|
||||
}
|
||||
|
||||
// NewUDPForwarder creates a new UDP forwarder
|
||||
func NewUDPForwarder(s *stack.Stack, protectSocket func(fd int) error, ctx context.Context) *UDPForwarder {
|
||||
func NewUDPForwarder(s *stack.Stack, protectSocket func(fd int) error, ctx context.Context, ipTracker *IPTracker) *UDPForwarder {
|
||||
f := &UDPForwarder{
|
||||
protectSocket: protectSocket,
|
||||
ctx: ctx,
|
||||
ipTracker: ipTracker,
|
||||
connections: make(map[string]*udpConn),
|
||||
}
|
||||
|
||||
@@ -102,11 +104,28 @@ func (f *UDPForwarder) createConnection(req *udp.ForwarderRequest, connKey strin
|
||||
// Extract destination address
|
||||
// LocalAddress/LocalPort = destination (where packet is going TO)
|
||||
// RemoteAddress/RemotePort = source (where packet is coming FROM)
|
||||
dstIP := net.IP(id.LocalAddress.AsSlice())
|
||||
dstAddr := &net.UDPAddr{
|
||||
IP: net.IP(id.LocalAddress.AsSlice()),
|
||||
IP: dstIP,
|
||||
Port: int(id.LocalPort),
|
||||
}
|
||||
|
||||
// Check if IP blocking is enabled (firewall mode only)
|
||||
// Skip blocking for internal VPN subnet (10.0.0.0/24)
|
||||
if f.ipTracker != nil && f.ipTracker.IsEnabled() {
|
||||
// Allow internal VPN traffic (10.0.0.0/24)
|
||||
if !(dstIP[0] == 10 && dstIP[1] == 0 && dstIP[2] == 0) {
|
||||
// Check if destination IP was resolved through ControlD DNS
|
||||
// ONLY allow connections to IPs that went through DNS (whitelist approach)
|
||||
if !f.ipTracker.IsTracked(dstIP) {
|
||||
srcAddr := net.IP(id.RemoteAddress.AsSlice())
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[UDP] BLOCKED hardcoded IP: %s:%d -> %s:%d (not resolved via DNS)",
|
||||
srcAddr, id.RemotePort, dstIP, id.LocalPort)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create dialer with socket protection DURING dial
|
||||
dialer := &net.Dialer{}
|
||||
|
||||
@@ -214,13 +233,18 @@ func (f *UDPForwarder) forwardUpstreamToTun(conn *udpConn, ctx context.Context,
|
||||
|
||||
// Close closes all UDP connections
|
||||
func (f *UDPForwarder) Close() {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[UDP] Close() called - closing all connections")
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
for _, conn := range f.connections {
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[UDP] Close() - closing %d connections", len(f.connections))
|
||||
for key, conn := range f.connections {
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("[UDP] Close() - closing connection: %s", key)
|
||||
conn.cancel()
|
||||
conn.tunEP.Close()
|
||||
conn.upstreamConn.Close()
|
||||
}
|
||||
f.connections = make(map[string]*udpConn)
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[UDP] Close() - all connections closed")
|
||||
}
|
||||
|
||||
@@ -120,7 +120,8 @@ func (pc *PacketCaptureController) StartWithPacketCapture(
|
||||
MTU: 1500,
|
||||
TUNIPv4: tunIPv4,
|
||||
DNSHandler: dnsHandler,
|
||||
UpstreamInterface: nil, // Will use default interface
|
||||
UpstreamInterface: nil, // Will use default interface
|
||||
EnableIPBlocking: true, // Enable IP whitelisting in firewall mode
|
||||
}
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Netstack TUN IP: %s", tunIP)
|
||||
@@ -223,41 +224,66 @@ func (pc *PacketCaptureController) handleDNSQuery(query *netstack.DNSQuery) {
|
||||
|
||||
// Stop stops the packet capture controller
|
||||
func (pc *PacketCaptureController) Stop(restart bool, pin int64) int {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() called - starting shutdown")
|
||||
var errorCode = 0
|
||||
|
||||
// Clear global socket protector
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - clearing socket protector")
|
||||
ctrld.SetSocketProtector(nil)
|
||||
|
||||
// Stop DNS bridge
|
||||
if pc.dnsBridge != nil {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - stopping DNS bridge")
|
||||
pc.dnsBridge.Stop()
|
||||
pc.dnsBridge = nil
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - DNS bridge stopped")
|
||||
}
|
||||
|
||||
// Stop netstack
|
||||
if pc.netstackCtrl != nil {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - stopping netstack controller")
|
||||
if err := pc.netstackCtrl.Stop(); err != nil {
|
||||
// Log error but continue shutdown
|
||||
fmt.Printf("Error stopping netstack: %v\n", err)
|
||||
ctrld.ProxyLogger.Load().Error().Msgf("[PacketCapture] Stop() - error stopping netstack: %v", err)
|
||||
}
|
||||
pc.netstackCtrl = nil
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - netstack controller stopped")
|
||||
}
|
||||
|
||||
// Close packet stop channel
|
||||
if pc.packetStopCh != nil {
|
||||
close(pc.packetStopCh)
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - closing packet stop channel")
|
||||
select {
|
||||
case <-pc.packetStopCh:
|
||||
// Already closed
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - packet stop channel already closed")
|
||||
default:
|
||||
close(pc.packetStopCh)
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - packet stop channel closed")
|
||||
}
|
||||
pc.packetStopCh = make(chan struct{})
|
||||
}
|
||||
|
||||
// Stop base controller
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Stop() - stopping base controller (restart=%v, pin=%d)", restart, pin)
|
||||
if !restart {
|
||||
errorCode = cli.CheckDeactivationPin(pin, pc.baseController.stopCh)
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Stop() - deactivation pin check returned: %d", errorCode)
|
||||
}
|
||||
if errorCode == 0 && pc.baseController.stopCh != nil {
|
||||
close(pc.baseController.stopCh)
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - closing base controller stop channel")
|
||||
select {
|
||||
case <-pc.baseController.stopCh:
|
||||
// Already closed
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - base controller stop channel already closed")
|
||||
default:
|
||||
close(pc.baseController.stopCh)
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - base controller stop channel closed")
|
||||
}
|
||||
pc.baseController.stopCh = nil
|
||||
}
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Stop() - shutdown complete, errorCode=%d", errorCode)
|
||||
return errorCode
|
||||
}
|
||||
|
||||
@@ -270,3 +296,15 @@ func (pc *PacketCaptureController) IsRunning() bool {
|
||||
func (pc *PacketCaptureController) IsPacketMode() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SetFirewallMode enables or disables firewall mode (IP whitelisting) at runtime
|
||||
func (pc *PacketCaptureController) SetFirewallMode(enabled bool) {
|
||||
if pc.netstackCtrl != nil {
|
||||
pc.netstackCtrl.SetFirewallMode(enabled)
|
||||
if enabled {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Firewall mode ENABLED - IP whitelisting active")
|
||||
} else {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Firewall mode DISABLED - all IPs allowed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user