Files
ctrld/cmd/ctrld_library/netstack/udp_forwarder.go
2026-03-19 16:53:34 -04:00

251 lines
5.9 KiB
Go

package netstack
import (
"context"
"fmt"
"net"
"sync"
"syscall"
"time"
"github.com/Control-D-Inc/ctrld"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
// UDPForwarder handles UDP packets from the TUN interface
type UDPForwarder struct {
protectSocket func(fd int) error
ctx context.Context
forwarder *udp.Forwarder
ipTracker *IPTracker
// Track UDP "connections" (address pairs)
connections map[string]*udpConn
mu sync.Mutex
}
type udpConn struct {
tunEP *gonet.UDPConn
upstreamConn *net.UDPConn
cancel context.CancelFunc
}
// NewUDPForwarder creates a new UDP forwarder
func NewUDPForwarder(s *stack.Stack, protectSocket func(fd int) error, ctx context.Context, ipTracker *IPTracker) *UDPForwarder {
f := &UDPForwarder{
protectSocket: protectSocket,
ctx: ctx,
ipTracker: ipTracker,
connections: make(map[string]*udpConn),
}
// Create gVisor UDP forwarder with handler callback
f.forwarder = udp.NewForwarder(s, f.handlePacket)
return f
}
// GetForwarder returns the underlying gVisor forwarder
func (f *UDPForwarder) GetForwarder() *udp.Forwarder {
return f.forwarder
}
// handlePacket handles an incoming UDP packet
func (f *UDPForwarder) handlePacket(req *udp.ForwarderRequest) {
// Get the endpoint ID
id := req.ID()
// Create connection key (source -> destination)
connKey := fmt.Sprintf("%s:%d->%s:%d",
net.IP(id.RemoteAddress.AsSlice()),
id.RemotePort,
net.IP(id.LocalAddress.AsSlice()),
id.LocalPort,
)
f.mu.Lock()
conn, exists := f.connections[connKey]
if !exists {
// Create new connection
conn = f.createConnection(req, connKey)
if conn == nil {
f.mu.Unlock()
return
}
f.connections[connKey] = conn
// Log new UDP session
srcAddr := net.IP(id.RemoteAddress.AsSlice())
dstAddr := net.IP(id.LocalAddress.AsSlice())
ctrld.ProxyLogger.Load().Debug().Msgf("[UDP] New session: %s:%d -> %s:%d (total: %d)",
srcAddr, id.RemotePort, dstAddr, id.LocalPort, len(f.connections))
}
f.mu.Unlock()
}
func (f *UDPForwarder) createConnection(req *udp.ForwarderRequest, connKey string) *udpConn {
id := req.ID()
// Create waiter queue
var wq waiter.Queue
// Create endpoint from request
ep, err := req.CreateEndpoint(&wq)
if err != nil {
return nil
}
// Convert to Go UDP conn
tunConn := gonet.NewUDPConn(&wq, ep)
// Extract destination address
// LocalAddress/LocalPort = destination (where packet is going TO)
// RemoteAddress/RemotePort = source (where packet is coming FROM)
dstIP := net.IP(id.LocalAddress.AsSlice())
dstAddr := &net.UDPAddr{
IP: dstIP,
Port: int(id.LocalPort),
}
// Check if IP blocking is enabled (firewall mode only)
// Skip blocking for internal VPN subnet (10.0.0.0/24)
if f.ipTracker != nil && f.ipTracker.IsEnabled() {
// Allow internal VPN traffic (10.0.0.0/24)
if !(dstIP[0] == 10 && dstIP[1] == 0 && dstIP[2] == 0) {
// Check if destination IP was resolved through ControlD DNS
// ONLY allow connections to IPs that went through DNS (whitelist approach)
if !f.ipTracker.IsTracked(dstIP) {
srcAddr := net.IP(id.RemoteAddress.AsSlice())
ctrld.ProxyLogger.Load().Info().Msgf("[UDP] BLOCKED hardcoded IP: %s:%d -> %s:%d (not resolved via DNS)",
srcAddr, id.RemotePort, dstIP, id.LocalPort)
return nil
}
}
}
// Create dialer with socket protection DURING dial
dialer := &net.Dialer{}
// CRITICAL: Protect socket BEFORE connect() is called
if f.protectSocket != nil {
dialer.Control = func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
f.protectSocket(int(fd))
})
}
}
// Create outbound UDP connection
dialConn, dialErr := dialer.Dial("udp", dstAddr.String())
if dialErr != nil {
tunConn.Close()
return nil
}
upstreamConn, ok := dialConn.(*net.UDPConn)
if !ok {
dialConn.Close()
tunConn.Close()
return nil
}
// Create connection context
ctx, cancel := context.WithCancel(f.ctx)
udpConnection := &udpConn{
tunEP: tunConn,
upstreamConn: upstreamConn,
cancel: cancel,
}
// Start forwarding goroutines
go f.forwardTunToUpstream(udpConnection, ctx)
go f.forwardUpstreamToTun(udpConnection, ctx, connKey)
return udpConnection
}
func (f *UDPForwarder) forwardTunToUpstream(conn *udpConn, ctx context.Context) {
buffer := make([]byte, 65535)
for {
select {
case <-ctx.Done():
return
default:
}
// Read from TUN
n, err := conn.tunEP.Read(buffer)
if err != nil {
return
}
// Write to upstream
_, err = conn.upstreamConn.Write(buffer[:n])
if err != nil {
return
}
}
}
func (f *UDPForwarder) forwardUpstreamToTun(conn *udpConn, ctx context.Context, connKey string) {
defer func() {
conn.tunEP.Close()
conn.upstreamConn.Close()
f.mu.Lock()
delete(f.connections, connKey)
f.mu.Unlock()
}()
buffer := make([]byte, 65535)
// Set read timeout
conn.upstreamConn.SetReadDeadline(time.Now().Add(30 * time.Second))
for {
select {
case <-ctx.Done():
return
default:
}
// Read from upstream
n, err := conn.upstreamConn.Read(buffer)
if err != nil {
return
}
// Reset read deadline
conn.upstreamConn.SetReadDeadline(time.Now().Add(30 * time.Second))
// Write to TUN
_, err = conn.tunEP.Write(buffer[:n])
if err != nil {
return
}
}
}
// Close closes all UDP connections
func (f *UDPForwarder) Close() {
ctrld.ProxyLogger.Load().Info().Msg("[UDP] Close() called - closing all connections")
f.mu.Lock()
defer f.mu.Unlock()
ctrld.ProxyLogger.Load().Info().Msgf("[UDP] Close() - closing %d connections", len(f.connections))
for key, conn := range f.connections {
ctrld.ProxyLogger.Load().Debug().Msgf("[UDP] Close() - closing connection: %s", key)
conn.cancel()
conn.tunEP.Close()
conn.upstreamConn.Close()
}
f.connections = make(map[string]*udpConn)
ctrld.ProxyLogger.Load().Info().Msg("[UDP] Close() - all connections closed")
}