mirror of
https://github.com/zhom/donutbrowser.git
synced 2026-06-03 21:58:02 +02:00
feat: extension management
This commit is contained in:
@@ -6,6 +6,7 @@ use hyper::server::conn::http1;
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{Method, Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use regex_lite::Regex;
|
||||
use std::convert::Infallible;
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
@@ -18,6 +19,38 @@ use tokio::net::TcpListener;
|
||||
use tokio::net::TcpStream;
|
||||
use url::Url;
|
||||
|
||||
enum CompiledRule {
|
||||
Regex(Regex),
|
||||
Exact(String),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct BypassMatcher {
|
||||
rules: Arc<Vec<CompiledRule>>,
|
||||
}
|
||||
|
||||
impl BypassMatcher {
|
||||
pub fn new(rules: &[String]) -> Self {
|
||||
let compiled = rules
|
||||
.iter()
|
||||
.map(|rule| match Regex::new(rule) {
|
||||
Ok(re) => CompiledRule::Regex(re),
|
||||
Err(_) => CompiledRule::Exact(rule.clone()),
|
||||
})
|
||||
.collect();
|
||||
Self {
|
||||
rules: Arc::new(compiled),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn should_bypass(&self, host: &str) -> bool {
|
||||
self.rules.iter().any(|rule| match rule {
|
||||
CompiledRule::Regex(re) => re.is_match(host),
|
||||
CompiledRule::Exact(exact) => host == exact,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper stream that counts bytes read and written
|
||||
struct CountingStream<S> {
|
||||
inner: S,
|
||||
@@ -133,19 +166,21 @@ impl AsyncWrite for PrependReader {
|
||||
async fn handle_request(
|
||||
req: Request<hyper::body::Incoming>,
|
||||
upstream_url: Option<String>,
|
||||
bypass_matcher: BypassMatcher,
|
||||
) -> Result<Response<Full<Bytes>>, Infallible> {
|
||||
// Handle CONNECT method for HTTPS tunneling
|
||||
if req.method() == Method::CONNECT {
|
||||
return handle_connect(req, upstream_url).await;
|
||||
return handle_connect(req, upstream_url, bypass_matcher).await;
|
||||
}
|
||||
|
||||
// Handle regular HTTP requests
|
||||
handle_http(req, upstream_url).await
|
||||
handle_http(req, upstream_url, bypass_matcher).await
|
||||
}
|
||||
|
||||
async fn handle_connect(
|
||||
req: Request<hyper::body::Incoming>,
|
||||
upstream_url: Option<String>,
|
||||
bypass_matcher: BypassMatcher,
|
||||
) -> Result<Response<Full<Bytes>>, Infallible> {
|
||||
let authority = req.uri().authority().cloned();
|
||||
|
||||
@@ -161,12 +196,13 @@ async fn handle_connect(
|
||||
(&target_addr[..], 443)
|
||||
};
|
||||
|
||||
// If no upstream proxy, connect directly
|
||||
// If no upstream proxy, or bypass rule matches, connect directly
|
||||
if upstream_url.is_none()
|
||||
|| upstream_url
|
||||
.as_ref()
|
||||
.map(|s| s == "DIRECT")
|
||||
.unwrap_or(false)
|
||||
|| bypass_matcher.should_bypass(target_host)
|
||||
{
|
||||
match TcpStream::connect(&target_addr).await {
|
||||
Ok(_stream) => {
|
||||
@@ -674,6 +710,7 @@ async fn handle_http_via_socks4(
|
||||
async fn handle_http(
|
||||
req: Request<hyper::body::Incoming>,
|
||||
upstream_url: Option<String>,
|
||||
bypass_matcher: BypassMatcher,
|
||||
) -> Result<Response<Full<Bytes>>, Infallible> {
|
||||
// Extract domain for traffic tracking
|
||||
let domain = req
|
||||
@@ -689,13 +726,17 @@ async fn handle_http(
|
||||
req.uri().host()
|
||||
);
|
||||
|
||||
let should_bypass = bypass_matcher.should_bypass(&domain);
|
||||
|
||||
// Check if we need to handle SOCKS4 manually (reqwest doesn't support it)
|
||||
if let Some(ref upstream) = upstream_url {
|
||||
if upstream != "DIRECT" {
|
||||
if let Ok(url) = Url::parse(upstream) {
|
||||
if url.scheme() == "socks4" {
|
||||
// Handle SOCKS4 manually for HTTP requests
|
||||
return handle_http_via_socks4(req, upstream).await;
|
||||
if !should_bypass {
|
||||
if let Some(ref upstream) = upstream_url {
|
||||
if upstream != "DIRECT" {
|
||||
if let Ok(url) = Url::parse(upstream) {
|
||||
if url.scheme() == "socks4" {
|
||||
// Handle SOCKS4 manually for HTTP requests
|
||||
return handle_http_via_socks4(req, upstream).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -705,7 +746,9 @@ async fn handle_http(
|
||||
use reqwest::Client;
|
||||
|
||||
let client_builder = Client::builder();
|
||||
let client = if let Some(ref upstream) = upstream_url {
|
||||
let client = if should_bypass {
|
||||
client_builder.build().unwrap_or_default()
|
||||
} else if let Some(ref upstream) = upstream_url {
|
||||
if upstream == "DIRECT" {
|
||||
client_builder.build().unwrap_or_default()
|
||||
} else {
|
||||
@@ -1003,6 +1046,8 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box<dyn std::er
|
||||
}
|
||||
});
|
||||
|
||||
let bypass_matcher = BypassMatcher::new(&config.bypass_rules);
|
||||
|
||||
// Keep the runtime alive with an infinite loop
|
||||
// This ensures the process doesn't exit even if there are no active connections
|
||||
loop {
|
||||
@@ -1014,6 +1059,7 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box<dyn std::er
|
||||
log::error!("DEBUG: Accepted connection from {:?}", peer_addr);
|
||||
|
||||
let upstream = upstream_url.clone();
|
||||
let matcher = bypass_matcher.clone();
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
// Read first bytes to detect CONNECT requests
|
||||
@@ -1108,7 +1154,9 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box<dyn std::er
|
||||
"DEBUG: Handling CONNECT manually for: {}",
|
||||
String::from_utf8_lossy(&full_request[..full_request.len().min(200)])
|
||||
);
|
||||
if let Err(e) = handle_connect_from_buffer(stream, full_request, upstream).await {
|
||||
if let Err(e) =
|
||||
handle_connect_from_buffer(stream, full_request, upstream, matcher).await
|
||||
{
|
||||
log::error!("Error handling CONNECT request: {:?}", e);
|
||||
} else {
|
||||
log::error!("DEBUG: CONNECT handled successfully");
|
||||
@@ -1130,7 +1178,8 @@ pub async fn run_proxy_server(config: ProxyConfig) -> Result<(), Box<dyn std::er
|
||||
inner: stream,
|
||||
};
|
||||
let io = TokioIo::new(prepended_reader);
|
||||
let service = service_fn(move |req| handle_request(req, upstream.clone()));
|
||||
let service =
|
||||
service_fn(move |req| handle_request(req, upstream.clone(), matcher.clone()));
|
||||
|
||||
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
|
||||
log::error!("Error serving connection: {:?}", err);
|
||||
@@ -1156,6 +1205,7 @@ async fn handle_connect_from_buffer(
|
||||
mut client_stream: TcpStream,
|
||||
request_buffer: Vec<u8>,
|
||||
upstream_url: Option<String>,
|
||||
bypass_matcher: BypassMatcher,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Parse the CONNECT request from the buffer
|
||||
let request_str = String::from_utf8_lossy(&request_buffer);
|
||||
@@ -1193,6 +1243,7 @@ async fn handle_connect_from_buffer(
|
||||
}
|
||||
|
||||
// Connect to target (directly or via upstream proxy)
|
||||
let should_bypass = bypass_matcher.should_bypass(target_host);
|
||||
let target_stream = match upstream_url.as_ref() {
|
||||
None => {
|
||||
// Direct connection
|
||||
@@ -1202,6 +1253,10 @@ async fn handle_connect_from_buffer(
|
||||
// Direct connection
|
||||
TcpStream::connect((target_host, target_port)).await?
|
||||
}
|
||||
_ if should_bypass => {
|
||||
// Bypass rule matched - connect directly
|
||||
TcpStream::connect((target_host, target_port)).await?
|
||||
}
|
||||
Some(upstream_url_str) => {
|
||||
// Connect via upstream proxy
|
||||
let upstream = Url::parse(upstream_url_str)?;
|
||||
|
||||
Reference in New Issue
Block a user