use super::config::{VpnError, WireGuardConfig}; use boringtun::noise::{Tunn, TunnResult}; use boringtun::x25519::{PublicKey, StaticSecret}; use smoltcp::iface::{Config as IfaceConfig, Interface, SocketHandle, SocketSet}; use smoltcp::phy::{Device, DeviceCapabilities, Medium, RxToken, TxToken}; use smoltcp::socket::tcp::{Socket as TcpSocket, SocketBuffer}; use smoltcp::time::Instant as SmolInstant; use smoltcp::wire::{HardwareAddress, IpAddress, IpCidr, Ipv4Address}; use std::collections::VecDeque; use std::net::{SocketAddr, ToSocketAddrs, UdpSocket}; use std::sync::{Arc, Mutex}; use tokio::net::{TcpListener, TcpStream}; const SMOLTCP_TCP_RX_BUF: usize = 65536; const SMOLTCP_TCP_TX_BUF: usize = 65536; struct WgDevice { tunn: Arc>>, udp_socket: Arc, peer_addr: SocketAddr, rx_queue: VecDeque>, tx_queue: VecDeque>, } impl WgDevice { fn pump_wg_to_rx(&mut self) { let mut recv_buf = vec![0u8; 2048]; loop { match self.udp_socket.recv_from(&mut recv_buf) { Ok((len, _)) => { let mut dst = vec![0u8; 2048]; let mut tunn = self.tunn.lock().unwrap(); let result = tunn.decapsulate(None, &recv_buf[..len], &mut dst); match result { TunnResult::WriteToTunnelV4(data, _) | TunnResult::WriteToTunnelV6(data, _) => { self.rx_queue.push_back(data.to_vec()); } TunnResult::WriteToNetwork(response) => { let _ = self.udp_socket.send_to(response, self.peer_addr); } _ => {} } } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break, Err(_) => break, } } } fn flush_tx_queue(&mut self) { while let Some(ip_packet) = self.tx_queue.pop_front() { let mut dst = vec![0u8; ip_packet.len() + 256]; let mut tunn = self.tunn.lock().unwrap(); let result = tunn.encapsulate(&ip_packet, &mut dst); match result { TunnResult::WriteToNetwork(packet) => { if let Err(e) = self.udp_socket.send_to(packet, self.peer_addr) { log::error!("[wg] udp send_to failed: {e}"); } } TunnResult::Done => { // boringtun has nothing to send right now (e.g. handshake not yet // complete); silently drop. smoltcp will retransmit. } TunnResult::Err(e) => { log::error!( "[wg] encapsulate error for {}B IP packet: {e:?}", ip_packet.len() ); } TunnResult::WriteToTunnelV4(_, _) | TunnResult::WriteToTunnelV6(_, _) => { log::error!("[wg] encapsulate returned unexpected WriteToTunnel — bug?"); } } } } fn tick_timers(&mut self) { let mut dst = vec![0u8; 2048]; let mut tunn = self.tunn.lock().unwrap(); let result = tunn.update_timers(&mut dst); if let TunnResult::WriteToNetwork(packet) = result { let _ = self.udp_socket.send_to(packet, self.peer_addr); } } } struct WgRxToken { data: Vec, } impl RxToken for WgRxToken { fn consume(self, f: F) -> R where F: FnOnce(&[u8]) -> R, { f(&self.data) } } struct WgTxToken<'a> { tx_queue: &'a mut VecDeque>, } impl<'a> TxToken for WgTxToken<'a> { fn consume(self, len: usize, f: F) -> R where F: FnOnce(&mut [u8]) -> R, { let mut buf = vec![0u8; len]; let result = f(&mut buf); self.tx_queue.push_back(buf); result } } impl Device for WgDevice { type RxToken<'a> = WgRxToken; type TxToken<'a> = WgTxToken<'a>; fn receive(&mut self, _timestamp: SmolInstant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { if let Some(data) = self.rx_queue.pop_front() { Some(( WgRxToken { data }, WgTxToken { tx_queue: &mut self.tx_queue, }, )) } else { None } } fn transmit(&mut self, _timestamp: SmolInstant) -> Option> { Some(WgTxToken { tx_queue: &mut self.tx_queue, }) } fn capabilities(&self) -> DeviceCapabilities { let mut caps = DeviceCapabilities::default(); caps.medium = Medium::Ip; caps.max_transmission_unit = 1420; caps } } fn parse_key(key: &str) -> Result<[u8; 32], VpnError> { let decoded = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, key) .map_err(|e| VpnError::InvalidWireGuard(format!("Invalid key encoding: {e}")))?; if decoded.len() != 32 { return Err(VpnError::InvalidWireGuard(format!( "Invalid key length: {} (expected 32)", decoded.len() ))); } let mut key_bytes = [0u8; 32]; key_bytes.copy_from_slice(&decoded); Ok(key_bytes) } fn parse_cidr_address(addr: &str) -> Result<(IpCidr, IpAddress), VpnError> { let first_addr = addr.split(',').next().unwrap_or(addr).trim(); let parts: Vec<&str> = first_addr.split('/').collect(); let ip_str = parts[0]; let prefix = if parts.len() > 1 { parts[1] .parse::() .map_err(|_| VpnError::InvalidWireGuard(format!("Invalid prefix length: {}", parts[1])))? } else { 32 }; let ip: std::net::IpAddr = ip_str .parse() .map_err(|_| VpnError::InvalidWireGuard(format!("Invalid IP address: {ip_str}")))?; match ip { std::net::IpAddr::V4(v4) => { let smol_ip = Ipv4Address::new( v4.octets()[0], v4.octets()[1], v4.octets()[2], v4.octets()[3], ); Ok(( IpCidr::new(IpAddress::Ipv4(smol_ip), prefix), IpAddress::Ipv4(smol_ip), )) } std::net::IpAddr::V6(v6) => { let smol_ip = smoltcp::wire::Ipv6Address::from(v6.octets()); Ok(( IpCidr::new(IpAddress::Ipv6(smol_ip), prefix), IpAddress::Ipv6(smol_ip), )) } } } pub struct WireGuardSocks5Server { config: WireGuardConfig, port: u16, } impl WireGuardSocks5Server { pub fn new(config: WireGuardConfig, port: u16) -> Self { Self { config, port } } fn create_tunnel(&self) -> Result, VpnError> { let private_key_bytes = parse_key(&self.config.private_key)?; let static_private = StaticSecret::from(private_key_bytes); let peer_public_bytes = parse_key(&self.config.peer_public_key)?; let peer_public = PublicKey::from(peer_public_bytes); let preshared_key = if let Some(ref psk) = self.config.preshared_key { Some(parse_key(psk)?) } else { None }; Ok(Box::new(Tunn::new( static_private, peer_public, preshared_key, self.config.persistent_keepalive, 0, None, ))) } fn resolve_endpoint(&self) -> Result { self .config .peer_endpoint .to_socket_addrs() .map_err(|e| { VpnError::Connection(format!( "Failed to resolve endpoint '{}': {e}", self.config.peer_endpoint )) })? .next() .ok_or_else(|| { VpnError::Connection(format!( "No addresses found for endpoint: {}", self.config.peer_endpoint )) }) } fn do_handshake( tunn: &mut Tunn, socket: &UdpSocket, peer_addr: SocketAddr, ) -> Result<(), VpnError> { socket .set_read_timeout(Some(std::time::Duration::from_secs(5))) .map_err(|e| VpnError::Connection(format!("Failed to set timeout: {e}")))?; // WireGuard handshakes use UDP which can silently lose packets, especially // through Docker port-forwarding layers. Retry the handshake initiation up // to 5 times (25s total) before giving up — the protocol is designed for // retransmission and peers handle duplicate initiations correctly. let max_attempts = 5; let mut last_error = String::from("no handshake attempt completed"); for attempt in 1..=max_attempts { let mut dst = vec![0u8; 2048]; let result = tunn.format_handshake_initiation(&mut dst, false); match result { TunnResult::WriteToNetwork(packet) => { socket .send_to(packet, peer_addr) .map_err(|e| VpnError::Connection(format!("Failed to send handshake: {e}")))?; } TunnResult::Err(e) => { return Err(VpnError::Tunnel(format!( "Handshake initiation failed: {e:?}" ))); } _ => {} } let mut recv_buf = vec![0u8; 2048]; match socket.recv_from(&mut recv_buf) { Ok((len, _)) => { let result = tunn.decapsulate(None, &recv_buf[..len], &mut dst); match result { TunnResult::WriteToNetwork(response) => { socket .send_to(response, peer_addr) .map_err(|e| VpnError::Connection(format!("Failed to send response: {e}")))?; } TunnResult::Done => {} TunnResult::Err(e) => { last_error = format!("handshake response error: {e:?}"); log::warn!( "[vpn-worker] Handshake attempt {attempt}/{max_attempts} failed: {last_error}" ); continue; } _ => {} } socket .set_read_timeout(None) .map_err(|e| VpnError::Connection(format!("Failed to clear timeout: {e}")))?; return Ok(()); } Err(e) if attempt < max_attempts => { log::warn!( "[vpn-worker] Handshake attempt {attempt}/{max_attempts} timed out: {e}, retrying" ); last_error = format!("timeout: {e}"); continue; } Err(e) => { last_error = format!("timeout: {e}"); } } } Err(VpnError::Connection(format!( "Handshake failed after {max_attempts} attempts: {last_error}" ))) } pub async fn run( self, config_id: String, config_path: Option, ) -> Result<(), VpnError> { let peer_addr = self.resolve_endpoint()?; let mut tunn = self.create_tunnel()?; let udp_socket = UdpSocket::bind("0.0.0.0:0") .map_err(|e| VpnError::Connection(format!("Failed to create UDP socket: {e}")))?; Self::do_handshake(&mut tunn, &udp_socket, peer_addr)?; udp_socket .set_nonblocking(true) .map_err(|e| VpnError::Connection(format!("Failed to set non-blocking: {e}")))?; log::info!("[vpn-worker] WireGuard handshake completed"); let (cidr, local_ip) = parse_cidr_address(&self.config.address)?; let tunn_arc = Arc::new(Mutex::new(tunn)); let udp_arc = Arc::new(udp_socket); let mut device = WgDevice { tunn: tunn_arc.clone(), udp_socket: udp_arc.clone(), peer_addr, rx_queue: VecDeque::new(), tx_queue: VecDeque::new(), }; let iface_config = IfaceConfig::new(HardwareAddress::Ip); let mut iface = Interface::new(iface_config, &mut device, SmolInstant::now()); iface.update_ip_addrs(|addrs| { let _ = addrs.push(cidr); }); // Set default gateway match local_ip { IpAddress::Ipv4(v4) => { let octets = v4.octets(); let gw = Ipv4Address::new(octets[0], octets[1], octets[2], 1); iface .routes_mut() .add_default_ipv4_route(gw) .map_err(|e| VpnError::Tunnel(format!("Failed to add default route: {e}")))?; } IpAddress::Ipv6(_) => { // IPv6 routing not yet implemented } } let listener = TcpListener::bind(format!("127.0.0.1:{}", self.port)) .await .map_err(|e| VpnError::Connection(format!("Failed to bind SOCKS5 listener: {e}")))?; let actual_port = listener .local_addr() .map_err(|e| VpnError::Connection(format!("Failed to get local addr: {e}")))? .port(); // Update config with actual port and local_url. Prefer the explicit // config path the worker was started with — see issue #287, where // get_storage_dir() in the worker process resolved to a different // directory than in the parent (Qubes/sandboxed Linux), causing the // write-back to land in the wrong place and the parent to time out. let updated = match &config_path { Some(path) => crate::vpn_worker_storage::get_vpn_worker_config_from_path(path) .or_else(|| crate::vpn_worker_storage::get_vpn_worker_config(&config_id)), None => crate::vpn_worker_storage::get_vpn_worker_config(&config_id), }; if let Some(mut wc) = updated { wc.local_port = Some(actual_port); wc.local_url = Some(format!("socks5://127.0.0.1:{}", actual_port)); let result = match &config_path { Some(path) => crate::vpn_worker_storage::save_vpn_worker_config_to_path(&wc, path) .map_err(|e| e.to_string()), None => crate::vpn_worker_storage::save_vpn_worker_config(&wc).map_err(|e| e.to_string()), }; if let Err(e) = result { log::error!( "[vpn-worker] Failed to write back local_url to config: {} (path={:?})", e, config_path ); } } else { log::error!( "[vpn-worker] Could not load worker config for write-back (id={}, path={:?})", config_id, config_path ); } log::info!( "[vpn-worker] SOCKS5 server listening on 127.0.0.1:{}", actual_port ); let mut sockets = SocketSet::new(vec![]); struct Connection { smol_handle: SocketHandle, tcp_stream: TcpStream, socks_done: bool, connecting: bool, greeting_done: bool, read_buf: Vec, dest_addr: Option, } let mut connections: Vec = Vec::new(); let mut timer_counter: u64 = 0; loop { // Accept new SOCKS5 connections (non-blocking via short timeout) if let Ok(Ok((stream, _addr))) = tokio::time::timeout(tokio::time::Duration::from_millis(1), listener.accept()).await { let tcp_rx = SocketBuffer::new(vec![0u8; SMOLTCP_TCP_RX_BUF]); let tcp_tx = SocketBuffer::new(vec![0u8; SMOLTCP_TCP_TX_BUF]); let tcp_socket = TcpSocket::new(tcp_rx, tcp_tx); let handle = sockets.add(tcp_socket); connections.push(Connection { smol_handle: handle, tcp_stream: stream, socks_done: false, connecting: false, greeting_done: false, read_buf: Vec::new(), dest_addr: None, }); } // Pump WireGuard packets into smoltcp rx queue device.pump_wg_to_rx(); // Poll the smoltcp interface let timestamp = SmolInstant::now(); let _changed = iface.poll(timestamp, &mut device, &mut sockets); // Flush encrypted packets out through WireGuard device.flush_tx_queue(); // Process each connection let mut completed = Vec::new(); for (idx, conn) in connections.iter_mut().enumerate() { if conn.connecting { let socket = sockets.get_mut::(conn.smol_handle); if socket.may_send() { let _ = conn.tcp_stream.try_write(&[ 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, (actual_port >> 8) as u8, (actual_port & 0xff) as u8, ]); conn.connecting = false; conn.socks_done = true; } else if !socket.is_open() { let _ = conn .tcp_stream .try_write(&[0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0]); completed.push(idx); } } else if !conn.socks_done { // Handle SOCKS5 handshake let mut buf = [0u8; 512]; match conn.tcp_stream.try_read(&mut buf) { Ok(0) => { completed.push(idx); continue; } Ok(n) => { conn.read_buf.extend_from_slice(&buf[..n]); } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {} Err(_) => { completed.push(idx); continue; } } if !conn.greeting_done && conn.read_buf.len() >= 3 { // SOCKS5 greeting: version, nmethods, methods if conn.read_buf[0] != 0x05 { completed.push(idx); continue; } let nmethods = conn.read_buf[1] as usize; if conn.read_buf.len() < 2 + nmethods { continue; } // Reply: no auth required if conn.tcp_stream.try_write(&[0x05, 0x00]).is_err() { completed.push(idx); continue; } conn.read_buf.drain(..2 + nmethods); conn.greeting_done = true; } if conn.greeting_done && conn.dest_addr.is_none() && conn.read_buf.len() >= 10 { // SOCKS5 connect request if conn.read_buf[0] != 0x05 || conn.read_buf[1] != 0x01 { completed.push(idx); continue; } let (addr, addr_len) = match conn.read_buf[3] { 0x01 => { // IPv4 if conn.read_buf.len() < 10 { continue; } let ip = std::net::Ipv4Addr::new( conn.read_buf[4], conn.read_buf[5], conn.read_buf[6], conn.read_buf[7], ); let port = u16::from_be_bytes([conn.read_buf[8], conn.read_buf[9]]); (SocketAddr::new(std::net::IpAddr::V4(ip), port), 10) } 0x03 => { // Domain name let domain_len = conn.read_buf[4] as usize; let needed = 4 + 1 + domain_len + 2; if conn.read_buf.len() < needed { continue; } let domain = String::from_utf8_lossy(&conn.read_buf[5..5 + domain_len]).to_string(); let port_start = 5 + domain_len; let port = u16::from_be_bytes([conn.read_buf[port_start], conn.read_buf[port_start + 1]]); // Resolve domain match format!("{}:{}", domain, port).to_socket_addrs() { Ok(mut addrs) => { if let Some(addr) = addrs.next() { (addr, needed) } else { // Send SOCKS5 error: host unreachable let _ = conn .tcp_stream .try_write(&[0x05, 0x04, 0x00, 0x01, 0, 0, 0, 0, 0, 0]); completed.push(idx); continue; } } Err(_) => { let _ = conn .tcp_stream .try_write(&[0x05, 0x04, 0x00, 0x01, 0, 0, 0, 0, 0, 0]); completed.push(idx); continue; } } } 0x04 => { // IPv6 if conn.read_buf.len() < 22 { continue; } let mut octets = [0u8; 16]; octets.copy_from_slice(&conn.read_buf[4..20]); let ip = std::net::Ipv6Addr::from(octets); let port = u16::from_be_bytes([conn.read_buf[20], conn.read_buf[21]]); (SocketAddr::new(std::net::IpAddr::V6(ip), port), 22) } _ => { completed.push(idx); continue; } }; conn.read_buf.drain(..addr_len); conn.dest_addr = Some(addr); // Open smoltcp TCP socket to the destination let socket = sockets.get_mut::(conn.smol_handle); let smol_addr = match addr.ip() { std::net::IpAddr::V4(v4) => { let o = v4.octets(); IpAddress::Ipv4(Ipv4Address::new(o[0], o[1], o[2], o[3])) } std::net::IpAddr::V6(v6) => { IpAddress::Ipv6(smoltcp::wire::Ipv6Address::from(v6.octets())) } }; let local_port = 10000 + (rand::random::() % 50000); if socket .connect(iface.context(), (smol_addr, addr.port()), local_port) .is_err() { let _ = conn .tcp_stream .try_write(&[0x05, 0x05, 0x00, 0x01, 0, 0, 0, 0, 0, 0]); completed.push(idx); continue; } conn.connecting = true; } } else { // Data relay between SOCKS5 client and smoltcp socket let socket = sockets.get_mut::(conn.smol_handle); // Client → smoltcp let mut buf = [0u8; 4096]; match conn.tcp_stream.try_read(&mut buf) { Ok(0) => { socket.close(); completed.push(idx); continue; } Ok(n) => { if socket.can_send() { let _ = socket.send_slice(&buf[..n]); } } Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {} Err(_) => { socket.close(); completed.push(idx); continue; } } // smoltcp → Client if socket.can_recv() { match socket.recv(|data| (data.len(), data.to_vec())) { Ok(data) if !data.is_empty() && conn.tcp_stream.try_write(&data).is_err() => { socket.close(); completed.push(idx); continue; } _ => {} } } // Check if smoltcp socket closed if !socket.is_open() && !socket.is_active() { completed.push(idx); } } } // Remove completed connections (in reverse order) completed.sort_unstable(); completed.dedup(); for idx in completed.into_iter().rev() { let conn = connections.remove(idx); sockets.remove(conn.smol_handle); } // Timer ticks for WireGuard keepalives timer_counter += 1; if timer_counter.is_multiple_of(500) { device.tick_timers(); } // Small sleep to avoid busy-spinning tokio::time::sleep(tokio::time::Duration::from_millis(1)).await; } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_parse_cidr_ipv4() { let (cidr, ip) = parse_cidr_address("10.0.0.2/24").unwrap(); assert_eq!(cidr.prefix_len(), 24); assert_eq!(ip, IpAddress::Ipv4(Ipv4Address::new(10, 0, 0, 2))); } #[test] fn test_parse_cidr_no_prefix() { let (cidr, _) = parse_cidr_address("10.0.0.2").unwrap(); assert_eq!(cidr.prefix_len(), 32); } #[test] fn test_parse_cidr_multi_address() { let (_, ip) = parse_cidr_address("10.0.0.2/24, fd00::2/128").unwrap(); assert_eq!(ip, IpAddress::Ipv4(Ipv4Address::new(10, 0, 0, 2))); } #[test] fn test_parse_key_valid() { let key = "YEocP0e2o1WT5GlvBvQzVF7EeR6z9aCk+ZdZ5NKEuXA="; assert!(parse_key(key).is_ok()); } #[test] fn test_parse_key_invalid() { assert!(parse_key("not-valid").is_err()); } }