mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-03-25 23:30:41 +01:00
tcp/ip stack + firewall mode.
This commit is contained in:
222
cmd/ctrld_library/netstack/README.md
Normal file
222
cmd/ctrld_library/netstack/README.md
Normal file
@@ -0,0 +1,222 @@
|
||||
# Netstack - Full Packet Capture for Mobile VPN
|
||||
|
||||
Complete TCP/UDP/DNS packet capture implementation using gVisor netstack for Android and iOS VPN apps.
|
||||
|
||||
## Overview
|
||||
|
||||
This module provides full packet capture capabilities for mobile VPN applications, handling:
|
||||
- **DNS filtering** through ControlD proxy
|
||||
- **TCP forwarding** for HTTP/HTTPS and all TCP traffic
|
||||
- **UDP forwarding** for games, video streaming, VoIP, etc.
|
||||
- **Socket protection** to prevent routing loops on Android/iOS
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Mobile Apps (Browser, Games, etc)
|
||||
↓
|
||||
VPN TUN Interface (10.0.0.2/24)
|
||||
↓
|
||||
PacketHandler (Read/Write/Protect)
|
||||
↓
|
||||
gVisor Netstack (TCP/IP Stack)
|
||||
├─→ DNS Filter (Port 53)
|
||||
│ └─→ ControlD DNS Proxy (localhost:5354)
|
||||
├─→ TCP Forwarder
|
||||
│ └─→ net.Dial("tcp") + protect(fd)
|
||||
└─→ UDP Forwarder
|
||||
└─→ net.Dial("udp") + protect(fd)
|
||||
↓
|
||||
Real Network (WiFi/Cellular) - Protected Sockets
|
||||
```
|
||||
|
||||
## Key Components
|
||||
|
||||
### 1. DNS Filter (`dns_filter.go`)
|
||||
- Detects DNS packets (UDP port 53)
|
||||
- Extracts DNS query payload
|
||||
- Sends to DNS bridge
|
||||
- Builds DNS response packets
|
||||
|
||||
### 2. DNS Bridge (`dns_bridge.go`)
|
||||
- Transaction ID tracking
|
||||
- Query/response matching
|
||||
- 5-second timeout per query
|
||||
- Channel-based communication
|
||||
|
||||
### 3. TCP Forwarder (`tcp_forwarder.go`)
|
||||
- Uses gVisor's `tcp.NewForwarder()`
|
||||
- Converts gVisor endpoints to Go `net.Conn`
|
||||
- Dials regular TCP sockets (no root required)
|
||||
- Protects sockets using `VpnService.protect()` callback
|
||||
- Bidirectional `io.Copy()` for data forwarding
|
||||
|
||||
### 4. UDP Forwarder (`udp_forwarder.go`)
|
||||
- Uses gVisor's `udp.NewForwarder()`
|
||||
- Per-session connection tracking
|
||||
- Dials regular UDP sockets (no root required)
|
||||
- Protected sockets prevent routing loops
|
||||
- 60-second idle timeout with automatic cleanup
|
||||
|
||||
### 5. Packet Handler (`packet_handler.go`)
|
||||
- Interface for reading/writing raw IP packets
|
||||
- Mobile platforms implement:
|
||||
- `ReadPacket()` - Read from TUN file descriptor
|
||||
- `WritePacket()` - Write to TUN file descriptor
|
||||
- `ProtectSocket(fd)` - Protect socket from VPN routing
|
||||
- `Close()` - Clean up resources
|
||||
|
||||
### 6. Netstack Controller (`netstack.go`)
|
||||
- Manages gVisor stack lifecycle
|
||||
- Coordinates DNS filter and TCP/UDP forwarders
|
||||
- Filters outbound packets (source=10.0.0.x)
|
||||
- Drops return packets (handled by forwarders)
|
||||
|
||||
## Critical Design Decisions
|
||||
|
||||
### Socket Protection
|
||||
|
||||
**Why It's Critical:**
|
||||
Without socket protection, outbound connections would route back through the VPN, creating infinite loops:
|
||||
|
||||
```
|
||||
Bad (without protect):
|
||||
App → VPN → TCP Forwarder → net.Dial() → VPN → TCP Forwarder → LOOP!
|
||||
|
||||
Good (with protect):
|
||||
App → VPN → TCP Forwarder → net.Dial() → [PROTECTED] → WiFi → Internet ✅
|
||||
```
|
||||
|
||||
**Implementation:**
|
||||
```go
|
||||
// Protect socket BEFORE connect() is called
|
||||
dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
protectSocket(int(fd)) // Android: VpnService.protect()
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
**All Protected Sockets:**
|
||||
1. TCP forwarder sockets (user traffic)
|
||||
2. UDP forwarder sockets (user traffic)
|
||||
3. ControlD API HTTP sockets (api.controld.com)
|
||||
4. DoH upstream sockets (freedns.controld.com)
|
||||
|
||||
### Outbound vs Return Packets
|
||||
|
||||
**Outbound packets** (10.0.0.x → Internet):
|
||||
- Source IP: 10.0.0.x
|
||||
- Injected into gVisor netstack
|
||||
- Handled by TCP/UDP forwarders
|
||||
|
||||
**Return packets** (Internet → 10.0.0.x):
|
||||
- Source IP: NOT 10.0.0.x
|
||||
- Dropped by readPackets()
|
||||
- Return through forwarder's upstream connection automatically
|
||||
|
||||
### Address Mapping in gVisor
|
||||
|
||||
For inbound connections to the netstack:
|
||||
- `id.LocalAddress/LocalPort` = **Destination** (where packet is going TO)
|
||||
- `id.RemoteAddress/RemotePort` = **Source** (where packet is coming FROM)
|
||||
|
||||
Therefore, we dial `LocalAddress:LocalPort` (the destination).
|
||||
|
||||
## Usage Example (Android)
|
||||
|
||||
```kotlin
|
||||
// In VpnService
|
||||
val callback = object : PacketAppCallback {
|
||||
override fun readPacket(): ByteArray {
|
||||
// Read from TUN file descriptor
|
||||
val length = inputStream.channel.read(buffer)
|
||||
return packet
|
||||
}
|
||||
|
||||
override fun writePacket(packet: ByteArray) {
|
||||
// Write to TUN file descriptor
|
||||
outputStream.write(packet)
|
||||
}
|
||||
|
||||
override fun protectSocket(fd: Long) {
|
||||
// CRITICAL: Protect socket from VPN routing
|
||||
val success = protect(fd.toInt()) // VpnService.protect()
|
||||
if (!success) throw Exception("Failed to protect socket")
|
||||
}
|
||||
|
||||
override fun closePacketIO() {
|
||||
inputStream?.close()
|
||||
outputStream?.close()
|
||||
}
|
||||
|
||||
override fun exit(s: String) { }
|
||||
override fun hostname(): String = "android-device"
|
||||
override fun lanIp(): String = "10.0.0.2"
|
||||
override fun macAddress(): String = "00:00:00:00:00:00"
|
||||
}
|
||||
|
||||
// Create packet capture controller
|
||||
val controller = Ctrld_library.newPacketCaptureController(callback)
|
||||
|
||||
// Start packet capture
|
||||
controller.startWithPacketCapture(
|
||||
callback,
|
||||
"your-cd-uid",
|
||||
"", "", // provision ID, custom hostname
|
||||
filesDir.absolutePath,
|
||||
"doh", // upstream protocol
|
||||
2, // log level
|
||||
"$filesDir/ctrld.log"
|
||||
)
|
||||
|
||||
// Stop when done
|
||||
controller.stop(false, 0)
|
||||
```
|
||||
|
||||
## Protocol Support
|
||||
|
||||
| Protocol | Support | Details |
|
||||
|----------|---------|---------|
|
||||
| **DNS** | ✅ Full | Filtered through ControlD proxy |
|
||||
| **TCP** | ✅ Full | All ports, bidirectional forwarding |
|
||||
| **UDP** | ✅ Full | All ports except 53, session tracking |
|
||||
| **ICMP** | ⚠️ Partial | Basic support (no forwarding yet) |
|
||||
| **IPv4** | ✅ Full | Complete support |
|
||||
| **IPv6** | ✅ Full | Complete support |
|
||||
|
||||
## Performance
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| **DNS Timeout** | 5 seconds |
|
||||
| **TCP Dial Timeout** | 30 seconds |
|
||||
| **UDP Idle Timeout** | 60 seconds |
|
||||
| **UDP Cleanup Interval** | 30 seconds |
|
||||
| **MTU** | 1500 bytes |
|
||||
| **Overhead per TCP connection** | ~2KB |
|
||||
| **Overhead per UDP session** | ~1KB |
|
||||
|
||||
## Requirements
|
||||
|
||||
- Go 1.23+
|
||||
- gVisor netstack v0.0.0-20240722211153-64c016c92987
|
||||
- For Android: API 24+ (Android 7.0+)
|
||||
- For iOS: iOS 12+
|
||||
|
||||
## No Root Required
|
||||
|
||||
This implementation uses **regular TCP/UDP sockets** instead of raw sockets, making it compatible with non-rooted Android/iOS devices. Socket protection via `VpnService.protect()` (Android) or `NEPacketTunnelFlow` (iOS) prevents routing loops.
|
||||
|
||||
## Files
|
||||
|
||||
- `packet_handler.go` - Interface for TUN I/O and socket protection
|
||||
- `netstack.go` - Main controller managing gVisor stack
|
||||
- `dns_filter.go` - DNS packet detection and response building
|
||||
- `dns_bridge.go` - DNS query/response bridging
|
||||
- `tcp_forwarder.go` - TCP connection forwarding
|
||||
- `udp_forwarder.go` - UDP packet forwarding with session tracking
|
||||
|
||||
## License
|
||||
|
||||
Same as parent ctrld project.
|
||||
228
cmd/ctrld_library/netstack/dns_bridge.go
Normal file
228
cmd/ctrld_library/netstack/dns_bridge.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// DNSBridge provides a bridge between the netstack DNS filter and the existing ctrld DNS proxy.
|
||||
// It allows DNS queries captured from packets to be processed by the same logic as traditional DNS queries.
|
||||
type DNSBridge struct {
|
||||
// Channel for sending DNS queries
|
||||
queryCh chan *DNSQuery
|
||||
|
||||
// Channel for receiving DNS responses
|
||||
responseCh chan *DNSResponse
|
||||
|
||||
// Map to track pending queries by transaction ID
|
||||
pendingQueries map[uint16]*PendingQuery
|
||||
mu sync.RWMutex
|
||||
|
||||
// Timeout for DNS queries
|
||||
queryTimeout time.Duration
|
||||
|
||||
// Running state
|
||||
running bool
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// DNSQuery represents a DNS query to be processed
|
||||
type DNSQuery struct {
|
||||
ID uint16 // Transaction ID for matching response
|
||||
Query []byte // Raw DNS query bytes
|
||||
RespCh chan []byte // Response channel
|
||||
SrcIP string // Source IP for logging
|
||||
SrcPort uint16 // Source port
|
||||
}
|
||||
|
||||
// DNSResponse represents a DNS response
|
||||
type DNSResponse struct {
|
||||
ID uint16
|
||||
Response []byte
|
||||
}
|
||||
|
||||
// PendingQuery tracks a query waiting for response
|
||||
type PendingQuery struct {
|
||||
Query *DNSQuery
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// NewDNSBridge creates a new DNS bridge
|
||||
func NewDNSBridge() *DNSBridge {
|
||||
return &DNSBridge{
|
||||
queryCh: make(chan *DNSQuery, 100),
|
||||
responseCh: make(chan *DNSResponse, 100),
|
||||
pendingQueries: make(map[uint16]*PendingQuery),
|
||||
queryTimeout: 5 * time.Second,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the DNS bridge
|
||||
func (b *DNSBridge) Start() {
|
||||
b.mu.Lock()
|
||||
if b.running {
|
||||
b.mu.Unlock()
|
||||
return
|
||||
}
|
||||
b.running = true
|
||||
b.mu.Unlock()
|
||||
|
||||
// Start response handler
|
||||
b.wg.Add(1)
|
||||
go b.handleResponses()
|
||||
|
||||
// Start timeout checker
|
||||
b.wg.Add(1)
|
||||
go b.checkTimeouts()
|
||||
}
|
||||
|
||||
// Stop stops the DNS bridge
|
||||
func (b *DNSBridge) Stop() {
|
||||
b.mu.Lock()
|
||||
if !b.running {
|
||||
b.mu.Unlock()
|
||||
return
|
||||
}
|
||||
b.running = false
|
||||
b.mu.Unlock()
|
||||
|
||||
close(b.stopCh)
|
||||
b.wg.Wait()
|
||||
|
||||
// Clean up pending queries
|
||||
b.mu.Lock()
|
||||
for _, pending := range b.pendingQueries {
|
||||
close(pending.Query.RespCh)
|
||||
}
|
||||
b.pendingQueries = make(map[uint16]*PendingQuery)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
// ProcessQuery processes a DNS query and waits for response
|
||||
func (b *DNSBridge) ProcessQuery(query []byte, srcIP string, srcPort uint16) ([]byte, error) {
|
||||
if len(query) < 12 {
|
||||
return nil, fmt.Errorf("invalid DNS query: too short")
|
||||
}
|
||||
|
||||
// Parse DNS message to get transaction ID
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(query); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse DNS query: %v", err)
|
||||
}
|
||||
|
||||
// Create response channel
|
||||
respCh := make(chan []byte, 1)
|
||||
|
||||
// Create query
|
||||
dnsQuery := &DNSQuery{
|
||||
ID: msg.Id,
|
||||
Query: query,
|
||||
RespCh: respCh,
|
||||
SrcIP: srcIP,
|
||||
SrcPort: srcPort,
|
||||
}
|
||||
|
||||
// Store as pending
|
||||
b.mu.Lock()
|
||||
b.pendingQueries[msg.Id] = &PendingQuery{
|
||||
Query: dnsQuery,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
// Send query
|
||||
select {
|
||||
case b.queryCh <- dnsQuery:
|
||||
case <-time.After(time.Second):
|
||||
b.mu.Lock()
|
||||
delete(b.pendingQueries, msg.Id)
|
||||
b.mu.Unlock()
|
||||
return nil, fmt.Errorf("query channel full")
|
||||
}
|
||||
|
||||
// Wait for response with timeout
|
||||
select {
|
||||
case response := <-respCh:
|
||||
b.mu.Lock()
|
||||
delete(b.pendingQueries, msg.Id)
|
||||
b.mu.Unlock()
|
||||
return response, nil
|
||||
|
||||
case <-time.After(b.queryTimeout):
|
||||
b.mu.Lock()
|
||||
delete(b.pendingQueries, msg.Id)
|
||||
b.mu.Unlock()
|
||||
return nil, fmt.Errorf("DNS query timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// GetQueryChannel returns the channel for receiving DNS queries
|
||||
func (b *DNSBridge) GetQueryChannel() <-chan *DNSQuery {
|
||||
return b.queryCh
|
||||
}
|
||||
|
||||
// SendResponse sends a DNS response back to the waiting query
|
||||
func (b *DNSBridge) SendResponse(id uint16, response []byte) error {
|
||||
b.mu.RLock()
|
||||
pending, exists := b.pendingQueries[id]
|
||||
b.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("no pending query for ID %d", id)
|
||||
}
|
||||
|
||||
select {
|
||||
case pending.Query.RespCh <- response:
|
||||
return nil
|
||||
case <-time.After(time.Second):
|
||||
return fmt.Errorf("failed to send response: channel blocked")
|
||||
}
|
||||
}
|
||||
|
||||
// handleResponses handles incoming responses
|
||||
func (b *DNSBridge) handleResponses() {
|
||||
defer b.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.stopCh:
|
||||
return
|
||||
|
||||
case resp := <-b.responseCh:
|
||||
if err := b.SendResponse(resp.ID, resp.Response); err != nil {
|
||||
// Log error but continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkTimeouts periodically checks for and removes timed out queries
|
||||
func (b *DNSBridge) checkTimeouts() {
|
||||
defer b.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.stopCh:
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
b.mu.Lock()
|
||||
for id, pending := range b.pendingQueries {
|
||||
if now.Sub(pending.Timestamp) > b.queryTimeout {
|
||||
close(pending.Query.RespCh)
|
||||
delete(b.pendingQueries, id)
|
||||
}
|
||||
}
|
||||
b.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
324
cmd/ctrld_library/netstack/dns_filter.go
Normal file
324
cmd/ctrld_library/netstack/dns_filter.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
)
|
||||
|
||||
// DNSFilter intercepts and processes DNS packets.
|
||||
type DNSFilter struct {
|
||||
dnsHandler func([]byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// NewDNSFilter creates a new DNS filter with the given handler.
|
||||
func NewDNSFilter(handler func([]byte) ([]byte, error)) *DNSFilter {
|
||||
return &DNSFilter{
|
||||
dnsHandler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessPacket checks if a packet is a DNS query and processes it.
|
||||
// Returns:
|
||||
// - isDNS: true if this is a DNS packet
|
||||
// - response: DNS response packet (if handled), nil otherwise
|
||||
// - error: any error that occurred
|
||||
func (df *DNSFilter) ProcessPacket(packet []byte) (isDNS bool, response []byte, err error) {
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Parse IP version
|
||||
ipVersion := packet[0] >> 4
|
||||
|
||||
switch ipVersion {
|
||||
case 4:
|
||||
return df.processIPv4(packet)
|
||||
case 6:
|
||||
return df.processIPv6(packet)
|
||||
default:
|
||||
return false, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// processIPv4 processes an IPv4 packet and checks if it's DNS.
|
||||
func (df *DNSFilter) processIPv4(packet []byte) (bool, []byte, error) {
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Parse IPv4 header
|
||||
ipHdr := header.IPv4(packet)
|
||||
if !ipHdr.IsValid(len(packet)) {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Check if it's UDP
|
||||
if ipHdr.TransportProtocol() != header.UDPProtocolNumber {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Get IP header length
|
||||
ihl := int(ipHdr.HeaderLength())
|
||||
if len(packet) < ihl+header.UDPMinimumSize {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Parse UDP header
|
||||
udpHdr := header.UDP(packet[ihl:])
|
||||
srcPort := udpHdr.SourcePort()
|
||||
dstPort := udpHdr.DestinationPort()
|
||||
|
||||
// Check if destination port is 53 (DNS)
|
||||
if dstPort != 53 {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
srcIP := ipHdr.SourceAddress()
|
||||
dstIP := ipHdr.DestinationAddress()
|
||||
|
||||
// Extract DNS payload
|
||||
udpPayloadOffset := ihl + header.UDPMinimumSize
|
||||
if len(packet) <= udpPayloadOffset {
|
||||
return true, nil, fmt.Errorf("invalid UDP packet length")
|
||||
}
|
||||
|
||||
dnsQuery := packet[udpPayloadOffset:]
|
||||
if len(dnsQuery) == 0 {
|
||||
return true, nil, fmt.Errorf("empty DNS query")
|
||||
}
|
||||
|
||||
// Process DNS query
|
||||
if df.dnsHandler == nil {
|
||||
return true, nil, fmt.Errorf("no DNS handler configured")
|
||||
}
|
||||
|
||||
dnsResponse, err := df.dnsHandler(dnsQuery)
|
||||
if err != nil {
|
||||
return true, nil, fmt.Errorf("DNS handler error: %v", err)
|
||||
}
|
||||
|
||||
// Build response packet
|
||||
responsePacket := df.buildIPv4UDPPacket(
|
||||
dstIP.As4(), // Swap src/dst
|
||||
srcIP.As4(),
|
||||
dstPort, // Swap ports
|
||||
srcPort,
|
||||
dnsResponse,
|
||||
)
|
||||
|
||||
return true, responsePacket, nil
|
||||
}
|
||||
|
||||
// processIPv6 processes an IPv6 packet and checks if it's DNS.
|
||||
func (df *DNSFilter) processIPv6(packet []byte) (bool, []byte, error) {
|
||||
if len(packet) < header.IPv6MinimumSize {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Parse IPv6 header
|
||||
ipHdr := header.IPv6(packet)
|
||||
if !ipHdr.IsValid(len(packet)) {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Check if it's UDP
|
||||
if ipHdr.TransportProtocol() != header.UDPProtocolNumber {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// IPv6 header is fixed size
|
||||
if len(packet) < header.IPv6MinimumSize+header.UDPMinimumSize {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Parse UDP header
|
||||
udpHdr := header.UDP(packet[header.IPv6MinimumSize:])
|
||||
srcPort := udpHdr.SourcePort()
|
||||
dstPort := udpHdr.DestinationPort()
|
||||
|
||||
// Check if destination port is 53 (DNS)
|
||||
if dstPort != 53 {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Extract DNS payload
|
||||
udpPayloadOffset := header.IPv6MinimumSize + header.UDPMinimumSize
|
||||
if len(packet) <= udpPayloadOffset {
|
||||
return true, nil, fmt.Errorf("invalid UDP packet length")
|
||||
}
|
||||
|
||||
dnsQuery := packet[udpPayloadOffset:]
|
||||
if len(dnsQuery) == 0 {
|
||||
return true, nil, fmt.Errorf("empty DNS query")
|
||||
}
|
||||
|
||||
// Process DNS query
|
||||
if df.dnsHandler == nil {
|
||||
return true, nil, fmt.Errorf("no DNS handler configured")
|
||||
}
|
||||
|
||||
dnsResponse, err := df.dnsHandler(dnsQuery)
|
||||
if err != nil {
|
||||
return true, nil, fmt.Errorf("DNS handler error: %v", err)
|
||||
}
|
||||
|
||||
// Build response packet
|
||||
srcIP := ipHdr.SourceAddress()
|
||||
dstIP := ipHdr.DestinationAddress()
|
||||
|
||||
responsePacket := df.buildIPv6UDPPacket(
|
||||
dstIP.As16(), // Swap src/dst
|
||||
srcIP.As16(),
|
||||
dstPort, // Swap ports
|
||||
srcPort,
|
||||
dnsResponse,
|
||||
)
|
||||
|
||||
return true, responsePacket, nil
|
||||
}
|
||||
|
||||
// buildIPv4UDPPacket builds a complete IPv4/UDP packet with the given payload.
|
||||
func (df *DNSFilter) buildIPv4UDPPacket(srcIP, dstIP [4]byte, srcPort, dstPort uint16, payload []byte) []byte {
|
||||
// Calculate lengths
|
||||
udpLen := header.UDPMinimumSize + len(payload)
|
||||
ipLen := header.IPv4MinimumSize + udpLen
|
||||
packet := make([]byte, ipLen)
|
||||
|
||||
// Build IPv4 header
|
||||
ipHdr := header.IPv4(packet)
|
||||
ipHdr.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(ipLen),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.UDPProtocolNumber),
|
||||
SrcAddr: tcpip.AddrFrom4(srcIP),
|
||||
DstAddr: tcpip.AddrFrom4(dstIP),
|
||||
})
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
|
||||
// Build UDP header
|
||||
udpHdr := header.UDP(packet[header.IPv4MinimumSize:])
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
Length: uint16(udpLen),
|
||||
})
|
||||
|
||||
// Copy payload
|
||||
copy(packet[header.IPv4MinimumSize+header.UDPMinimumSize:], payload)
|
||||
|
||||
// Calculate UDP checksum
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.UDPProtocolNumber,
|
||||
tcpip.AddrFrom4(srcIP),
|
||||
tcpip.AddrFrom4(dstIP),
|
||||
uint16(udpLen),
|
||||
)
|
||||
xsum = checksum(payload, xsum)
|
||||
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
|
||||
|
||||
return packet
|
||||
}
|
||||
|
||||
// buildIPv6UDPPacket builds a complete IPv6/UDP packet with the given payload.
|
||||
func (df *DNSFilter) buildIPv6UDPPacket(srcIP, dstIP [16]byte, srcPort, dstPort uint16, payload []byte) []byte {
|
||||
// Calculate lengths
|
||||
udpLen := header.UDPMinimumSize + len(payload)
|
||||
ipLen := header.IPv6MinimumSize + udpLen
|
||||
packet := make([]byte, ipLen)
|
||||
|
||||
// Build IPv6 header
|
||||
ipHdr := header.IPv6(packet)
|
||||
ipHdr.Encode(&header.IPv6Fields{
|
||||
PayloadLength: uint16(udpLen),
|
||||
TransportProtocol: header.UDPProtocolNumber,
|
||||
HopLimit: 64,
|
||||
SrcAddr: tcpip.AddrFrom16(srcIP),
|
||||
DstAddr: tcpip.AddrFrom16(dstIP),
|
||||
})
|
||||
|
||||
// Build UDP header
|
||||
udpHdr := header.UDP(packet[header.IPv6MinimumSize:])
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
Length: uint16(udpLen),
|
||||
})
|
||||
|
||||
// Copy payload
|
||||
copy(packet[header.IPv6MinimumSize+header.UDPMinimumSize:], payload)
|
||||
|
||||
// Calculate UDP checksum
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.UDPProtocolNumber,
|
||||
tcpip.AddrFrom16(srcIP),
|
||||
tcpip.AddrFrom16(dstIP),
|
||||
uint16(udpLen),
|
||||
)
|
||||
xsum = checksum(payload, xsum)
|
||||
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
|
||||
|
||||
return packet
|
||||
}
|
||||
|
||||
// checksum calculates the checksum for the given data.
|
||||
func checksum(buf []byte, initial uint16) uint16 {
|
||||
v := uint32(initial)
|
||||
l := len(buf)
|
||||
if l&1 != 0 {
|
||||
l--
|
||||
v += uint32(buf[l]) << 8
|
||||
}
|
||||
for i := 0; i < l; i += 2 {
|
||||
v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
|
||||
}
|
||||
return reduceChecksum(v)
|
||||
}
|
||||
|
||||
// reduceChecksum reduces a 32-bit checksum to 16 bits.
|
||||
func reduceChecksum(v uint32) uint16 {
|
||||
v = (v >> 16) + (v & 0xffff)
|
||||
v = (v >> 16) + (v & 0xffff)
|
||||
return uint16(v)
|
||||
}
|
||||
|
||||
// IPv4Address is a helper to create an IPv4 address from a byte array.
|
||||
func IPv4Address(b [4]byte) net.IP {
|
||||
return net.IPv4(b[0], b[1], b[2], b[3])
|
||||
}
|
||||
|
||||
// IPv6Address is a helper to create an IPv6 address from a byte array.
|
||||
func IPv6Address(b [16]byte) net.IP {
|
||||
return net.IP(b[:])
|
||||
}
|
||||
|
||||
// parseIPv4 extracts source and destination IPs from an IPv4 packet.
|
||||
func parseIPv4(packet []byte) (srcIP, dstIP [4]byte, ok bool) {
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
return
|
||||
}
|
||||
ipHdr := header.IPv4(packet)
|
||||
if !ipHdr.IsValid(len(packet)) {
|
||||
return
|
||||
}
|
||||
srcAddr := ipHdr.SourceAddress().As4()
|
||||
dstAddr := ipHdr.DestinationAddress().As4()
|
||||
copy(srcIP[:], srcAddr[:])
|
||||
copy(dstIP[:], dstAddr[:])
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
// parseUDP extracts UDP header information.
|
||||
func parseUDP(udpHeader []byte) (srcPort, dstPort uint16, ok bool) {
|
||||
if len(udpHeader) < header.UDPMinimumSize {
|
||||
return
|
||||
}
|
||||
srcPort = binary.BigEndian.Uint16(udpHeader[0:2])
|
||||
dstPort = binary.BigEndian.Uint16(udpHeader[2:4])
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
349
cmd/ctrld_library/netstack/netstack.go
Normal file
349
cmd/ctrld_library/netstack/netstack.go
Normal file
@@ -0,0 +1,349 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
// Default MTU for the TUN interface
|
||||
defaultMTU = 1500
|
||||
|
||||
// NICID is the ID of the network interface
|
||||
NICID = 1
|
||||
|
||||
// Channel capacity for packet buffers
|
||||
channelCapacity = 256
|
||||
)
|
||||
|
||||
// NetstackController manages the gVisor netstack integration for mobile packet capture.
|
||||
type NetstackController struct {
|
||||
stack *stack.Stack
|
||||
linkEP *channel.Endpoint
|
||||
packetHandler PacketHandler
|
||||
dnsFilter *DNSFilter
|
||||
tcpForwarder *TCPForwarder
|
||||
udpForwarder *UDPForwarder
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
started bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Config holds configuration for NetstackController.
|
||||
type Config struct {
|
||||
// MTU is the maximum transmission unit
|
||||
MTU uint32
|
||||
|
||||
// TUNIPv4 is the IPv4 address assigned to the TUN interface
|
||||
TUNIPv4 netip.Addr
|
||||
|
||||
// TUNIPv6 is the IPv6 address assigned to the TUN interface (optional)
|
||||
TUNIPv6 netip.Addr
|
||||
|
||||
// DNSHandler is the function to process DNS queries
|
||||
DNSHandler func([]byte) ([]byte, error)
|
||||
|
||||
// UpstreamInterface is the real network interface for routing non-DNS traffic
|
||||
UpstreamInterface *net.Interface
|
||||
}
|
||||
|
||||
// NewNetstackController creates a new netstack controller.
|
||||
func NewNetstackController(handler PacketHandler, cfg *Config) (*NetstackController, error) {
|
||||
if handler == nil {
|
||||
return nil, fmt.Errorf("packet handler cannot be nil")
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
cfg = &Config{
|
||||
MTU: defaultMTU,
|
||||
TUNIPv4: netip.MustParseAddr("10.0.0.1"),
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.MTU == 0 {
|
||||
cfg.MTU = defaultMTU
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create gVisor stack
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
ipv4.NewProtocol,
|
||||
ipv6.NewProtocol,
|
||||
},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
udp.NewProtocol,
|
||||
},
|
||||
})
|
||||
|
||||
// Create link endpoint
|
||||
linkEP := channel.New(channelCapacity, cfg.MTU, "")
|
||||
|
||||
// Create DNS filter
|
||||
dnsFilter := NewDNSFilter(cfg.DNSHandler)
|
||||
|
||||
// Create TCP forwarder
|
||||
tcpForwarder := NewTCPForwarder(s, handler.ProtectSocket, ctx)
|
||||
|
||||
// Create UDP forwarder
|
||||
udpForwarder := NewUDPForwarder(s, handler.ProtectSocket, ctx)
|
||||
|
||||
// Create NIC
|
||||
if err := s.CreateNIC(NICID, linkEP); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
// Enable spoofing to allow packets with any source IP
|
||||
if err := s.SetSpoofing(NICID, true); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to enable spoofing: %v", err)
|
||||
}
|
||||
|
||||
// Enable promiscuous mode to accept all packets
|
||||
if err := s.SetPromiscuousMode(NICID, true); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to enable promiscuous mode: %v", err)
|
||||
}
|
||||
|
||||
// Add IPv4 address
|
||||
protocolAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(cfg.TUNIPv4.AsSlice()),
|
||||
PrefixLen: 24,
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to add IPv4 address: %v", err)
|
||||
}
|
||||
|
||||
// Add IPv6 address if provided
|
||||
if cfg.TUNIPv6.IsValid() {
|
||||
protocolAddr6 := tcpip.ProtocolAddress{
|
||||
Protocol: ipv6.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(cfg.TUNIPv6.AsSlice()),
|
||||
PrefixLen: 64,
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(NICID, protocolAddr6, stack.AddressProperties{}); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to add IPv6 address: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add default routes
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: NICID,
|
||||
},
|
||||
{
|
||||
Destination: header.IPv6EmptySubnet,
|
||||
NIC: NICID,
|
||||
},
|
||||
})
|
||||
|
||||
// Register forwarders with the stack
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.forwarder.HandlePacket)
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.forwarder.HandlePacket)
|
||||
|
||||
nc := &NetstackController{
|
||||
stack: s,
|
||||
linkEP: linkEP,
|
||||
packetHandler: handler,
|
||||
dnsFilter: dnsFilter,
|
||||
tcpForwarder: tcpForwarder,
|
||||
udpForwarder: udpForwarder,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
started: false,
|
||||
}
|
||||
|
||||
return nc, nil
|
||||
}
|
||||
|
||||
// Start starts the netstack controller and begins processing packets.
|
||||
func (nc *NetstackController) Start() error {
|
||||
nc.mu.Lock()
|
||||
defer nc.mu.Unlock()
|
||||
|
||||
if nc.started {
|
||||
return fmt.Errorf("netstack controller already started")
|
||||
}
|
||||
|
||||
nc.started = true
|
||||
|
||||
// Start packet reader goroutine (TUN -> netstack)
|
||||
nc.wg.Add(1)
|
||||
go nc.readPackets()
|
||||
|
||||
// Start packet writer goroutine (netstack -> TUN)
|
||||
nc.wg.Add(1)
|
||||
go nc.writePackets()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the netstack controller and waits for all goroutines to finish.
|
||||
func (nc *NetstackController) Stop() error {
|
||||
nc.mu.Lock()
|
||||
if !nc.started {
|
||||
nc.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
nc.mu.Unlock()
|
||||
|
||||
nc.cancel()
|
||||
nc.wg.Wait()
|
||||
|
||||
// Close UDP forwarder
|
||||
if nc.udpForwarder != nil {
|
||||
nc.udpForwarder.Close()
|
||||
}
|
||||
|
||||
if err := nc.packetHandler.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close packet handler: %v", err)
|
||||
}
|
||||
|
||||
nc.mu.Lock()
|
||||
nc.started = false
|
||||
nc.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// readPackets reads packets from the TUN interface and injects them into the netstack.
|
||||
func (nc *NetstackController) readPackets() {
|
||||
defer nc.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-nc.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read packet from TUN
|
||||
packet, err := nc.packetHandler.ReadPacket()
|
||||
if err != nil {
|
||||
if nc.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(packet) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is a DNS packet
|
||||
isDNS, response, err := nc.dnsFilter.ProcessPacket(packet)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if isDNS && response != nil {
|
||||
// DNS packet was handled, send response back to TUN
|
||||
nc.packetHandler.WritePacket(response)
|
||||
continue
|
||||
}
|
||||
|
||||
if isDNS {
|
||||
continue
|
||||
}
|
||||
|
||||
// Not a DNS packet - check if it's an OUTBOUND packet (source = 10.0.0.x)
|
||||
// We should ONLY inject outbound packets, not return packets
|
||||
if len(packet) >= 20 {
|
||||
// Check if source is in our VPN subnet (10.0.0.x)
|
||||
isOutbound := packet[12] == 10 && packet[13] == 0 && packet[14] == 0
|
||||
|
||||
if !isOutbound {
|
||||
// This is a return packet (server -> mobile)
|
||||
// Drop it - return packets come through forwarder's upstream connection
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Create packet buffer
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
|
||||
// Determine protocol number
|
||||
var proto tcpip.NetworkProtocolNumber
|
||||
if len(packet) > 0 {
|
||||
version := packet[0] >> 4
|
||||
switch version {
|
||||
case 4:
|
||||
proto = header.IPv4ProtocolNumber
|
||||
case 6:
|
||||
proto = header.IPv6ProtocolNumber
|
||||
default:
|
||||
pkt.DecRef()
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
pkt.DecRef()
|
||||
continue
|
||||
}
|
||||
|
||||
// Inject into netstack - TCP/UDP forwarders will handle it
|
||||
nc.linkEP.InjectInbound(proto, pkt)
|
||||
}
|
||||
}
|
||||
|
||||
// writePackets reads packets from netstack and writes them to the TUN interface.
|
||||
func (nc *NetstackController) writePackets() {
|
||||
defer nc.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-nc.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read packet from netstack
|
||||
pkt := nc.linkEP.ReadContext(nc.ctx)
|
||||
if pkt == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert packet to bytes
|
||||
vv := pkt.ToView()
|
||||
packet := vv.AsSlice()
|
||||
|
||||
// Write to TUN
|
||||
if err := nc.packetHandler.WritePacket(packet); err != nil {
|
||||
// Log error
|
||||
continue
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
115
cmd/ctrld_library/netstack/packet_handler.go
Normal file
115
cmd/ctrld_library/netstack/packet_handler.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PacketHandler defines the interface for reading and writing raw IP packets
|
||||
// from/to the mobile TUN interface.
|
||||
type PacketHandler interface {
|
||||
// ReadPacket reads a raw IP packet from the TUN interface.
|
||||
// This should be a blocking call.
|
||||
ReadPacket() ([]byte, error)
|
||||
|
||||
// WritePacket writes a raw IP packet back to the TUN interface.
|
||||
WritePacket(packet []byte) error
|
||||
|
||||
// Close closes the packet handler and releases resources.
|
||||
Close() error
|
||||
|
||||
// ProtectSocket protects a socket file descriptor from being routed through the VPN.
|
||||
// This is required on Android/iOS to prevent routing loops.
|
||||
// Returns nil if successful, error otherwise.
|
||||
ProtectSocket(fd int) error
|
||||
}
|
||||
|
||||
// MobilePacketHandler implements PacketHandler using callbacks from mobile platforms.
|
||||
// This bridges Go Mobile interface with the netstack implementation.
|
||||
type MobilePacketHandler struct {
|
||||
readFunc func() ([]byte, error)
|
||||
writeFunc func([]byte) error
|
||||
closeFunc func() error
|
||||
protectFunc func(int) error
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewMobilePacketHandler creates a new packet handler with mobile callbacks.
|
||||
func NewMobilePacketHandler(
|
||||
readFunc func() ([]byte, error),
|
||||
writeFunc func([]byte) error,
|
||||
closeFunc func() error,
|
||||
protectFunc func(int) error,
|
||||
) *MobilePacketHandler {
|
||||
return &MobilePacketHandler{
|
||||
readFunc: readFunc,
|
||||
writeFunc: writeFunc,
|
||||
closeFunc: closeFunc,
|
||||
protectFunc: protectFunc,
|
||||
closed: false,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadPacket reads a packet from mobile TUN interface.
|
||||
func (m *MobilePacketHandler) ReadPacket() ([]byte, error) {
|
||||
m.mu.Lock()
|
||||
closed := m.closed
|
||||
m.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return nil, fmt.Errorf("packet handler is closed")
|
||||
}
|
||||
|
||||
if m.readFunc == nil {
|
||||
return nil, fmt.Errorf("read function not set")
|
||||
}
|
||||
|
||||
return m.readFunc()
|
||||
}
|
||||
|
||||
// WritePacket writes a packet back to mobile TUN interface.
|
||||
func (m *MobilePacketHandler) WritePacket(packet []byte) error {
|
||||
m.mu.Lock()
|
||||
closed := m.closed
|
||||
m.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return fmt.Errorf("packet handler is closed")
|
||||
}
|
||||
|
||||
if m.writeFunc == nil {
|
||||
return fmt.Errorf("write function not set")
|
||||
}
|
||||
|
||||
return m.writeFunc(packet)
|
||||
}
|
||||
|
||||
// Close closes the packet handler.
|
||||
func (m *MobilePacketHandler) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.closed = true
|
||||
|
||||
if m.closeFunc != nil {
|
||||
return m.closeFunc()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProtectSocket protects a socket file descriptor from VPN routing.
|
||||
func (m *MobilePacketHandler) ProtectSocket(fd int) error {
|
||||
if m.protectFunc == nil {
|
||||
// No protect function provided - this is okay for non-VPN scenarios
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.protectFunc(fd)
|
||||
}
|
||||
118
cmd/ctrld_library/netstack/tcp_forwarder.go
Normal file
118
cmd/ctrld_library/netstack/tcp_forwarder.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
// TCPForwarder handles TCP connections from the TUN interface
|
||||
type TCPForwarder struct {
|
||||
protectSocket func(fd int) error
|
||||
ctx context.Context
|
||||
forwarder *tcp.Forwarder
|
||||
}
|
||||
|
||||
// NewTCPForwarder creates a new TCP forwarder
|
||||
func NewTCPForwarder(s *stack.Stack, protectSocket func(fd int) error, ctx context.Context) *TCPForwarder {
|
||||
f := &TCPForwarder{
|
||||
protectSocket: protectSocket,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
// Create gVisor TCP forwarder with handler callback
|
||||
// rcvWnd=0 (default), maxInFlight=1024
|
||||
f.forwarder = tcp.NewForwarder(s, 0, 1024, f.handleRequest)
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
// GetForwarder returns the underlying gVisor forwarder
|
||||
func (f *TCPForwarder) GetForwarder() *tcp.Forwarder {
|
||||
return f.forwarder
|
||||
}
|
||||
|
||||
// handleRequest handles an incoming TCP connection request
|
||||
func (f *TCPForwarder) handleRequest(req *tcp.ForwarderRequest) {
|
||||
// Get the endpoint ID
|
||||
id := req.ID()
|
||||
|
||||
// Create waiter queue
|
||||
var wq waiter.Queue
|
||||
|
||||
// Create endpoint from request
|
||||
ep, err := req.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
req.Complete(true) // Send RST
|
||||
return
|
||||
}
|
||||
|
||||
// Accept the connection
|
||||
req.Complete(false)
|
||||
|
||||
// Cast to TCP endpoint
|
||||
tcpEP, ok := ep.(*tcp.Endpoint)
|
||||
if !ok {
|
||||
ep.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Handle in goroutine
|
||||
go f.handleConnection(tcpEP, &wq, id)
|
||||
}
|
||||
|
||||
func (f *TCPForwarder) handleConnection(ep *tcp.Endpoint, wq *waiter.Queue, id stack.TransportEndpointID) {
|
||||
// Convert endpoint to Go net.Conn
|
||||
tunConn := gonet.NewTCPConn(wq, ep)
|
||||
defer tunConn.Close()
|
||||
|
||||
// In gVisor's TransportEndpointID for an inbound connection:
|
||||
// - LocalAddress/LocalPort = the destination (where packet is going TO)
|
||||
// - RemoteAddress/RemotePort = the source (where packet is coming FROM)
|
||||
// We want to dial the DESTINATION (LocalAddress/LocalPort)
|
||||
dstAddr := net.TCPAddr{
|
||||
IP: net.IP(id.LocalAddress.AsSlice()),
|
||||
Port: int(id.LocalPort),
|
||||
}
|
||||
|
||||
// Create outbound connection with socket protection DURING dial
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// CRITICAL: Protect socket BEFORE connect() is called
|
||||
if f.protectSocket != nil {
|
||||
dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
f.protectSocket(int(fd))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
upstreamConn, err := dialer.DialContext(f.ctx, "tcp", dstAddr.String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer upstreamConn.Close()
|
||||
|
||||
// Bidirectional copy
|
||||
done := make(chan struct{}, 2)
|
||||
go func() {
|
||||
io.Copy(upstreamConn, tunConn)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
io.Copy(tunConn, upstreamConn)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
// Wait for one direction to finish
|
||||
<-done
|
||||
}
|
||||
257
cmd/ctrld_library/netstack/udp_forwarder.go
Normal file
257
cmd/ctrld_library/netstack/udp_forwarder.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
// UDPForwarder handles UDP packets from the TUN interface
|
||||
type UDPForwarder struct {
|
||||
protectSocket func(fd int) error
|
||||
ctx context.Context
|
||||
forwarder *udp.Forwarder
|
||||
|
||||
// Track UDP "connections" (address pairs)
|
||||
connections map[string]*udpConn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type udpConn struct {
|
||||
tunEP *gonet.UDPConn
|
||||
upstreamConn *net.UDPConn
|
||||
lastActivity time.Time
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewUDPForwarder creates a new UDP forwarder
|
||||
func NewUDPForwarder(s *stack.Stack, protectSocket func(fd int) error, ctx context.Context) *UDPForwarder {
|
||||
f := &UDPForwarder{
|
||||
protectSocket: protectSocket,
|
||||
ctx: ctx,
|
||||
connections: make(map[string]*udpConn),
|
||||
}
|
||||
|
||||
// Create gVisor UDP forwarder with handler callback
|
||||
f.forwarder = udp.NewForwarder(s, f.handlePacket)
|
||||
|
||||
// Start cleanup goroutine
|
||||
go f.cleanupStaleConnections()
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
// GetForwarder returns the underlying gVisor forwarder
|
||||
func (f *UDPForwarder) GetForwarder() *udp.Forwarder {
|
||||
return f.forwarder
|
||||
}
|
||||
|
||||
// handlePacket handles an incoming UDP packet
|
||||
func (f *UDPForwarder) handlePacket(req *udp.ForwarderRequest) {
|
||||
// Get the endpoint ID
|
||||
id := req.ID()
|
||||
|
||||
// Create connection key (source -> destination)
|
||||
connKey := fmt.Sprintf("%s:%d->%s:%d",
|
||||
net.IP(id.RemoteAddress.AsSlice()),
|
||||
id.RemotePort,
|
||||
net.IP(id.LocalAddress.AsSlice()),
|
||||
id.LocalPort,
|
||||
)
|
||||
|
||||
f.mu.Lock()
|
||||
conn, exists := f.connections[connKey]
|
||||
if !exists {
|
||||
// Create new connection
|
||||
conn = f.createConnection(req, connKey)
|
||||
if conn == nil {
|
||||
f.mu.Unlock()
|
||||
return
|
||||
}
|
||||
f.connections[connKey] = conn
|
||||
}
|
||||
conn.lastActivity = time.Now()
|
||||
f.mu.Unlock()
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) createConnection(req *udp.ForwarderRequest, connKey string) *udpConn {
|
||||
id := req.ID()
|
||||
|
||||
// Create waiter queue
|
||||
var wq waiter.Queue
|
||||
|
||||
// Create endpoint from request
|
||||
ep, err := req.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert to Go UDP conn
|
||||
tunConn := gonet.NewUDPConn(&wq, ep)
|
||||
|
||||
// Extract destination address
|
||||
// LocalAddress/LocalPort = destination (where packet is going TO)
|
||||
// RemoteAddress/RemotePort = source (where packet is coming FROM)
|
||||
dstAddr := &net.UDPAddr{
|
||||
IP: net.IP(id.LocalAddress.AsSlice()),
|
||||
Port: int(id.LocalPort),
|
||||
}
|
||||
|
||||
// Create dialer with socket protection DURING dial
|
||||
dialer := &net.Dialer{}
|
||||
|
||||
// CRITICAL: Protect socket BEFORE connect() is called
|
||||
if f.protectSocket != nil {
|
||||
dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
f.protectSocket(int(fd))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Create outbound UDP connection
|
||||
dialConn, dialErr := dialer.Dial("udp", dstAddr.String())
|
||||
if dialErr != nil {
|
||||
tunConn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
upstreamConn, ok := dialConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
dialConn.Close()
|
||||
tunConn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create connection context
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
|
||||
udpConnection := &udpConn{
|
||||
tunEP: tunConn,
|
||||
upstreamConn: upstreamConn,
|
||||
lastActivity: time.Now(),
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Start forwarding goroutines
|
||||
go f.forwardTunToUpstream(udpConnection, ctx)
|
||||
go f.forwardUpstreamToTun(udpConnection, ctx, connKey)
|
||||
|
||||
return udpConnection
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) forwardTunToUpstream(conn *udpConn, ctx context.Context) {
|
||||
buffer := make([]byte, 65535)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read from TUN
|
||||
n, err := conn.tunEP.Read(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Write to upstream
|
||||
_, err = conn.upstreamConn.Write(buffer[:n])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
conn.lastActivity = time.Now()
|
||||
f.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) forwardUpstreamToTun(conn *udpConn, ctx context.Context, connKey string) {
|
||||
defer func() {
|
||||
conn.tunEP.Close()
|
||||
conn.upstreamConn.Close()
|
||||
|
||||
f.mu.Lock()
|
||||
delete(f.connections, connKey)
|
||||
f.mu.Unlock()
|
||||
}()
|
||||
|
||||
buffer := make([]byte, 65535)
|
||||
|
||||
// Set read timeout
|
||||
conn.upstreamConn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read from upstream
|
||||
n, err := conn.upstreamConn.Read(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Reset read deadline
|
||||
conn.upstreamConn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
// Write to TUN
|
||||
_, err = conn.tunEP.Write(buffer[:n])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
conn.lastActivity = time.Now()
|
||||
f.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) cleanupStaleConnections() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-f.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
f.mu.Lock()
|
||||
now := time.Now()
|
||||
for key, conn := range f.connections {
|
||||
if now.Sub(conn.lastActivity) > 60*time.Second {
|
||||
conn.cancel()
|
||||
conn.tunEP.Close()
|
||||
conn.upstreamConn.Close()
|
||||
delete(f.connections, key)
|
||||
}
|
||||
}
|
||||
f.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes all UDP connections
|
||||
func (f *UDPForwarder) Close() {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
for _, conn := range f.connections {
|
||||
conn.cancel()
|
||||
conn.tunEP.Close()
|
||||
conn.upstreamConn.Close()
|
||||
}
|
||||
f.connections = make(map[string]*udpConn)
|
||||
}
|
||||
247
cmd/ctrld_library/packet_capture.go
Normal file
247
cmd/ctrld_library/packet_capture.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package ctrld_library
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
"github.com/Control-D-Inc/ctrld/cmd/ctrld_library/netstack"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// PacketAppCallback extends AppCallback with packet read/write capabilities.
|
||||
// Mobile platforms implementing full packet capture should use this interface.
|
||||
type PacketAppCallback interface {
|
||||
AppCallback
|
||||
|
||||
// ReadPacket reads a raw IP packet from the TUN interface.
|
||||
// This should be a blocking call that returns when a packet is available.
|
||||
ReadPacket() ([]byte, error)
|
||||
|
||||
// WritePacket writes a raw IP packet back to the TUN interface.
|
||||
WritePacket(packet []byte) error
|
||||
|
||||
// ClosePacketIO closes packet I/O resources.
|
||||
ClosePacketIO() error
|
||||
|
||||
// ProtectSocket protects a socket file descriptor from being routed through the VPN.
|
||||
// On Android, this calls VpnService.protect() to prevent routing loops.
|
||||
// On iOS, this marks the socket to bypass the VPN.
|
||||
// Returns nil on success, error on failure.
|
||||
ProtectSocket(fd int) error
|
||||
}
|
||||
|
||||
// PacketCaptureController holds state for packet capture mode
|
||||
type PacketCaptureController struct {
|
||||
baseController *Controller
|
||||
|
||||
// Packet capture mode fields
|
||||
netstackCtrl *netstack.NetstackController
|
||||
dnsBridge *netstack.DNSBridge
|
||||
packetStopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewPacketCaptureController creates a new packet capture controller
|
||||
func NewPacketCaptureController(appCallback PacketAppCallback) *PacketCaptureController {
|
||||
return &PacketCaptureController{
|
||||
baseController: &Controller{AppCallback: appCallback},
|
||||
packetStopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// StartWithPacketCapture starts ctrld in full packet capture mode for mobile.
|
||||
// This method enables full IP packet processing with DNS filtering and upstream routing.
|
||||
// It requires a PacketAppCallback that provides packet read/write capabilities.
|
||||
func (pc *PacketCaptureController) StartWithPacketCapture(
|
||||
packetCallback PacketAppCallback,
|
||||
CdUID string,
|
||||
ProvisionID string,
|
||||
CustomHostname string,
|
||||
HomeDir string,
|
||||
UpstreamProto string,
|
||||
logLevel int,
|
||||
logPath string,
|
||||
) error {
|
||||
if pc.baseController.stopCh != nil {
|
||||
return fmt.Errorf("controller already running")
|
||||
}
|
||||
|
||||
// Set up configuration
|
||||
pc.baseController.Config = cli.AppConfig{
|
||||
CdUID: CdUID,
|
||||
ProvisionID: ProvisionID,
|
||||
CustomHostname: CustomHostname,
|
||||
HomeDir: HomeDir,
|
||||
UpstreamProto: UpstreamProto,
|
||||
Verbose: logLevel,
|
||||
LogPath: logPath,
|
||||
}
|
||||
pc.baseController.AppCallback = packetCallback
|
||||
|
||||
// Set global socket protector for HTTP client sockets (API calls, etc)
|
||||
// This prevents routing loops when ctrld makes HTTP requests to api.controld.com
|
||||
ctrld.SetSocketProtector(packetCallback.ProtectSocket)
|
||||
|
||||
// Create DNS bridge for communication between netstack and DNS proxy
|
||||
pc.dnsBridge = netstack.NewDNSBridge()
|
||||
pc.dnsBridge.Start()
|
||||
|
||||
// Create packet handler that wraps the mobile callbacks
|
||||
packetHandler := netstack.NewMobilePacketHandler(
|
||||
packetCallback.ReadPacket,
|
||||
packetCallback.WritePacket,
|
||||
packetCallback.ClosePacketIO,
|
||||
packetCallback.ProtectSocket,
|
||||
)
|
||||
|
||||
// Create DNS handler that uses the bridge
|
||||
dnsHandler := func(query []byte) ([]byte, error) {
|
||||
// Extract source IP from query context if available
|
||||
// For now, use a placeholder
|
||||
return pc.dnsBridge.ProcessQuery(query, "10.0.0.2", 0)
|
||||
}
|
||||
|
||||
// Create netstack configuration
|
||||
tunIPv4, err := netip.ParseAddr("10.0.0.1")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse TUN IPv4: %v", err)
|
||||
}
|
||||
|
||||
netstackCfg := &netstack.Config{
|
||||
MTU: 1500,
|
||||
TUNIPv4: tunIPv4,
|
||||
DNSHandler: dnsHandler,
|
||||
UpstreamInterface: nil, // Will use default interface
|
||||
}
|
||||
|
||||
// Create netstack controller
|
||||
netstackCtrl, err := netstack.NewNetstackController(packetHandler, netstackCfg)
|
||||
if err != nil {
|
||||
pc.dnsBridge.Stop()
|
||||
return fmt.Errorf("failed to create netstack controller: %v", err)
|
||||
}
|
||||
|
||||
pc.netstackCtrl = netstackCtrl
|
||||
|
||||
// Start netstack processing
|
||||
if err := pc.netstackCtrl.Start(); err != nil {
|
||||
pc.dnsBridge.Stop()
|
||||
return fmt.Errorf("failed to start netstack: %v", err)
|
||||
}
|
||||
|
||||
// Start regular ctrld DNS processing in background
|
||||
// This allows us to use existing DNS filtering logic
|
||||
pc.baseController.stopCh = make(chan struct{})
|
||||
|
||||
// Start DNS query processor that receives queries from the bridge
|
||||
// and sends them to the ctrld DNS proxy
|
||||
go pc.processDNSQueries()
|
||||
|
||||
// Start the main ctrld mobile runner
|
||||
go func() {
|
||||
appCallback := mapCallback(pc.baseController.AppCallback)
|
||||
cli.RunMobile(&pc.baseController.Config, &appCallback, pc.baseController.stopCh)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processDNSQueries processes DNS queries from the bridge using the ctrld DNS proxy
|
||||
func (pc *PacketCaptureController) processDNSQueries() {
|
||||
queryCh := pc.dnsBridge.GetQueryChannel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pc.packetStopCh:
|
||||
return
|
||||
case <-pc.baseController.stopCh:
|
||||
return
|
||||
case query := <-queryCh:
|
||||
go pc.handleDNSQuery(query)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleDNSQuery handles a single DNS query
|
||||
func (pc *PacketCaptureController) handleDNSQuery(query *netstack.DNSQuery) {
|
||||
// Parse DNS message
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(query.Query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Send query to actual DNS proxy running on localhost:5354
|
||||
client := &dns.Client{
|
||||
Net: "udp",
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
response, _, err := client.Exchange(msg, "127.0.0.1:5354")
|
||||
if err != nil {
|
||||
// Create SERVFAIL response
|
||||
response = new(dns.Msg)
|
||||
response.SetReply(msg)
|
||||
response.Rcode = dns.RcodeServerFailure
|
||||
}
|
||||
|
||||
// Pack response
|
||||
responseBytes, err := response.Pack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Send response back through bridge
|
||||
pc.dnsBridge.SendResponse(query.ID, responseBytes)
|
||||
}
|
||||
|
||||
// Stop stops the packet capture controller
|
||||
func (pc *PacketCaptureController) Stop(restart bool, pin int64) int {
|
||||
var errorCode = 0
|
||||
|
||||
// Clear global socket protector
|
||||
ctrld.SetSocketProtector(nil)
|
||||
|
||||
// Stop DNS bridge
|
||||
if pc.dnsBridge != nil {
|
||||
pc.dnsBridge.Stop()
|
||||
pc.dnsBridge = nil
|
||||
}
|
||||
|
||||
// Stop netstack
|
||||
if pc.netstackCtrl != nil {
|
||||
if err := pc.netstackCtrl.Stop(); err != nil {
|
||||
// Log error but continue shutdown
|
||||
fmt.Printf("Error stopping netstack: %v\n", err)
|
||||
}
|
||||
pc.netstackCtrl = nil
|
||||
}
|
||||
|
||||
// Close packet stop channel
|
||||
if pc.packetStopCh != nil {
|
||||
close(pc.packetStopCh)
|
||||
pc.packetStopCh = make(chan struct{})
|
||||
}
|
||||
|
||||
// Stop base controller
|
||||
if !restart {
|
||||
errorCode = cli.CheckDeactivationPin(pin, pc.baseController.stopCh)
|
||||
}
|
||||
if errorCode == 0 && pc.baseController.stopCh != nil {
|
||||
close(pc.baseController.stopCh)
|
||||
pc.baseController.stopCh = nil
|
||||
}
|
||||
|
||||
return errorCode
|
||||
}
|
||||
|
||||
// IsRunning returns true if the controller is running
|
||||
func (pc *PacketCaptureController) IsRunning() bool {
|
||||
return pc.baseController.stopCh != nil
|
||||
}
|
||||
|
||||
// IsPacketMode returns true (always in packet mode for this controller)
|
||||
func (pc *PacketCaptureController) IsPacketMode() bool {
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user