feat: daemon support, general improvement, and preparation for Windows release

This commit is contained in:
zhom
2026-02-01 20:55:09 +04:00
parent e9f4edd120
commit 4a59459eb2
58 changed files with 9763 additions and 296 deletions
+8 -60
View File
@@ -9,10 +9,10 @@ use std::path::PathBuf;
use std::process;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc;
use std::time::{Duration, Instant};
use muda::MenuEvent;
use serde::{Deserialize, Serialize};
use single_instance::SingleInstance;
use tao::event::{Event, StartCause};
use tao::event_loop::{ControlFlow, EventLoopBuilder};
use tokio::runtime::Runtime;
@@ -69,52 +69,6 @@ fn write_state(state: &DaemonState) -> std::io::Result<()> {
fs::write(path, content)
}
fn detach_from_parent() {
#[cfg(unix)]
{
unsafe {
libc::setsid();
}
}
}
fn spawn_detached() {
#[cfg(unix)]
{
match unsafe { libc::fork() } {
-1 => {
eprintln!("Fork failed");
process::exit(1);
}
0 => {
detach_from_parent();
}
_ => {
process::exit(0);
}
}
}
#[cfg(windows)]
{
use std::os::windows::process::CommandExt;
use std::process::{Command, Stdio};
const DETACHED_PROCESS: u32 = 0x00000008;
const CREATE_NEW_PROCESS_GROUP: u32 = 0x00000200;
let current_exe = env::current_exe().expect("Failed to get current exe path");
let _ = Command::new(current_exe)
.arg("--daemon-internal")
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.creation_flags(DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP)
.spawn();
process::exit(0);
}
}
fn set_high_priority() {
#[cfg(unix)]
{
@@ -174,13 +128,6 @@ fn run_daemon() {
process::exit(1);
}
let instance =
SingleInstance::new("donut-browser-daemon").expect("Failed to create single instance lock");
if !instance.is_single() {
eprintln!("Daemon is already running");
process::exit(1);
}
log::info!("[daemon] Starting with PID {}", process::id());
// Create tokio runtime for async operations
@@ -231,7 +178,8 @@ fn run_daemon() {
// Run the event loop
event_loop.run(move |event, _, control_flow| {
*control_flow = ControlFlow::Poll;
// Use WaitUntil to check for menu events periodically while staying low on CPU
*control_flow = ControlFlow::WaitUntil(Instant::now() + Duration::from_millis(100));
match event {
Event::NewEvents(StartCause::Init) => {
@@ -290,7 +238,8 @@ fn run_daemon() {
}
}
if SHOULD_QUIT.load(Ordering::SeqCst) {
// Use swap to only run cleanup once
if SHOULD_QUIT.swap(false, Ordering::SeqCst) {
// Cleanup
let mut state = read_state();
state.daemon_pid = None;
@@ -405,8 +354,10 @@ fn main() {
match args[1].as_str() {
"start" => {
// "start" is now an alias for "run"
// On macOS, the daemon should be started via launchctl (see daemon_spawn.rs)
// This command is kept for backward compatibility
eprintln!("Starting daemon...");
spawn_detached();
run_daemon();
}
"stop" => {
@@ -418,9 +369,6 @@ fn main() {
"run" => {
run_daemon();
}
"--daemon-internal" => {
run_daemon();
}
"autostart" => {
if args.len() < 3 {
eprintln!("Usage: donut-daemon autostart <enable|disable|status>");
+94 -15
View File
@@ -80,6 +80,12 @@ pub fn enable_autostart() -> io::Result<()> {
let plist_path = plist_dir.join("com.donutbrowser.daemon.plist");
// Get log directory (use data directory instead of /tmp)
let log_dir = get_data_dir()
.unwrap_or_else(|| PathBuf::from("/tmp"))
.join("logs");
fs::create_dir_all(&log_dir)?;
let plist_content = format!(
r#"<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
@@ -89,21 +95,29 @@ pub fn enable_autostart() -> io::Result<()> {
<string>com.donutbrowser.daemon</string>
<key>ProgramArguments</key>
<array>
<string>{}</string>
<string>start</string>
<string>{daemon_path}</string>
<string>run</string>
</array>
<key>RunAtLoad</key>
<true/>
<key>LimitLoadToSessionType</key>
<string>Aqua</string>
<key>KeepAlive</key>
<false/>
<dict>
<key>SuccessfulExit</key>
<false/>
</dict>
<key>ProcessType</key>
<string>Interactive</string>
<key>StandardOutPath</key>
<string>/tmp/donut-daemon.out.log</string>
<string>{log_dir}/daemon.out.log</string>
<key>StandardErrorPath</key>
<string>/tmp/donut-daemon.err.log</string>
<string>{log_dir}/daemon.err.log</string>
</dict>
</plist>
"#,
daemon_path.display()
daemon_path = daemon_path.display(),
log_dir = log_dir.display()
);
fs::write(&plist_path, plist_content)?;
@@ -112,13 +126,19 @@ pub fn enable_autostart() -> io::Result<()> {
Ok(())
}
#[cfg(target_os = "macos")]
pub fn get_plist_path() -> Option<PathBuf> {
dirs::home_dir().map(|h| h.join("Library/LaunchAgents/com.donutbrowser.daemon.plist"))
}
#[cfg(target_os = "macos")]
pub fn disable_autostart() -> io::Result<()> {
let plist_path = dirs::home_dir()
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Home directory not found"))?
.join("Library/LaunchAgents/com.donutbrowser.daemon.plist");
let plist_path = get_plist_path()
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Home directory not found"))?;
if plist_path.exists() {
// First unload the launch agent if it's loaded
let _ = unload_launch_agent();
fs::remove_file(&plist_path)?;
log::info!("Removed launch agent at {:?}", plist_path);
}
@@ -128,12 +148,71 @@ pub fn disable_autostart() -> io::Result<()> {
#[cfg(target_os = "macos")]
pub fn is_autostart_enabled() -> bool {
dirs::home_dir()
.map(|h| {
h.join("Library/LaunchAgents/com.donutbrowser.daemon.plist")
.exists()
})
.unwrap_or(false)
get_plist_path().is_some_and(|p| p.exists())
}
#[cfg(target_os = "macos")]
pub fn load_launch_agent() -> io::Result<()> {
use std::process::Command;
let plist_path = get_plist_path()
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Could not determine plist path"))?;
if !plist_path.exists() {
return Err(io::Error::new(
io::ErrorKind::NotFound,
"Launch agent plist does not exist",
));
}
// Use launchctl load to start the daemon via launchd
// The -w flag writes the "disabled" key to the override plist
let output = Command::new("launchctl")
.args(["load", "-w"])
.arg(&plist_path)
.output()?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
// "already loaded" is not an error condition for us
if !stderr.contains("already loaded") {
return Err(io::Error::other(format!(
"launchctl load failed: {}",
stderr
)));
}
}
log::info!("Loaded launch agent via launchctl");
Ok(())
}
#[cfg(target_os = "macos")]
pub fn unload_launch_agent() -> io::Result<()> {
use std::process::Command;
let plist_path = get_plist_path()
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Could not determine plist path"))?;
if !plist_path.exists() {
return Ok(());
}
let output = Command::new("launchctl")
.args(["unload"])
.arg(&plist_path)
.output()?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
// Not being loaded is not an error
if !stderr.contains("Could not find specified service") {
log::warn!("launchctl unload warning: {}", stderr);
}
}
log::info!("Unloaded launch agent via launchctl");
Ok(())
}
#[cfg(target_os = "linux")]
+11 -5
View File
@@ -6,7 +6,9 @@ use tray_icon::{Icon, TrayIcon, TrayIconBuilder};
static GUI_RUNNING: AtomicBool = AtomicBool::new(false);
pub fn load_icon() -> Icon {
let icon_bytes = include_bytes!("../../icons/32x32.png");
// Use the generated template icon (44x44 for retina, macOS standard menu bar size)
// This is the donut logo converted to template format (black with alpha)
let icon_bytes = include_bytes!("../../icons/tray-icon-44.png");
let image = image::load_from_memory(icon_bytes)
.expect("Failed to load icon")
@@ -89,12 +91,16 @@ impl TrayMenu {
}
pub fn create_tray_icon(icon: Icon, menu: &Menu) -> TrayIcon {
TrayIconBuilder::new()
let builder = TrayIconBuilder::new()
.with_icon(icon)
.with_tooltip("Donut Browser")
.with_menu(Box::new(menu.clone()))
.build()
.expect("Failed to create tray icon")
.with_menu(Box::new(menu.clone()));
// On macOS, template icons are automatically colored by the system for light/dark mode
#[cfg(target_os = "macos")]
let builder = builder.with_icon_as_template(true);
builder.build().expect("Failed to create tray icon")
}
pub fn open_gui() {
+161 -53
View File
@@ -60,6 +60,38 @@ fn is_daemon_running() -> bool {
}
}
fn is_dev_mode() -> bool {
if let Ok(current_exe) = std::env::current_exe() {
let path_str = current_exe.to_string_lossy();
path_str.contains("target/debug") || path_str.contains("target/release")
} else {
false
}
}
#[cfg(target_os = "macos")]
fn get_daemon_path() -> Option<PathBuf> {
// First try to find the daemon binary next to the current executable
if let Ok(current_exe) = std::env::current_exe() {
if let Some(exe_dir) = current_exe.parent() {
let daemon_path = exe_dir.join("donut-daemon");
if daemon_path.exists() {
return Some(daemon_path);
}
}
}
// Try common installation paths
let paths = [
PathBuf::from("/Applications/Donut Browser.app/Contents/MacOS/donut-daemon"),
dirs::home_dir()
.map(|h| h.join("Applications/Donut Browser.app/Contents/MacOS/donut-daemon"))
.unwrap_or_default(),
];
paths.into_iter().find(|path| path.exists())
}
#[cfg(any(target_os = "linux", windows))]
fn get_daemon_path() -> Option<PathBuf> {
// First, try to find it next to the current executable
if let Ok(current_exe) = std::env::current_exe() {
@@ -68,25 +100,13 @@ fn get_daemon_path() -> Option<PathBuf> {
// Check for daemon binary in same directory
#[cfg(target_os = "windows")]
let daemon_name = "donut-daemon.exe";
#[cfg(not(target_os = "windows"))]
#[cfg(target_os = "linux")]
let daemon_name = "donut-daemon";
let daemon_path = exe_dir.join(daemon_name);
if daemon_path.exists() {
return Some(daemon_path);
}
// On macOS, check inside the app bundle
#[cfg(target_os = "macos")]
{
// If we're in Contents/MacOS, daemon should be there too
if exe_dir.ends_with("Contents/MacOS") {
let daemon_path = exe_dir.join(daemon_name);
if daemon_path.exists() {
return Some(daemon_path);
}
}
}
}
// Try to find it in PATH
@@ -101,7 +121,7 @@ fn get_daemon_path() -> Option<PathBuf> {
}
}
#[cfg(unix)]
#[cfg(target_os = "linux")]
{
if let Ok(output) = Command::new("which").arg("donut-daemon").output() {
if output.status.success() {
@@ -118,60 +138,39 @@ fn get_daemon_path() -> Option<PathBuf> {
}
pub fn spawn_daemon() -> Result<(), String> {
// Log the daemon state for debugging
let state = read_state();
log::info!("Daemon state before spawn: pid={:?}", state.daemon_pid);
// Check if already running
if is_daemon_running() {
log::info!("Daemon is already running");
log::info!("Daemon is already running (verified by PID check)");
return Ok(());
}
log::info!("Daemon is not running, attempting to start...");
// Log current exe location for debugging
let current_exe = std::env::current_exe().ok();
log::info!("Current exe: {:?}", current_exe);
let daemon_path = get_daemon_path().ok_or_else(|| {
format!(
"Could not find daemon binary. Current exe: {:?}",
current_exe
)
})?;
log::info!("Spawning daemon from: {:?}", daemon_path);
// Use "run" instead of "start" - we handle detachment here
#[cfg(unix)]
// On macOS, use launchctl to start the daemon via launchd
// This ensures the daemon runs in the user's Aqua session with WindowServer access
// and survives app termination since it's managed by launchd, not as a child process
#[cfg(target_os = "macos")]
{
use std::os::unix::process::CommandExt;
spawn_daemon_macos()?;
}
// Create a new process group so daemon survives parent exit
// Note: We don't call setsid() because on macOS that disconnects from the WindowServer
// which prevents the tray icon from appearing. Instead, we just set a new process group.
let mut cmd = Command::new(&daemon_path);
cmd
.arg("run")
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.process_group(0); // Create new process group without new session
cmd
.spawn()
.map_err(|e| format!("Failed to spawn daemon: {}", e))?;
// On Linux, use direct spawn
#[cfg(target_os = "linux")]
{
spawn_daemon_unix()?;
}
#[cfg(windows)]
{
use std::os::windows::process::CommandExt;
const DETACHED_PROCESS: u32 = 0x00000008;
const CREATE_NEW_PROCESS_GROUP: u32 = 0x00000200;
Command::new(&daemon_path)
.arg("run")
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.creation_flags(DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP)
.spawn()
.map_err(|e| format!("Failed to spawn daemon: {}", e))?;
spawn_daemon_windows()?;
}
// Wait for daemon to start (max 3 seconds)
@@ -196,6 +195,115 @@ pub fn spawn_daemon() -> Result<(), String> {
Err("Daemon did not start within timeout".to_string())
}
#[cfg(target_os = "macos")]
fn spawn_daemon_macos() -> Result<(), String> {
use std::os::unix::process::CommandExt;
// In dev mode, use direct spawn instead of launchctl
// This avoids issues with plist paths pointing to wrong binaries
if is_dev_mode() {
log::info!("Dev mode detected, using direct spawn instead of launchctl");
let daemon_path = get_daemon_path().ok_or_else(|| {
format!(
"Could not find daemon binary. Current exe: {:?}",
std::env::current_exe().ok()
)
})?;
log::info!("Spawning daemon from: {:?}", daemon_path);
// Create a new process group so daemon survives parent exit
let mut cmd = Command::new(&daemon_path);
cmd
.arg("run")
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.process_group(0);
cmd
.spawn()
.map_err(|e| format!("Failed to spawn daemon: {}", e))?;
return Ok(());
}
// Production mode: use launchctl for proper daemon management
// First, ensure the LaunchAgent plist is installed
let autostart_enabled = autostart::is_autostart_enabled();
log::info!("LaunchAgent plist exists: {}", autostart_enabled);
if !autostart_enabled {
log::info!("Installing LaunchAgent plist for daemon management");
autostart::enable_autostart().map_err(|e| format!("Failed to install LaunchAgent: {}", e))?;
log::info!("LaunchAgent plist installed successfully");
}
// Load the launch agent via launchctl
log::info!("Loading daemon via launchctl...");
autostart::load_launch_agent().map_err(|e| format!("Failed to load LaunchAgent: {}", e))?;
log::info!("launchctl load completed");
Ok(())
}
#[cfg(target_os = "linux")]
fn spawn_daemon_unix() -> Result<(), String> {
use std::os::unix::process::CommandExt;
let daemon_path = get_daemon_path().ok_or_else(|| {
format!(
"Could not find daemon binary. Current exe: {:?}",
std::env::current_exe().ok()
)
})?;
log::info!("Spawning daemon from: {:?}", daemon_path);
// Create a new process group so daemon survives parent exit
let mut cmd = Command::new(&daemon_path);
cmd
.arg("run")
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.process_group(0);
cmd
.spawn()
.map_err(|e| format!("Failed to spawn daemon: {}", e))?;
Ok(())
}
#[cfg(windows)]
fn spawn_daemon_windows() -> Result<(), String> {
use std::os::windows::process::CommandExt;
const DETACHED_PROCESS: u32 = 0x00000008;
const CREATE_NEW_PROCESS_GROUP: u32 = 0x00000200;
let daemon_path = get_daemon_path().ok_or_else(|| {
format!(
"Could not find daemon binary. Current exe: {:?}",
std::env::current_exe().ok()
)
})?;
log::info!("Spawning daemon from: {:?}", daemon_path);
Command::new(&daemon_path)
.arg("run")
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.creation_flags(DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP)
.spawn()
.map_err(|e| format!("Failed to spawn daemon: {}", e))?;
Ok(())
}
pub fn ensure_daemon_running() -> Result<(), String> {
if !is_daemon_running() {
spawn_daemon()?;
+81 -5
View File
@@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
use std::io;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use tokio_util::sync::CancellationToken;
use crate::api_client::ApiClient;
use crate::browser::{create_browser, BrowserType};
@@ -13,6 +14,8 @@ use crate::events;
lazy_static::lazy_static! {
static ref DOWNLOADING_BROWSERS: std::sync::Arc<Mutex<std::collections::HashSet<String>>> =
std::sync::Arc::new(Mutex::new(std::collections::HashSet::new()));
static ref DOWNLOAD_CANCELLATION_TOKENS: std::sync::Arc<Mutex<std::collections::HashMap<String, CancellationToken>>> =
std::sync::Arc::new(Mutex::new(std::collections::HashMap::new()));
}
#[derive(Debug, Serialize, Deserialize, Clone)]
@@ -438,6 +441,7 @@ impl Downloader {
version: &str,
download_info: &DownloadInfo,
dest_path: &Path,
cancel_token: Option<&CancellationToken>,
) -> Result<PathBuf, Box<dyn std::error::Error + Send + Sync>> {
let file_path = dest_path.join(&download_info.filename);
@@ -573,6 +577,13 @@ impl Downloader {
use futures_util::StreamExt;
while let Some(chunk) = stream.next().await {
if let Some(token) = cancel_token {
if token.is_cancelled() {
drop(file);
let _ = std::fs::remove_file(&file_path);
return Err("Download cancelled".into());
}
}
let chunk = chunk?;
io::copy(&mut chunk.as_ref(), &mut file)?;
downloaded += chunk.len() as u64;
@@ -635,21 +646,27 @@ impl Downloader {
browser_str: String,
version: String,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
// Check if Wayfern terms have been accepted before allowing any browser downloads
if !crate::wayfern_terms::WayfernTermsManager::instance().is_terms_accepted() {
// Only check Wayfern terms if Wayfern is already downloaded
let terms_manager = crate::wayfern_terms::WayfernTermsManager::instance();
if terms_manager.is_wayfern_downloaded() && !terms_manager.is_terms_accepted() {
return Err("Please accept Wayfern Terms and Conditions before downloading browsers".into());
}
// Check if this browser-version pair is already being downloaded
let download_key = format!("{browser_str}-{version}");
{
let cancel_token = {
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
if downloading.contains(&download_key) {
return Err(format!("Browser '{browser_str}' version '{version}' is already being downloaded. Please wait for the current download to complete.").into());
}
// Mark this browser-version pair as being downloaded
downloading.insert(download_key.clone());
}
let token = CancellationToken::new();
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
tokens.insert(download_key.clone(), token.clone());
token
};
let browser_type =
BrowserType::from_str(&browser_str).map_err(|e| format!("Invalid browser type: {e}"))?;
@@ -681,6 +698,9 @@ impl Downloader {
// Remove from downloading set since it's already downloaded
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
downloading.remove(&download_key);
drop(downloading);
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
tokens.remove(&download_key);
return Ok(version);
} else {
// Registry says it's downloaded but files don't exist - clean up registry
@@ -702,6 +722,9 @@ impl Downloader {
// Remove from downloading set on error
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
downloading.remove(&download_key);
drop(downloading);
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
tokens.remove(&download_key);
return Err(
format!(
"Browser '{}' is not supported on your platform ({} {}). Supported browsers: {}",
@@ -741,6 +764,7 @@ impl Downloader {
&version,
&download_info,
&browser_dir,
Some(&cancel_token),
)
.await
{
@@ -752,6 +776,25 @@ impl Downloader {
let _ = self.registry.save();
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
downloading.remove(&download_key);
drop(downloading);
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
tokens.remove(&download_key);
// Emit cancelled stage if the download was cancelled by user
if cancel_token.is_cancelled() {
let progress = DownloadProgress {
browser: browser_str.clone(),
version: version.clone(),
downloaded_bytes: 0,
total_bytes: None,
percentage: 0.0,
speed_bytes_per_sec: 0.0,
eta_seconds: None,
stage: "cancelled".to_string(),
};
let _ = events::emit("download-progress", &progress);
}
return Err(format!("Failed to download browser: {e}").into());
}
};
@@ -782,6 +825,10 @@ impl Downloader {
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
downloading.remove(&download_key);
}
{
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
tokens.remove(&download_key);
}
return Err(format!("Failed to extract browser: {e}").into());
}
}
@@ -869,6 +916,10 @@ impl Downloader {
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
downloading.remove(&download_key);
}
{
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
tokens.remove(&download_key);
}
return Err(error_details.into());
}
@@ -941,11 +992,15 @@ impl Downloader {
};
let _ = events::emit("download-progress", &progress);
// Remove browser-version pair from downloading set
// Remove browser-version pair from downloading set and cancel token
{
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
downloading.remove(&download_key);
}
{
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
tokens.remove(&download_key);
}
Ok(version)
}
@@ -964,6 +1019,24 @@ pub async fn download_browser(
.map_err(|e| format!("Failed to download browser: {e}"))
}
#[tauri::command]
pub async fn cancel_download(browser_str: String, version: String) -> Result<(), String> {
let download_key = format!("{browser_str}-{version}");
let token = {
let tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
tokens.get(&download_key).cloned()
};
if let Some(token) = token {
token.cancel();
Ok(())
} else {
Err(format!(
"No active download found for {browser_str} {version}"
))
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -1074,6 +1147,7 @@ mod tests {
"139.0",
&download_info,
dest_path,
None,
)
.await;
@@ -1118,6 +1192,7 @@ mod tests {
"139.0",
&download_info,
dest_path,
None,
)
.await;
@@ -1163,6 +1238,7 @@ mod tests {
"1465660",
&download_info,
dest_path,
None,
)
.await;
+248 -4
View File
@@ -47,6 +47,7 @@ pub mod events;
mod mcp_server;
mod tag_manager;
mod version_updater;
pub mod vpn;
use browser_runner::{
check_browser_exists, kill_browser_profile, launch_browser_profile, open_url_with_profile,
@@ -68,12 +69,12 @@ use downloaded_browsers_registry::{
check_missing_binaries, ensure_all_binaries_exist, get_downloaded_browser_versions,
};
use downloader::download_browser;
use downloader::{cancel_download, download_browser};
use settings_manager::{
decline_launch_on_login, enable_launch_on_login, get_app_settings, get_sync_settings,
get_table_sorting_settings, save_app_settings, save_sync_settings, save_table_sorting_settings,
should_show_launch_on_login_prompt,
get_system_language, get_table_sorting_settings, save_app_settings, save_sync_settings,
save_table_sorting_settings, should_show_launch_on_login_prompt,
};
use sync::{
@@ -232,6 +233,41 @@ fn get_cached_proxy_check(proxy_id: String) -> Option<crate::proxy_manager::Prox
crate::proxy_manager::PROXY_MANAGER.get_cached_proxy_check(&proxy_id)
}
#[tauri::command]
fn export_proxies(format: String) -> Result<String, String> {
match format.as_str() {
"json" => crate::proxy_manager::PROXY_MANAGER.export_proxies_json(),
"txt" => Ok(crate::proxy_manager::PROXY_MANAGER.export_proxies_txt()),
_ => Err(format!("Unsupported export format: {format}")),
}
}
#[tauri::command]
async fn import_proxies_json(
app_handle: tauri::AppHandle,
content: String,
) -> Result<crate::proxy_manager::ProxyImportResult, String> {
crate::proxy_manager::PROXY_MANAGER
.import_proxies_json(&app_handle, &content)
.map_err(|e| format!("Failed to import proxies: {e}"))
}
#[tauri::command]
fn parse_txt_proxies(content: String) -> Vec<crate::proxy_manager::ProxyParseResult> {
crate::proxy_manager::ProxyManager::parse_txt_proxies(&content)
}
#[tauri::command]
async fn import_proxies_from_parsed(
app_handle: tauri::AppHandle,
parsed_proxies: Vec<crate::proxy_manager::ParsedProxyLine>,
name_prefix: Option<String>,
) -> Result<crate::proxy_manager::ProxyImportResult, String> {
crate::proxy_manager::PROXY_MANAGER
.import_proxies_from_parsed(&app_handle, parsed_proxies, name_prefix)
.map_err(|e| format!("Failed to import proxies: {e}"))
}
#[tauri::command]
fn read_profile_cookies(profile_id: String) -> Result<cookie_manager::CookieReadResult, String> {
cookie_manager::CookieManager::read_cookies(&profile_id)
@@ -250,6 +286,11 @@ fn check_wayfern_terms_accepted() -> bool {
wayfern_terms::WayfernTermsManager::instance().is_terms_accepted()
}
#[tauri::command]
fn check_wayfern_downloaded() -> bool {
wayfern_terms::WayfernTermsManager::instance().is_wayfern_downloaded()
}
#[tauri::command]
async fn accept_wayfern_terms() -> Result<(), String> {
wayfern_terms::WayfernTermsManager::instance()
@@ -374,6 +415,172 @@ async fn download_geoip_database(app_handle: tauri::AppHandle) -> Result<(), Str
.map_err(|e| format!("Failed to download GeoIP database: {e}"))
}
// VPN commands
#[tauri::command]
async fn import_vpn_config(
content: String,
filename: String,
name: Option<String>,
) -> Result<vpn::VpnImportResult, String> {
let storage = vpn::VPN_STORAGE
.lock()
.map_err(|e| format!("Failed to lock VPN storage: {e}"))?;
match storage.import_config(&content, &filename, name.clone()) {
Ok(config) => Ok(vpn::VpnImportResult {
success: true,
vpn_id: Some(config.id),
vpn_type: Some(config.vpn_type),
name: config.name,
error: None,
}),
Err(e) => Ok(vpn::VpnImportResult {
success: false,
vpn_id: None,
vpn_type: None,
name: name.unwrap_or_else(|| filename.clone()),
error: Some(e.to_string()),
}),
}
}
#[tauri::command]
async fn list_vpn_configs() -> Result<Vec<vpn::VpnConfig>, String> {
let storage = vpn::VPN_STORAGE
.lock()
.map_err(|e| format!("Failed to lock VPN storage: {e}"))?;
storage
.list_configs()
.map_err(|e| format!("Failed to list VPN configs: {e}"))
}
#[tauri::command]
async fn get_vpn_config(vpn_id: String) -> Result<vpn::VpnConfig, String> {
let storage = vpn::VPN_STORAGE
.lock()
.map_err(|e| format!("Failed to lock VPN storage: {e}"))?;
storage
.load_config(&vpn_id)
.map_err(|e| format!("Failed to load VPN config: {e}"))
}
#[tauri::command]
async fn delete_vpn_config(vpn_id: String) -> Result<(), String> {
// First disconnect if connected
{
let mut manager = vpn::TUNNEL_MANAGER.lock().await;
if manager.is_tunnel_active(&vpn_id) {
if let Some(tunnel) = manager.get_tunnel_mut(&vpn_id) {
let _ = tunnel.disconnect().await;
}
manager.remove_tunnel(&vpn_id);
}
}
// Then delete from storage
let storage = vpn::VPN_STORAGE
.lock()
.map_err(|e| format!("Failed to lock VPN storage: {e}"))?;
storage
.delete_config(&vpn_id)
.map_err(|e| format!("Failed to delete VPN config: {e}"))
}
#[tauri::command]
async fn connect_vpn(vpn_id: String) -> Result<(), String> {
// Load config from storage
let config = {
let storage = vpn::VPN_STORAGE
.lock()
.map_err(|e| format!("Failed to lock VPN storage: {e}"))?;
storage
.load_config(&vpn_id)
.map_err(|e| format!("Failed to load VPN config: {e}"))?
};
// Create and connect the appropriate tunnel
let mut manager = vpn::TUNNEL_MANAGER.lock().await;
// Check if already connected
if manager.is_tunnel_active(&vpn_id) {
return Ok(());
}
let mut tunnel: Box<dyn vpn::VpnTunnel> = match config.vpn_type {
vpn::VpnType::WireGuard => {
let wg_config = vpn::parse_wireguard_config(&config.config_data)
.map_err(|e| format!("Invalid WireGuard config: {e}"))?;
Box::new(vpn::WireGuardTunnel::new(vpn_id.clone(), wg_config))
}
vpn::VpnType::OpenVPN => {
let ovpn_config = vpn::parse_openvpn_config(&config.config_data)
.map_err(|e| format!("Invalid OpenVPN config: {e}"))?;
Box::new(vpn::OpenVpnTunnel::new(vpn_id.clone(), ovpn_config))
}
};
tunnel
.connect()
.await
.map_err(|e| format!("Failed to connect VPN: {e}"))?;
manager.register_tunnel(vpn_id.clone(), tunnel);
// Update last_used timestamp
{
let storage = vpn::VPN_STORAGE
.lock()
.map_err(|e| format!("Failed to lock VPN storage: {e}"))?;
let _ = storage.update_last_used(&vpn_id);
}
Ok(())
}
#[tauri::command]
async fn disconnect_vpn(vpn_id: String) -> Result<(), String> {
let mut manager = vpn::TUNNEL_MANAGER.lock().await;
if let Some(tunnel) = manager.get_tunnel_mut(&vpn_id) {
tunnel
.disconnect()
.await
.map_err(|e| format!("Failed to disconnect VPN: {e}"))?;
}
manager.remove_tunnel(&vpn_id);
Ok(())
}
#[tauri::command]
async fn get_vpn_status(vpn_id: String) -> Result<vpn::VpnStatus, String> {
let manager = vpn::TUNNEL_MANAGER.lock().await;
if let Some(tunnel) = manager.get_tunnel(&vpn_id) {
Ok(tunnel.get_status())
} else {
// Not connected
Ok(vpn::VpnStatus {
connected: false,
vpn_id,
connected_at: None,
bytes_sent: None,
bytes_received: None,
last_handshake: None,
})
}
}
#[tauri::command]
async fn list_active_vpn_connections() -> Result<Vec<vpn::VpnStatus>, String> {
let manager = vpn::TUNNEL_MANAGER.lock().await;
Ok(manager.get_all_statuses())
}
#[cfg_attr(mobile, tauri::mobile_entry_point)]
pub fn run() {
let args: Vec<String> = env::args().collect();
@@ -883,6 +1090,7 @@ pub fn run() {
get_supported_browsers,
is_browser_supported_on_platform,
download_browser,
cancel_download,
delete_profile,
check_browser_exists,
create_browser_profile_new,
@@ -907,6 +1115,7 @@ pub fn run() {
decline_launch_on_login,
get_table_sorting_settings,
save_table_sorting_settings,
get_system_language,
clear_all_version_cache_and_refetch,
is_default_browser,
open_url_with_profile,
@@ -931,6 +1140,10 @@ pub fn run() {
delete_stored_proxy,
check_proxy_validity,
get_cached_proxy_check,
export_proxies,
import_proxies_json,
parse_txt_proxies,
import_proxies_from_parsed,
update_camoufox_config,
update_wayfern_config,
get_profile_groups,
@@ -959,6 +1172,7 @@ pub fn run() {
read_profile_cookies,
copy_profile_cookies,
check_wayfern_terms_accepted,
check_wayfern_downloaded,
accept_wayfern_terms,
get_commercial_trial_status,
acknowledge_trial_expiration,
@@ -966,7 +1180,16 @@ pub fn run() {
start_mcp_server,
stop_mcp_server,
get_mcp_server_status,
get_mcp_config
get_mcp_config,
// VPN commands
import_vpn_config,
list_vpn_configs,
get_vpn_config,
delete_vpn_config,
connect_vpn,
disconnect_vpn,
get_vpn_status,
list_active_vpn_connections
])
.run(tauri::generate_context!())
.expect("error while running tauri application");
@@ -987,6 +1210,18 @@ mod tests {
}
fn check_unused_commands(verbose: bool) {
// Commands that are intentionally not used in the frontend
// but are used via MCP server or other programmatic APIs
let mcp_only_commands = [
"list_vpn_configs",
"get_vpn_config",
"delete_vpn_config",
"connect_vpn",
"disconnect_vpn",
"get_vpn_status",
"list_active_vpn_connections",
];
// Extract command names from the generate_handler! macro in this file
let lib_rs_content = fs::read_to_string("src/lib.rs").expect("Failed to read lib.rs");
let commands = extract_tauri_commands(&lib_rs_content);
@@ -999,6 +1234,15 @@ mod tests {
let mut used_commands = Vec::new();
for command in &commands {
// Skip commands that are intentionally MCP-only
if mcp_only_commands.contains(&command.as_str()) {
used_commands.push(command.clone());
if verbose {
println!("{command} (MCP-only)");
}
continue;
}
let mut is_used = false;
for file_content in &frontend_files {
+533 -2
View File
@@ -553,6 +553,132 @@ impl McpServer {
"required": ["proxy_id"]
}),
},
McpTool {
name: "export_proxies".to_string(),
description: "Export all proxy configurations".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"format": {
"type": "string",
"enum": ["json", "txt"],
"description": "Export format (json for structured data, txt for URL format)"
}
},
"required": ["format"]
}),
},
McpTool {
name: "import_proxies".to_string(),
description: "Import proxy configurations from JSON or TXT content".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The proxy configuration content to import"
},
"format": {
"type": "string",
"enum": ["json", "txt"],
"description": "Import format (json or txt)"
},
"name_prefix": {
"type": "string",
"description": "Optional prefix for imported proxy names (default: 'Imported')"
}
},
"required": ["content", "format"]
}),
},
// VPN management tools
McpTool {
name: "import_vpn".to_string(),
description: "Import a WireGuard (.conf) or OpenVPN (.ovpn) configuration".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "Raw VPN config file content"
},
"filename": {
"type": "string",
"description": "Original filename (.conf or .ovpn) for type detection"
},
"name": {
"type": "string",
"description": "Optional display name for the VPN config"
}
},
"required": ["content", "filename"]
}),
},
McpTool {
name: "list_vpn_configs".to_string(),
description: "List all stored VPN configurations".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {},
"required": []
}),
},
McpTool {
name: "delete_vpn".to_string(),
description: "Delete a VPN configuration".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"vpn_id": {
"type": "string",
"description": "The UUID of the VPN config to delete"
}
},
"required": ["vpn_id"]
}),
},
McpTool {
name: "connect_vpn".to_string(),
description: "Connect to a VPN configuration".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"vpn_id": {
"type": "string",
"description": "The UUID of the VPN config to connect"
}
},
"required": ["vpn_id"]
}),
},
McpTool {
name: "disconnect_vpn".to_string(),
description: "Disconnect from a VPN".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"vpn_id": {
"type": "string",
"description": "The UUID of the VPN to disconnect"
}
},
"required": ["vpn_id"]
}),
},
McpTool {
name: "get_vpn_status".to_string(),
description: "Get the connection status of a VPN".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"vpn_id": {
"type": "string",
"description": "The UUID of the VPN to check"
}
},
"required": ["vpn_id"]
}),
},
]
}
@@ -641,6 +767,16 @@ impl McpServer {
"create_proxy" => self.handle_create_proxy(&arguments).await,
"update_proxy" => self.handle_update_proxy(&arguments).await,
"delete_proxy" => self.handle_delete_proxy(&arguments).await,
// Proxy import/export
"export_proxies" => self.handle_export_proxies(&arguments).await,
"import_proxies" => self.handle_import_proxies(&arguments).await,
// VPN management
"import_vpn" => self.handle_import_vpn(&arguments).await,
"list_vpn_configs" => self.handle_list_vpn_configs().await,
"delete_vpn" => self.handle_delete_vpn(&arguments).await,
"connect_vpn" => self.handle_connect_vpn(&arguments).await,
"disconnect_vpn" => self.handle_disconnect_vpn(&arguments).await,
"get_vpn_status" => self.handle_get_vpn_status(&arguments).await,
_ => Err(McpError {
code: -32602,
message: format!("Unknown tool: {tool_name}"),
@@ -1361,6 +1497,391 @@ impl McpServer {
}]
}))
}
async fn handle_export_proxies(
&self,
arguments: &serde_json::Value,
) -> Result<serde_json::Value, McpError> {
let format = arguments
.get("format")
.and_then(|v| v.as_str())
.ok_or_else(|| McpError {
code: -32602,
message: "Missing format".to_string(),
})?;
let content = match format {
"json" => PROXY_MANAGER.export_proxies_json().map_err(|e| McpError {
code: -32000,
message: format!("Failed to export proxies: {e}"),
})?,
"txt" => PROXY_MANAGER.export_proxies_txt(),
_ => {
return Err(McpError {
code: -32602,
message: format!("Invalid format '{}', must be 'json' or 'txt'", format),
})
}
};
Ok(serde_json::json!({
"content": [{
"type": "text",
"text": content
}]
}))
}
async fn handle_import_proxies(
&self,
arguments: &serde_json::Value,
) -> Result<serde_json::Value, McpError> {
let content = arguments
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| McpError {
code: -32602,
message: "Missing content".to_string(),
})?;
let format = arguments
.get("format")
.and_then(|v| v.as_str())
.ok_or_else(|| McpError {
code: -32602,
message: "Missing format".to_string(),
})?;
let name_prefix = arguments
.get("name_prefix")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let inner = self.inner.lock().await;
let app_handle = inner.app_handle.as_ref().ok_or_else(|| McpError {
code: -32000,
message: "MCP server not properly initialized".to_string(),
})?;
let result = match format {
"json" => PROXY_MANAGER
.import_proxies_json(app_handle, content)
.map_err(|e| McpError {
code: -32000,
message: format!("Failed to import proxies: {e}"),
})?,
"txt" => {
use crate::proxy_manager::{ProxyManager, ProxyParseResult};
let parse_results = ProxyManager::parse_txt_proxies(content);
let parsed: Vec<_> = parse_results
.into_iter()
.filter_map(|r| {
if let ProxyParseResult::Parsed(p) = r {
Some(p)
} else {
None
}
})
.collect();
if parsed.is_empty() {
return Err(McpError {
code: -32000,
message: "No valid proxies found in content".to_string(),
});
}
PROXY_MANAGER
.import_proxies_from_parsed(app_handle, parsed, name_prefix)
.map_err(|e| McpError {
code: -32000,
message: format!("Failed to import proxies: {e}"),
})?
}
_ => {
return Err(McpError {
code: -32602,
message: format!("Invalid format '{}', must be 'json' or 'txt'", format),
})
}
};
Ok(serde_json::json!({
"content": [{
"type": "text",
"text": format!(
"Import complete: {} imported, {} skipped, {} errors",
result.imported_count,
result.skipped_count,
result.errors.len()
)
}]
}))
}
// VPN management handlers
async fn handle_import_vpn(
&self,
arguments: &serde_json::Value,
) -> Result<serde_json::Value, McpError> {
let content = arguments
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| McpError {
code: -32602,
message: "Missing content".to_string(),
})?;
let filename = arguments
.get("filename")
.and_then(|v| v.as_str())
.ok_or_else(|| McpError {
code: -32602,
message: "Missing filename".to_string(),
})?;
let name = arguments
.get("name")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let storage = crate::vpn::VPN_STORAGE.lock().map_err(|e| McpError {
code: -32000,
message: format!("Failed to lock VPN storage: {e}"),
})?;
let config = storage
.import_config(content, filename, name)
.map_err(|e| McpError {
code: -32000,
message: format!("Failed to import VPN config: {e}"),
})?;
Ok(serde_json::json!({
"content": [{
"type": "text",
"text": format!(
"VPN '{}' ({}) imported successfully with ID: {}",
config.name,
config.vpn_type,
config.id
)
}]
}))
}
async fn handle_list_vpn_configs(&self) -> Result<serde_json::Value, McpError> {
let storage = crate::vpn::VPN_STORAGE.lock().map_err(|e| McpError {
code: -32000,
message: format!("Failed to lock VPN storage: {e}"),
})?;
let configs = storage.list_configs().map_err(|e| McpError {
code: -32000,
message: format!("Failed to list VPN configs: {e}"),
})?;
Ok(serde_json::json!({
"content": [{
"type": "text",
"text": serde_json::to_string_pretty(&configs).unwrap_or_default()
}]
}))
}
async fn handle_delete_vpn(
&self,
arguments: &serde_json::Value,
) -> Result<serde_json::Value, McpError> {
let vpn_id = arguments
.get("vpn_id")
.and_then(|v| v.as_str())
.ok_or_else(|| McpError {
code: -32602,
message: "Missing vpn_id".to_string(),
})?;
// First disconnect if connected
{
let mut manager = crate::vpn::TUNNEL_MANAGER.lock().await;
if manager.is_tunnel_active(vpn_id) {
if let Some(tunnel) = manager.get_tunnel_mut(vpn_id) {
let _ = tunnel.disconnect().await;
}
manager.remove_tunnel(vpn_id);
}
}
let storage = crate::vpn::VPN_STORAGE.lock().map_err(|e| McpError {
code: -32000,
message: format!("Failed to lock VPN storage: {e}"),
})?;
storage.delete_config(vpn_id).map_err(|e| McpError {
code: -32000,
message: format!("Failed to delete VPN config: {e}"),
})?;
Ok(serde_json::json!({
"content": [{
"type": "text",
"text": format!("VPN '{}' deleted successfully", vpn_id)
}]
}))
}
async fn handle_connect_vpn(
&self,
arguments: &serde_json::Value,
) -> Result<serde_json::Value, McpError> {
let vpn_id = arguments
.get("vpn_id")
.and_then(|v| v.as_str())
.ok_or_else(|| McpError {
code: -32602,
message: "Missing vpn_id".to_string(),
})?;
// Load config from storage
let config = {
let storage = crate::vpn::VPN_STORAGE.lock().map_err(|e| McpError {
code: -32000,
message: format!("Failed to lock VPN storage: {e}"),
})?;
storage.load_config(vpn_id).map_err(|e| McpError {
code: -32000,
message: format!("Failed to load VPN config: {e}"),
})?
};
let mut manager = crate::vpn::TUNNEL_MANAGER.lock().await;
// Check if already connected
if manager.is_tunnel_active(vpn_id) {
return Ok(serde_json::json!({
"content": [{
"type": "text",
"text": format!("VPN '{}' is already connected", config.name)
}]
}));
}
let mut tunnel: Box<dyn crate::vpn::VpnTunnel> = match config.vpn_type {
crate::vpn::VpnType::WireGuard => {
let wg_config =
crate::vpn::parse_wireguard_config(&config.config_data).map_err(|e| McpError {
code: -32000,
message: format!("Invalid WireGuard config: {e}"),
})?;
Box::new(crate::vpn::WireGuardTunnel::new(
vpn_id.to_string(),
wg_config,
))
}
crate::vpn::VpnType::OpenVPN => {
let ovpn_config =
crate::vpn::parse_openvpn_config(&config.config_data).map_err(|e| McpError {
code: -32000,
message: format!("Invalid OpenVPN config: {e}"),
})?;
Box::new(crate::vpn::OpenVpnTunnel::new(
vpn_id.to_string(),
ovpn_config,
))
}
};
tunnel.connect().await.map_err(|e| McpError {
code: -32000,
message: format!("Failed to connect VPN: {e}"),
})?;
manager.register_tunnel(vpn_id.to_string(), tunnel);
// Update last_used timestamp
{
let storage = crate::vpn::VPN_STORAGE.lock().map_err(|e| McpError {
code: -32000,
message: format!("Failed to lock VPN storage: {e}"),
})?;
let _ = storage.update_last_used(vpn_id);
}
Ok(serde_json::json!({
"content": [{
"type": "text",
"text": format!("VPN '{}' connected successfully", config.name)
}]
}))
}
async fn handle_disconnect_vpn(
&self,
arguments: &serde_json::Value,
) -> Result<serde_json::Value, McpError> {
let vpn_id = arguments
.get("vpn_id")
.and_then(|v| v.as_str())
.ok_or_else(|| McpError {
code: -32602,
message: "Missing vpn_id".to_string(),
})?;
let mut manager = crate::vpn::TUNNEL_MANAGER.lock().await;
if let Some(tunnel) = manager.get_tunnel_mut(vpn_id) {
tunnel.disconnect().await.map_err(|e| McpError {
code: -32000,
message: format!("Failed to disconnect VPN: {e}"),
})?;
}
manager.remove_tunnel(vpn_id);
Ok(serde_json::json!({
"content": [{
"type": "text",
"text": format!("VPN '{}' disconnected successfully", vpn_id)
}]
}))
}
async fn handle_get_vpn_status(
&self,
arguments: &serde_json::Value,
) -> Result<serde_json::Value, McpError> {
let vpn_id = arguments
.get("vpn_id")
.and_then(|v| v.as_str())
.ok_or_else(|| McpError {
code: -32602,
message: "Missing vpn_id".to_string(),
})?;
let manager = crate::vpn::TUNNEL_MANAGER.lock().await;
let status = if let Some(tunnel) = manager.get_tunnel(vpn_id) {
tunnel.get_status()
} else {
crate::vpn::VpnStatus {
connected: false,
vpn_id: vpn_id.to_string(),
connected_at: None,
bytes_sent: None,
bytes_received: None,
last_handshake: None,
}
};
Ok(serde_json::json!({
"content": [{
"type": "text",
"text": serde_json::to_string_pretty(&status).unwrap_or_default()
}]
}))
}
}
lazy_static::lazy_static! {
@@ -1376,8 +1897,8 @@ mod tests {
let server = McpServer::new();
let tools = server.get_tools();
// Should have all 16 tools
assert!(tools.len() >= 16);
// Should have at least 24 tools (18 + 6 VPN tools)
assert!(tools.len() >= 24);
// Check tool names
let tool_names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
@@ -1400,6 +1921,16 @@ mod tests {
assert!(tool_names.contains(&"create_proxy"));
assert!(tool_names.contains(&"update_proxy"));
assert!(tool_names.contains(&"delete_proxy"));
// Proxy import/export tools
assert!(tool_names.contains(&"export_proxies"));
assert!(tool_names.contains(&"import_proxies"));
// VPN tools
assert!(tool_names.contains(&"import_vpn"));
assert!(tool_names.contains(&"list_vpn_configs"));
assert!(tool_names.contains(&"delete_vpn"));
assert!(tool_names.contains(&"connect_vpn"));
assert!(tool_names.contains(&"disconnect_vpn"));
assert!(tool_names.contains(&"get_vpn_status"));
}
#[test]
+380
View File
@@ -1,3 +1,4 @@
use chrono::Utc;
use directories::BaseDirs;
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -12,6 +13,60 @@ use crate::browser::ProxySettings;
use crate::events;
use crate::ip_utils;
// Export data format for JSON export
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyExportData {
pub version: String,
pub proxies: Vec<ExportedProxy>,
pub exported_at: String,
pub source: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExportedProxy {
pub name: String,
#[serde(rename = "type")]
pub proxy_type: String,
pub host: String,
pub port: u16,
#[serde(skip_serializing_if = "Option::is_none")]
pub username: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub password: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyImportResult {
pub imported_count: usize,
pub skipped_count: usize,
pub errors: Vec<String>,
pub proxies: Vec<StoredProxy>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParsedProxyLine {
pub proxy_type: String,
pub host: String,
pub port: u16,
pub username: Option<String>,
pub password: Option<String>,
pub original_line: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "status")]
pub enum ProxyParseResult {
#[serde(rename = "parsed")]
Parsed(ParsedProxyLine),
#[serde(rename = "ambiguous")]
Ambiguous {
line: String,
possible_formats: Vec<String>,
},
#[serde(rename = "invalid")]
Invalid { line: String, reason: String },
}
// Store active proxy information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyInfo {
@@ -541,6 +596,331 @@ impl ProxyManager {
self.load_proxy_check_cache(proxy_id)
}
// Export all proxies as JSON
pub fn export_proxies_json(&self) -> Result<String, String> {
let stored_proxies = self.stored_proxies.lock().unwrap();
let proxies: Vec<ExportedProxy> = stored_proxies
.values()
.map(|p| ExportedProxy {
name: p.name.clone(),
proxy_type: p.proxy_settings.proxy_type.clone(),
host: p.proxy_settings.host.clone(),
port: p.proxy_settings.port,
username: p.proxy_settings.username.clone(),
password: p.proxy_settings.password.clone(),
})
.collect();
let export_data = ProxyExportData {
version: "1.0".to_string(),
proxies,
exported_at: Utc::now().to_rfc3339(),
source: "DonutBrowser".to_string(),
};
serde_json::to_string_pretty(&export_data).map_err(|e| format!("Failed to serialize: {e}"))
}
// Export all proxies as TXT (one per line: protocol://user:pass@host:port)
pub fn export_proxies_txt(&self) -> String {
let stored_proxies = self.stored_proxies.lock().unwrap();
stored_proxies
.values()
.map(|p| Self::build_proxy_url(&p.proxy_settings))
.collect::<Vec<_>>()
.join("\n")
}
// Parse TXT content with auto-detection of formats
pub fn parse_txt_proxies(content: &str) -> Vec<ProxyParseResult> {
content
.lines()
.filter(|line| !line.trim().is_empty() && !line.trim().starts_with('#'))
.map(|line| Self::parse_single_proxy_line(line.trim()))
.collect()
}
// Parse a single proxy line with format auto-detection
fn parse_single_proxy_line(line: &str) -> ProxyParseResult {
// Format 1: protocol://username:password@host:port (full URL)
if let Some(result) = Self::try_parse_url_format(line) {
return result;
}
// Try colon-separated formats
let parts: Vec<&str> = line.split(':').collect();
match parts.len() {
// host:port (no auth)
2 => {
if let Ok(port) = parts[1].parse::<u16>() {
return ProxyParseResult::Parsed(ParsedProxyLine {
proxy_type: "http".to_string(),
host: parts[0].to_string(),
port,
username: None,
password: None,
original_line: line.to_string(),
});
}
ProxyParseResult::Invalid {
line: line.to_string(),
reason: "Invalid port number".to_string(),
}
}
// Could be: host:port:user or user:pass@host (with @ in the middle)
3 => {
// Try username:password@host:port first
if let Some(result) = Self::try_parse_user_pass_at_host_port(line) {
return result;
}
ProxyParseResult::Invalid {
line: line.to_string(),
reason: "Could not determine format with 3 parts".to_string(),
}
}
// 4 parts: could be host:port:user:pass OR user:pass:host:port
4 => {
// Try to detect which format
let port_at_1 = parts[1].parse::<u16>().is_ok();
let port_at_3 = parts[3].parse::<u16>().is_ok();
match (port_at_1, port_at_3) {
// host:port:user:pass
(true, false) => {
let port = parts[1].parse::<u16>().unwrap();
ProxyParseResult::Parsed(ParsedProxyLine {
proxy_type: "http".to_string(),
host: parts[0].to_string(),
port,
username: Some(parts[2].to_string()),
password: Some(parts[3].to_string()),
original_line: line.to_string(),
})
}
// user:pass:host:port
(false, true) => {
let port = parts[3].parse::<u16>().unwrap();
ProxyParseResult::Parsed(ParsedProxyLine {
proxy_type: "http".to_string(),
host: parts[2].to_string(),
port,
username: Some(parts[0].to_string()),
password: Some(parts[1].to_string()),
original_line: line.to_string(),
})
}
// Both could be ports - ambiguous
(true, true) => ProxyParseResult::Ambiguous {
line: line.to_string(),
possible_formats: vec![
"host:port:username:password".to_string(),
"username:password:host:port".to_string(),
],
},
// Neither is a valid port
(false, false) => ProxyParseResult::Invalid {
line: line.to_string(),
reason: "No valid port number found".to_string(),
},
}
}
_ => ProxyParseResult::Invalid {
line: line.to_string(),
reason: format!("Unexpected format with {} parts", parts.len()),
},
}
}
// Try to parse URL format: protocol://username:password@host:port
fn try_parse_url_format(line: &str) -> Option<ProxyParseResult> {
// Check for protocol prefix using strip_prefix
let (protocol, rest) = if let Some(rest) = line.strip_prefix("http://") {
("http", rest)
} else if let Some(rest) = line.strip_prefix("https://") {
("https", rest)
} else if let Some(rest) = line.strip_prefix("socks4://") {
("socks4", rest)
} else if let Some(rest) = line.strip_prefix("socks5://") {
("socks5", rest)
} else if let Some(rest) = line.strip_prefix("socks://") {
("socks5", rest) // Default socks to socks5
} else {
return None;
};
// Check if there's auth (contains @)
if let Some(at_pos) = rest.rfind('@') {
let auth = &rest[..at_pos];
let host_port = &rest[at_pos + 1..];
// Parse auth (user:pass)
let (username, password) = if let Some(colon_pos) = auth.find(':') {
let user = urlencoding::decode(&auth[..colon_pos]).unwrap_or_default();
let pass = urlencoding::decode(&auth[colon_pos + 1..]).unwrap_or_default();
(Some(user.to_string()), Some(pass.to_string()))
} else {
(
Some(urlencoding::decode(auth).unwrap_or_default().to_string()),
None,
)
};
// Parse host:port
if let Some(colon_pos) = host_port.rfind(':') {
let host = &host_port[..colon_pos];
if let Ok(port) = host_port[colon_pos + 1..].parse::<u16>() {
return Some(ProxyParseResult::Parsed(ParsedProxyLine {
proxy_type: protocol.to_string(),
host: host.to_string(),
port,
username,
password,
original_line: line.to_string(),
}));
}
}
} else {
// No auth, just host:port
if let Some(colon_pos) = rest.rfind(':') {
let host = &rest[..colon_pos];
if let Ok(port) = rest[colon_pos + 1..].parse::<u16>() {
return Some(ProxyParseResult::Parsed(ParsedProxyLine {
proxy_type: protocol.to_string(),
host: host.to_string(),
port,
username: None,
password: None,
original_line: line.to_string(),
}));
}
}
}
Some(ProxyParseResult::Invalid {
line: line.to_string(),
reason: "Invalid URL format".to_string(),
})
}
// Try to parse: username:password@host:port format (no protocol)
fn try_parse_user_pass_at_host_port(line: &str) -> Option<ProxyParseResult> {
if let Some(at_pos) = line.rfind('@') {
let auth = &line[..at_pos];
let host_port = &line[at_pos + 1..];
// Parse auth
let (username, password) = if let Some(colon_pos) = auth.find(':') {
(
Some(auth[..colon_pos].to_string()),
Some(auth[colon_pos + 1..].to_string()),
)
} else {
return None;
};
// Parse host:port
if let Some(colon_pos) = host_port.rfind(':') {
let host = &host_port[..colon_pos];
if let Ok(port) = host_port[colon_pos + 1..].parse::<u16>() {
return Some(ProxyParseResult::Parsed(ParsedProxyLine {
proxy_type: "http".to_string(),
host: host.to_string(),
port,
username,
password,
original_line: line.to_string(),
}));
}
}
}
None
}
// Import proxies from JSON content
pub fn import_proxies_json(
&self,
app_handle: &tauri::AppHandle,
content: &str,
) -> Result<ProxyImportResult, String> {
let export_data: ProxyExportData =
serde_json::from_str(content).map_err(|e| format!("Invalid JSON format: {e}"))?;
let mut imported = Vec::new();
let mut skipped = 0;
let mut errors = Vec::new();
for exported in export_data.proxies {
let proxy_settings = ProxySettings {
proxy_type: exported.proxy_type,
host: exported.host,
port: exported.port,
username: exported.username,
password: exported.password,
};
match self.create_stored_proxy(app_handle, exported.name.clone(), proxy_settings) {
Ok(proxy) => imported.push(proxy),
Err(e) => {
if e.contains("already exists") {
skipped += 1;
} else {
errors.push(format!("Failed to import '{}': {}", exported.name, e));
}
}
}
}
Ok(ProxyImportResult {
imported_count: imported.len(),
skipped_count: skipped,
errors,
proxies: imported,
})
}
// Import proxies from already parsed proxy lines
pub fn import_proxies_from_parsed(
&self,
app_handle: &tauri::AppHandle,
parsed_proxies: Vec<ParsedProxyLine>,
name_prefix: Option<String>,
) -> Result<ProxyImportResult, String> {
let mut imported = Vec::new();
let mut skipped = 0;
let mut errors = Vec::new();
let prefix = name_prefix.unwrap_or_else(|| "Imported".to_string());
for (i, parsed) in parsed_proxies.into_iter().enumerate() {
let proxy_name = format!("{} Proxy {}", prefix, i + 1);
let proxy_settings = ProxySettings {
proxy_type: parsed.proxy_type,
host: parsed.host,
port: parsed.port,
username: parsed.username,
password: parsed.password,
};
match self.create_stored_proxy(app_handle, proxy_name.clone(), proxy_settings) {
Ok(proxy) => imported.push(proxy),
Err(e) => {
if e.contains("already exists") {
skipped += 1;
} else {
errors.push(format!("Failed to import '{}': {}", proxy_name, e));
}
}
}
}
Ok(ProxyImportResult {
imported_count: imported.len(),
skipped_count: skipped,
errors,
proxies: imported,
})
}
// Start a proxy for given proxy settings and associate it with a browser process ID
// If proxy_settings is None, starts a direct proxy for traffic monitoring
pub async fn start_proxy(
+29
View File
@@ -52,6 +52,8 @@ pub struct AppSettings {
pub mcp_token: Option<String>, // Displayed token for user to copy (not persisted, loaded from encrypted file)
#[serde(default)]
pub launch_on_login_declined: bool, // User permanently declined the launch-on-login prompt
#[serde(default)]
pub language: Option<String>, // ISO 639-1: "en", "es", "pt", "fr", "zh", "ja", "ru", or None for system default
}
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
@@ -84,6 +86,7 @@ impl Default for AppSettings {
mcp_port: None,
mcp_token: None,
launch_on_login_declined: false,
language: None,
}
}
}
@@ -809,6 +812,17 @@ pub async fn save_app_settings(
let mut persist_settings = settings.clone();
persist_settings.api_token = None;
persist_settings.mcp_token = None;
log::info!(
"[settings] Saving settings: theme={}, custom_theme_keys={}",
persist_settings.theme,
persist_settings
.custom_theme
.as_ref()
.map(|t| t.len())
.unwrap_or(0)
);
manager
.save_settings(&persist_settings)
.map_err(|e| format!("Failed to save settings: {e}"))?;
@@ -899,6 +913,20 @@ pub async fn save_sync_settings(
})
}
#[tauri::command]
pub fn get_system_language() -> String {
sys_locale::get_locale()
.map(|locale| {
// Extract just the language code (e.g., "en" from "en-US")
locale
.split(['-', '_'])
.next()
.unwrap_or("en")
.to_lowercase()
})
.unwrap_or_else(|| "en".to_string())
}
// Global singleton instance
lazy_static::lazy_static! {
static ref SETTINGS_MANAGER: SettingsManager = SettingsManager::new();
@@ -985,6 +1013,7 @@ mod tests {
mcp_port: None,
mcp_token: None,
launch_on_login_declined: false,
language: None,
};
// Save settings
+489
View File
@@ -0,0 +1,489 @@
//! VPN configuration types and parsing.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
/// VPN-related errors
#[derive(Error, Debug)]
pub enum VpnError {
#[error("Unknown VPN config format")]
UnknownFormat,
#[error("Invalid WireGuard config: {0}")]
InvalidWireGuard(String),
#[error("Invalid OpenVPN config: {0}")]
InvalidOpenVpn(String),
#[error("Storage error: {0}")]
Storage(String),
#[error("Connection error: {0}")]
Connection(String),
#[error("Encryption error: {0}")]
Encryption(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("VPN not found: {0}")]
NotFound(String),
#[error("Tunnel error: {0}")]
Tunnel(String),
}
/// The type of VPN configuration
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum VpnType {
WireGuard,
OpenVPN,
}
impl std::fmt::Display for VpnType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VpnType::WireGuard => write!(f, "WireGuard"),
VpnType::OpenVPN => write!(f, "OpenVPN"),
}
}
}
/// A stored VPN configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VpnConfig {
pub id: String,
pub name: String,
pub vpn_type: VpnType,
pub config_data: String, // Raw config content (encrypted at rest)
pub created_at: i64,
pub last_used: Option<i64>,
}
/// Parsed WireGuard configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WireGuardConfig {
pub private_key: String,
pub address: String,
pub dns: Option<String>,
pub mtu: Option<u16>,
pub peer_public_key: String,
pub peer_endpoint: String,
pub allowed_ips: Vec<String>,
pub persistent_keepalive: Option<u16>,
pub preshared_key: Option<String>,
}
/// Parsed OpenVPN configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenVpnConfig {
pub raw_config: String,
pub remote_host: String,
pub remote_port: u16,
pub protocol: String, // "udp" or "tcp"
pub dev_type: String, // "tun" or "tap"
pub has_inline_ca: bool,
pub has_inline_cert: bool,
pub has_inline_key: bool,
}
/// Result of importing a VPN configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VpnImportResult {
pub success: bool,
pub vpn_id: Option<String>,
pub vpn_type: Option<VpnType>,
pub name: String,
pub error: Option<String>,
}
/// VPN connection status
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VpnStatus {
pub connected: bool,
pub vpn_id: String,
pub connected_at: Option<i64>,
pub bytes_sent: Option<u64>,
pub bytes_received: Option<u64>,
pub last_handshake: Option<i64>,
}
/// Detect the VPN type from file content and filename
pub fn detect_vpn_type(content: &str, filename: &str) -> Result<VpnType, VpnError> {
let filename_lower = filename.to_lowercase();
// Check file extension first
if filename_lower.ends_with(".conf") {
// .conf could be WireGuard - check content
if content.contains("[Interface]") && content.contains("[Peer]") {
return Ok(VpnType::WireGuard);
}
}
if filename_lower.ends_with(".ovpn") {
return Ok(VpnType::OpenVPN);
}
// Check content patterns
if content.contains("[Interface]") && content.contains("PrivateKey") && content.contains("[Peer]")
{
return Ok(VpnType::WireGuard);
}
if content.contains("remote ") && (content.contains("client") || content.contains("dev tun")) {
return Ok(VpnType::OpenVPN);
}
Err(VpnError::UnknownFormat)
}
/// Parse a WireGuard configuration file
pub fn parse_wireguard_config(content: &str) -> Result<WireGuardConfig, VpnError> {
let mut interface: HashMap<String, String> = HashMap::new();
let mut peer: HashMap<String, String> = HashMap::new();
let mut current_section: Option<&str> = None;
for line in content.lines() {
let line = line.trim();
// Skip empty lines and comments
if line.is_empty() || line.starts_with('#') {
continue;
}
// Check for section headers
if line == "[Interface]" {
current_section = Some("interface");
continue;
}
if line == "[Peer]" {
current_section = Some("peer");
continue;
}
// Parse key-value pairs
if let Some((key, value)) = line.split_once('=') {
let key = key.trim().to_string();
let value = value.trim().to_string();
match current_section {
Some("interface") => {
interface.insert(key, value);
}
Some("peer") => {
peer.insert(key, value);
}
_ => {}
}
}
}
// Validate required fields
let private_key = interface
.get("PrivateKey")
.ok_or_else(|| VpnError::InvalidWireGuard("Missing PrivateKey in [Interface]".to_string()))?
.clone();
let address = interface
.get("Address")
.ok_or_else(|| VpnError::InvalidWireGuard("Missing Address in [Interface]".to_string()))?
.clone();
let peer_public_key = peer
.get("PublicKey")
.ok_or_else(|| VpnError::InvalidWireGuard("Missing PublicKey in [Peer]".to_string()))?
.clone();
let peer_endpoint = peer
.get("Endpoint")
.ok_or_else(|| VpnError::InvalidWireGuard("Missing Endpoint in [Peer]".to_string()))?
.clone();
let allowed_ips = peer
.get("AllowedIPs")
.map(|s| s.split(',').map(|ip| ip.trim().to_string()).collect())
.unwrap_or_else(|| vec!["0.0.0.0/0".to_string()]);
let persistent_keepalive = peer.get("PersistentKeepalive").and_then(|s| s.parse().ok());
let dns = interface.get("DNS").cloned();
let mtu = interface.get("MTU").and_then(|s| s.parse().ok());
let preshared_key = peer.get("PresharedKey").cloned();
Ok(WireGuardConfig {
private_key,
address,
dns,
mtu,
peer_public_key,
peer_endpoint,
allowed_ips,
persistent_keepalive,
preshared_key,
})
}
/// Parse an OpenVPN configuration file
pub fn parse_openvpn_config(content: &str) -> Result<OpenVpnConfig, VpnError> {
let mut remote_host = String::new();
let mut remote_port: u16 = 1194; // Default OpenVPN port
let mut protocol = "udp".to_string();
let mut dev_type = "tun".to_string();
let has_inline_ca = content.contains("<ca>") && content.contains("</ca>");
let has_inline_cert = content.contains("<cert>") && content.contains("</cert>");
let has_inline_key = content.contains("<key>") && content.contains("</key>");
for line in content.lines() {
let line = line.trim();
// Skip empty lines and comments
if line.is_empty() || line.starts_with('#') || line.starts_with(';') {
continue;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.is_empty() {
continue;
}
match parts[0] {
"remote" => {
if parts.len() >= 2 {
remote_host = parts[1].to_string();
}
if parts.len() >= 3 {
if let Ok(port) = parts[2].parse() {
remote_port = port;
}
}
if parts.len() >= 4 {
protocol = parts[3].to_string();
}
}
"proto" => {
if parts.len() >= 2 {
protocol = parts[1].to_string();
}
}
"port" => {
if parts.len() >= 2 {
if let Ok(port) = parts[1].parse() {
remote_port = port;
}
}
}
"dev" => {
if parts.len() >= 2 {
dev_type = parts[1].to_string();
}
}
_ => {}
}
}
if remote_host.is_empty() {
return Err(VpnError::InvalidOpenVpn(
"Missing 'remote' directive".to_string(),
));
}
Ok(OpenVpnConfig {
raw_config: content.to_string(),
remote_host,
remote_port,
protocol,
dev_type,
has_inline_ca,
has_inline_cert,
has_inline_key,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_wireguard_by_extension() {
let content = "[Interface]\nPrivateKey = test\n[Peer]\nPublicKey = test";
assert_eq!(
detect_vpn_type(content, "test.conf").unwrap(),
VpnType::WireGuard
);
}
#[test]
fn test_detect_openvpn_by_extension() {
let content = "client\nremote vpn.example.com 1194";
assert_eq!(
detect_vpn_type(content, "test.ovpn").unwrap(),
VpnType::OpenVPN
);
}
#[test]
fn test_detect_wireguard_by_content() {
let content = "[Interface]\nPrivateKey = testkey123\nAddress = 10.0.0.2/24\n\n[Peer]\nPublicKey = peerkey456\nEndpoint = vpn.example.com:51820";
assert_eq!(
detect_vpn_type(content, "config").unwrap(),
VpnType::WireGuard
);
}
#[test]
fn test_detect_openvpn_by_content() {
let content = "client\ndev tun\nproto udp\nremote vpn.example.com 1194";
assert_eq!(
detect_vpn_type(content, "config").unwrap(),
VpnType::OpenVPN
);
}
#[test]
fn test_detect_unknown_format() {
let content = "random text that is not a vpn config";
assert!(detect_vpn_type(content, "random.txt").is_err());
}
#[test]
fn test_parse_wireguard_config() {
let content = r#"
[Interface]
PrivateKey = WGTestPrivateKey123456789012345678901234567890
Address = 10.0.0.2/24
DNS = 1.1.1.1
MTU = 1420
[Peer]
PublicKey = WGTestPublicKey1234567890123456789012345678901
Endpoint = vpn.example.com:51820
AllowedIPs = 0.0.0.0/0, ::/0
PersistentKeepalive = 25
"#;
let config = parse_wireguard_config(content).unwrap();
assert_eq!(
config.private_key,
"WGTestPrivateKey123456789012345678901234567890"
);
assert_eq!(config.address, "10.0.0.2/24");
assert_eq!(config.dns, Some("1.1.1.1".to_string()));
assert_eq!(config.mtu, Some(1420));
assert_eq!(
config.peer_public_key,
"WGTestPublicKey1234567890123456789012345678901"
);
assert_eq!(config.peer_endpoint, "vpn.example.com:51820");
assert_eq!(config.allowed_ips, vec!["0.0.0.0/0", "::/0"]);
assert_eq!(config.persistent_keepalive, Some(25));
}
#[test]
fn test_parse_wireguard_config_minimal() {
let content = r#"
[Interface]
PrivateKey = minimalkey
Address = 10.0.0.2/32
[Peer]
PublicKey = peerpubkey
Endpoint = 1.2.3.4:51820
"#;
let config = parse_wireguard_config(content).unwrap();
assert_eq!(config.private_key, "minimalkey");
assert_eq!(config.address, "10.0.0.2/32");
assert!(config.dns.is_none());
assert!(config.mtu.is_none());
assert_eq!(config.peer_public_key, "peerpubkey");
assert_eq!(config.peer_endpoint, "1.2.3.4:51820");
}
#[test]
fn test_parse_wireguard_missing_private_key() {
let content = r#"
[Interface]
Address = 10.0.0.2/24
[Peer]
PublicKey = key
Endpoint = 1.2.3.4:51820
"#;
let result = parse_wireguard_config(content);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("PrivateKey"));
}
#[test]
fn test_parse_openvpn_config() {
let content = r#"
client
dev tun
proto udp
remote vpn.example.com 1194
resolv-retry infinite
nobind
persist-key
persist-tun
<ca>
-----BEGIN CERTIFICATE-----
...certificate data...
-----END CERTIFICATE-----
</ca>
<cert>
-----BEGIN CERTIFICATE-----
...cert data...
-----END CERTIFICATE-----
</cert>
<key>
-----BEGIN PRIVATE KEY-----
...key data...
-----END PRIVATE KEY-----
</key>
"#;
let config = parse_openvpn_config(content).unwrap();
assert_eq!(config.remote_host, "vpn.example.com");
assert_eq!(config.remote_port, 1194);
assert_eq!(config.protocol, "udp");
assert_eq!(config.dev_type, "tun");
assert!(config.has_inline_ca);
assert!(config.has_inline_cert);
assert!(config.has_inline_key);
}
#[test]
fn test_parse_openvpn_config_minimal() {
let content = r#"
client
remote vpn.example.com
"#;
let config = parse_openvpn_config(content).unwrap();
assert_eq!(config.remote_host, "vpn.example.com");
assert_eq!(config.remote_port, 1194); // Default
assert_eq!(config.protocol, "udp"); // Default
}
#[test]
fn test_parse_openvpn_config_with_port_and_proto() {
let content = r#"
client
remote vpn.example.com 443 tcp
"#;
let config = parse_openvpn_config(content).unwrap();
assert_eq!(config.remote_host, "vpn.example.com");
assert_eq!(config.remote_port, 443);
assert_eq!(config.protocol, "tcp");
}
#[test]
fn test_parse_openvpn_missing_remote() {
let content = r#"
client
dev tun
proto udp
"#;
let result = parse_openvpn_config(content);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("remote"));
}
}
+31
View File
@@ -0,0 +1,31 @@
//! VPN support module for WireGuard and OpenVPN configurations.
//!
//! This module provides:
//! - VPN config parsing (WireGuard .conf and OpenVPN .ovpn files)
//! - Encrypted storage for VPN configurations
//! - Tunnel management with userspace WireGuard (boringtun) and OpenVPN process management
mod config;
mod openvpn;
mod storage;
mod tunnel;
mod wireguard;
pub use config::{
detect_vpn_type, parse_openvpn_config, parse_wireguard_config, OpenVpnConfig, VpnConfig,
VpnError, VpnImportResult, VpnStatus, VpnType, WireGuardConfig,
};
pub use openvpn::OpenVpnTunnel;
pub use storage::VpnStorage;
pub use tunnel::{TunnelManager, VpnTunnel};
pub use wireguard::WireGuardTunnel;
use once_cell::sync::Lazy;
use std::sync::Mutex;
/// Global VPN storage instance
pub static VPN_STORAGE: Lazy<Mutex<VpnStorage>> = Lazy::new(|| Mutex::new(VpnStorage::new()));
/// Global tunnel manager instance
pub static TUNNEL_MANAGER: Lazy<tokio::sync::Mutex<TunnelManager>> =
Lazy::new(|| tokio::sync::Mutex::new(TunnelManager::new()));
+343
View File
@@ -0,0 +1,343 @@
//! OpenVPN tunnel implementation using system openvpn binary.
use super::config::{OpenVpnConfig, VpnError, VpnStatus};
use super::tunnel::VpnTunnel;
use async_trait::async_trait;
use chrono::Utc;
use std::io::{BufRead, BufReader};
use std::path::PathBuf;
use std::process::{Child, Command, Stdio};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use tempfile::NamedTempFile;
use tokio::sync::Mutex;
/// OpenVPN tunnel implementation
pub struct OpenVpnTunnel {
vpn_id: String,
config: OpenVpnConfig,
process: Arc<Mutex<Option<Child>>>,
config_file: Option<NamedTempFile>,
connected: AtomicBool,
connected_at: Option<i64>,
bytes_sent: AtomicU64,
bytes_received: AtomicU64,
}
impl OpenVpnTunnel {
/// Create a new OpenVPN tunnel
pub fn new(vpn_id: String, config: OpenVpnConfig) -> Self {
Self {
vpn_id,
config,
process: Arc::new(Mutex::new(None)),
config_file: None,
connected: AtomicBool::new(false),
connected_at: None,
bytes_sent: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
}
}
/// Find the openvpn binary
fn find_openvpn_binary() -> Result<PathBuf, VpnError> {
// Check common locations
let locations = [
"/usr/sbin/openvpn",
"/usr/local/sbin/openvpn",
"/opt/homebrew/bin/openvpn",
"/usr/bin/openvpn",
"C:\\Program Files\\OpenVPN\\bin\\openvpn.exe",
"C:\\Program Files (x86)\\OpenVPN\\bin\\openvpn.exe",
];
for loc in &locations {
let path = PathBuf::from(loc);
if path.exists() {
return Ok(path);
}
}
// Try to find via which/where command
#[cfg(unix)]
{
if let Ok(output) = Command::new("which").arg("openvpn").output() {
if output.status.success() {
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !path.is_empty() {
return Ok(PathBuf::from(path));
}
}
}
}
#[cfg(windows)]
{
if let Ok(output) = Command::new("where").arg("openvpn").output() {
if output.status.success() {
let path = String::from_utf8_lossy(&output.stdout)
.lines()
.next()
.unwrap_or("")
.trim()
.to_string();
if !path.is_empty() {
return Ok(PathBuf::from(path));
}
}
}
}
Err(VpnError::Connection(
"OpenVPN binary not found. Please install OpenVPN.".to_string(),
))
}
/// Write config to temporary file
fn write_config_file(&mut self) -> Result<PathBuf, VpnError> {
let temp_file =
NamedTempFile::new().map_err(|e| VpnError::Io(std::io::Error::other(e.to_string())))?;
std::fs::write(temp_file.path(), &self.config.raw_config).map_err(VpnError::Io)?;
let path = temp_file.path().to_path_buf();
self.config_file = Some(temp_file);
Ok(path)
}
/// Start the OpenVPN process
async fn start_process(&mut self) -> Result<(), VpnError> {
let openvpn_bin = Self::find_openvpn_binary()?;
let config_path = self.write_config_file()?;
log::info!(
"[vpn] Starting OpenVPN with config: {}",
config_path.display()
);
// Build command with common options
let mut cmd = Command::new(&openvpn_bin);
cmd
.arg("--config")
.arg(&config_path)
.arg("--verb")
.arg("3") // Verbosity level
.stdout(Stdio::piped())
.stderr(Stdio::piped());
// On Unix, try to avoid requiring root if possible
#[cfg(unix)]
{
cmd.arg("--script-security").arg("2");
}
let child = cmd
.spawn()
.map_err(|e| VpnError::Connection(format!("Failed to start OpenVPN: {e}")))?;
*self.process.lock().await = Some(child);
// Wait a bit and check if process is still running
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let mut process_guard = self.process.lock().await;
if let Some(ref mut child) = *process_guard {
match child.try_wait() {
Ok(Some(status)) => {
// Process exited early
let mut error_msg = format!("OpenVPN exited with status: {status}");
// Try to get stderr output
if let Some(stderr) = child.stderr.take() {
let reader = BufReader::new(stderr);
let lines: Vec<String> = reader.lines().map_while(Result::ok).take(5).collect();
if !lines.is_empty() {
error_msg.push_str(&format!("\nError: {}", lines.join("\n")));
}
}
return Err(VpnError::Connection(error_msg));
}
Ok(None) => {
// Still running, good
}
Err(e) => {
return Err(VpnError::Connection(format!(
"Failed to check process status: {e}"
)));
}
}
}
Ok(())
}
/// Kill the OpenVPN process
async fn kill_process(&mut self) -> Result<(), VpnError> {
let mut process_guard = self.process.lock().await;
if let Some(mut child) = process_guard.take() {
// Try graceful shutdown first
#[cfg(unix)]
{
use nix::sys::signal::{kill, Signal};
use nix::unistd::Pid;
if let Ok(pid) = child.id().try_into() {
let _ = kill(Pid::from_raw(pid), Signal::SIGTERM);
// Wait a bit for graceful shutdown
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
}
}
// Force kill if still running
let _ = child.kill();
let _ = child.wait();
}
// Clean up config file
self.config_file = None;
Ok(())
}
}
#[async_trait]
impl VpnTunnel for OpenVpnTunnel {
async fn connect(&mut self) -> Result<(), VpnError> {
if self.connected.load(Ordering::Relaxed) {
return Ok(());
}
// Start OpenVPN process
self.start_process().await?;
// Wait for connection to be established
// Note: In a real implementation, we'd monitor the OpenVPN management interface
// For now, we assume success if the process starts and runs for a bit
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
// Check if process is still running
let process_guard = self.process.lock().await;
if let Some(ref child) = *process_guard {
let id = child.id();
if id > 0 {
self.connected.store(true, Ordering::Release);
self.connected_at = Some(Utc::now().timestamp());
log::info!("[vpn] OpenVPN tunnel {} connected (PID: {id})", self.vpn_id);
return Ok(());
}
}
Err(VpnError::Connection(
"Failed to establish OpenVPN connection".to_string(),
))
}
async fn disconnect(&mut self) -> Result<(), VpnError> {
if !self.connected.load(Ordering::Relaxed) {
return Ok(());
}
self.kill_process().await?;
self.connected.store(false, Ordering::Release);
self.connected_at = None;
log::info!("[vpn] OpenVPN tunnel {} disconnected", self.vpn_id);
Ok(())
}
fn is_connected(&self) -> bool {
self.connected.load(Ordering::Acquire)
}
fn vpn_id(&self) -> &str {
&self.vpn_id
}
fn get_status(&self) -> VpnStatus {
VpnStatus {
connected: self.is_connected(),
vpn_id: self.vpn_id.clone(),
connected_at: self.connected_at,
bytes_sent: Some(self.bytes_sent.load(Ordering::Relaxed)),
bytes_received: Some(self.bytes_received.load(Ordering::Relaxed)),
last_handshake: None,
}
}
fn bytes_sent(&self) -> u64 {
self.bytes_sent.load(Ordering::Relaxed)
}
fn bytes_received(&self) -> u64 {
self.bytes_received.load(Ordering::Relaxed)
}
}
impl Drop for OpenVpnTunnel {
fn drop(&mut self) {
// Clean up process on drop (synchronously)
if let Ok(mut guard) = self.process.try_lock() {
if let Some(mut child) = guard.take() {
let _ = child.kill();
let _ = child.wait();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_config() -> OpenVpnConfig {
OpenVpnConfig {
raw_config: "client\nremote test.example.com 1194\ndev tun".to_string(),
remote_host: "test.example.com".to_string(),
remote_port: 1194,
protocol: "udp".to_string(),
dev_type: "tun".to_string(),
has_inline_ca: false,
has_inline_cert: false,
has_inline_key: false,
}
}
#[test]
fn test_openvpn_tunnel_creation() {
let config = create_test_config();
let tunnel = OpenVpnTunnel::new("test-ovpn-1".to_string(), config);
assert_eq!(tunnel.vpn_id(), "test-ovpn-1");
assert!(!tunnel.is_connected());
assert_eq!(tunnel.bytes_sent(), 0);
assert_eq!(tunnel.bytes_received(), 0);
}
#[test]
fn test_openvpn_status() {
let config = create_test_config();
let tunnel = OpenVpnTunnel::new("test-ovpn-2".to_string(), config);
let status = tunnel.get_status();
assert!(!status.connected);
assert_eq!(status.vpn_id, "test-ovpn-2");
assert!(status.connected_at.is_none());
}
#[test]
fn test_find_openvpn_binary_format() {
// This test just checks that the function doesn't panic
// It may or may not find openvpn depending on the system
let result = OpenVpnTunnel::find_openvpn_binary();
// Just check that it returns a valid Result
match result {
Ok(path) => assert!(!path.as_os_str().is_empty()),
Err(e) => assert!(e.to_string().contains("not found")),
}
}
}
+415
View File
@@ -0,0 +1,415 @@
//! Encrypted storage for VPN configurations.
use super::config::{VpnConfig, VpnError, VpnType};
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use chrono::Utc;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
use uuid::Uuid;
/// Storage format version for migration support
const STORAGE_VERSION: u32 = 1;
/// Stored VPN configs container
#[derive(Debug, Serialize, Deserialize)]
struct VpnStorageData {
version: u32,
configs: Vec<StoredVpnConfig>,
}
/// Encrypted VPN config as stored on disk
#[derive(Debug, Serialize, Deserialize)]
struct StoredVpnConfig {
id: String,
name: String,
vpn_type: VpnType,
encrypted_data: String, // Base64 encoded encrypted config
nonce: String, // Base64 encoded nonce
created_at: i64,
last_used: Option<i64>,
}
/// VPN storage manager with encryption
pub struct VpnStorage {
storage_path: PathBuf,
encryption_key: [u8; 32],
}
impl Default for VpnStorage {
fn default() -> Self {
Self::new()
}
}
impl VpnStorage {
/// Create a new VPN storage manager
pub fn new() -> Self {
let storage_path = Self::get_storage_path();
let encryption_key = Self::get_or_create_key();
Self {
storage_path,
encryption_key,
}
}
/// Get the storage file path
fn get_storage_path() -> PathBuf {
let data_dir = directories::ProjectDirs::from("com", "donut", "donutbrowser")
.map(|dirs| dirs.data_local_dir().to_path_buf())
.unwrap_or_else(|| PathBuf::from("."));
if !data_dir.exists() {
let _ = fs::create_dir_all(&data_dir);
}
data_dir.join("vpn_configs.json")
}
/// Get or create the encryption key
fn get_or_create_key() -> [u8; 32] {
let key_path = directories::ProjectDirs::from("com", "donut", "donutbrowser")
.map(|dirs| dirs.data_local_dir().join(".vpn_key"))
.unwrap_or_else(|| PathBuf::from(".vpn_key"));
if key_path.exists() {
if let Ok(key_data) = fs::read(&key_path) {
if key_data.len() == 32 {
let mut key = [0u8; 32];
key.copy_from_slice(&key_data);
return key;
}
}
}
// Generate a new key
let key: [u8; 32] = rand::rng().random();
let _ = fs::write(&key_path, key);
// Set restrictive permissions on Unix
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = fs::set_permissions(&key_path, fs::Permissions::from_mode(0o600));
}
key
}
/// Load storage data from disk
fn load_storage(&self) -> Result<VpnStorageData, VpnError> {
if !self.storage_path.exists() {
return Ok(VpnStorageData {
version: STORAGE_VERSION,
configs: Vec::new(),
});
}
let content = fs::read_to_string(&self.storage_path)
.map_err(|e| VpnError::Storage(format!("Failed to read storage file: {e}")))?;
serde_json::from_str(&content)
.map_err(|e| VpnError::Storage(format!("Failed to parse storage file: {e}")))
}
/// Save storage data to disk
fn save_storage(&self, data: &VpnStorageData) -> Result<(), VpnError> {
let content = serde_json::to_string_pretty(data)
.map_err(|e| VpnError::Storage(format!("Failed to serialize storage: {e}")))?;
fs::write(&self.storage_path, content)
.map_err(|e| VpnError::Storage(format!("Failed to write storage file: {e}")))?;
Ok(())
}
/// Encrypt config data
fn encrypt(&self, data: &str) -> Result<(String, String), VpnError> {
let cipher = Aes256Gcm::new_from_slice(&self.encryption_key)
.map_err(|e| VpnError::Encryption(format!("Failed to create cipher: {e}")))?;
let nonce_bytes: [u8; 12] = rand::rng().random();
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, data.as_bytes())
.map_err(|e| VpnError::Encryption(format!("Encryption failed: {e}")))?;
Ok((
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &ciphertext),
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, nonce_bytes),
))
}
/// Decrypt config data
fn decrypt(&self, encrypted_data: &str, nonce_str: &str) -> Result<String, VpnError> {
let cipher = Aes256Gcm::new_from_slice(&self.encryption_key)
.map_err(|e| VpnError::Encryption(format!("Failed to create cipher: {e}")))?;
let ciphertext =
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, encrypted_data)
.map_err(|e| VpnError::Encryption(format!("Failed to decode ciphertext: {e}")))?;
let nonce_bytes = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, nonce_str)
.map_err(|e| VpnError::Encryption(format!("Failed to decode nonce: {e}")))?;
if nonce_bytes.len() != 12 {
return Err(VpnError::Encryption("Invalid nonce length".to_string()));
}
let nonce = Nonce::from_slice(&nonce_bytes);
let plaintext = cipher
.decrypt(nonce, ciphertext.as_ref())
.map_err(|e| VpnError::Encryption(format!("Decryption failed: {e}")))?;
String::from_utf8(plaintext)
.map_err(|e| VpnError::Encryption(format!("Failed to decode plaintext: {e}")))
}
/// Save a VPN configuration
pub fn save_config(&self, config: &VpnConfig) -> Result<(), VpnError> {
let mut storage = self.load_storage()?;
// Encrypt the config data
let (encrypted_data, nonce) = self.encrypt(&config.config_data)?;
let stored = StoredVpnConfig {
id: config.id.clone(),
name: config.name.clone(),
vpn_type: config.vpn_type,
encrypted_data,
nonce,
created_at: config.created_at,
last_used: config.last_used,
};
// Update existing or add new
if let Some(pos) = storage.configs.iter().position(|c| c.id == config.id) {
storage.configs[pos] = stored;
} else {
storage.configs.push(stored);
}
self.save_storage(&storage)
}
/// Load a VPN configuration by ID
pub fn load_config(&self, id: &str) -> Result<VpnConfig, VpnError> {
let storage = self.load_storage()?;
let stored = storage
.configs
.iter()
.find(|c| c.id == id)
.ok_or_else(|| VpnError::NotFound(id.to_string()))?;
let config_data = self.decrypt(&stored.encrypted_data, &stored.nonce)?;
Ok(VpnConfig {
id: stored.id.clone(),
name: stored.name.clone(),
vpn_type: stored.vpn_type,
config_data,
created_at: stored.created_at,
last_used: stored.last_used,
})
}
/// List all VPN configurations (without decrypted config data)
pub fn list_configs(&self) -> Result<Vec<VpnConfig>, VpnError> {
let storage = self.load_storage()?;
Ok(
storage
.configs
.iter()
.map(|stored| VpnConfig {
id: stored.id.clone(),
name: stored.name.clone(),
vpn_type: stored.vpn_type,
config_data: String::new(), // Don't include config data in list
created_at: stored.created_at,
last_used: stored.last_used,
})
.collect(),
)
}
/// Delete a VPN configuration
pub fn delete_config(&self, id: &str) -> Result<(), VpnError> {
let mut storage = self.load_storage()?;
let initial_len = storage.configs.len();
storage.configs.retain(|c| c.id != id);
if storage.configs.len() == initial_len {
return Err(VpnError::NotFound(id.to_string()));
}
self.save_storage(&storage)
}
/// Update last_used timestamp
pub fn update_last_used(&self, id: &str) -> Result<(), VpnError> {
let mut storage = self.load_storage()?;
if let Some(config) = storage.configs.iter_mut().find(|c| c.id == id) {
config.last_used = Some(Utc::now().timestamp());
self.save_storage(&storage)
} else {
Err(VpnError::NotFound(id.to_string()))
}
}
/// Import a VPN config from raw content
pub fn import_config(
&self,
content: &str,
filename: &str,
name: Option<String>,
) -> Result<VpnConfig, VpnError> {
let vpn_type = super::detect_vpn_type(content, filename)?;
// Validate the config by parsing it
match vpn_type {
VpnType::WireGuard => {
super::parse_wireguard_config(content)?;
}
VpnType::OpenVPN => {
super::parse_openvpn_config(content)?;
}
}
let id = Uuid::new_v4().to_string();
let display_name = name.unwrap_or_else(|| {
// Generate name from filename
let base = filename.trim_end_matches(".conf").trim_end_matches(".ovpn");
format!("{} ({})", base, vpn_type)
});
let config = VpnConfig {
id,
name: display_name,
vpn_type,
config_data: content.to_string(),
created_at: Utc::now().timestamp(),
last_used: None,
};
self.save_config(&config)?;
Ok(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_storage() -> (VpnStorage, TempDir) {
let temp_dir = TempDir::new().unwrap();
let mut storage = VpnStorage::new();
storage.storage_path = temp_dir.path().join("test_vpn_configs.json");
(storage, temp_dir)
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let (storage, _temp) = create_test_storage();
let original = "This is a secret VPN configuration";
let (encrypted, nonce) = storage.encrypt(original).unwrap();
let decrypted = storage.decrypt(&encrypted, &nonce).unwrap();
assert_eq!(original, decrypted);
}
#[test]
fn test_save_and_load_config() {
let (storage, _temp) = create_test_storage();
let config = VpnConfig {
id: "test-id-123".to_string(),
name: "Test VPN".to_string(),
vpn_type: VpnType::WireGuard,
config_data: "[Interface]\nPrivateKey = test\n[Peer]\nPublicKey = peer".to_string(),
created_at: 1234567890,
last_used: None,
};
storage.save_config(&config).unwrap();
let loaded = storage.load_config("test-id-123").unwrap();
assert_eq!(loaded.id, config.id);
assert_eq!(loaded.name, config.name);
assert_eq!(loaded.vpn_type, config.vpn_type);
assert_eq!(loaded.config_data, config.config_data);
}
#[test]
fn test_list_configs() {
let (storage, _temp) = create_test_storage();
let config1 = VpnConfig {
id: "id-1".to_string(),
name: "VPN 1".to_string(),
vpn_type: VpnType::WireGuard,
config_data: "secret1".to_string(),
created_at: 1000,
last_used: None,
};
let config2 = VpnConfig {
id: "id-2".to_string(),
name: "VPN 2".to_string(),
vpn_type: VpnType::OpenVPN,
config_data: "secret2".to_string(),
created_at: 2000,
last_used: Some(3000),
};
storage.save_config(&config1).unwrap();
storage.save_config(&config2).unwrap();
let configs = storage.list_configs().unwrap();
assert_eq!(configs.len(), 2);
// Config data should be empty in listing
assert!(configs[0].config_data.is_empty());
assert!(configs[1].config_data.is_empty());
}
#[test]
fn test_delete_config() {
let (storage, _temp) = create_test_storage();
let config = VpnConfig {
id: "delete-me".to_string(),
name: "To Delete".to_string(),
vpn_type: VpnType::WireGuard,
config_data: "data".to_string(),
created_at: 1000,
last_used: None,
};
storage.save_config(&config).unwrap();
assert!(storage.load_config("delete-me").is_ok());
storage.delete_config("delete-me").unwrap();
assert!(storage.load_config("delete-me").is_err());
}
#[test]
fn test_load_nonexistent_config() {
let (storage, _temp) = create_test_storage();
let result = storage.load_config("nonexistent");
assert!(result.is_err());
}
}
+256
View File
@@ -0,0 +1,256 @@
//! VPN tunnel trait and management.
use super::config::{VpnError, VpnStatus};
use async_trait::async_trait;
use std::collections::HashMap;
/// Trait for VPN tunnel implementations
#[async_trait]
pub trait VpnTunnel: Send + Sync {
/// Connect the VPN tunnel
async fn connect(&mut self) -> Result<(), VpnError>;
/// Disconnect the VPN tunnel
async fn disconnect(&mut self) -> Result<(), VpnError>;
/// Check if the tunnel is connected
fn is_connected(&self) -> bool;
/// Get the VPN config ID
fn vpn_id(&self) -> &str;
/// Get the current status of the tunnel
fn get_status(&self) -> VpnStatus;
/// Get bytes sent through the tunnel
fn bytes_sent(&self) -> u64;
/// Get bytes received through the tunnel
fn bytes_received(&self) -> u64;
}
/// Manager for active VPN tunnels
pub struct TunnelManager {
active_tunnels: HashMap<String, Box<dyn VpnTunnel>>,
}
impl Default for TunnelManager {
fn default() -> Self {
Self::new()
}
}
impl TunnelManager {
/// Create a new tunnel manager
pub fn new() -> Self {
Self {
active_tunnels: HashMap::new(),
}
}
/// Register an active tunnel
pub fn register_tunnel(&mut self, vpn_id: String, tunnel: Box<dyn VpnTunnel>) {
self.active_tunnels.insert(vpn_id, tunnel);
}
/// Remove a tunnel from management
pub fn remove_tunnel(&mut self, vpn_id: &str) -> Option<Box<dyn VpnTunnel>> {
self.active_tunnels.remove(vpn_id)
}
/// Get a reference to an active tunnel
pub fn get_tunnel(&self, vpn_id: &str) -> Option<&dyn VpnTunnel> {
self.active_tunnels.get(vpn_id).map(|t| t.as_ref())
}
/// Get a mutable reference to an active tunnel
pub fn get_tunnel_mut(&mut self, vpn_id: &str) -> Option<&mut Box<dyn VpnTunnel>> {
self.active_tunnels.get_mut(vpn_id)
}
/// Check if a tunnel is active
pub fn is_tunnel_active(&self, vpn_id: &str) -> bool {
self
.active_tunnels
.get(vpn_id)
.is_some_and(|t| t.is_connected())
}
/// Get status of all active tunnels
pub fn get_all_statuses(&self) -> Vec<VpnStatus> {
self
.active_tunnels
.values()
.map(|t| t.get_status())
.collect()
}
/// Disconnect all active tunnels
pub async fn disconnect_all(&mut self) -> Vec<Result<(), VpnError>> {
let mut results = Vec::new();
for tunnel in self.active_tunnels.values_mut() {
results.push(tunnel.disconnect().await);
}
self.active_tunnels.clear();
results
}
/// Get the number of active tunnels
pub fn active_count(&self) -> usize {
self
.active_tunnels
.values()
.filter(|t| t.is_connected())
.count()
}
/// List IDs of all active VPN connections
pub fn list_active_ids(&self) -> Vec<String> {
self
.active_tunnels
.iter()
.filter(|(_, t)| t.is_connected())
.map(|(id, _)| id.clone())
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockTunnel {
id: String,
connected: bool,
bytes_sent: u64,
bytes_received: u64,
}
#[async_trait]
impl VpnTunnel for MockTunnel {
async fn connect(&mut self) -> Result<(), VpnError> {
self.connected = true;
Ok(())
}
async fn disconnect(&mut self) -> Result<(), VpnError> {
self.connected = false;
Ok(())
}
fn is_connected(&self) -> bool {
self.connected
}
fn vpn_id(&self) -> &str {
&self.id
}
fn get_status(&self) -> VpnStatus {
VpnStatus {
connected: self.connected,
vpn_id: self.id.clone(),
connected_at: if self.connected { Some(1000) } else { None },
bytes_sent: Some(self.bytes_sent),
bytes_received: Some(self.bytes_received),
last_handshake: None,
}
}
fn bytes_sent(&self) -> u64 {
self.bytes_sent
}
fn bytes_received(&self) -> u64 {
self.bytes_received
}
}
#[test]
fn test_tunnel_manager_register() {
let mut manager = TunnelManager::new();
let tunnel = Box::new(MockTunnel {
id: "test-1".to_string(),
connected: true,
bytes_sent: 100,
bytes_received: 200,
});
manager.register_tunnel("test-1".to_string(), tunnel);
assert!(manager.is_tunnel_active("test-1"));
assert!(!manager.is_tunnel_active("test-2"));
}
#[test]
fn test_tunnel_manager_remove() {
let mut manager = TunnelManager::new();
let tunnel = Box::new(MockTunnel {
id: "test-1".to_string(),
connected: true,
bytes_sent: 0,
bytes_received: 0,
});
manager.register_tunnel("test-1".to_string(), tunnel);
assert!(manager.is_tunnel_active("test-1"));
let removed = manager.remove_tunnel("test-1");
assert!(removed.is_some());
assert!(!manager.is_tunnel_active("test-1"));
}
#[test]
fn test_tunnel_manager_active_count() {
let mut manager = TunnelManager::new();
let tunnel1 = Box::new(MockTunnel {
id: "t1".to_string(),
connected: true,
bytes_sent: 0,
bytes_received: 0,
});
let tunnel2 = Box::new(MockTunnel {
id: "t2".to_string(),
connected: false,
bytes_sent: 0,
bytes_received: 0,
});
manager.register_tunnel("t1".to_string(), tunnel1);
manager.register_tunnel("t2".to_string(), tunnel2);
assert_eq!(manager.active_count(), 1);
}
#[tokio::test]
async fn test_tunnel_manager_disconnect_all() {
let mut manager = TunnelManager::new();
let tunnel1 = Box::new(MockTunnel {
id: "t1".to_string(),
connected: true,
bytes_sent: 0,
bytes_received: 0,
});
let tunnel2 = Box::new(MockTunnel {
id: "t2".to_string(),
connected: true,
bytes_sent: 0,
bytes_received: 0,
});
manager.register_tunnel("t1".to_string(), tunnel1);
manager.register_tunnel("t2".to_string(), tunnel2);
assert_eq!(manager.active_count(), 2);
let results = manager.disconnect_all().await;
assert_eq!(results.len(), 2);
assert!(results.iter().all(|r| r.is_ok()));
assert_eq!(manager.active_count(), 0);
}
}
+413
View File
@@ -0,0 +1,413 @@
//! WireGuard tunnel implementation using boringtun.
use super::config::{VpnError, VpnStatus, WireGuardConfig};
use super::tunnel::VpnTunnel;
use async_trait::async_trait;
use boringtun::noise::{Tunn, TunnResult};
use boringtun::x25519::{PublicKey, StaticSecret};
use chrono::Utc;
use std::net::{SocketAddr, ToSocketAddrs, UdpSocket};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::Mutex;
/// WireGuard tunnel implementation
pub struct WireGuardTunnel {
vpn_id: String,
config: WireGuardConfig,
tunnel: Option<Arc<Mutex<Box<Tunn>>>>,
socket: Option<Arc<UdpSocket>>,
connected: AtomicBool,
connected_at: Option<i64>,
bytes_sent: AtomicU64,
bytes_received: AtomicU64,
last_handshake: Option<i64>,
peer_addr: Option<SocketAddr>,
}
impl WireGuardTunnel {
/// Create a new WireGuard tunnel
pub fn new(vpn_id: String, config: WireGuardConfig) -> Self {
Self {
vpn_id,
config,
tunnel: None,
socket: None,
connected: AtomicBool::new(false),
connected_at: None,
bytes_sent: AtomicU64::new(0),
bytes_received: AtomicU64::new(0),
last_handshake: None,
peer_addr: None,
}
}
/// Parse base64 key to bytes
fn parse_key(key: &str) -> Result<[u8; 32], VpnError> {
let decoded = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, key)
.map_err(|e| VpnError::InvalidWireGuard(format!("Invalid key encoding: {e}")))?;
if decoded.len() != 32 {
return Err(VpnError::InvalidWireGuard(format!(
"Invalid key length: {} (expected 32)",
decoded.len()
)));
}
let mut key_bytes = [0u8; 32];
key_bytes.copy_from_slice(&decoded);
Ok(key_bytes)
}
/// Initialize the WireGuard tunnel
fn init_tunnel(&mut self) -> Result<(), VpnError> {
// Parse private key
let private_key_bytes = Self::parse_key(&self.config.private_key)?;
let static_private = StaticSecret::from(private_key_bytes);
// Parse peer public key
let peer_public_bytes = Self::parse_key(&self.config.peer_public_key)?;
let peer_public = PublicKey::from(peer_public_bytes);
// Parse optional preshared key
let preshared_key = if let Some(ref psk) = self.config.preshared_key {
Some(Self::parse_key(psk)?)
} else {
None
};
// Create the boringtun tunnel
let tunn = Tunn::new(
static_private,
peer_public,
preshared_key,
self.config.persistent_keepalive,
0, // index
None,
)
.map_err(|e| VpnError::Tunnel(format!("Failed to create tunnel: {e}")))?;
self.tunnel = Some(Arc::new(Mutex::new(Box::new(tunn))));
Ok(())
}
/// Resolve peer endpoint to socket address
fn resolve_endpoint(&mut self) -> Result<SocketAddr, VpnError> {
let endpoint = &self.config.peer_endpoint;
// Try to resolve the endpoint
let addrs: Vec<SocketAddr> = endpoint
.to_socket_addrs()
.map_err(|e| VpnError::Connection(format!("Failed to resolve endpoint '{endpoint}': {e}")))?
.collect();
addrs
.into_iter()
.next()
.ok_or_else(|| VpnError::Connection(format!("No addresses found for endpoint: {endpoint}")))
}
/// Perform WireGuard handshake
async fn handshake(&mut self) -> Result<(), VpnError> {
let tunnel = self
.tunnel
.as_ref()
.ok_or_else(|| VpnError::Tunnel("Tunnel not initialized".to_string()))?;
let socket = self
.socket
.as_ref()
.ok_or_else(|| VpnError::Tunnel("Socket not initialized".to_string()))?;
let peer_addr = self
.peer_addr
.ok_or_else(|| VpnError::Tunnel("Peer address not resolved".to_string()))?;
let mut tunnel_guard = tunnel.lock().await;
// Generate handshake initiation
let mut dst = vec![0u8; 2048];
let result = tunnel_guard.format_handshake_initiation(&mut dst, false);
match result {
TunnResult::WriteToNetwork(packet) => {
socket
.send_to(packet, peer_addr)
.map_err(|e| VpnError::Connection(format!("Failed to send handshake: {e}")))?;
self
.bytes_sent
.fetch_add(packet.len() as u64, Ordering::Relaxed);
}
TunnResult::Err(e) => {
return Err(VpnError::Tunnel(format!(
"Handshake initiation failed: {e:?}"
)));
}
_ => {}
}
// Wait for handshake response (with timeout)
socket
.set_read_timeout(Some(std::time::Duration::from_secs(10)))
.map_err(|e| VpnError::Connection(format!("Failed to set timeout: {e}")))?;
let mut recv_buf = vec![0u8; 2048];
match socket.recv_from(&mut recv_buf) {
Ok((len, _from)) => {
self.bytes_received.fetch_add(len as u64, Ordering::Relaxed);
let result = tunnel_guard.decapsulate(None, &recv_buf[..len], &mut dst);
match result {
TunnResult::WriteToNetwork(response) => {
socket
.send_to(response, peer_addr)
.map_err(|e| VpnError::Connection(format!("Failed to send response: {e}")))?;
self
.bytes_sent
.fetch_add(response.len() as u64, Ordering::Relaxed);
self.last_handshake = Some(Utc::now().timestamp());
}
TunnResult::Done => {
self.last_handshake = Some(Utc::now().timestamp());
}
TunnResult::Err(e) => {
return Err(VpnError::Tunnel(format!(
"Handshake response failed: {e:?}"
)));
}
_ => {}
}
}
Err(e) => {
return Err(VpnError::Connection(format!(
"Handshake timeout or error: {e}"
)));
}
}
Ok(())
}
/// Encrypt and send data through the tunnel
pub async fn send(&self, data: &[u8]) -> Result<(), VpnError> {
let tunnel = self
.tunnel
.as_ref()
.ok_or_else(|| VpnError::Tunnel("Tunnel not initialized".to_string()))?;
let socket = self
.socket
.as_ref()
.ok_or_else(|| VpnError::Tunnel("Socket not initialized".to_string()))?;
let peer_addr = self
.peer_addr
.ok_or_else(|| VpnError::Tunnel("Peer address not resolved".to_string()))?;
let mut tunnel_guard = tunnel.lock().await;
let mut dst = vec![0u8; data.len() + 256]; // Extra space for WireGuard overhead
let result = tunnel_guard.encapsulate(data, &mut dst);
match result {
TunnResult::WriteToNetwork(packet) => {
socket
.send_to(packet, peer_addr)
.map_err(|e| VpnError::Connection(format!("Failed to send data: {e}")))?;
self
.bytes_sent
.fetch_add(packet.len() as u64, Ordering::Relaxed);
}
TunnResult::Err(e) => {
return Err(VpnError::Tunnel(format!("Encryption failed: {e:?}")));
}
_ => {}
}
Ok(())
}
/// Receive and decrypt data from the tunnel
pub async fn receive(&self, buf: &mut [u8]) -> Result<usize, VpnError> {
let tunnel = self
.tunnel
.as_ref()
.ok_or_else(|| VpnError::Tunnel("Tunnel not initialized".to_string()))?;
let socket = self
.socket
.as_ref()
.ok_or_else(|| VpnError::Tunnel("Socket not initialized".to_string()))?;
let mut recv_buf = vec![0u8; 2048];
let (len, _from) = socket
.recv_from(&mut recv_buf)
.map_err(|e| VpnError::Connection(format!("Receive failed: {e}")))?;
self.bytes_received.fetch_add(len as u64, Ordering::Relaxed);
let mut tunnel_guard = tunnel.lock().await;
// decapsulate writes decrypted data directly to buf and returns a slice pointing to it
let result = tunnel_guard.decapsulate(None, &recv_buf[..len], buf);
match result {
// Data is already written to buf by decapsulate, just return the length
TunnResult::WriteToTunnelV4(decrypted, _) => Ok(decrypted.len()),
TunnResult::WriteToTunnelV6(decrypted, _) => Ok(decrypted.len()),
TunnResult::Done => Ok(0),
TunnResult::Err(e) => Err(VpnError::Tunnel(format!("Decryption failed: {e:?}"))),
_ => Ok(0),
}
}
}
#[async_trait]
impl VpnTunnel for WireGuardTunnel {
async fn connect(&mut self) -> Result<(), VpnError> {
if self.connected.load(Ordering::Relaxed) {
return Ok(());
}
// Initialize the tunnel
self.init_tunnel()?;
// Resolve endpoint
self.peer_addr = Some(self.resolve_endpoint()?);
// Create UDP socket
let socket = UdpSocket::bind("0.0.0.0:0")
.map_err(|e| VpnError::Connection(format!("Failed to create socket: {e}")))?;
socket
.set_nonblocking(false)
.map_err(|e| VpnError::Connection(format!("Failed to set socket options: {e}")))?;
self.socket = Some(Arc::new(socket));
// Perform handshake
self.handshake().await?;
self.connected.store(true, Ordering::Release);
self.connected_at = Some(Utc::now().timestamp());
log::info!("[vpn] WireGuard tunnel {} connected", self.vpn_id);
Ok(())
}
async fn disconnect(&mut self) -> Result<(), VpnError> {
if !self.connected.load(Ordering::Relaxed) {
return Ok(());
}
self.connected.store(false, Ordering::Release);
self.tunnel = None;
self.socket = None;
self.connected_at = None;
log::info!("[vpn] WireGuard tunnel {} disconnected", self.vpn_id);
Ok(())
}
fn is_connected(&self) -> bool {
self.connected.load(Ordering::Acquire)
}
fn vpn_id(&self) -> &str {
&self.vpn_id
}
fn get_status(&self) -> VpnStatus {
VpnStatus {
connected: self.is_connected(),
vpn_id: self.vpn_id.clone(),
connected_at: self.connected_at,
bytes_sent: Some(self.bytes_sent.load(Ordering::Relaxed)),
bytes_received: Some(self.bytes_received.load(Ordering::Relaxed)),
last_handshake: self.last_handshake,
}
}
fn bytes_sent(&self) -> u64 {
self.bytes_sent.load(Ordering::Relaxed)
}
fn bytes_received(&self) -> u64 {
self.bytes_received.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_config() -> WireGuardConfig {
WireGuardConfig {
// These are test keys, not real ones
private_key: "YEocP0e2o1WT5GlvBvQzVF7EeR6z9aCk+ZdZ5NKEuXA=".to_string(),
address: "10.0.0.2/24".to_string(),
dns: Some("1.1.1.1".to_string()),
mtu: Some(1420),
peer_public_key: "aGnF7JlG+U5t0BqB1PVf1yOuELHrWLGGcUJb0eCK9Aw=".to_string(),
peer_endpoint: "127.0.0.1:51820".to_string(),
allowed_ips: vec!["0.0.0.0/0".to_string()],
persistent_keepalive: Some(25),
preshared_key: None,
}
}
#[test]
fn test_wireguard_tunnel_creation() {
let config = create_test_config();
let tunnel = WireGuardTunnel::new("test-wg-1".to_string(), config);
assert_eq!(tunnel.vpn_id(), "test-wg-1");
assert!(!tunnel.is_connected());
assert_eq!(tunnel.bytes_sent(), 0);
assert_eq!(tunnel.bytes_received(), 0);
}
#[test]
fn test_parse_key_valid() {
// Valid base64-encoded 32-byte key
let key = "YEocP0e2o1WT5GlvBvQzVF7EeR6z9aCk+ZdZ5NKEuXA=";
let result = WireGuardTunnel::parse_key(key);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 32);
}
#[test]
fn test_parse_key_invalid_base64() {
let key = "not-valid-base64!!!";
let result = WireGuardTunnel::parse_key(key);
assert!(result.is_err());
}
#[test]
fn test_parse_key_wrong_length() {
// Valid base64 but wrong length
let key = "YWJjZA=="; // "abcd" in base64
let result = WireGuardTunnel::parse_key(key);
assert!(result.is_err());
}
#[test]
fn test_wireguard_status() {
let config = create_test_config();
let tunnel = WireGuardTunnel::new("test-wg-2".to_string(), config);
let status = tunnel.get_status();
assert!(!status.connected);
assert_eq!(status.vpn_id, "test-wg-2");
assert!(status.connected_at.is_none());
assert_eq!(status.bytes_sent, Some(0));
assert_eq!(status.bytes_received, Some(0));
}
}
+6
View File
@@ -97,6 +97,12 @@ impl WayfernTermsManager {
timestamp >= MIN_VALID_TIMESTAMP
}
pub fn is_wayfern_downloaded(&self) -> bool {
let registry = DownloadedBrowsersRegistry::instance();
let versions = registry.get_downloaded_versions("wayfern");
!versions.is_empty()
}
fn get_any_wayfern_executable(&self) -> Option<PathBuf> {
// First try to get executable from any downloaded Wayfern version
let registry = DownloadedBrowsersRegistry::instance();