mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-03-25 23:30:41 +01:00
239 lines
5.5 KiB
Go
239 lines
5.5 KiB
Go
package netstack
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/Control-D-Inc/ctrld"
|
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
|
"gvisor.dev/gvisor/pkg/waiter"
|
|
)
|
|
|
|
// UDPForwarder handles UDP packets from the TUN interface
|
|
type UDPForwarder struct {
|
|
ctx context.Context
|
|
forwarder *udp.Forwarder
|
|
ipTracker *IPTracker
|
|
|
|
// Track UDP "connections" (address pairs)
|
|
connections map[string]*udpConn
|
|
mu sync.Mutex
|
|
}
|
|
|
|
type udpConn struct {
|
|
tunEP *gonet.UDPConn
|
|
upstreamConn *net.UDPConn
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// NewUDPForwarder creates a new UDP forwarder
|
|
func NewUDPForwarder(s *stack.Stack, ctx context.Context, ipTracker *IPTracker) *UDPForwarder {
|
|
f := &UDPForwarder{
|
|
ctx: ctx,
|
|
ipTracker: ipTracker,
|
|
connections: make(map[string]*udpConn),
|
|
}
|
|
|
|
// Create gVisor UDP forwarder with handler callback
|
|
f.forwarder = udp.NewForwarder(s, f.handlePacket)
|
|
|
|
return f
|
|
}
|
|
|
|
// GetForwarder returns the underlying gVisor forwarder
|
|
func (f *UDPForwarder) GetForwarder() *udp.Forwarder {
|
|
return f.forwarder
|
|
}
|
|
|
|
// handlePacket handles an incoming UDP packet
|
|
func (f *UDPForwarder) handlePacket(req *udp.ForwarderRequest) {
|
|
// Get the endpoint ID
|
|
id := req.ID()
|
|
|
|
// Create connection key (source -> destination)
|
|
connKey := fmt.Sprintf("%s:%d->%s:%d",
|
|
net.IP(id.RemoteAddress.AsSlice()),
|
|
id.RemotePort,
|
|
net.IP(id.LocalAddress.AsSlice()),
|
|
id.LocalPort,
|
|
)
|
|
|
|
f.mu.Lock()
|
|
conn, exists := f.connections[connKey]
|
|
if !exists {
|
|
// Create new connection
|
|
conn = f.createConnection(req, connKey)
|
|
if conn == nil {
|
|
f.mu.Unlock()
|
|
return
|
|
}
|
|
f.connections[connKey] = conn
|
|
|
|
// Log new UDP session
|
|
srcAddr := net.IP(id.RemoteAddress.AsSlice())
|
|
dstAddr := net.IP(id.LocalAddress.AsSlice())
|
|
ctrld.ProxyLogger.Load().Debug().Msgf("[UDP] New session: %s:%d -> %s:%d (total: %d)",
|
|
srcAddr, id.RemotePort, dstAddr, id.LocalPort, len(f.connections))
|
|
}
|
|
f.mu.Unlock()
|
|
}
|
|
|
|
func (f *UDPForwarder) createConnection(req *udp.ForwarderRequest, connKey string) *udpConn {
|
|
id := req.ID()
|
|
|
|
// Create waiter queue
|
|
var wq waiter.Queue
|
|
|
|
// Create endpoint from request
|
|
ep, err := req.CreateEndpoint(&wq)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
// Convert to Go UDP conn
|
|
tunConn := gonet.NewUDPConn(&wq, ep)
|
|
|
|
// Extract destination address
|
|
// LocalAddress/LocalPort = destination (where packet is going TO)
|
|
// RemoteAddress/RemotePort = source (where packet is coming FROM)
|
|
dstIP := net.IP(id.LocalAddress.AsSlice())
|
|
dstAddr := &net.UDPAddr{
|
|
IP: dstIP,
|
|
Port: int(id.LocalPort),
|
|
}
|
|
|
|
// Check if IP blocking is enabled (firewall mode only)
|
|
// Skip blocking for internal VPN subnet (10.0.0.0/24)
|
|
if f.ipTracker != nil {
|
|
// Allow internal VPN traffic (10.0.0.0/24)
|
|
if !(dstIP[0] == 10 && dstIP[1] == 0 && dstIP[2] == 0) {
|
|
// Check if destination IP was resolved through ControlD DNS
|
|
// ONLY allow connections to IPs that went through DNS (whitelist approach)
|
|
if !f.ipTracker.IsTracked(dstIP) {
|
|
srcAddr := net.IP(id.RemoteAddress.AsSlice())
|
|
ctrld.ProxyLogger.Load().Info().Msgf("[UDP] BLOCKED hardcoded IP: %s:%d -> %s:%d (not resolved via DNS)",
|
|
srcAddr, id.RemotePort, dstIP, id.LocalPort)
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create dialer
|
|
dialer := &net.Dialer{}
|
|
|
|
// Create outbound UDP connection
|
|
dialConn, dialErr := dialer.Dial("udp", dstAddr.String())
|
|
if dialErr != nil {
|
|
tunConn.Close()
|
|
return nil
|
|
}
|
|
|
|
upstreamConn, ok := dialConn.(*net.UDPConn)
|
|
if !ok {
|
|
dialConn.Close()
|
|
tunConn.Close()
|
|
return nil
|
|
}
|
|
|
|
// Create connection context
|
|
ctx, cancel := context.WithCancel(f.ctx)
|
|
|
|
udpConnection := &udpConn{
|
|
tunEP: tunConn,
|
|
upstreamConn: upstreamConn,
|
|
cancel: cancel,
|
|
}
|
|
|
|
// Start forwarding goroutines
|
|
go f.forwardTunToUpstream(udpConnection, ctx)
|
|
go f.forwardUpstreamToTun(udpConnection, ctx, connKey)
|
|
|
|
return udpConnection
|
|
}
|
|
|
|
func (f *UDPForwarder) forwardTunToUpstream(conn *udpConn, ctx context.Context) {
|
|
buffer := make([]byte, 65535)
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
// Read from TUN
|
|
n, err := conn.tunEP.Read(buffer)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Write to upstream
|
|
_, err = conn.upstreamConn.Write(buffer[:n])
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (f *UDPForwarder) forwardUpstreamToTun(conn *udpConn, ctx context.Context, connKey string) {
|
|
defer func() {
|
|
conn.tunEP.Close()
|
|
conn.upstreamConn.Close()
|
|
|
|
f.mu.Lock()
|
|
delete(f.connections, connKey)
|
|
f.mu.Unlock()
|
|
}()
|
|
|
|
buffer := make([]byte, 65535)
|
|
|
|
// Set read timeout
|
|
conn.upstreamConn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
// Read from upstream
|
|
n, err := conn.upstreamConn.Read(buffer)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Reset read deadline
|
|
conn.upstreamConn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
|
|
|
// Write to TUN
|
|
_, err = conn.tunEP.Write(buffer[:n])
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Close closes all UDP connections
|
|
func (f *UDPForwarder) Close() {
|
|
ctrld.ProxyLogger.Load().Info().Msg("[UDP] Close() called - closing all connections")
|
|
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
|
|
ctrld.ProxyLogger.Load().Info().Msgf("[UDP] Close() - closing %d connections", len(f.connections))
|
|
for key, conn := range f.connections {
|
|
ctrld.ProxyLogger.Load().Debug().Msgf("[UDP] Close() - closing connection: %s", key)
|
|
conn.cancel()
|
|
conn.tunEP.Close()
|
|
conn.upstreamConn.Close()
|
|
}
|
|
f.connections = make(map[string]*udpConn)
|
|
ctrld.ProxyLogger.Load().Info().Msg("[UDP] Close() - all connections closed")
|
|
}
|