mirror of
https://github.com/zhom/donutbrowser.git
synced 2026-06-11 17:27:54 +02:00
742 lines
23 KiB
Rust
742 lines
23 KiB
Rust
use super::config::{VpnError, WireGuardConfig};
|
|
use boringtun::noise::{Tunn, TunnResult};
|
|
use boringtun::x25519::{PublicKey, StaticSecret};
|
|
use smoltcp::iface::{Config as IfaceConfig, Interface, SocketHandle, SocketSet};
|
|
use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken};
|
|
use smoltcp::socket::tcp::{Socket as TcpSocket, SocketBuffer};
|
|
use smoltcp::time::Instant as SmolInstant;
|
|
use smoltcp::wire::{HardwareAddress, IpAddress, IpCidr, Ipv4Address};
|
|
use std::collections::VecDeque;
|
|
use std::net::{SocketAddr, ToSocketAddrs, UdpSocket};
|
|
use std::sync::{Arc, Mutex};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
|
|
const SMOLTCP_TCP_RX_BUF: usize = 65536;
|
|
const SMOLTCP_TCP_TX_BUF: usize = 65536;
|
|
|
|
struct WgDevice {
|
|
tunn: Arc<Mutex<Box<Tunn>>>,
|
|
udp_socket: Arc<UdpSocket>,
|
|
peer_addr: SocketAddr,
|
|
rx_queue: VecDeque<Vec<u8>>,
|
|
tx_queue: VecDeque<Vec<u8>>,
|
|
}
|
|
|
|
impl WgDevice {
|
|
fn pump_wg_to_rx(&mut self) {
|
|
let mut recv_buf = vec![0u8; 2048];
|
|
loop {
|
|
match self.udp_socket.recv_from(&mut recv_buf) {
|
|
Ok((len, _)) => {
|
|
let mut dst = vec![0u8; 2048];
|
|
let mut tunn = self.tunn.lock().unwrap();
|
|
let result = tunn.decapsulate(None, &recv_buf[..len], &mut dst);
|
|
match result {
|
|
TunnResult::WriteToTunnelV4(data, _) | TunnResult::WriteToTunnelV6(data, _) => {
|
|
self.rx_queue.push_back(data.to_vec());
|
|
}
|
|
TunnResult::WriteToNetwork(response) => {
|
|
let _ = self.udp_socket.send_to(response, self.peer_addr);
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn flush_tx_queue(&mut self) {
|
|
while let Some(ip_packet) = self.tx_queue.pop_front() {
|
|
let mut dst = vec![0u8; ip_packet.len() + 256];
|
|
let mut tunn = self.tunn.lock().unwrap();
|
|
let result = tunn.encapsulate(&ip_packet, &mut dst);
|
|
match result {
|
|
TunnResult::WriteToNetwork(packet) => {
|
|
if let Err(e) = self.udp_socket.send_to(packet, self.peer_addr) {
|
|
log::error!("[wg] udp send_to failed: {e}");
|
|
}
|
|
}
|
|
TunnResult::Done => {
|
|
// boringtun has nothing to send right now (e.g. handshake not yet
|
|
// complete); silently drop. smoltcp will retransmit.
|
|
}
|
|
TunnResult::Err(e) => {
|
|
log::error!(
|
|
"[wg] encapsulate error for {}B IP packet: {e:?}",
|
|
ip_packet.len()
|
|
);
|
|
}
|
|
TunnResult::WriteToTunnelV4(_, _) | TunnResult::WriteToTunnelV6(_, _) => {
|
|
log::error!("[wg] encapsulate returned unexpected WriteToTunnel — bug?");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn tick_timers(&mut self) {
|
|
let mut dst = vec![0u8; 2048];
|
|
let mut tunn = self.tunn.lock().unwrap();
|
|
let result = tunn.update_timers(&mut dst);
|
|
if let TunnResult::WriteToNetwork(packet) = result {
|
|
let _ = self.udp_socket.send_to(packet, self.peer_addr);
|
|
}
|
|
}
|
|
}
|
|
|
|
struct WgRxToken {
|
|
data: Vec<u8>,
|
|
}
|
|
|
|
impl RxToken for WgRxToken {
|
|
fn consume<R, F>(self, f: F) -> R
|
|
where
|
|
F: FnOnce(&[u8]) -> R,
|
|
{
|
|
f(&self.data)
|
|
}
|
|
}
|
|
|
|
struct WgTxToken<'a> {
|
|
tx_queue: &'a mut VecDeque<Vec<u8>>,
|
|
}
|
|
|
|
impl<'a> TxToken for WgTxToken<'a> {
|
|
fn consume<R, F>(self, len: usize, f: F) -> R
|
|
where
|
|
F: FnOnce(&mut [u8]) -> R,
|
|
{
|
|
let mut buf = vec![0u8; len];
|
|
let result = f(&mut buf);
|
|
self.tx_queue.push_back(buf);
|
|
result
|
|
}
|
|
}
|
|
|
|
impl Device for WgDevice {
|
|
type RxToken<'a> = WgRxToken;
|
|
type TxToken<'a> = WgTxToken<'a>;
|
|
|
|
fn receive(&mut self, _timestamp: SmolInstant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
|
|
if let Some(data) = self.rx_queue.pop_front() {
|
|
Some((
|
|
WgRxToken { data },
|
|
WgTxToken {
|
|
tx_queue: &mut self.tx_queue,
|
|
},
|
|
))
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
fn transmit(&mut self, _timestamp: SmolInstant) -> Option<Self::TxToken<'_>> {
|
|
Some(WgTxToken {
|
|
tx_queue: &mut self.tx_queue,
|
|
})
|
|
}
|
|
|
|
fn capabilities(&self) -> DeviceCapabilities {
|
|
let mut caps = DeviceCapabilities::default();
|
|
caps.medium = Medium::Ip;
|
|
caps.max_transmission_unit = 1420;
|
|
caps
|
|
}
|
|
}
|
|
|
|
fn parse_key(key: &str) -> Result<[u8; 32], VpnError> {
|
|
let decoded = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, key)
|
|
.map_err(|e| VpnError::InvalidWireGuard(format!("Invalid key encoding: {e}")))?;
|
|
if decoded.len() != 32 {
|
|
return Err(VpnError::InvalidWireGuard(format!(
|
|
"Invalid key length: {} (expected 32)",
|
|
decoded.len()
|
|
)));
|
|
}
|
|
let mut key_bytes = [0u8; 32];
|
|
key_bytes.copy_from_slice(&decoded);
|
|
Ok(key_bytes)
|
|
}
|
|
|
|
fn parse_cidr_address(addr: &str) -> Result<(IpCidr, IpAddress), VpnError> {
|
|
let first_addr = addr.split(',').next().unwrap_or(addr).trim();
|
|
|
|
let parts: Vec<&str> = first_addr.split('/').collect();
|
|
let ip_str = parts[0];
|
|
let prefix = if parts.len() > 1 {
|
|
parts[1]
|
|
.parse::<u8>()
|
|
.map_err(|_| VpnError::InvalidWireGuard(format!("Invalid prefix length: {}", parts[1])))?
|
|
} else {
|
|
32
|
|
};
|
|
|
|
let ip: std::net::IpAddr = ip_str
|
|
.parse()
|
|
.map_err(|_| VpnError::InvalidWireGuard(format!("Invalid IP address: {ip_str}")))?;
|
|
|
|
match ip {
|
|
std::net::IpAddr::V4(v4) => {
|
|
let smol_ip = Ipv4Address::new(
|
|
v4.octets()[0],
|
|
v4.octets()[1],
|
|
v4.octets()[2],
|
|
v4.octets()[3],
|
|
);
|
|
Ok((
|
|
IpCidr::new(IpAddress::Ipv4(smol_ip), prefix),
|
|
IpAddress::Ipv4(smol_ip),
|
|
))
|
|
}
|
|
std::net::IpAddr::V6(v6) => {
|
|
let smol_ip = smoltcp::wire::Ipv6Address::from(v6.octets());
|
|
Ok((
|
|
IpCidr::new(IpAddress::Ipv6(smol_ip), prefix),
|
|
IpAddress::Ipv6(smol_ip),
|
|
))
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct WireGuardSocks5Server {
|
|
config: WireGuardConfig,
|
|
port: u16,
|
|
}
|
|
|
|
impl WireGuardSocks5Server {
|
|
pub fn new(config: WireGuardConfig, port: u16) -> Self {
|
|
Self { config, port }
|
|
}
|
|
|
|
fn create_tunnel(&self) -> Result<Box<Tunn>, VpnError> {
|
|
let private_key_bytes = parse_key(&self.config.private_key)?;
|
|
let static_private = StaticSecret::from(private_key_bytes);
|
|
|
|
let peer_public_bytes = parse_key(&self.config.peer_public_key)?;
|
|
let peer_public = PublicKey::from(peer_public_bytes);
|
|
|
|
let preshared_key = if let Some(ref psk) = self.config.preshared_key {
|
|
Some(parse_key(psk)?)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
Ok(Box::new(Tunn::new(
|
|
static_private,
|
|
peer_public,
|
|
preshared_key,
|
|
self.config.persistent_keepalive,
|
|
0,
|
|
None,
|
|
)))
|
|
}
|
|
|
|
fn resolve_endpoint(&self) -> Result<SocketAddr, VpnError> {
|
|
self
|
|
.config
|
|
.peer_endpoint
|
|
.to_socket_addrs()
|
|
.map_err(|e| {
|
|
VpnError::Connection(format!(
|
|
"Failed to resolve endpoint '{}': {e}",
|
|
self.config.peer_endpoint
|
|
))
|
|
})?
|
|
.next()
|
|
.ok_or_else(|| {
|
|
VpnError::Connection(format!(
|
|
"No addresses found for endpoint: {}",
|
|
self.config.peer_endpoint
|
|
))
|
|
})
|
|
}
|
|
|
|
fn do_handshake(
|
|
tunn: &mut Tunn,
|
|
socket: &UdpSocket,
|
|
peer_addr: SocketAddr,
|
|
) -> Result<(), VpnError> {
|
|
socket
|
|
.set_read_timeout(Some(std::time::Duration::from_secs(5)))
|
|
.map_err(|e| VpnError::Connection(format!("Failed to set timeout: {e}")))?;
|
|
|
|
// WireGuard handshakes use UDP which can silently lose packets, especially
|
|
// through Docker port-forwarding layers. Retry the handshake initiation up
|
|
// to 5 times (25s total) before giving up — the protocol is designed for
|
|
// retransmission and peers handle duplicate initiations correctly.
|
|
let max_attempts = 5;
|
|
let mut last_error = String::from("no handshake attempt completed");
|
|
|
|
for attempt in 1..=max_attempts {
|
|
let mut dst = vec![0u8; 2048];
|
|
let result = tunn.format_handshake_initiation(&mut dst, false);
|
|
|
|
match result {
|
|
TunnResult::WriteToNetwork(packet) => {
|
|
socket
|
|
.send_to(packet, peer_addr)
|
|
.map_err(|e| VpnError::Connection(format!("Failed to send handshake: {e}")))?;
|
|
}
|
|
TunnResult::Err(e) => {
|
|
return Err(VpnError::Tunnel(format!(
|
|
"Handshake initiation failed: {e:?}"
|
|
)));
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
let mut recv_buf = vec![0u8; 2048];
|
|
match socket.recv_from(&mut recv_buf) {
|
|
Ok((len, _)) => {
|
|
let result = tunn.decapsulate(None, &recv_buf[..len], &mut dst);
|
|
match result {
|
|
TunnResult::WriteToNetwork(response) => {
|
|
socket
|
|
.send_to(response, peer_addr)
|
|
.map_err(|e| VpnError::Connection(format!("Failed to send response: {e}")))?;
|
|
}
|
|
TunnResult::Done => {}
|
|
TunnResult::Err(e) => {
|
|
last_error = format!("handshake response error: {e:?}");
|
|
log::warn!(
|
|
"[vpn-worker] Handshake attempt {attempt}/{max_attempts} failed: {last_error}"
|
|
);
|
|
continue;
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
socket
|
|
.set_read_timeout(None)
|
|
.map_err(|e| VpnError::Connection(format!("Failed to clear timeout: {e}")))?;
|
|
return Ok(());
|
|
}
|
|
Err(e) if attempt < max_attempts => {
|
|
log::warn!(
|
|
"[vpn-worker] Handshake attempt {attempt}/{max_attempts} timed out: {e}, retrying"
|
|
);
|
|
last_error = format!("timeout: {e}");
|
|
continue;
|
|
}
|
|
Err(e) => {
|
|
last_error = format!("timeout: {e}");
|
|
}
|
|
}
|
|
}
|
|
|
|
Err(VpnError::Connection(format!(
|
|
"Handshake failed after {max_attempts} attempts: {last_error}"
|
|
)))
|
|
}
|
|
|
|
pub async fn run(
|
|
self,
|
|
config_id: String,
|
|
config_path: Option<std::path::PathBuf>,
|
|
) -> Result<(), VpnError> {
|
|
let peer_addr = self.resolve_endpoint()?;
|
|
let mut tunn = self.create_tunnel()?;
|
|
|
|
let udp_socket = UdpSocket::bind("0.0.0.0:0")
|
|
.map_err(|e| VpnError::Connection(format!("Failed to create UDP socket: {e}")))?;
|
|
|
|
Self::do_handshake(&mut tunn, &udp_socket, peer_addr)?;
|
|
|
|
udp_socket
|
|
.set_nonblocking(true)
|
|
.map_err(|e| VpnError::Connection(format!("Failed to set non-blocking: {e}")))?;
|
|
|
|
log::info!("[vpn-worker] WireGuard handshake completed");
|
|
|
|
let (cidr, local_ip) = parse_cidr_address(&self.config.address)?;
|
|
|
|
let tunn_arc = Arc::new(Mutex::new(tunn));
|
|
let udp_arc = Arc::new(udp_socket);
|
|
|
|
let mut device = WgDevice {
|
|
tunn: tunn_arc.clone(),
|
|
udp_socket: udp_arc.clone(),
|
|
peer_addr,
|
|
rx_queue: VecDeque::new(),
|
|
tx_queue: VecDeque::new(),
|
|
};
|
|
|
|
let iface_config = IfaceConfig::new(HardwareAddress::Ip);
|
|
let mut iface = Interface::new(iface_config, &mut device, SmolInstant::now());
|
|
iface.update_ip_addrs(|addrs| {
|
|
let _ = addrs.push(cidr);
|
|
});
|
|
|
|
// Set default gateway
|
|
match local_ip {
|
|
IpAddress::Ipv4(v4) => {
|
|
let octets = v4.octets();
|
|
let gw = Ipv4Address::new(octets[0], octets[1], octets[2], 1);
|
|
iface
|
|
.routes_mut()
|
|
.add_default_ipv4_route(gw)
|
|
.map_err(|e| VpnError::Tunnel(format!("Failed to add default route: {e}")))?;
|
|
}
|
|
IpAddress::Ipv6(_) => {
|
|
// IPv6 routing not yet implemented
|
|
}
|
|
}
|
|
|
|
let listener = TcpListener::bind(format!("127.0.0.1:{}", self.port))
|
|
.await
|
|
.map_err(|e| VpnError::Connection(format!("Failed to bind SOCKS5 listener: {e}")))?;
|
|
|
|
let actual_port = listener
|
|
.local_addr()
|
|
.map_err(|e| VpnError::Connection(format!("Failed to get local addr: {e}")))?
|
|
.port();
|
|
|
|
// Update config with actual port and local_url. Prefer the explicit
|
|
// config path the worker was started with — see issue #287, where
|
|
// get_storage_dir() in the worker process resolved to a different
|
|
// directory than in the parent (Qubes/sandboxed Linux), causing the
|
|
// write-back to land in the wrong place and the parent to time out.
|
|
let updated = match &config_path {
|
|
Some(path) => crate::vpn_worker_storage::get_vpn_worker_config_from_path(path)
|
|
.or_else(|| crate::vpn_worker_storage::get_vpn_worker_config(&config_id)),
|
|
None => crate::vpn_worker_storage::get_vpn_worker_config(&config_id),
|
|
};
|
|
if let Some(mut wc) = updated {
|
|
wc.local_port = Some(actual_port);
|
|
wc.local_url = Some(format!("socks5://127.0.0.1:{}", actual_port));
|
|
let result = match &config_path {
|
|
Some(path) => crate::vpn_worker_storage::save_vpn_worker_config_to_path(&wc, path)
|
|
.map_err(|e| e.to_string()),
|
|
None => crate::vpn_worker_storage::save_vpn_worker_config(&wc).map_err(|e| e.to_string()),
|
|
};
|
|
if let Err(e) = result {
|
|
log::error!(
|
|
"[vpn-worker] Failed to write back local_url to config: {} (path={:?})",
|
|
e,
|
|
config_path
|
|
);
|
|
}
|
|
} else {
|
|
log::error!(
|
|
"[vpn-worker] Could not load worker config for write-back (id={}, path={:?})",
|
|
config_id,
|
|
config_path
|
|
);
|
|
}
|
|
|
|
log::info!(
|
|
"[vpn-worker] SOCKS5 server listening on 127.0.0.1:{}",
|
|
actual_port
|
|
);
|
|
|
|
let mut sockets = SocketSet::new(vec![]);
|
|
|
|
struct Connection {
|
|
smol_handle: SocketHandle,
|
|
tcp_stream: TcpStream,
|
|
socks_done: bool,
|
|
connecting: bool,
|
|
greeting_done: bool,
|
|
read_buf: Vec<u8>,
|
|
dest_addr: Option<SocketAddr>,
|
|
}
|
|
|
|
let mut connections: Vec<Connection> = Vec::new();
|
|
let mut timer_counter: u64 = 0;
|
|
|
|
loop {
|
|
// Accept new SOCKS5 connections (non-blocking via short timeout)
|
|
if let Ok(Ok((stream, _addr))) =
|
|
tokio::time::timeout(tokio::time::Duration::from_millis(1), listener.accept()).await
|
|
{
|
|
let tcp_rx = SocketBuffer::new(vec![0u8; SMOLTCP_TCP_RX_BUF]);
|
|
let tcp_tx = SocketBuffer::new(vec![0u8; SMOLTCP_TCP_TX_BUF]);
|
|
let tcp_socket = TcpSocket::new(tcp_rx, tcp_tx);
|
|
let handle = sockets.add(tcp_socket);
|
|
|
|
connections.push(Connection {
|
|
smol_handle: handle,
|
|
tcp_stream: stream,
|
|
socks_done: false,
|
|
connecting: false,
|
|
greeting_done: false,
|
|
read_buf: Vec::new(),
|
|
dest_addr: None,
|
|
});
|
|
}
|
|
|
|
// Pump WireGuard packets into smoltcp rx queue
|
|
device.pump_wg_to_rx();
|
|
|
|
// Poll the smoltcp interface
|
|
let timestamp = SmolInstant::now();
|
|
let _changed = iface.poll(timestamp, &mut device, &mut sockets);
|
|
|
|
// Flush encrypted packets out through WireGuard
|
|
device.flush_tx_queue();
|
|
|
|
// Process each connection
|
|
let mut completed = Vec::new();
|
|
for (idx, conn) in connections.iter_mut().enumerate() {
|
|
if conn.connecting {
|
|
let socket = sockets.get_mut::<TcpSocket>(conn.smol_handle);
|
|
if socket.may_send() {
|
|
let _ = conn.tcp_stream.try_write(&[
|
|
0x05,
|
|
0x00,
|
|
0x00,
|
|
0x01,
|
|
127,
|
|
0,
|
|
0,
|
|
1,
|
|
(actual_port >> 8) as u8,
|
|
(actual_port & 0xff) as u8,
|
|
]);
|
|
conn.connecting = false;
|
|
conn.socks_done = true;
|
|
} else if !socket.is_open() {
|
|
let _ = conn
|
|
.tcp_stream
|
|
.try_write(&[0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0]);
|
|
completed.push(idx);
|
|
}
|
|
} else if !conn.socks_done {
|
|
// Handle SOCKS5 handshake
|
|
let mut buf = [0u8; 512];
|
|
match conn.tcp_stream.try_read(&mut buf) {
|
|
Ok(0) => {
|
|
completed.push(idx);
|
|
continue;
|
|
}
|
|
Ok(n) => {
|
|
conn.read_buf.extend_from_slice(&buf[..n]);
|
|
}
|
|
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
|
|
Err(_) => {
|
|
completed.push(idx);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
if !conn.greeting_done && conn.read_buf.len() >= 3 {
|
|
// SOCKS5 greeting: version, nmethods, methods
|
|
if conn.read_buf[0] != 0x05 {
|
|
completed.push(idx);
|
|
continue;
|
|
}
|
|
let nmethods = conn.read_buf[1] as usize;
|
|
if conn.read_buf.len() < 2 + nmethods {
|
|
continue;
|
|
}
|
|
// Reply: no auth required
|
|
if conn.tcp_stream.try_write(&[0x05, 0x00]).is_err() {
|
|
completed.push(idx);
|
|
continue;
|
|
}
|
|
conn.read_buf.drain(..2 + nmethods);
|
|
conn.greeting_done = true;
|
|
}
|
|
|
|
if conn.greeting_done && conn.dest_addr.is_none() && conn.read_buf.len() >= 10 {
|
|
// SOCKS5 connect request
|
|
if conn.read_buf[0] != 0x05 || conn.read_buf[1] != 0x01 {
|
|
completed.push(idx);
|
|
continue;
|
|
}
|
|
|
|
let (addr, addr_len) = match conn.read_buf[3] {
|
|
0x01 => {
|
|
// IPv4
|
|
if conn.read_buf.len() < 10 {
|
|
continue;
|
|
}
|
|
let ip = std::net::Ipv4Addr::new(
|
|
conn.read_buf[4],
|
|
conn.read_buf[5],
|
|
conn.read_buf[6],
|
|
conn.read_buf[7],
|
|
);
|
|
let port = u16::from_be_bytes([conn.read_buf[8], conn.read_buf[9]]);
|
|
(SocketAddr::new(std::net::IpAddr::V4(ip), port), 10)
|
|
}
|
|
0x03 => {
|
|
// Domain name
|
|
let domain_len = conn.read_buf[4] as usize;
|
|
let needed = 4 + 1 + domain_len + 2;
|
|
if conn.read_buf.len() < needed {
|
|
continue;
|
|
}
|
|
let domain = String::from_utf8_lossy(&conn.read_buf[5..5 + domain_len]).to_string();
|
|
let port_start = 5 + domain_len;
|
|
let port =
|
|
u16::from_be_bytes([conn.read_buf[port_start], conn.read_buf[port_start + 1]]);
|
|
// Resolve domain
|
|
match format!("{}:{}", domain, port).to_socket_addrs() {
|
|
Ok(mut addrs) => {
|
|
if let Some(addr) = addrs.next() {
|
|
(addr, needed)
|
|
} else {
|
|
// Send SOCKS5 error: host unreachable
|
|
let _ = conn
|
|
.tcp_stream
|
|
.try_write(&[0x05, 0x04, 0x00, 0x01, 0, 0, 0, 0, 0, 0]);
|
|
completed.push(idx);
|
|
continue;
|
|
}
|
|
}
|
|
Err(_) => {
|
|
let _ = conn
|
|
.tcp_stream
|
|
.try_write(&[0x05, 0x04, 0x00, 0x01, 0, 0, 0, 0, 0, 0]);
|
|
completed.push(idx);
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
0x04 => {
|
|
// IPv6
|
|
if conn.read_buf.len() < 22 {
|
|
continue;
|
|
}
|
|
let mut octets = [0u8; 16];
|
|
octets.copy_from_slice(&conn.read_buf[4..20]);
|
|
let ip = std::net::Ipv6Addr::from(octets);
|
|
let port = u16::from_be_bytes([conn.read_buf[20], conn.read_buf[21]]);
|
|
(SocketAddr::new(std::net::IpAddr::V6(ip), port), 22)
|
|
}
|
|
_ => {
|
|
completed.push(idx);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
conn.read_buf.drain(..addr_len);
|
|
conn.dest_addr = Some(addr);
|
|
|
|
// Open smoltcp TCP socket to the destination
|
|
let socket = sockets.get_mut::<TcpSocket>(conn.smol_handle);
|
|
let smol_addr = match addr.ip() {
|
|
std::net::IpAddr::V4(v4) => {
|
|
let o = v4.octets();
|
|
IpAddress::Ipv4(Ipv4Address::new(o[0], o[1], o[2], o[3]))
|
|
}
|
|
std::net::IpAddr::V6(v6) => {
|
|
IpAddress::Ipv6(smoltcp::wire::Ipv6Address::from(v6.octets()))
|
|
}
|
|
};
|
|
|
|
let local_port = 10000 + (rand::random::<u16>() % 50000);
|
|
if socket
|
|
.connect(iface.context(), (smol_addr, addr.port()), local_port)
|
|
.is_err()
|
|
{
|
|
let _ = conn
|
|
.tcp_stream
|
|
.try_write(&[0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0]);
|
|
completed.push(idx);
|
|
continue;
|
|
}
|
|
|
|
conn.connecting = true;
|
|
}
|
|
} else {
|
|
// Data relay between SOCKS5 client and smoltcp socket
|
|
let socket = sockets.get_mut::<TcpSocket>(conn.smol_handle);
|
|
|
|
// Client → smoltcp
|
|
let mut buf = [0u8; 4096];
|
|
match conn.tcp_stream.try_read(&mut buf) {
|
|
Ok(0) => {
|
|
socket.close();
|
|
completed.push(idx);
|
|
continue;
|
|
}
|
|
Ok(n) => {
|
|
if socket.can_send() {
|
|
let _ = socket.send_slice(&buf[..n]);
|
|
}
|
|
}
|
|
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
|
|
Err(_) => {
|
|
socket.close();
|
|
completed.push(idx);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// smoltcp → Client
|
|
if socket.can_recv() {
|
|
match socket.recv(|data| (data.len(), data.to_vec())) {
|
|
Ok(data) if !data.is_empty() && conn.tcp_stream.try_write(&data).is_err() => {
|
|
socket.close();
|
|
completed.push(idx);
|
|
continue;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
// Check if smoltcp socket closed
|
|
if !socket.is_open() && !socket.is_active() {
|
|
completed.push(idx);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove completed connections (in reverse order)
|
|
completed.sort_unstable();
|
|
completed.dedup();
|
|
for idx in completed.into_iter().rev() {
|
|
let conn = connections.remove(idx);
|
|
sockets.remove(conn.smol_handle);
|
|
}
|
|
|
|
// Timer ticks for WireGuard keepalives
|
|
timer_counter += 1;
|
|
if timer_counter.is_multiple_of(500) {
|
|
device.tick_timers();
|
|
}
|
|
|
|
// Small sleep to avoid busy-spinning
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_parse_cidr_ipv4() {
|
|
let (cidr, ip) = parse_cidr_address("10.0.0.2/24").unwrap();
|
|
assert_eq!(cidr.prefix_len(), 24);
|
|
assert_eq!(ip, IpAddress::Ipv4(Ipv4Address::new(10, 0, 0, 2)));
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_cidr_no_prefix() {
|
|
let (cidr, _) = parse_cidr_address("10.0.0.2").unwrap();
|
|
assert_eq!(cidr.prefix_len(), 32);
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_cidr_multi_address() {
|
|
let (_, ip) = parse_cidr_address("10.0.0.2/24, fd00::2/128").unwrap();
|
|
assert_eq!(ip, IpAddress::Ipv4(Ipv4Address::new(10, 0, 0, 2)));
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_key_valid() {
|
|
let key = "YEocP0e2o1WT5GlvBvQzVF7EeR6z9aCk+ZdZ5NKEuXA=";
|
|
assert!(parse_key(key).is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_key_invalid() {
|
|
assert!(parse_key("not-valid").is_err());
|
|
}
|
|
}
|