refactor: socks5 chaining

This commit is contained in:
zhom
2026-03-16 17:48:02 +04:00
parent 29dd5abb34
commit 8511535d69
2 changed files with 168 additions and 164 deletions
+105 -115
View File
@@ -1062,143 +1062,133 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box<dyn std::er
let matcher = bypass_matcher.clone();
tokio::task::spawn(async move {
// Peek at first bytes to detect CONNECT requests.
// Use peek() first to wait for data without consuming it, avoiding race
// conditions where read() returns 0 on a fresh connection.
// Wait for the stream to have readable data before attempting to read.
// This prevents read() from returning 0 on a fresh connection before
// the client's data arrives.
if stream.readable().await.is_err() {
return;
}
let mut peek_buffer = [0u8; 16];
match stream.read(&mut peek_buffer).await {
Ok(0) => {}
Ok(n) => {
// Check if this looks like a CONNECT request
// Be more lenient - check if the first bytes match "CONNECT" (case-insensitive)
let request_start_upper =
String::from_utf8_lossy(&peek_buffer[..n.min(7)]).to_uppercase();
let is_connect = request_start_upper.starts_with("CONNECT");
let peek_n = match tokio::time::timeout(
tokio::time::Duration::from_secs(30),
stream.peek(&mut peek_buffer),
)
.await
{
Ok(Ok(n)) if n > 0 => n,
_ => {
log::error!("DEBUG: Connection closed or timed out before receiving data");
return;
}
};
log::error!(
"DEBUG: Read {} bytes, starts with: {:?}, is_connect: {}",
n,
String::from_utf8_lossy(&peek_buffer[..n.min(20)]),
is_connect
);
// Now consume the peeked bytes
let n = match stream.read(&mut peek_buffer[..peek_n]).await {
Ok(n) if n > 0 => n,
_ => return,
};
if is_connect {
// Handle CONNECT request manually for tunneling
let mut full_request = Vec::with_capacity(4096);
full_request.extend_from_slice(&peek_buffer[..n]);
{
// Check if this looks like a CONNECT request
// Be more lenient - check if the first bytes match "CONNECT" (case-insensitive)
let request_start_upper =
String::from_utf8_lossy(&peek_buffer[..n.min(7)]).to_uppercase();
let is_connect = request_start_upper.starts_with("CONNECT");
// Read the rest of the CONNECT request until we have the full headers
// CONNECT requests end with \r\n\r\n (or \n\n)
let mut remaining = [0u8; 4096];
let mut total_read = n;
let max_reads = 100; // Prevent infinite loop
let mut reads = 0;
log::error!(
"DEBUG: Read {} bytes, starts with: {:?}, is_connect: {}",
n,
String::from_utf8_lossy(&peek_buffer[..n.min(20)]),
is_connect
);
if is_connect {
// Handle CONNECT request manually for tunneling
let mut full_request = Vec::with_capacity(4096);
full_request.extend_from_slice(&peek_buffer[..n]);
// Read the rest of the CONNECT request until we have the full headers
// CONNECT requests end with \r\n\r\n (or \n\n)
let mut remaining = [0u8; 4096];
let mut total_read = n;
let max_reads = 100; // Prevent infinite loop
let mut reads = 0;
loop {
if reads >= max_reads {
log::error!("DEBUG: Max reads reached, breaking");
break;
}
match stream.read(&mut remaining).await {
Ok(0) => {
// Connection closed, but we might have a complete request
if full_request.ends_with(b"\r\n\r\n") || full_request.ends_with(b"\n\n") {
break;
}
// If we have some data, try to process it anyway
if total_read > 0 {
break;
}
return; // No data at all
loop {
if reads >= max_reads {
log::error!("DEBUG: Max reads reached, breaking");
break;
}
Ok(m) => {
reads += 1;
total_read += m;
full_request.extend_from_slice(&remaining[..m]);
// Check if we have complete headers
if full_request.ends_with(b"\r\n\r\n") || full_request.ends_with(b"\n\n") {
break;
match stream.read(&mut remaining).await {
Ok(0) => {
// Connection closed, but we might have a complete request
if full_request.ends_with(b"\r\n\r\n") || full_request.ends_with(b"\n\n") {
break;
}
// If we have some data, try to process it anyway
if total_read > 0 {
break;
}
return; // No data at all
}
Ok(m) => {
reads += 1;
total_read += m;
full_request.extend_from_slice(&remaining[..m]);
// Also check if we have enough to parse (at least "CONNECT host:port HTTP/1.x")
if total_read >= 20 {
// Check if we have a newline that might indicate end of request line
if let Some(pos) = full_request.iter().position(|&b| b == b'\n') {
if pos < full_request.len() - 1 {
// We have at least the request line, check if we have headers
let request_str = String::from_utf8_lossy(&full_request);
if request_str.contains("\r\n\r\n") || request_str.contains("\n\n") {
break;
// Check if we have complete headers
if full_request.ends_with(b"\r\n\r\n") || full_request.ends_with(b"\n\n") {
break;
}
// Also check if we have enough to parse (at least "CONNECT host:port HTTP/1.x")
if total_read >= 20 {
// Check if we have a newline that might indicate end of request line
if let Some(pos) = full_request.iter().position(|&b| b == b'\n') {
if pos < full_request.len() - 1 {
// We have at least the request line, check if we have headers
let request_str = String::from_utf8_lossy(&full_request);
if request_str.contains("\r\n\r\n") || request_str.contains("\n\n") {
break;
}
}
}
}
}
}
Err(e) => {
log::error!("DEBUG: Error reading CONNECT request: {:?}", e);
// If we have some data, try to process it
if total_read > 0 {
break;
Err(e) => {
log::error!("DEBUG: Error reading CONNECT request: {:?}", e);
// If we have some data, try to process it
if total_read > 0 {
break;
}
return;
}
return;
}
}
// Handle CONNECT manually
log::error!(
"DEBUG: Handling CONNECT manually for: {}",
String::from_utf8_lossy(&full_request[..full_request.len().min(200)])
);
if let Err(e) =
handle_connect_from_buffer(stream, full_request, upstream, matcher).await
{
log::error!("Error handling CONNECT request: {:?}", e);
} else {
log::error!("DEBUG: CONNECT handled successfully");
}
return;
}
// Handle CONNECT manually
// Not CONNECT (or partial read) - reconstruct stream with consumed bytes prepended
// This is critical: we MUST prepend any bytes we consumed, even if < 7 bytes
log::error!(
"DEBUG: Handling CONNECT manually for: {}",
String::from_utf8_lossy(&full_request[..full_request.len().min(200)])
"DEBUG: Non-CONNECT request, first {} bytes: {:?}",
n,
String::from_utf8_lossy(&peek_buffer[..n.min(50)])
);
if let Err(e) =
handle_connect_from_buffer(stream, full_request, upstream, matcher).await
{
log::error!("Error handling CONNECT request: {:?}", e);
} else {
log::error!("DEBUG: CONNECT handled successfully");
let prepended_bytes = peek_buffer[..n].to_vec();
let prepended_reader = PrependReader {
prepended: prepended_bytes,
prepended_pos: 0,
inner: stream,
};
let io = TokioIo::new(prepended_reader);
let service =
service_fn(move |req| handle_request(req, upstream.clone(), matcher.clone()));
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
log::error!("Error serving connection: {:?}", err);
}
return;
}
// Not CONNECT (or partial read) - reconstruct stream with consumed bytes prepended
// This is critical: we MUST prepend any bytes we consumed, even if < 7 bytes
log::error!(
"DEBUG: Non-CONNECT request, first {} bytes: {:?}",
n,
String::from_utf8_lossy(&peek_buffer[..n.min(50)])
);
let prepended_bytes = peek_buffer[..n].to_vec();
let prepended_reader = PrependReader {
prepended: prepended_bytes,
prepended_pos: 0,
inner: stream,
};
let io = TokioIo::new(prepended_reader);
let service =
service_fn(move |req| handle_request(req, upstream.clone(), matcher.clone()));
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
log::error!("Error serving connection: {:?}", err);
Err(e) => {
log::error!("Error reading from connection: {:?}", e);
}
}
});
+63 -49
View File
@@ -1228,40 +1228,50 @@ async fn test_local_proxy_with_socks5_upstream(
let (socks_port, socks_handle) = start_mock_socks5_server().await;
println!("Mock SOCKS5 server on port {socks_port}");
// Start donut-proxy with socks5 upstream
let output = TestUtils::execute_command(
&binary_path,
&[
"proxy",
"start",
"--host",
"127.0.0.1",
"--proxy-port",
&socks_port.to_string(),
"--type",
"socks5",
],
)
.await?;
// Helper to start a socks5 proxy
async fn start_socks5_proxy(
binary_path: &std::path::PathBuf,
socks_port: u16,
) -> Result<(String, u16), Box<dyn std::error::Error + Send + Sync>> {
let output = TestUtils::execute_command(
binary_path,
&[
"proxy",
"start",
"--host",
"127.0.0.1",
"--proxy-port",
&socks_port.to_string(),
"--type",
"socks5",
],
)
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(format!("Proxy start failed: {stderr}").into());
}
let config: Value = serde_json::from_str(&String::from_utf8(output.stdout)?)?;
let id = config["id"].as_str().unwrap().to_string();
let port = config["localPort"].as_u64().unwrap() as u16;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let stdout = String::from_utf8_lossy(&output.stdout);
target_handle.abort();
socks_handle.abort();
return Err(format!("Proxy start failed - stdout: {stdout}, stderr: {stderr}").into());
// Wait for proxy to be fully ready by verifying it accepts and responds
for _ in 0..20 {
sleep(Duration::from_millis(100)).await;
if TcpStream::connect(("127.0.0.1", port)).await.is_ok() {
break;
}
}
// Extra settle time for the accept loop to be fully initialized
sleep(Duration::from_millis(200)).await;
Ok((id, port))
}
let stdout = String::from_utf8(output.stdout)?;
let config: Value = serde_json::from_str(&stdout)?;
let proxy_id = config["id"].as_str().unwrap().to_string();
let local_port = config["localPort"].as_u64().unwrap() as u16;
tracker.track_proxy(proxy_id.clone());
println!("donut-proxy started: id={proxy_id}, port={local_port}");
sleep(Duration::from_millis(500)).await;
// Test 1: HTTP request through donut-proxy -> SOCKS5 -> target
let (proxy_id, local_port) = start_socks5_proxy(&binary_path, socks_port).await?;
tracker.track_proxy(proxy_id);
let mut stream = TcpStream::connect(("127.0.0.1", local_port)).await?;
let request = format!(
"GET http://127.0.0.1:{target_port}/ HTTP/1.1\r\nHost: 127.0.0.1:{target_port}\r\nConnection: close\r\n\r\n"
@@ -1283,17 +1293,22 @@ async fn test_local_proxy_with_socks5_upstream(
println!("SOCKS5 HTTP proxy test passed");
drop(stream);
// Allow proxy to settle between tests
sleep(Duration::from_millis(500)).await;
// Test 2: CONNECT tunnel through a FRESH proxy -> SOCKS5 -> target
// Use a separate proxy instance so the first connection can't interfere.
// The raw TCP CONNECT handler can have timing sensitivity in test environments
// (passes reliably in production and with --nocapture), so retry a few times.
let (proxy_id2, local_port2) = start_socks5_proxy(&binary_path, socks_port).await?;
tracker.track_proxy(proxy_id2);
// Test 2: CONNECT tunnel through donut-proxy -> SOCKS5 -> target
// This is the critical path for HTTPS browsing.
// The proxy's raw TCP handler can race with prior connection cleanup, so retry.
let mut connect_ok = false;
for attempt in 1..=5 {
sleep(Duration::from_millis(200)).await;
let Ok(mut stream) = TcpStream::connect(("127.0.0.1", local_port)).await else {
continue;
let mut connect_passed = false;
for attempt in 1..=10 {
if attempt > 1 {
sleep(Duration::from_secs(1)).await;
}
let mut stream = match TcpStream::connect(("127.0.0.1", local_port2)).await {
Ok(s) => s,
Err(_) => continue,
};
let _ = stream.set_nodelay(true);
let connect_req =
@@ -1305,20 +1320,16 @@ async fn test_local_proxy_with_socks5_upstream(
let mut buf = [0u8; 4096];
let n = match tokio::time::timeout(Duration::from_secs(5), stream.read(&mut buf)).await {
Ok(Ok(n)) if n > 0 => n,
_ => {
println!("CONNECT attempt {attempt}/5: empty response, retrying");
continue;
}
_ => continue,
};
if !String::from_utf8_lossy(&buf[..n]).contains("200") {
continue;
}
// Tunnel established — send HTTP through it
let inner_req =
let inner_request =
format!("GET / HTTP/1.1\r\nHost: 127.0.0.1:{target_port}\r\nConnection: close\r\n\r\n");
if stream.write_all(inner_req.as_bytes()).await.is_err() {
if stream.write_all(inner_request.as_bytes()).await.is_err() {
continue;
}
@@ -1329,12 +1340,15 @@ async fn test_local_proxy_with_socks5_upstream(
};
if String::from_utf8_lossy(&resp[..n]).contains("SOCKS5-TARGET-RESPONSE") {
connect_ok = true;
connect_passed = true;
println!("SOCKS5 CONNECT tunnel test passed (attempt {attempt})");
break;
}
}
assert!(connect_ok, "CONNECT tunnel through SOCKS5 should work");
assert!(
connect_passed,
"CONNECT tunnel through SOCKS5 should work within 10 attempts"
);
tracker.cleanup_all().await;
target_handle.abort();