mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-05-15 00:50:25 +02:00
blocks direct Ip.
This commit is contained in:
@@ -6,8 +6,9 @@ Complete TCP/UDP/DNS packet capture implementation using gVisor netstack for And
|
|||||||
|
|
||||||
Provides full packet capture for mobile VPN applications:
|
Provides full packet capture for mobile VPN applications:
|
||||||
- **DNS filtering** through ControlD proxy
|
- **DNS filtering** through ControlD proxy
|
||||||
- **TCP forwarding** for all TCP traffic
|
- **IP whitelisting** - only allows connections to DNS-resolved IPs
|
||||||
- **UDP forwarding** with session tracking
|
- **TCP forwarding** for all TCP traffic (with whitelist enforcement)
|
||||||
|
- **UDP forwarding** with session tracking (with whitelist enforcement)
|
||||||
- **Socket protection** to prevent routing loops
|
- **Socket protection** to prevent routing loops
|
||||||
- **QUIC blocking** for better content filtering
|
- **QUIC blocking** for better content filtering
|
||||||
|
|
||||||
@@ -29,16 +30,19 @@ Real Network (Protected Sockets)
|
|||||||
## Components
|
## Components
|
||||||
|
|
||||||
### DNS Filter (`dns_filter.go`)
|
### DNS Filter (`dns_filter.go`)
|
||||||
Detects DNS packets on port 53 and routes to ControlD proxy.
|
Detects DNS packets on port 53, routes to ControlD proxy, and extracts resolved IPs.
|
||||||
|
|
||||||
### DNS Bridge (`dns_bridge.go`)
|
### DNS Bridge (`dns_bridge.go`)
|
||||||
Tracks DNS queries by transaction ID with 5-second timeout.
|
Tracks DNS queries by transaction ID with 5-second timeout.
|
||||||
|
|
||||||
|
### IP Tracker (`ip_tracker.go`)
|
||||||
|
Maintains whitelist of DNS-resolved IPs with 5-minute TTL.
|
||||||
|
|
||||||
### TCP Forwarder (`tcp_forwarder.go`)
|
### TCP Forwarder (`tcp_forwarder.go`)
|
||||||
Forwards TCP connections using gVisor's `tcp.NewForwarder()`.
|
Forwards TCP connections using gVisor's `tcp.NewForwarder()`. Blocks non-whitelisted IPs.
|
||||||
|
|
||||||
### UDP Forwarder (`udp_forwarder.go`)
|
### UDP Forwarder (`udp_forwarder.go`)
|
||||||
Forwards UDP packets with session tracking and 60-second idle timeout.
|
Forwards UDP packets with 30-second read deadline. Blocks non-whitelisted IPs.
|
||||||
|
|
||||||
### Packet Handler (`packet_handler.go`)
|
### Packet Handler (`packet_handler.go`)
|
||||||
Interface for TUN I/O and socket protection.
|
Interface for TUN I/O and socket protection.
|
||||||
@@ -125,6 +129,34 @@ Drops UDP packets on ports 443 and 80 to force TCP fallback:
|
|||||||
- No user-visible errors
|
- No user-visible errors
|
||||||
- Slightly slower initial connection, then normal
|
- Slightly slower initial connection, then normal
|
||||||
|
|
||||||
|
## IP Blocking (DNS Bypass Prevention)
|
||||||
|
|
||||||
|
Enforces whitelist approach: ONLY allows connections to IPs resolved through ControlD DNS.
|
||||||
|
|
||||||
|
**How it works:**
|
||||||
|
1. DNS responses are parsed to extract A and AAAA records
|
||||||
|
2. Resolved IPs are tracked in memory whitelist for 5 minutes
|
||||||
|
3. TCP/UDP connections to **non-whitelisted** IPs are **BLOCKED**
|
||||||
|
4. Only IPs that went through DNS resolution are allowed
|
||||||
|
|
||||||
|
**Why:**
|
||||||
|
- Prevents DNS bypass via hardcoded/cached IPs
|
||||||
|
- Ensures ALL traffic must go through ControlD DNS first
|
||||||
|
- Blocks apps that try to skip DNS filtering
|
||||||
|
- Enforces strict ControlD policy compliance
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```
|
||||||
|
✅ ALLOWED: App queries "example.com" → 93.184.216.34 → connects to 93.184.216.34
|
||||||
|
❌ BLOCKED: App connects directly to 1.2.3.4 (not resolved via DNS)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Components:**
|
||||||
|
- `ip_tracker.go` - Manages whitelist of DNS-resolved IPs with TTL
|
||||||
|
- `dns_filter.go` - Extracts IPs from DNS responses for whitelist
|
||||||
|
- `tcp_forwarder.go` - Allows only whitelisted IPs, blocks others
|
||||||
|
- `udp_forwarder.go` - Allows only whitelisted IPs, blocks others
|
||||||
|
|
||||||
## Usage (Android)
|
## Usage (Android)
|
||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
@@ -196,10 +228,11 @@ proxy.startFirewall(
|
|||||||
|
|
||||||
- `packet_handler.go` - TUN I/O interface
|
- `packet_handler.go` - TUN I/O interface
|
||||||
- `netstack.go` - gVisor controller
|
- `netstack.go` - gVisor controller
|
||||||
- `dns_filter.go` - DNS packet detection
|
- `dns_filter.go` - DNS packet detection and IP extraction
|
||||||
- `dns_bridge.go` - Transaction tracking
|
- `dns_bridge.go` - Transaction tracking
|
||||||
- `tcp_forwarder.go` - TCP forwarding
|
- `ip_tracker.go` - DNS-resolved IP whitelist with TTL
|
||||||
- `udp_forwarder.go` - UDP forwarding
|
- `tcp_forwarder.go` - TCP forwarding with whitelist enforcement
|
||||||
|
- `udp_forwarder.go` - UDP forwarding with whitelist enforcement
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
)
|
)
|
||||||
@@ -12,12 +13,14 @@ import (
|
|||||||
// DNSFilter intercepts and processes DNS packets.
|
// DNSFilter intercepts and processes DNS packets.
|
||||||
type DNSFilter struct {
|
type DNSFilter struct {
|
||||||
dnsHandler func([]byte) ([]byte, error)
|
dnsHandler func([]byte) ([]byte, error)
|
||||||
|
ipTracker *IPTracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSFilter creates a new DNS filter with the given handler.
|
// NewDNSFilter creates a new DNS filter with the given handler.
|
||||||
func NewDNSFilter(handler func([]byte) ([]byte, error)) *DNSFilter {
|
func NewDNSFilter(handler func([]byte) ([]byte, error), ipTracker *IPTracker) *DNSFilter {
|
||||||
return &DNSFilter{
|
return &DNSFilter{
|
||||||
dnsHandler: handler,
|
dnsHandler: handler,
|
||||||
|
ipTracker: ipTracker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,6 +104,11 @@ func (df *DNSFilter) processIPv4(packet []byte) (bool, []byte, error) {
|
|||||||
return true, nil, fmt.Errorf("DNS handler error: %v", err)
|
return true, nil, fmt.Errorf("DNS handler error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Track IPs from DNS response
|
||||||
|
if df.ipTracker != nil {
|
||||||
|
df.extractAndTrackIPs(dnsResponse)
|
||||||
|
}
|
||||||
|
|
||||||
// Build response packet
|
// Build response packet
|
||||||
responsePacket := df.buildIPv4UDPPacket(
|
responsePacket := df.buildIPv4UDPPacket(
|
||||||
dstIP.As4(), // Swap src/dst
|
dstIP.As4(), // Swap src/dst
|
||||||
@@ -166,6 +174,11 @@ func (df *DNSFilter) processIPv6(packet []byte) (bool, []byte, error) {
|
|||||||
return true, nil, fmt.Errorf("DNS handler error: %v", err)
|
return true, nil, fmt.Errorf("DNS handler error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Track IPs from DNS response
|
||||||
|
if df.ipTracker != nil {
|
||||||
|
df.extractAndTrackIPs(dnsResponse)
|
||||||
|
}
|
||||||
|
|
||||||
// Build response packet
|
// Build response packet
|
||||||
srcIP := ipHdr.SourceAddress()
|
srcIP := ipHdr.SourceAddress()
|
||||||
dstIP := ipHdr.DestinationAddress()
|
dstIP := ipHdr.DestinationAddress()
|
||||||
@@ -322,3 +335,31 @@ func parseUDP(udpHeader []byte) (srcPort, dstPort uint16, ok bool) {
|
|||||||
ok = true
|
ok = true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractAndTrackIPs parses DNS response and tracks resolved IP addresses
|
||||||
|
func (df *DNSFilter) extractAndTrackIPs(dnsResponse []byte) {
|
||||||
|
if len(dnsResponse) < 12 {
|
||||||
|
return // Invalid DNS response
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
if err := msg.Unpack(dnsResponse); err != nil {
|
||||||
|
return // Failed to parse DNS response
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract IPs from answer section
|
||||||
|
for _, answer := range msg.Answer {
|
||||||
|
switch rr := answer.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
// IPv4 address
|
||||||
|
if rr.A != nil {
|
||||||
|
df.ipTracker.TrackIP(rr.A)
|
||||||
|
}
|
||||||
|
case *dns.AAAA:
|
||||||
|
// IPv6 address
|
||||||
|
if rr.AAAA != nil {
|
||||||
|
df.ipTracker.TrackIP(rr.AAAA)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,179 @@
|
|||||||
|
package netstack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IPTracker tracks IP addresses that have been resolved through DNS.
|
||||||
|
// This allows blocking direct IP connections that bypass DNS filtering.
|
||||||
|
type IPTracker struct {
|
||||||
|
// Map of IP address string -> expiration time
|
||||||
|
resolvedIPs map[string]time.Time
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// TTL for tracked IPs (how long to remember them)
|
||||||
|
ttl time.Duration
|
||||||
|
|
||||||
|
// Enable IP blocking (only in firewall mode)
|
||||||
|
enabled bool
|
||||||
|
|
||||||
|
// Running state
|
||||||
|
running bool
|
||||||
|
stopCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIPTracker creates a new IP tracker with the specified TTL and enabled flag
|
||||||
|
func NewIPTracker(ttl time.Duration, enabled bool) *IPTracker {
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = 5 * time.Minute // Default 5 minutes
|
||||||
|
}
|
||||||
|
|
||||||
|
return &IPTracker{
|
||||||
|
resolvedIPs: make(map[string]time.Time),
|
||||||
|
ttl: ttl,
|
||||||
|
enabled: enabled,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEnabled returns whether IP blocking is enabled
|
||||||
|
func (t *IPTracker) IsEnabled() bool {
|
||||||
|
if t == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
return t.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEnabled sets whether IP blocking is enabled
|
||||||
|
func (t *IPTracker) SetEnabled(enabled bool) {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.mu.Lock()
|
||||||
|
t.enabled = enabled
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the IP tracker cleanup routine
|
||||||
|
func (t *IPTracker) Start() {
|
||||||
|
t.mu.Lock()
|
||||||
|
if t.running {
|
||||||
|
t.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.running = true
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
// Start cleanup goroutine to remove expired IPs
|
||||||
|
t.wg.Add(1)
|
||||||
|
go t.cleanupExpiredIPs()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the IP tracker
|
||||||
|
func (t *IPTracker) Stop() {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
if !t.running {
|
||||||
|
t.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.running = false
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
// Close stop channel (protected against double close)
|
||||||
|
select {
|
||||||
|
case <-t.stopCh:
|
||||||
|
// Already closed
|
||||||
|
default:
|
||||||
|
close(t.stopCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.wg.Wait()
|
||||||
|
|
||||||
|
// Clear all tracked IPs
|
||||||
|
t.mu.Lock()
|
||||||
|
t.resolvedIPs = make(map[string]time.Time)
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackIP adds an IP address to the tracking list
|
||||||
|
func (t *IPTracker) TrackIP(ip net.IP) {
|
||||||
|
if ip == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize to string format
|
||||||
|
ipStr := ip.String()
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
t.resolvedIPs[ipStr] = time.Now().Add(t.ttl)
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTracked checks if an IP address is in the tracking list
|
||||||
|
func (t *IPTracker) IsTracked(ip net.IP) bool {
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ipStr := ip.String()
|
||||||
|
|
||||||
|
t.mu.RLock()
|
||||||
|
expiration, exists := t.resolvedIPs[ipStr]
|
||||||
|
t.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if expired
|
||||||
|
if time.Now().After(expiration) {
|
||||||
|
// Clean up expired entry
|
||||||
|
t.mu.Lock()
|
||||||
|
delete(t.resolvedIPs, ipStr)
|
||||||
|
t.mu.Unlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTrackedCount returns the number of currently tracked IPs
|
||||||
|
func (t *IPTracker) GetTrackedCount() int {
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
return len(t.resolvedIPs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupExpiredIPs periodically removes expired IP entries
|
||||||
|
func (t *IPTracker) cleanupExpiredIPs() {
|
||||||
|
defer t.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.stopCh:
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-ticker.C:
|
||||||
|
now := time.Now()
|
||||||
|
t.mu.Lock()
|
||||||
|
for ip, expiration := range t.resolvedIPs {
|
||||||
|
if now.After(expiration) {
|
||||||
|
delete(t.resolvedIPs, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -37,6 +37,7 @@ type NetstackController struct {
|
|||||||
linkEP *channel.Endpoint
|
linkEP *channel.Endpoint
|
||||||
packetHandler PacketHandler
|
packetHandler PacketHandler
|
||||||
dnsFilter *DNSFilter
|
dnsFilter *DNSFilter
|
||||||
|
ipTracker *IPTracker
|
||||||
tcpForwarder *TCPForwarder
|
tcpForwarder *TCPForwarder
|
||||||
udpForwarder *UDPForwarder
|
udpForwarder *UDPForwarder
|
||||||
|
|
||||||
@@ -64,6 +65,9 @@ type Config struct {
|
|||||||
|
|
||||||
// UpstreamInterface is the real network interface for routing non-DNS traffic
|
// UpstreamInterface is the real network interface for routing non-DNS traffic
|
||||||
UpstreamInterface *net.Interface
|
UpstreamInterface *net.Interface
|
||||||
|
|
||||||
|
// EnableIPBlocking enables IP whitelisting (firewall mode only)
|
||||||
|
EnableIPBlocking bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNetstackController creates a new netstack controller.
|
// NewNetstackController creates a new netstack controller.
|
||||||
@@ -100,14 +104,17 @@ func NewNetstackController(handler PacketHandler, cfg *Config) (*NetstackControl
|
|||||||
// Create link endpoint
|
// Create link endpoint
|
||||||
linkEP := channel.New(channelCapacity, cfg.MTU, "")
|
linkEP := channel.New(channelCapacity, cfg.MTU, "")
|
||||||
|
|
||||||
// Create DNS filter
|
// Create IP tracker (5 minute TTL for tracked IPs, enabled based on config)
|
||||||
dnsFilter := NewDNSFilter(cfg.DNSHandler)
|
ipTracker := NewIPTracker(5*time.Minute, cfg.EnableIPBlocking)
|
||||||
|
|
||||||
// Create TCP forwarder
|
// Create DNS filter with IP tracker
|
||||||
tcpForwarder := NewTCPForwarder(s, handler.ProtectSocket, ctx)
|
dnsFilter := NewDNSFilter(cfg.DNSHandler, ipTracker)
|
||||||
|
|
||||||
// Create UDP forwarder
|
// Create TCP forwarder with IP tracker
|
||||||
udpForwarder := NewUDPForwarder(s, handler.ProtectSocket, ctx)
|
tcpForwarder := NewTCPForwarder(s, handler.ProtectSocket, ctx, ipTracker)
|
||||||
|
|
||||||
|
// Create UDP forwarder with IP tracker
|
||||||
|
udpForwarder := NewUDPForwarder(s, handler.ProtectSocket, ctx, ipTracker)
|
||||||
|
|
||||||
// Create NIC
|
// Create NIC
|
||||||
if err := s.CreateNIC(NICID, linkEP); err != nil {
|
if err := s.CreateNIC(NICID, linkEP); err != nil {
|
||||||
@@ -176,6 +183,7 @@ func NewNetstackController(handler PacketHandler, cfg *Config) (*NetstackControl
|
|||||||
linkEP: linkEP,
|
linkEP: linkEP,
|
||||||
packetHandler: handler,
|
packetHandler: handler,
|
||||||
dnsFilter: dnsFilter,
|
dnsFilter: dnsFilter,
|
||||||
|
ipTracker: ipTracker,
|
||||||
tcpForwarder: tcpForwarder,
|
tcpForwarder: tcpForwarder,
|
||||||
udpForwarder: udpForwarder,
|
udpForwarder: udpForwarder,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -199,6 +207,9 @@ func (nc *NetstackController) Start() error {
|
|||||||
|
|
||||||
nc.started = true
|
nc.started = true
|
||||||
|
|
||||||
|
// Start IP tracker
|
||||||
|
nc.ipTracker.Start()
|
||||||
|
|
||||||
// Start packet reader goroutine (TUN -> netstack)
|
// Start packet reader goroutine (TUN -> netstack)
|
||||||
nc.wg.Add(1)
|
nc.wg.Add(1)
|
||||||
go nc.readPackets()
|
go nc.readPackets()
|
||||||
@@ -207,36 +218,82 @@ func (nc *NetstackController) Start() error {
|
|||||||
nc.wg.Add(1)
|
nc.wg.Add(1)
|
||||||
go nc.writePackets()
|
go nc.writePackets()
|
||||||
|
|
||||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Packet processing started (read/write goroutines)")
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Packet processing started (read/write goroutines + IP tracker)")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFirewallMode enables or disables IP whitelisting at runtime
|
||||||
|
func (nc *NetstackController) SetFirewallMode(enabled bool) {
|
||||||
|
if nc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nc.mu.Lock()
|
||||||
|
defer nc.mu.Unlock()
|
||||||
|
|
||||||
|
if nc.ipTracker != nil {
|
||||||
|
nc.ipTracker.SetEnabled(enabled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Stop stops the netstack controller and waits for all goroutines to finish.
|
// Stop stops the netstack controller and waits for all goroutines to finish.
|
||||||
func (nc *NetstackController) Stop() error {
|
func (nc *NetstackController) Stop() error {
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() called - starting shutdown")
|
||||||
|
|
||||||
nc.mu.Lock()
|
nc.mu.Lock()
|
||||||
if !nc.started {
|
if !nc.started {
|
||||||
nc.mu.Unlock()
|
nc.mu.Unlock()
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - already stopped, returning")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
nc.mu.Unlock()
|
nc.mu.Unlock()
|
||||||
|
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - canceling context")
|
||||||
nc.cancel()
|
nc.cancel()
|
||||||
nc.wg.Wait()
|
|
||||||
|
// Close packet handler FIRST to unblock all pending reads
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - closing packet handler to unblock goroutines")
|
||||||
|
if err := nc.packetHandler.Close(); err != nil {
|
||||||
|
ctrld.ProxyLogger.Load().Error().Msgf("[Netstack] Stop() - failed to close packet handler: %v", err)
|
||||||
|
// Continue shutdown even if close fails
|
||||||
|
}
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - packet handler closed")
|
||||||
|
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - waiting for goroutines (max 2 seconds)")
|
||||||
|
|
||||||
|
// Wait for goroutines with timeout
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
nc.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - all goroutines finished")
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
ctrld.ProxyLogger.Load().Warn().Msg("[Netstack] Stop() - timeout waiting for goroutines, proceeding anyway")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop IP tracker
|
||||||
|
if nc.ipTracker != nil {
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - stopping IP tracker")
|
||||||
|
nc.ipTracker.Stop()
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - IP tracker stopped")
|
||||||
|
}
|
||||||
|
|
||||||
// Close UDP forwarder
|
// Close UDP forwarder
|
||||||
if nc.udpForwarder != nil {
|
if nc.udpForwarder != nil {
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - closing UDP forwarder")
|
||||||
nc.udpForwarder.Close()
|
nc.udpForwarder.Close()
|
||||||
}
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - UDP forwarder closed")
|
||||||
|
|
||||||
if err := nc.packetHandler.Close(); err != nil {
|
|
||||||
return fmt.Errorf("failed to close packet handler: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nc.mu.Lock()
|
nc.mu.Lock()
|
||||||
nc.started = false
|
nc.started = false
|
||||||
nc.mu.Unlock()
|
nc.mu.Unlock()
|
||||||
|
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - shutdown complete")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,13 +19,15 @@ type TCPForwarder struct {
|
|||||||
protectSocket func(fd int) error
|
protectSocket func(fd int) error
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
forwarder *tcp.Forwarder
|
forwarder *tcp.Forwarder
|
||||||
|
ipTracker *IPTracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPForwarder creates a new TCP forwarder
|
// NewTCPForwarder creates a new TCP forwarder
|
||||||
func NewTCPForwarder(s *stack.Stack, protectSocket func(fd int) error, ctx context.Context) *TCPForwarder {
|
func NewTCPForwarder(s *stack.Stack, protectSocket func(fd int) error, ctx context.Context, ipTracker *IPTracker) *TCPForwarder {
|
||||||
f := &TCPForwarder{
|
f := &TCPForwarder{
|
||||||
protectSocket: protectSocket,
|
protectSocket: protectSocket,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
ipTracker: ipTracker,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create gVisor TCP forwarder with handler callback
|
// Create gVisor TCP forwarder with handler callback
|
||||||
@@ -78,11 +80,28 @@ func (f *TCPForwarder) handleConnection(ep *tcp.Endpoint, wq *waiter.Queue, id s
|
|||||||
// - LocalAddress/LocalPort = the destination (where packet is going TO)
|
// - LocalAddress/LocalPort = the destination (where packet is going TO)
|
||||||
// - RemoteAddress/RemotePort = the source (where packet is coming FROM)
|
// - RemoteAddress/RemotePort = the source (where packet is coming FROM)
|
||||||
// We want to dial the DESTINATION (LocalAddress/LocalPort)
|
// We want to dial the DESTINATION (LocalAddress/LocalPort)
|
||||||
|
dstIP := net.IP(id.LocalAddress.AsSlice())
|
||||||
dstAddr := net.TCPAddr{
|
dstAddr := net.TCPAddr{
|
||||||
IP: net.IP(id.LocalAddress.AsSlice()),
|
IP: dstIP,
|
||||||
Port: int(id.LocalPort),
|
Port: int(id.LocalPort),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if IP blocking is enabled (firewall mode only)
|
||||||
|
// Skip blocking for internal VPN subnet (10.0.0.0/24)
|
||||||
|
if f.ipTracker != nil && f.ipTracker.IsEnabled() {
|
||||||
|
// Allow internal VPN traffic (10.0.0.0/24)
|
||||||
|
if !(dstIP[0] == 10 && dstIP[1] == 0 && dstIP[2] == 0) {
|
||||||
|
// Check if destination IP was resolved through ControlD DNS
|
||||||
|
// ONLY allow connections to IPs that went through DNS (whitelist approach)
|
||||||
|
if !f.ipTracker.IsTracked(dstIP) {
|
||||||
|
srcAddr := net.IP(id.RemoteAddress.AsSlice())
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msgf("[TCP] BLOCKED hardcoded IP: %s:%d -> %s:%d (not resolved via DNS)",
|
||||||
|
srcAddr, id.RemotePort, dstIP, id.LocalPort)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create outbound connection with socket protection DURING dial
|
// Create outbound connection with socket protection DURING dial
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type UDPForwarder struct {
|
|||||||
protectSocket func(fd int) error
|
protectSocket func(fd int) error
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
forwarder *udp.Forwarder
|
forwarder *udp.Forwarder
|
||||||
|
ipTracker *IPTracker
|
||||||
|
|
||||||
// Track UDP "connections" (address pairs)
|
// Track UDP "connections" (address pairs)
|
||||||
connections map[string]*udpConn
|
connections map[string]*udpConn
|
||||||
@@ -33,10 +34,11 @@ type udpConn struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPForwarder creates a new UDP forwarder
|
// NewUDPForwarder creates a new UDP forwarder
|
||||||
func NewUDPForwarder(s *stack.Stack, protectSocket func(fd int) error, ctx context.Context) *UDPForwarder {
|
func NewUDPForwarder(s *stack.Stack, protectSocket func(fd int) error, ctx context.Context, ipTracker *IPTracker) *UDPForwarder {
|
||||||
f := &UDPForwarder{
|
f := &UDPForwarder{
|
||||||
protectSocket: protectSocket,
|
protectSocket: protectSocket,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
ipTracker: ipTracker,
|
||||||
connections: make(map[string]*udpConn),
|
connections: make(map[string]*udpConn),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,11 +104,28 @@ func (f *UDPForwarder) createConnection(req *udp.ForwarderRequest, connKey strin
|
|||||||
// Extract destination address
|
// Extract destination address
|
||||||
// LocalAddress/LocalPort = destination (where packet is going TO)
|
// LocalAddress/LocalPort = destination (where packet is going TO)
|
||||||
// RemoteAddress/RemotePort = source (where packet is coming FROM)
|
// RemoteAddress/RemotePort = source (where packet is coming FROM)
|
||||||
|
dstIP := net.IP(id.LocalAddress.AsSlice())
|
||||||
dstAddr := &net.UDPAddr{
|
dstAddr := &net.UDPAddr{
|
||||||
IP: net.IP(id.LocalAddress.AsSlice()),
|
IP: dstIP,
|
||||||
Port: int(id.LocalPort),
|
Port: int(id.LocalPort),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if IP blocking is enabled (firewall mode only)
|
||||||
|
// Skip blocking for internal VPN subnet (10.0.0.0/24)
|
||||||
|
if f.ipTracker != nil && f.ipTracker.IsEnabled() {
|
||||||
|
// Allow internal VPN traffic (10.0.0.0/24)
|
||||||
|
if !(dstIP[0] == 10 && dstIP[1] == 0 && dstIP[2] == 0) {
|
||||||
|
// Check if destination IP was resolved through ControlD DNS
|
||||||
|
// ONLY allow connections to IPs that went through DNS (whitelist approach)
|
||||||
|
if !f.ipTracker.IsTracked(dstIP) {
|
||||||
|
srcAddr := net.IP(id.RemoteAddress.AsSlice())
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msgf("[UDP] BLOCKED hardcoded IP: %s:%d -> %s:%d (not resolved via DNS)",
|
||||||
|
srcAddr, id.RemotePort, dstIP, id.LocalPort)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create dialer with socket protection DURING dial
|
// Create dialer with socket protection DURING dial
|
||||||
dialer := &net.Dialer{}
|
dialer := &net.Dialer{}
|
||||||
|
|
||||||
@@ -214,13 +233,18 @@ func (f *UDPForwarder) forwardUpstreamToTun(conn *udpConn, ctx context.Context,
|
|||||||
|
|
||||||
// Close closes all UDP connections
|
// Close closes all UDP connections
|
||||||
func (f *UDPForwarder) Close() {
|
func (f *UDPForwarder) Close() {
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[UDP] Close() called - closing all connections")
|
||||||
|
|
||||||
f.mu.Lock()
|
f.mu.Lock()
|
||||||
defer f.mu.Unlock()
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
for _, conn := range f.connections {
|
ctrld.ProxyLogger.Load().Info().Msgf("[UDP] Close() - closing %d connections", len(f.connections))
|
||||||
|
for key, conn := range f.connections {
|
||||||
|
ctrld.ProxyLogger.Load().Debug().Msgf("[UDP] Close() - closing connection: %s", key)
|
||||||
conn.cancel()
|
conn.cancel()
|
||||||
conn.tunEP.Close()
|
conn.tunEP.Close()
|
||||||
conn.upstreamConn.Close()
|
conn.upstreamConn.Close()
|
||||||
}
|
}
|
||||||
f.connections = make(map[string]*udpConn)
|
f.connections = make(map[string]*udpConn)
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[UDP] Close() - all connections closed")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -120,7 +120,8 @@ func (pc *PacketCaptureController) StartWithPacketCapture(
|
|||||||
MTU: 1500,
|
MTU: 1500,
|
||||||
TUNIPv4: tunIPv4,
|
TUNIPv4: tunIPv4,
|
||||||
DNSHandler: dnsHandler,
|
DNSHandler: dnsHandler,
|
||||||
UpstreamInterface: nil, // Will use default interface
|
UpstreamInterface: nil, // Will use default interface
|
||||||
|
EnableIPBlocking: true, // Enable IP whitelisting in firewall mode
|
||||||
}
|
}
|
||||||
|
|
||||||
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Netstack TUN IP: %s", tunIP)
|
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Netstack TUN IP: %s", tunIP)
|
||||||
@@ -223,41 +224,66 @@ func (pc *PacketCaptureController) handleDNSQuery(query *netstack.DNSQuery) {
|
|||||||
|
|
||||||
// Stop stops the packet capture controller
|
// Stop stops the packet capture controller
|
||||||
func (pc *PacketCaptureController) Stop(restart bool, pin int64) int {
|
func (pc *PacketCaptureController) Stop(restart bool, pin int64) int {
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() called - starting shutdown")
|
||||||
var errorCode = 0
|
var errorCode = 0
|
||||||
|
|
||||||
// Clear global socket protector
|
// Clear global socket protector
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - clearing socket protector")
|
||||||
ctrld.SetSocketProtector(nil)
|
ctrld.SetSocketProtector(nil)
|
||||||
|
|
||||||
// Stop DNS bridge
|
// Stop DNS bridge
|
||||||
if pc.dnsBridge != nil {
|
if pc.dnsBridge != nil {
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - stopping DNS bridge")
|
||||||
pc.dnsBridge.Stop()
|
pc.dnsBridge.Stop()
|
||||||
pc.dnsBridge = nil
|
pc.dnsBridge = nil
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - DNS bridge stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop netstack
|
// Stop netstack
|
||||||
if pc.netstackCtrl != nil {
|
if pc.netstackCtrl != nil {
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - stopping netstack controller")
|
||||||
if err := pc.netstackCtrl.Stop(); err != nil {
|
if err := pc.netstackCtrl.Stop(); err != nil {
|
||||||
// Log error but continue shutdown
|
// Log error but continue shutdown
|
||||||
fmt.Printf("Error stopping netstack: %v\n", err)
|
ctrld.ProxyLogger.Load().Error().Msgf("[PacketCapture] Stop() - error stopping netstack: %v", err)
|
||||||
}
|
}
|
||||||
pc.netstackCtrl = nil
|
pc.netstackCtrl = nil
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - netstack controller stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close packet stop channel
|
// Close packet stop channel
|
||||||
if pc.packetStopCh != nil {
|
if pc.packetStopCh != nil {
|
||||||
close(pc.packetStopCh)
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - closing packet stop channel")
|
||||||
|
select {
|
||||||
|
case <-pc.packetStopCh:
|
||||||
|
// Already closed
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - packet stop channel already closed")
|
||||||
|
default:
|
||||||
|
close(pc.packetStopCh)
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - packet stop channel closed")
|
||||||
|
}
|
||||||
pc.packetStopCh = make(chan struct{})
|
pc.packetStopCh = make(chan struct{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop base controller
|
// Stop base controller
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Stop() - stopping base controller (restart=%v, pin=%d)", restart, pin)
|
||||||
if !restart {
|
if !restart {
|
||||||
errorCode = cli.CheckDeactivationPin(pin, pc.baseController.stopCh)
|
errorCode = cli.CheckDeactivationPin(pin, pc.baseController.stopCh)
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Stop() - deactivation pin check returned: %d", errorCode)
|
||||||
}
|
}
|
||||||
if errorCode == 0 && pc.baseController.stopCh != nil {
|
if errorCode == 0 && pc.baseController.stopCh != nil {
|
||||||
close(pc.baseController.stopCh)
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - closing base controller stop channel")
|
||||||
|
select {
|
||||||
|
case <-pc.baseController.stopCh:
|
||||||
|
// Already closed
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - base controller stop channel already closed")
|
||||||
|
default:
|
||||||
|
close(pc.baseController.stopCh)
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - base controller stop channel closed")
|
||||||
|
}
|
||||||
pc.baseController.stopCh = nil
|
pc.baseController.stopCh = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Stop() - shutdown complete, errorCode=%d", errorCode)
|
||||||
return errorCode
|
return errorCode
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,3 +296,15 @@ func (pc *PacketCaptureController) IsRunning() bool {
|
|||||||
func (pc *PacketCaptureController) IsPacketMode() bool {
|
func (pc *PacketCaptureController) IsPacketMode() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFirewallMode enables or disables firewall mode (IP whitelisting) at runtime
|
||||||
|
func (pc *PacketCaptureController) SetFirewallMode(enabled bool) {
|
||||||
|
if pc.netstackCtrl != nil {
|
||||||
|
pc.netstackCtrl.SetFirewallMode(enabled)
|
||||||
|
if enabled {
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Firewall mode ENABLED - IP whitelisting active")
|
||||||
|
} else {
|
||||||
|
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Firewall mode DISABLED - all IPs allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user