refactor: more robust proxy connection

This commit is contained in:
zhom
2025-12-12 15:04:09 +04:00
parent 0b43c6776b
commit ab8db06dfb
+401 -23
View File
@@ -359,13 +359,321 @@ async fn connect_via_socks(
}
}
async fn handle_http_via_socks4(
req: Request<hyper::body::Incoming>,
upstream_url: &str,
) -> Result<Response<Full<Bytes>>, Infallible> {
// Extract domain for traffic tracking
let domain = req
.uri()
.host()
.map(|h| h.to_string())
.unwrap_or_else(|| "unknown".to_string());
// Parse upstream SOCKS4 proxy URL
let upstream = match Url::parse(upstream_url) {
Ok(url) => url,
Err(e) => {
log::error!("Failed to parse SOCKS4 proxy URL: {}", e);
let mut response = Response::new(Full::new(Bytes::from("Invalid proxy URL")));
*response.status_mut() = StatusCode::BAD_GATEWAY;
return Ok(response);
}
};
let socks_host = upstream.host_str().unwrap_or("127.0.0.1");
let socks_port = upstream.port().unwrap_or(1080);
let socks_addr = format!("{}:{}", socks_host, socks_port);
// Parse target from request URI
let target_uri = req.uri();
let target_host = target_uri.host().unwrap_or("localhost");
let target_port = target_uri.port_u16().unwrap_or(80);
// Connect to SOCKS4 proxy
let mut socks_stream = match TcpStream::connect(&socks_addr).await {
Ok(stream) => stream,
Err(e) => {
log::error!("Failed to connect to SOCKS4 proxy {}: {}", socks_addr, e);
let mut response = Response::new(Full::new(Bytes::from(format!(
"Failed to connect to SOCKS4 proxy: {}",
e
))));
*response.status_mut() = StatusCode::BAD_GATEWAY;
return Ok(response);
}
};
// Resolve target host to IP (SOCKS4 requires IP addresses)
let target_ip = match tokio::net::lookup_host((target_host, target_port)).await {
Ok(mut addrs) => {
if let Some(addr) = addrs.next() {
match addr.ip() {
std::net::IpAddr::V4(ipv4) => ipv4.octets(),
std::net::IpAddr::V6(_) => {
log::error!("SOCKS4 does not support IPv6");
let mut response = Response::new(Full::new(Bytes::from(
"SOCKS4 does not support IPv6 addresses",
)));
*response.status_mut() = StatusCode::BAD_GATEWAY;
return Ok(response);
}
}
} else {
log::error!("Failed to resolve target host: {}", target_host);
let mut response = Response::new(Full::new(Bytes::from(format!(
"Failed to resolve target host: {}",
target_host
))));
*response.status_mut() = StatusCode::BAD_GATEWAY;
return Ok(response);
}
}
Err(e) => {
log::error!("Failed to resolve target host {}: {}", target_host, e);
let mut response = Response::new(Full::new(Bytes::from(format!(
"Failed to resolve target host: {}",
e
))));
*response.status_mut() = StatusCode::BAD_GATEWAY;
return Ok(response);
}
};
// Build SOCKS4 CONNECT request
let mut socks_request = vec![0x04, 0x01]; // SOCKS4, CONNECT
socks_request.extend_from_slice(&target_port.to_be_bytes());
socks_request.extend_from_slice(&target_ip);
socks_request.push(0); // NULL terminator for userid
// Send SOCKS4 CONNECT request
if let Err(e) = socks_stream.write_all(&socks_request).await {
log::error!("Failed to send SOCKS4 CONNECT request: {}", e);
let mut response = Response::new(Full::new(Bytes::from(format!(
"Failed to send SOCKS4 request: {}",
e
))));
*response.status_mut() = StatusCode::BAD_GATEWAY;
return Ok(response);
}
// Read SOCKS4 response
let mut socks_response = [0u8; 8];
if let Err(e) = socks_stream.read_exact(&mut socks_response).await {
log::error!("Failed to read SOCKS4 response: {}", e);
let mut response = Response::new(Full::new(Bytes::from(format!(
"Failed to read SOCKS4 response: {}",
e
))));
*response.status_mut() = StatusCode::BAD_GATEWAY;
return Ok(response);
}
// Check SOCKS4 response (second byte should be 0x5A for success)
if socks_response[1] != 0x5A {
log::error!(
"SOCKS4 connection failed, response code: {}",
socks_response[1]
);
let mut response = Response::new(Full::new(Bytes::from("SOCKS4 connection failed")));
*response.status_mut() = StatusCode::BAD_GATEWAY;
return Ok(response);
}
// Now send the HTTP request through the SOCKS4 connection
// Build HTTP request line
let method = req.method().as_str();
let path = target_uri
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/");
let http_version = if req.version() == hyper::Version::HTTP_11 {
"HTTP/1.1"
} else {
"HTTP/1.0"
};
let mut http_request = format!("{} {} {}\r\n", method, path, http_version);
// Add Host header if not present
let mut has_host = false;
for (name, value) in req.headers().iter() {
if name.as_str().eq_ignore_ascii_case("host") {
has_host = true;
}
// Skip proxy-specific headers
if name.as_str().eq_ignore_ascii_case("proxy-authorization")
|| name.as_str().eq_ignore_ascii_case("proxy-connection")
|| name.as_str().eq_ignore_ascii_case("proxy-authenticate")
{
continue;
}
if let Ok(val) = value.to_str() {
http_request.push_str(&format!("{}: {}\r\n", name.as_str(), val));
}
}
if !has_host {
http_request.push_str(&format!("Host: {}:{}\r\n", target_host, target_port));
}
// Get body
let body_bytes = match req.collect().await {
Ok(collected) => collected.to_bytes(),
Err(_) => Bytes::new(),
};
// Add Content-Length if there's a body
if !body_bytes.is_empty() {
http_request.push_str(&format!("Content-Length: {}\r\n", body_bytes.len()));
}
http_request.push_str("\r\n");
// Send HTTP request
if let Err(e) = socks_stream.write_all(http_request.as_bytes()).await {
log::error!("Failed to send HTTP request through SOCKS4: {}", e);
let mut response = Response::new(Full::new(Bytes::from(format!(
"Failed to send HTTP request: {}",
e
))));
*response.status_mut() = StatusCode::BAD_GATEWAY;
return Ok(response);
}
// Send body if present
if !body_bytes.is_empty() {
if let Err(e) = socks_stream.write_all(&body_bytes).await {
log::error!("Failed to send HTTP body through SOCKS4: {}", e);
let mut response = Response::new(Full::new(Bytes::from(format!(
"Failed to send HTTP body: {}",
e
))));
*response.status_mut() = StatusCode::BAD_GATEWAY;
return Ok(response);
}
}
// Read HTTP response
let mut response_buffer = Vec::with_capacity(8192);
let mut temp_buf = [0u8; 4096];
let mut content_length: Option<usize> = None;
let mut is_chunked = false;
// Read until we have complete headers
loop {
match socks_stream.read(&mut temp_buf).await {
Ok(0) => break, // Connection closed
Ok(n) => {
response_buffer.extend_from_slice(&temp_buf[..n]);
// Check for end of headers (\r\n\r\n)
if let Some(pos) = response_buffer.windows(4).position(|w| w == b"\r\n\r\n") {
// Parse headers
let headers_str = String::from_utf8_lossy(&response_buffer[..pos + 4]);
for line in headers_str.lines() {
let line_lower = line.to_lowercase();
if line_lower.starts_with("content-length:") {
if let Some(len_str) = line.split(':').nth(1) {
if let Ok(len) = len_str.trim().parse::<usize>() {
content_length = Some(len);
}
}
} else if line_lower.starts_with("transfer-encoding:") && line_lower.contains("chunked")
{
is_chunked = true;
}
}
// Read body if Content-Length is specified and we don't have it all
if let Some(cl) = content_length {
let body_start = pos + 4;
let body_received = response_buffer.len() - body_start;
if body_received < cl {
// Read remaining body (but don't use read_exact as connection might close)
let remaining = cl - body_received;
let mut read_so_far = 0;
while read_so_far < remaining {
match socks_stream.read(&mut temp_buf).await {
Ok(0) => break, // Connection closed
Ok(m) => {
let to_read = (remaining - read_so_far).min(m);
response_buffer.extend_from_slice(&temp_buf[..to_read]);
read_so_far += to_read;
if to_read < m {
// More data than needed, might be next response - stop here
break;
}
}
Err(_) => break,
}
}
}
} else if !is_chunked {
// No Content-Length and not chunked - read until connection closes
// But limit to reasonable size to avoid memory issues
let max_body_size = 10 * 1024 * 1024; // 10MB max
while response_buffer.len() < max_body_size {
match socks_stream.read(&mut temp_buf).await {
Ok(0) => break, // Connection closed
Ok(n) => {
response_buffer.extend_from_slice(&temp_buf[..n]);
}
Err(_) => break,
}
}
}
// Note: Chunked encoding is complex to parse manually, so we'll read what we can
// For full chunked support, we'd need a proper HTTP parser
break;
}
}
Err(e) => {
log::error!("Error reading HTTP response from SOCKS4: {}", e);
break;
}
}
}
// Parse HTTP response
let response_str = String::from_utf8_lossy(&response_buffer);
let mut lines = response_str.lines();
let status_line = lines.next().unwrap_or("HTTP/1.1 500 Internal Server Error");
let status_parts: Vec<&str> = status_line.split_whitespace().collect();
let status_code = status_parts
.get(1)
.and_then(|s| s.parse::<u16>().ok())
.unwrap_or(500);
// Find header/body boundary
let header_end = response_buffer
.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|p| p + 4)
.unwrap_or(response_buffer.len());
let body = response_buffer[header_end..].to_vec();
// Record request in traffic tracker
let response_size = body.len() as u64;
if let Some(tracker) = get_traffic_tracker() {
tracker.record_request(&domain, body_bytes.len() as u64, response_size);
}
let mut hyper_response = Response::new(Full::new(Bytes::from(body)));
*hyper_response.status_mut() = StatusCode::from_u16(status_code).unwrap();
Ok(hyper_response)
}
async fn handle_http(
req: Request<hyper::body::Incoming>,
upstream_url: Option<String>,
) -> Result<Response<Full<Bytes>>, Infallible> {
// Use reqwest for all HTTP requests as it handles proxies better
// This is faster and more reliable than trying to use hyper-proxy with version conflicts
use reqwest::Client;
// Extract domain for traffic tracking
let domain = req
.uri()
.host()
.map(|h| h.to_string())
.unwrap_or_else(|| "unknown".to_string());
log::error!(
"DEBUG: Handling HTTP request: {} {} (host: {:?})",
@@ -374,12 +682,20 @@ async fn handle_http(
req.uri().host()
);
// Extract domain for traffic tracking
let domain = req
.uri()
.host()
.map(|h| h.to_string())
.unwrap_or_else(|| "unknown".to_string());
// Check if we need to handle SOCKS4 manually (reqwest doesn't support it)
if let Some(ref upstream) = upstream_url {
if upstream != "DIRECT" {
if let Ok(url) = Url::parse(upstream) {
if url.scheme() == "socks4" {
// Handle SOCKS4 manually for HTTP requests
return handle_http_via_socks4(req, upstream).await;
}
}
}
}
// Use reqwest for HTTP/HTTPS/SOCKS5 proxies
use reqwest::Client;
let client_builder = Client::builder();
let client = if let Some(ref upstream) = upstream_url {
@@ -497,6 +813,7 @@ fn build_reqwest_client_with_proxy(
let proxy = match scheme {
"http" | "https" => {
// For HTTP/HTTPS proxies, reqwest handles them directly
// Note: HTTPS proxy URLs still use HTTP CONNECT method, reqwest handles TLS automatically
Proxy::http(upstream_url)?
}
"socks5" => {
@@ -504,8 +821,9 @@ fn build_reqwest_client_with_proxy(
Proxy::all(upstream_url)?
}
"socks4" => {
// SOCKS4 is not directly supported by reqwest, would need custom handling
return Err("SOCKS4 not supported for HTTP requests via reqwest".into());
// SOCKS4 is handled manually in handle_http_via_socks4
// This should not be reached, but return error as fallback
return Err("SOCKS4 should be handled manually".into());
}
_ => {
return Err(format!("Unsupported proxy scheme: {}", scheme).into());
@@ -693,38 +1011,95 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box<dyn std::er
tokio::task::spawn(async move {
// Read first bytes to detect CONNECT requests
// CONNECT requests need special handling for tunneling
let mut peek_buffer = [0u8; 8];
// Use a larger buffer to ensure we can detect CONNECT even with partial reads
let mut peek_buffer = [0u8; 16];
match stream.read(&mut peek_buffer).await {
Ok(0) => {
log::error!("DEBUG: Connection closed immediately (0 bytes read)");
}
Ok(n) => {
let request_start = String::from_utf8_lossy(&peek_buffer[..n.min(7)]);
log::error!("DEBUG: Read {} bytes, starts with: {:?}", n, request_start);
if n >= 7 && request_start.starts_with("CONNECT") {
// Check if this looks like a CONNECT request
// Be more lenient - check if the first bytes match "CONNECT" (case-insensitive)
let request_start_upper =
String::from_utf8_lossy(&peek_buffer[..n.min(7)]).to_uppercase();
let is_connect = request_start_upper.starts_with("CONNECT");
log::error!(
"DEBUG: Read {} bytes, starts with: {:?}, is_connect: {}",
n,
String::from_utf8_lossy(&peek_buffer[..n.min(20)]),
is_connect
);
if is_connect {
// Handle CONNECT request manually for tunneling
let mut full_request = Vec::with_capacity(4096);
full_request.extend_from_slice(&peek_buffer[..n]);
// Read the rest of the CONNECT request
// Read the rest of the CONNECT request until we have the full headers
// CONNECT requests end with \r\n\r\n (or \n\n)
let mut remaining = [0u8; 4096];
let mut total_read = n;
let max_reads = 100; // Prevent infinite loop
let mut reads = 0;
loop {
if reads >= max_reads {
log::error!("DEBUG: Max reads reached, breaking");
break;
}
match stream.read(&mut remaining).await {
Ok(0) => break,
Ok(m) => {
full_request.extend_from_slice(&remaining[..m]);
Ok(0) => {
// Connection closed, but we might have a complete request
if full_request.ends_with(b"\r\n\r\n") || full_request.ends_with(b"\n\n") {
break;
}
// If we have some data, try to process it anyway
if total_read > 0 {
break;
}
return; // No data at all
}
Ok(m) => {
reads += 1;
total_read += m;
full_request.extend_from_slice(&remaining[..m]);
// Check if we have complete headers
if full_request.ends_with(b"\r\n\r\n") || full_request.ends_with(b"\n\n") {
break;
}
// Also check if we have enough to parse (at least "CONNECT host:port HTTP/1.x")
if total_read >= 20 {
// Check if we have a newline that might indicate end of request line
if let Some(pos) = full_request.iter().position(|&b| b == b'\n') {
if pos < full_request.len() - 1 {
// We have at least the request line, check if we have headers
let request_str = String::from_utf8_lossy(&full_request);
if request_str.contains("\r\n\r\n") || request_str.contains("\n\n") {
break;
}
}
}
}
}
Err(e) => {
log::error!("DEBUG: Error reading CONNECT request: {:?}", e);
// If we have some data, try to process it
if total_read > 0 {
break;
}
return;
}
Err(_) => break,
}
}
// Handle CONNECT manually
log::error!(
"DEBUG: Handling CONNECT manually for: {}",
String::from_utf8_lossy(&full_request[..full_request.len().min(100)])
String::from_utf8_lossy(&full_request[..full_request.len().min(200)])
);
if let Err(e) = handle_connect_from_buffer(stream, full_request, upstream).await {
log::error!("Error handling CONNECT request: {:?}", e);
@@ -739,7 +1114,7 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box<dyn std::er
log::error!(
"DEBUG: Non-CONNECT request, first {} bytes: {:?}",
n,
String::from_utf8_lossy(&peek_buffer[..n])
String::from_utf8_lossy(&peek_buffer[..n.min(50)])
);
let prepended_bytes = peek_buffer[..n].to_vec();
let prepended_reader = PrependReader {
@@ -826,7 +1201,10 @@ async fn handle_connect_from_buffer(
match scheme {
"http" | "https" => {
// Connect via HTTP proxy CONNECT
// Connect via HTTP/HTTPS proxy CONNECT
// Note: HTTPS proxy URLs still use HTTP CONNECT method (CONNECT is always HTTP-based)
// For HTTPS proxies, reqwest handles TLS automatically in handle_http
// For manual CONNECT here, we use plain TCP - HTTPS proxy CONNECT typically works over plain TCP
let proxy_host = upstream.host_str().unwrap_or("127.0.0.1");
let proxy_port = upstream.port().unwrap_or(8080);
let mut proxy_stream = TcpStream::connect((proxy_host, proxy_port)).await?;