From ab8db06dfb75e177863bb1a932918a2ab31b9b66 Mon Sep 17 00:00:00 2001 From: zhom <2717306+zhom@users.noreply.github.com> Date: Fri, 12 Dec 2025 15:04:09 +0400 Subject: [PATCH] refactor: more robust proxy connection --- src-tauri/src/proxy_server.rs | 424 ++++++++++++++++++++++++++++++++-- 1 file changed, 401 insertions(+), 23 deletions(-) diff --git a/src-tauri/src/proxy_server.rs b/src-tauri/src/proxy_server.rs index 6e3ce4a..8e65211 100644 --- a/src-tauri/src/proxy_server.rs +++ b/src-tauri/src/proxy_server.rs @@ -359,13 +359,321 @@ async fn connect_via_socks( } } +async fn handle_http_via_socks4( + req: Request, + upstream_url: &str, +) -> Result>, Infallible> { + // Extract domain for traffic tracking + let domain = req + .uri() + .host() + .map(|h| h.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + + // Parse upstream SOCKS4 proxy URL + let upstream = match Url::parse(upstream_url) { + Ok(url) => url, + Err(e) => { + log::error!("Failed to parse SOCKS4 proxy URL: {}", e); + let mut response = Response::new(Full::new(Bytes::from("Invalid proxy URL"))); + *response.status_mut() = StatusCode::BAD_GATEWAY; + return Ok(response); + } + }; + + let socks_host = upstream.host_str().unwrap_or("127.0.0.1"); + let socks_port = upstream.port().unwrap_or(1080); + let socks_addr = format!("{}:{}", socks_host, socks_port); + + // Parse target from request URI + let target_uri = req.uri(); + let target_host = target_uri.host().unwrap_or("localhost"); + let target_port = target_uri.port_u16().unwrap_or(80); + + // Connect to SOCKS4 proxy + let mut socks_stream = match TcpStream::connect(&socks_addr).await { + Ok(stream) => stream, + Err(e) => { + log::error!("Failed to connect to SOCKS4 proxy {}: {}", socks_addr, e); + let mut response = Response::new(Full::new(Bytes::from(format!( + "Failed to connect to SOCKS4 proxy: {}", + e + )))); + *response.status_mut() = StatusCode::BAD_GATEWAY; + return Ok(response); + } + }; + + // Resolve target host to IP (SOCKS4 requires IP addresses) + let target_ip = match tokio::net::lookup_host((target_host, target_port)).await { + Ok(mut addrs) => { + if let Some(addr) = addrs.next() { + match addr.ip() { + std::net::IpAddr::V4(ipv4) => ipv4.octets(), + std::net::IpAddr::V6(_) => { + log::error!("SOCKS4 does not support IPv6"); + let mut response = Response::new(Full::new(Bytes::from( + "SOCKS4 does not support IPv6 addresses", + ))); + *response.status_mut() = StatusCode::BAD_GATEWAY; + return Ok(response); + } + } + } else { + log::error!("Failed to resolve target host: {}", target_host); + let mut response = Response::new(Full::new(Bytes::from(format!( + "Failed to resolve target host: {}", + target_host + )))); + *response.status_mut() = StatusCode::BAD_GATEWAY; + return Ok(response); + } + } + Err(e) => { + log::error!("Failed to resolve target host {}: {}", target_host, e); + let mut response = Response::new(Full::new(Bytes::from(format!( + "Failed to resolve target host: {}", + e + )))); + *response.status_mut() = StatusCode::BAD_GATEWAY; + return Ok(response); + } + }; + + // Build SOCKS4 CONNECT request + let mut socks_request = vec![0x04, 0x01]; // SOCKS4, CONNECT + socks_request.extend_from_slice(&target_port.to_be_bytes()); + socks_request.extend_from_slice(&target_ip); + socks_request.push(0); // NULL terminator for userid + + // Send SOCKS4 CONNECT request + if let Err(e) = socks_stream.write_all(&socks_request).await { + log::error!("Failed to send SOCKS4 CONNECT request: {}", e); + let mut response = Response::new(Full::new(Bytes::from(format!( + "Failed to send SOCKS4 request: {}", + e + )))); + *response.status_mut() = StatusCode::BAD_GATEWAY; + return Ok(response); + } + + // Read SOCKS4 response + let mut socks_response = [0u8; 8]; + if let Err(e) = socks_stream.read_exact(&mut socks_response).await { + log::error!("Failed to read SOCKS4 response: {}", e); + let mut response = Response::new(Full::new(Bytes::from(format!( + "Failed to read SOCKS4 response: {}", + e + )))); + *response.status_mut() = StatusCode::BAD_GATEWAY; + return Ok(response); + } + + // Check SOCKS4 response (second byte should be 0x5A for success) + if socks_response[1] != 0x5A { + log::error!( + "SOCKS4 connection failed, response code: {}", + socks_response[1] + ); + let mut response = Response::new(Full::new(Bytes::from("SOCKS4 connection failed"))); + *response.status_mut() = StatusCode::BAD_GATEWAY; + return Ok(response); + } + + // Now send the HTTP request through the SOCKS4 connection + // Build HTTP request line + let method = req.method().as_str(); + let path = target_uri + .path_and_query() + .map(|pq| pq.as_str()) + .unwrap_or("/"); + let http_version = if req.version() == hyper::Version::HTTP_11 { + "HTTP/1.1" + } else { + "HTTP/1.0" + }; + + let mut http_request = format!("{} {} {}\r\n", method, path, http_version); + + // Add Host header if not present + let mut has_host = false; + for (name, value) in req.headers().iter() { + if name.as_str().eq_ignore_ascii_case("host") { + has_host = true; + } + // Skip proxy-specific headers + if name.as_str().eq_ignore_ascii_case("proxy-authorization") + || name.as_str().eq_ignore_ascii_case("proxy-connection") + || name.as_str().eq_ignore_ascii_case("proxy-authenticate") + { + continue; + } + if let Ok(val) = value.to_str() { + http_request.push_str(&format!("{}: {}\r\n", name.as_str(), val)); + } + } + + if !has_host { + http_request.push_str(&format!("Host: {}:{}\r\n", target_host, target_port)); + } + + // Get body + let body_bytes = match req.collect().await { + Ok(collected) => collected.to_bytes(), + Err(_) => Bytes::new(), + }; + + // Add Content-Length if there's a body + if !body_bytes.is_empty() { + http_request.push_str(&format!("Content-Length: {}\r\n", body_bytes.len())); + } + + http_request.push_str("\r\n"); + + // Send HTTP request + if let Err(e) = socks_stream.write_all(http_request.as_bytes()).await { + log::error!("Failed to send HTTP request through SOCKS4: {}", e); + let mut response = Response::new(Full::new(Bytes::from(format!( + "Failed to send HTTP request: {}", + e + )))); + *response.status_mut() = StatusCode::BAD_GATEWAY; + return Ok(response); + } + + // Send body if present + if !body_bytes.is_empty() { + if let Err(e) = socks_stream.write_all(&body_bytes).await { + log::error!("Failed to send HTTP body through SOCKS4: {}", e); + let mut response = Response::new(Full::new(Bytes::from(format!( + "Failed to send HTTP body: {}", + e + )))); + *response.status_mut() = StatusCode::BAD_GATEWAY; + return Ok(response); + } + } + + // Read HTTP response + let mut response_buffer = Vec::with_capacity(8192); + let mut temp_buf = [0u8; 4096]; + let mut content_length: Option = None; + let mut is_chunked = false; + + // Read until we have complete headers + loop { + match socks_stream.read(&mut temp_buf).await { + Ok(0) => break, // Connection closed + Ok(n) => { + response_buffer.extend_from_slice(&temp_buf[..n]); + // Check for end of headers (\r\n\r\n) + if let Some(pos) = response_buffer.windows(4).position(|w| w == b"\r\n\r\n") { + // Parse headers + let headers_str = String::from_utf8_lossy(&response_buffer[..pos + 4]); + for line in headers_str.lines() { + let line_lower = line.to_lowercase(); + if line_lower.starts_with("content-length:") { + if let Some(len_str) = line.split(':').nth(1) { + if let Ok(len) = len_str.trim().parse::() { + content_length = Some(len); + } + } + } else if line_lower.starts_with("transfer-encoding:") && line_lower.contains("chunked") + { + is_chunked = true; + } + } + // Read body if Content-Length is specified and we don't have it all + if let Some(cl) = content_length { + let body_start = pos + 4; + let body_received = response_buffer.len() - body_start; + if body_received < cl { + // Read remaining body (but don't use read_exact as connection might close) + let remaining = cl - body_received; + let mut read_so_far = 0; + while read_so_far < remaining { + match socks_stream.read(&mut temp_buf).await { + Ok(0) => break, // Connection closed + Ok(m) => { + let to_read = (remaining - read_so_far).min(m); + response_buffer.extend_from_slice(&temp_buf[..to_read]); + read_so_far += to_read; + if to_read < m { + // More data than needed, might be next response - stop here + break; + } + } + Err(_) => break, + } + } + } + } else if !is_chunked { + // No Content-Length and not chunked - read until connection closes + // But limit to reasonable size to avoid memory issues + let max_body_size = 10 * 1024 * 1024; // 10MB max + while response_buffer.len() < max_body_size { + match socks_stream.read(&mut temp_buf).await { + Ok(0) => break, // Connection closed + Ok(n) => { + response_buffer.extend_from_slice(&temp_buf[..n]); + } + Err(_) => break, + } + } + } + // Note: Chunked encoding is complex to parse manually, so we'll read what we can + // For full chunked support, we'd need a proper HTTP parser + break; + } + } + Err(e) => { + log::error!("Error reading HTTP response from SOCKS4: {}", e); + break; + } + } + } + + // Parse HTTP response + let response_str = String::from_utf8_lossy(&response_buffer); + let mut lines = response_str.lines(); + let status_line = lines.next().unwrap_or("HTTP/1.1 500 Internal Server Error"); + let status_parts: Vec<&str> = status_line.split_whitespace().collect(); + let status_code = status_parts + .get(1) + .and_then(|s| s.parse::().ok()) + .unwrap_or(500); + + // Find header/body boundary + let header_end = response_buffer + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map(|p| p + 4) + .unwrap_or(response_buffer.len()); + + let body = response_buffer[header_end..].to_vec(); + + // Record request in traffic tracker + let response_size = body.len() as u64; + if let Some(tracker) = get_traffic_tracker() { + tracker.record_request(&domain, body_bytes.len() as u64, response_size); + } + + let mut hyper_response = Response::new(Full::new(Bytes::from(body))); + *hyper_response.status_mut() = StatusCode::from_u16(status_code).unwrap(); + + Ok(hyper_response) +} + async fn handle_http( req: Request, upstream_url: Option, ) -> Result>, Infallible> { - // Use reqwest for all HTTP requests as it handles proxies better - // This is faster and more reliable than trying to use hyper-proxy with version conflicts - use reqwest::Client; + // Extract domain for traffic tracking + let domain = req + .uri() + .host() + .map(|h| h.to_string()) + .unwrap_or_else(|| "unknown".to_string()); log::error!( "DEBUG: Handling HTTP request: {} {} (host: {:?})", @@ -374,12 +682,20 @@ async fn handle_http( req.uri().host() ); - // Extract domain for traffic tracking - let domain = req - .uri() - .host() - .map(|h| h.to_string()) - .unwrap_or_else(|| "unknown".to_string()); + // Check if we need to handle SOCKS4 manually (reqwest doesn't support it) + if let Some(ref upstream) = upstream_url { + if upstream != "DIRECT" { + if let Ok(url) = Url::parse(upstream) { + if url.scheme() == "socks4" { + // Handle SOCKS4 manually for HTTP requests + return handle_http_via_socks4(req, upstream).await; + } + } + } + } + + // Use reqwest for HTTP/HTTPS/SOCKS5 proxies + use reqwest::Client; let client_builder = Client::builder(); let client = if let Some(ref upstream) = upstream_url { @@ -497,6 +813,7 @@ fn build_reqwest_client_with_proxy( let proxy = match scheme { "http" | "https" => { // For HTTP/HTTPS proxies, reqwest handles them directly + // Note: HTTPS proxy URLs still use HTTP CONNECT method, reqwest handles TLS automatically Proxy::http(upstream_url)? } "socks5" => { @@ -504,8 +821,9 @@ fn build_reqwest_client_with_proxy( Proxy::all(upstream_url)? } "socks4" => { - // SOCKS4 is not directly supported by reqwest, would need custom handling - return Err("SOCKS4 not supported for HTTP requests via reqwest".into()); + // SOCKS4 is handled manually in handle_http_via_socks4 + // This should not be reached, but return error as fallback + return Err("SOCKS4 should be handled manually".into()); } _ => { return Err(format!("Unsupported proxy scheme: {}", scheme).into()); @@ -693,38 +1011,95 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box { log::error!("DEBUG: Connection closed immediately (0 bytes read)"); } Ok(n) => { - let request_start = String::from_utf8_lossy(&peek_buffer[..n.min(7)]); - log::error!("DEBUG: Read {} bytes, starts with: {:?}", n, request_start); - if n >= 7 && request_start.starts_with("CONNECT") { + // Check if this looks like a CONNECT request + // Be more lenient - check if the first bytes match "CONNECT" (case-insensitive) + let request_start_upper = + String::from_utf8_lossy(&peek_buffer[..n.min(7)]).to_uppercase(); + let is_connect = request_start_upper.starts_with("CONNECT"); + + log::error!( + "DEBUG: Read {} bytes, starts with: {:?}, is_connect: {}", + n, + String::from_utf8_lossy(&peek_buffer[..n.min(20)]), + is_connect + ); + + if is_connect { // Handle CONNECT request manually for tunneling let mut full_request = Vec::with_capacity(4096); full_request.extend_from_slice(&peek_buffer[..n]); - // Read the rest of the CONNECT request + // Read the rest of the CONNECT request until we have the full headers + // CONNECT requests end with \r\n\r\n (or \n\n) let mut remaining = [0u8; 4096]; + let mut total_read = n; + let max_reads = 100; // Prevent infinite loop + let mut reads = 0; + loop { + if reads >= max_reads { + log::error!("DEBUG: Max reads reached, breaking"); + break; + } + match stream.read(&mut remaining).await { - Ok(0) => break, - Ok(m) => { - full_request.extend_from_slice(&remaining[..m]); + Ok(0) => { + // Connection closed, but we might have a complete request if full_request.ends_with(b"\r\n\r\n") || full_request.ends_with(b"\n\n") { break; } + // If we have some data, try to process it anyway + if total_read > 0 { + break; + } + return; // No data at all + } + Ok(m) => { + reads += 1; + total_read += m; + full_request.extend_from_slice(&remaining[..m]); + + // Check if we have complete headers + if full_request.ends_with(b"\r\n\r\n") || full_request.ends_with(b"\n\n") { + break; + } + + // Also check if we have enough to parse (at least "CONNECT host:port HTTP/1.x") + if total_read >= 20 { + // Check if we have a newline that might indicate end of request line + if let Some(pos) = full_request.iter().position(|&b| b == b'\n') { + if pos < full_request.len() - 1 { + // We have at least the request line, check if we have headers + let request_str = String::from_utf8_lossy(&full_request); + if request_str.contains("\r\n\r\n") || request_str.contains("\n\n") { + break; + } + } + } + } + } + Err(e) => { + log::error!("DEBUG: Error reading CONNECT request: {:?}", e); + // If we have some data, try to process it + if total_read > 0 { + break; + } + return; } - Err(_) => break, } } // Handle CONNECT manually log::error!( "DEBUG: Handling CONNECT manually for: {}", - String::from_utf8_lossy(&full_request[..full_request.len().min(100)]) + String::from_utf8_lossy(&full_request[..full_request.len().min(200)]) ); if let Err(e) = handle_connect_from_buffer(stream, full_request, upstream).await { log::error!("Error handling CONNECT request: {:?}", e); @@ -739,7 +1114,7 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box { - // Connect via HTTP proxy CONNECT + // Connect via HTTP/HTTPS proxy CONNECT + // Note: HTTPS proxy URLs still use HTTP CONNECT method (CONNECT is always HTTP-based) + // For HTTPS proxies, reqwest handles TLS automatically in handle_http + // For manual CONNECT here, we use plain TCP - HTTPS proxy CONNECT typically works over plain TCP let proxy_host = upstream.host_str().unwrap_or("127.0.0.1"); let proxy_port = upstream.port().unwrap_or(8080); let mut proxy_stream = TcpStream::connect((proxy_host, proxy_port)).await?;