feat: extension management

This commit is contained in:
zhom
2026-03-02 07:26:42 +04:00
parent a723c8b30b
commit 8a96d18e46
36 changed files with 3915 additions and 86 deletions
+67 -12
View File
@@ -6,6 +6,7 @@ use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use regex_lite::Regex;
use std::convert::Infallible;
use std::io;
use std::net::SocketAddr;
@@ -18,6 +19,38 @@ use tokio::net::TcpListener;
use tokio::net::TcpStream;
use url::Url;
enum CompiledRule {
Regex(Regex),
Exact(String),
}
#[derive(Clone)]
pub struct BypassMatcher {
rules: Arc<Vec<CompiledRule>>,
}
impl BypassMatcher {
pub fn new(rules: &[String]) -> Self {
let compiled = rules
.iter()
.map(|rule| match Regex::new(rule) {
Ok(re) => CompiledRule::Regex(re),
Err(_) => CompiledRule::Exact(rule.clone()),
})
.collect();
Self {
rules: Arc::new(compiled),
}
}
pub fn should_bypass(&self, host: &str) -> bool {
self.rules.iter().any(|rule| match rule {
CompiledRule::Regex(re) => re.is_match(host),
CompiledRule::Exact(exact) => host == exact,
})
}
}
/// Wrapper stream that counts bytes read and written
struct CountingStream<S> {
inner: S,
@@ -133,19 +166,21 @@ impl AsyncWrite for PrependReader {
async fn handle_request(
req: Request<hyper::body::Incoming>,
upstream_url: Option<String>,
bypass_matcher: BypassMatcher,
) -> Result<Response<Full<Bytes>>, Infallible> {
// Handle CONNECT method for HTTPS tunneling
if req.method() == Method::CONNECT {
return handle_connect(req, upstream_url).await;
return handle_connect(req, upstream_url, bypass_matcher).await;
}
// Handle regular HTTP requests
handle_http(req, upstream_url).await
handle_http(req, upstream_url, bypass_matcher).await
}
async fn handle_connect(
req: Request<hyper::body::Incoming>,
upstream_url: Option<String>,
bypass_matcher: BypassMatcher,
) -> Result<Response<Full<Bytes>>, Infallible> {
let authority = req.uri().authority().cloned();
@@ -161,12 +196,13 @@ async fn handle_connect(
(&target_addr[..], 443)
};
// If no upstream proxy, connect directly
// If no upstream proxy, or bypass rule matches, connect directly
if upstream_url.is_none()
|| upstream_url
.as_ref()
.map(|s| s == "DIRECT")
.unwrap_or(false)
|| bypass_matcher.should_bypass(target_host)
{
match TcpStream::connect(&target_addr).await {
Ok(_stream) => {
@@ -674,6 +710,7 @@ async fn handle_http_via_socks4(
async fn handle_http(
req: Request<hyper::body::Incoming>,
upstream_url: Option<String>,
bypass_matcher: BypassMatcher,
) -> Result<Response<Full<Bytes>>, Infallible> {
// Extract domain for traffic tracking
let domain = req
@@ -689,13 +726,17 @@ async fn handle_http(
req.uri().host()
);
let should_bypass = bypass_matcher.should_bypass(&domain);
// Check if we need to handle SOCKS4 manually (reqwest doesn't support it)
if let Some(ref upstream) = upstream_url {
if upstream != "DIRECT" {
if let Ok(url) = Url::parse(upstream) {
if url.scheme() == "socks4" {
// Handle SOCKS4 manually for HTTP requests
return handle_http_via_socks4(req, upstream).await;
if !should_bypass {
if let Some(ref upstream) = upstream_url {
if upstream != "DIRECT" {
if let Ok(url) = Url::parse(upstream) {
if url.scheme() == "socks4" {
// Handle SOCKS4 manually for HTTP requests
return handle_http_via_socks4(req, upstream).await;
}
}
}
}
@@ -705,7 +746,9 @@ async fn handle_http(
use reqwest::Client;
let client_builder = Client::builder();
let client = if let Some(ref upstream) = upstream_url {
let client = if should_bypass {
client_builder.build().unwrap_or_default()
} else if let Some(ref upstream) = upstream_url {
if upstream == "DIRECT" {
client_builder.build().unwrap_or_default()
} else {
@@ -1003,6 +1046,8 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box<dyn std::er
}
});
let bypass_matcher = BypassMatcher::new(&config.bypass_rules);
// Keep the runtime alive with an infinite loop
// This ensures the process doesn't exit even if there are no active connections
loop {
@@ -1014,6 +1059,7 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box<dyn std::er
log::error!("DEBUG: Accepted connection from {:?}", peer_addr);
let upstream = upstream_url.clone();
let matcher = bypass_matcher.clone();
tokio::task::spawn(async move {
// Read first bytes to detect CONNECT requests
@@ -1108,7 +1154,9 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box<dyn std::er
"DEBUG: Handling CONNECT manually for: {}",
String::from_utf8_lossy(&full_request[..full_request.len().min(200)])
);
if let Err(e) = handle_connect_from_buffer(stream, full_request, upstream).await {
if let Err(e) =
handle_connect_from_buffer(stream, full_request, upstream, matcher).await
{
log::error!("Error handling CONNECT request: {:?}", e);
} else {
log::error!("DEBUG: CONNECT handled successfully");
@@ -1130,7 +1178,8 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box<dyn std::er
inner: stream,
};
let io = TokioIo::new(prepended_reader);
let service = service_fn(move |req| handle_request(req, upstream.clone()));
let service =
service_fn(move |req| handle_request(req, upstream.clone(), matcher.clone()));
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
log::error!("Error serving connection: {:?}", err);
@@ -1156,6 +1205,7 @@ async fn handle_connect_from_buffer(
mut client_stream: TcpStream,
request_buffer: Vec<u8>,
upstream_url: Option<String>,
bypass_matcher: BypassMatcher,
) -> Result<(), Box<dyn std::error::Error>> {
// Parse the CONNECT request from the buffer
let request_str = String::from_utf8_lossy(&request_buffer);
@@ -1193,6 +1243,7 @@ async fn handle_connect_from_buffer(
}
// Connect to target (directly or via upstream proxy)
let should_bypass = bypass_matcher.should_bypass(target_host);
let target_stream = match upstream_url.as_ref() {
None => {
// Direct connection
@@ -1202,6 +1253,10 @@ async fn handle_connect_from_buffer(
// Direct connection
TcpStream::connect((target_host, target_port)).await?
}
_ if should_bypass => {
// Bypass rule matched - connect directly
TcpStream::connect((target_host, target_port)).await?
}
Some(upstream_url_str) => {
// Connect via upstream proxy
let upstream = Url::parse(upstream_url_str)?;