mirror of
https://github.com/zhom/donutbrowser.git
synced 2026-06-09 16:33:58 +02:00
feat: daemon support, general improvement, and preparation for Windows release
This commit is contained in:
@@ -9,10 +9,10 @@ use std::path::PathBuf;
|
||||
use std::process;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::mpsc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use muda::MenuEvent;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use single_instance::SingleInstance;
|
||||
use tao::event::{Event, StartCause};
|
||||
use tao::event_loop::{ControlFlow, EventLoopBuilder};
|
||||
use tokio::runtime::Runtime;
|
||||
@@ -69,52 +69,6 @@ fn write_state(state: &DaemonState) -> std::io::Result<()> {
|
||||
fs::write(path, content)
|
||||
}
|
||||
|
||||
fn detach_from_parent() {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
unsafe {
|
||||
libc::setsid();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_detached() {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
match unsafe { libc::fork() } {
|
||||
-1 => {
|
||||
eprintln!("Fork failed");
|
||||
process::exit(1);
|
||||
}
|
||||
0 => {
|
||||
detach_from_parent();
|
||||
}
|
||||
_ => {
|
||||
process::exit(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
{
|
||||
use std::os::windows::process::CommandExt;
|
||||
use std::process::{Command, Stdio};
|
||||
const DETACHED_PROCESS: u32 = 0x00000008;
|
||||
const CREATE_NEW_PROCESS_GROUP: u32 = 0x00000200;
|
||||
let current_exe = env::current_exe().expect("Failed to get current exe path");
|
||||
|
||||
let _ = Command::new(current_exe)
|
||||
.arg("--daemon-internal")
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.creation_flags(DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP)
|
||||
.spawn();
|
||||
|
||||
process::exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_high_priority() {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
@@ -174,13 +128,6 @@ fn run_daemon() {
|
||||
process::exit(1);
|
||||
}
|
||||
|
||||
let instance =
|
||||
SingleInstance::new("donut-browser-daemon").expect("Failed to create single instance lock");
|
||||
if !instance.is_single() {
|
||||
eprintln!("Daemon is already running");
|
||||
process::exit(1);
|
||||
}
|
||||
|
||||
log::info!("[daemon] Starting with PID {}", process::id());
|
||||
|
||||
// Create tokio runtime for async operations
|
||||
@@ -231,7 +178,8 @@ fn run_daemon() {
|
||||
|
||||
// Run the event loop
|
||||
event_loop.run(move |event, _, control_flow| {
|
||||
*control_flow = ControlFlow::Poll;
|
||||
// Use WaitUntil to check for menu events periodically while staying low on CPU
|
||||
*control_flow = ControlFlow::WaitUntil(Instant::now() + Duration::from_millis(100));
|
||||
|
||||
match event {
|
||||
Event::NewEvents(StartCause::Init) => {
|
||||
@@ -290,7 +238,8 @@ fn run_daemon() {
|
||||
}
|
||||
}
|
||||
|
||||
if SHOULD_QUIT.load(Ordering::SeqCst) {
|
||||
// Use swap to only run cleanup once
|
||||
if SHOULD_QUIT.swap(false, Ordering::SeqCst) {
|
||||
// Cleanup
|
||||
let mut state = read_state();
|
||||
state.daemon_pid = None;
|
||||
@@ -405,8 +354,10 @@ fn main() {
|
||||
|
||||
match args[1].as_str() {
|
||||
"start" => {
|
||||
// "start" is now an alias for "run"
|
||||
// On macOS, the daemon should be started via launchctl (see daemon_spawn.rs)
|
||||
// This command is kept for backward compatibility
|
||||
eprintln!("Starting daemon...");
|
||||
spawn_detached();
|
||||
run_daemon();
|
||||
}
|
||||
"stop" => {
|
||||
@@ -418,9 +369,6 @@ fn main() {
|
||||
"run" => {
|
||||
run_daemon();
|
||||
}
|
||||
"--daemon-internal" => {
|
||||
run_daemon();
|
||||
}
|
||||
"autostart" => {
|
||||
if args.len() < 3 {
|
||||
eprintln!("Usage: donut-daemon autostart <enable|disable|status>");
|
||||
|
||||
@@ -80,6 +80,12 @@ pub fn enable_autostart() -> io::Result<()> {
|
||||
|
||||
let plist_path = plist_dir.join("com.donutbrowser.daemon.plist");
|
||||
|
||||
// Get log directory (use data directory instead of /tmp)
|
||||
let log_dir = get_data_dir()
|
||||
.unwrap_or_else(|| PathBuf::from("/tmp"))
|
||||
.join("logs");
|
||||
fs::create_dir_all(&log_dir)?;
|
||||
|
||||
let plist_content = format!(
|
||||
r#"<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
@@ -89,21 +95,29 @@ pub fn enable_autostart() -> io::Result<()> {
|
||||
<string>com.donutbrowser.daemon</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>{}</string>
|
||||
<string>start</string>
|
||||
<string>{daemon_path}</string>
|
||||
<string>run</string>
|
||||
</array>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>LimitLoadToSessionType</key>
|
||||
<string>Aqua</string>
|
||||
<key>KeepAlive</key>
|
||||
<false/>
|
||||
<dict>
|
||||
<key>SuccessfulExit</key>
|
||||
<false/>
|
||||
</dict>
|
||||
<key>ProcessType</key>
|
||||
<string>Interactive</string>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/tmp/donut-daemon.out.log</string>
|
||||
<string>{log_dir}/daemon.out.log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/tmp/donut-daemon.err.log</string>
|
||||
<string>{log_dir}/daemon.err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
"#,
|
||||
daemon_path.display()
|
||||
daemon_path = daemon_path.display(),
|
||||
log_dir = log_dir.display()
|
||||
);
|
||||
|
||||
fs::write(&plist_path, plist_content)?;
|
||||
@@ -112,13 +126,19 @@ pub fn enable_autostart() -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub fn get_plist_path() -> Option<PathBuf> {
|
||||
dirs::home_dir().map(|h| h.join("Library/LaunchAgents/com.donutbrowser.daemon.plist"))
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub fn disable_autostart() -> io::Result<()> {
|
||||
let plist_path = dirs::home_dir()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Home directory not found"))?
|
||||
.join("Library/LaunchAgents/com.donutbrowser.daemon.plist");
|
||||
let plist_path = get_plist_path()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Home directory not found"))?;
|
||||
|
||||
if plist_path.exists() {
|
||||
// First unload the launch agent if it's loaded
|
||||
let _ = unload_launch_agent();
|
||||
fs::remove_file(&plist_path)?;
|
||||
log::info!("Removed launch agent at {:?}", plist_path);
|
||||
}
|
||||
@@ -128,12 +148,71 @@ pub fn disable_autostart() -> io::Result<()> {
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub fn is_autostart_enabled() -> bool {
|
||||
dirs::home_dir()
|
||||
.map(|h| {
|
||||
h.join("Library/LaunchAgents/com.donutbrowser.daemon.plist")
|
||||
.exists()
|
||||
})
|
||||
.unwrap_or(false)
|
||||
get_plist_path().is_some_and(|p| p.exists())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub fn load_launch_agent() -> io::Result<()> {
|
||||
use std::process::Command;
|
||||
|
||||
let plist_path = get_plist_path()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Could not determine plist path"))?;
|
||||
|
||||
if !plist_path.exists() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
"Launch agent plist does not exist",
|
||||
));
|
||||
}
|
||||
|
||||
// Use launchctl load to start the daemon via launchd
|
||||
// The -w flag writes the "disabled" key to the override plist
|
||||
let output = Command::new("launchctl")
|
||||
.args(["load", "-w"])
|
||||
.arg(&plist_path)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
// "already loaded" is not an error condition for us
|
||||
if !stderr.contains("already loaded") {
|
||||
return Err(io::Error::other(format!(
|
||||
"launchctl load failed: {}",
|
||||
stderr
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Loaded launch agent via launchctl");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub fn unload_launch_agent() -> io::Result<()> {
|
||||
use std::process::Command;
|
||||
|
||||
let plist_path = get_plist_path()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "Could not determine plist path"))?;
|
||||
|
||||
if !plist_path.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let output = Command::new("launchctl")
|
||||
.args(["unload"])
|
||||
.arg(&plist_path)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
// Not being loaded is not an error
|
||||
if !stderr.contains("Could not find specified service") {
|
||||
log::warn!("launchctl unload warning: {}", stderr);
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Unloaded launch agent via launchctl");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
|
||||
@@ -6,7 +6,9 @@ use tray_icon::{Icon, TrayIcon, TrayIconBuilder};
|
||||
static GUI_RUNNING: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
pub fn load_icon() -> Icon {
|
||||
let icon_bytes = include_bytes!("../../icons/32x32.png");
|
||||
// Use the generated template icon (44x44 for retina, macOS standard menu bar size)
|
||||
// This is the donut logo converted to template format (black with alpha)
|
||||
let icon_bytes = include_bytes!("../../icons/tray-icon-44.png");
|
||||
|
||||
let image = image::load_from_memory(icon_bytes)
|
||||
.expect("Failed to load icon")
|
||||
@@ -89,12 +91,16 @@ impl TrayMenu {
|
||||
}
|
||||
|
||||
pub fn create_tray_icon(icon: Icon, menu: &Menu) -> TrayIcon {
|
||||
TrayIconBuilder::new()
|
||||
let builder = TrayIconBuilder::new()
|
||||
.with_icon(icon)
|
||||
.with_tooltip("Donut Browser")
|
||||
.with_menu(Box::new(menu.clone()))
|
||||
.build()
|
||||
.expect("Failed to create tray icon")
|
||||
.with_menu(Box::new(menu.clone()));
|
||||
|
||||
// On macOS, template icons are automatically colored by the system for light/dark mode
|
||||
#[cfg(target_os = "macos")]
|
||||
let builder = builder.with_icon_as_template(true);
|
||||
|
||||
builder.build().expect("Failed to create tray icon")
|
||||
}
|
||||
|
||||
pub fn open_gui() {
|
||||
|
||||
+161
-53
@@ -60,6 +60,38 @@ fn is_daemon_running() -> bool {
|
||||
}
|
||||
}
|
||||
|
||||
fn is_dev_mode() -> bool {
|
||||
if let Ok(current_exe) = std::env::current_exe() {
|
||||
let path_str = current_exe.to_string_lossy();
|
||||
path_str.contains("target/debug") || path_str.contains("target/release")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
fn get_daemon_path() -> Option<PathBuf> {
|
||||
// First try to find the daemon binary next to the current executable
|
||||
if let Ok(current_exe) = std::env::current_exe() {
|
||||
if let Some(exe_dir) = current_exe.parent() {
|
||||
let daemon_path = exe_dir.join("donut-daemon");
|
||||
if daemon_path.exists() {
|
||||
return Some(daemon_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try common installation paths
|
||||
let paths = [
|
||||
PathBuf::from("/Applications/Donut Browser.app/Contents/MacOS/donut-daemon"),
|
||||
dirs::home_dir()
|
||||
.map(|h| h.join("Applications/Donut Browser.app/Contents/MacOS/donut-daemon"))
|
||||
.unwrap_or_default(),
|
||||
];
|
||||
paths.into_iter().find(|path| path.exists())
|
||||
}
|
||||
|
||||
#[cfg(any(target_os = "linux", windows))]
|
||||
fn get_daemon_path() -> Option<PathBuf> {
|
||||
// First, try to find it next to the current executable
|
||||
if let Ok(current_exe) = std::env::current_exe() {
|
||||
@@ -68,25 +100,13 @@ fn get_daemon_path() -> Option<PathBuf> {
|
||||
// Check for daemon binary in same directory
|
||||
#[cfg(target_os = "windows")]
|
||||
let daemon_name = "donut-daemon.exe";
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[cfg(target_os = "linux")]
|
||||
let daemon_name = "donut-daemon";
|
||||
|
||||
let daemon_path = exe_dir.join(daemon_name);
|
||||
if daemon_path.exists() {
|
||||
return Some(daemon_path);
|
||||
}
|
||||
|
||||
// On macOS, check inside the app bundle
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
// If we're in Contents/MacOS, daemon should be there too
|
||||
if exe_dir.ends_with("Contents/MacOS") {
|
||||
let daemon_path = exe_dir.join(daemon_name);
|
||||
if daemon_path.exists() {
|
||||
return Some(daemon_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find it in PATH
|
||||
@@ -101,7 +121,7 @@ fn get_daemon_path() -> Option<PathBuf> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
if let Ok(output) = Command::new("which").arg("donut-daemon").output() {
|
||||
if output.status.success() {
|
||||
@@ -118,60 +138,39 @@ fn get_daemon_path() -> Option<PathBuf> {
|
||||
}
|
||||
|
||||
pub fn spawn_daemon() -> Result<(), String> {
|
||||
// Log the daemon state for debugging
|
||||
let state = read_state();
|
||||
log::info!("Daemon state before spawn: pid={:?}", state.daemon_pid);
|
||||
|
||||
// Check if already running
|
||||
if is_daemon_running() {
|
||||
log::info!("Daemon is already running");
|
||||
log::info!("Daemon is already running (verified by PID check)");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::info!("Daemon is not running, attempting to start...");
|
||||
|
||||
// Log current exe location for debugging
|
||||
let current_exe = std::env::current_exe().ok();
|
||||
log::info!("Current exe: {:?}", current_exe);
|
||||
|
||||
let daemon_path = get_daemon_path().ok_or_else(|| {
|
||||
format!(
|
||||
"Could not find daemon binary. Current exe: {:?}",
|
||||
current_exe
|
||||
)
|
||||
})?;
|
||||
|
||||
log::info!("Spawning daemon from: {:?}", daemon_path);
|
||||
|
||||
// Use "run" instead of "start" - we handle detachment here
|
||||
#[cfg(unix)]
|
||||
// On macOS, use launchctl to start the daemon via launchd
|
||||
// This ensures the daemon runs in the user's Aqua session with WindowServer access
|
||||
// and survives app termination since it's managed by launchd, not as a child process
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
use std::os::unix::process::CommandExt;
|
||||
spawn_daemon_macos()?;
|
||||
}
|
||||
|
||||
// Create a new process group so daemon survives parent exit
|
||||
// Note: We don't call setsid() because on macOS that disconnects from the WindowServer
|
||||
// which prevents the tray icon from appearing. Instead, we just set a new process group.
|
||||
let mut cmd = Command::new(&daemon_path);
|
||||
cmd
|
||||
.arg("run")
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.process_group(0); // Create new process group without new session
|
||||
|
||||
cmd
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to spawn daemon: {}", e))?;
|
||||
// On Linux, use direct spawn
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
spawn_daemon_unix()?;
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
{
|
||||
use std::os::windows::process::CommandExt;
|
||||
const DETACHED_PROCESS: u32 = 0x00000008;
|
||||
const CREATE_NEW_PROCESS_GROUP: u32 = 0x00000200;
|
||||
|
||||
Command::new(&daemon_path)
|
||||
.arg("run")
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.creation_flags(DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP)
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to spawn daemon: {}", e))?;
|
||||
spawn_daemon_windows()?;
|
||||
}
|
||||
|
||||
// Wait for daemon to start (max 3 seconds)
|
||||
@@ -196,6 +195,115 @@ pub fn spawn_daemon() -> Result<(), String> {
|
||||
Err("Daemon did not start within timeout".to_string())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
fn spawn_daemon_macos() -> Result<(), String> {
|
||||
use std::os::unix::process::CommandExt;
|
||||
|
||||
// In dev mode, use direct spawn instead of launchctl
|
||||
// This avoids issues with plist paths pointing to wrong binaries
|
||||
if is_dev_mode() {
|
||||
log::info!("Dev mode detected, using direct spawn instead of launchctl");
|
||||
|
||||
let daemon_path = get_daemon_path().ok_or_else(|| {
|
||||
format!(
|
||||
"Could not find daemon binary. Current exe: {:?}",
|
||||
std::env::current_exe().ok()
|
||||
)
|
||||
})?;
|
||||
|
||||
log::info!("Spawning daemon from: {:?}", daemon_path);
|
||||
|
||||
// Create a new process group so daemon survives parent exit
|
||||
let mut cmd = Command::new(&daemon_path);
|
||||
cmd
|
||||
.arg("run")
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.process_group(0);
|
||||
|
||||
cmd
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to spawn daemon: {}", e))?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Production mode: use launchctl for proper daemon management
|
||||
// First, ensure the LaunchAgent plist is installed
|
||||
let autostart_enabled = autostart::is_autostart_enabled();
|
||||
log::info!("LaunchAgent plist exists: {}", autostart_enabled);
|
||||
|
||||
if !autostart_enabled {
|
||||
log::info!("Installing LaunchAgent plist for daemon management");
|
||||
autostart::enable_autostart().map_err(|e| format!("Failed to install LaunchAgent: {}", e))?;
|
||||
log::info!("LaunchAgent plist installed successfully");
|
||||
}
|
||||
|
||||
// Load the launch agent via launchctl
|
||||
log::info!("Loading daemon via launchctl...");
|
||||
autostart::load_launch_agent().map_err(|e| format!("Failed to load LaunchAgent: {}", e))?;
|
||||
log::info!("launchctl load completed");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
fn spawn_daemon_unix() -> Result<(), String> {
|
||||
use std::os::unix::process::CommandExt;
|
||||
|
||||
let daemon_path = get_daemon_path().ok_or_else(|| {
|
||||
format!(
|
||||
"Could not find daemon binary. Current exe: {:?}",
|
||||
std::env::current_exe().ok()
|
||||
)
|
||||
})?;
|
||||
|
||||
log::info!("Spawning daemon from: {:?}", daemon_path);
|
||||
|
||||
// Create a new process group so daemon survives parent exit
|
||||
let mut cmd = Command::new(&daemon_path);
|
||||
cmd
|
||||
.arg("run")
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.process_group(0);
|
||||
|
||||
cmd
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to spawn daemon: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn spawn_daemon_windows() -> Result<(), String> {
|
||||
use std::os::windows::process::CommandExt;
|
||||
const DETACHED_PROCESS: u32 = 0x00000008;
|
||||
const CREATE_NEW_PROCESS_GROUP: u32 = 0x00000200;
|
||||
|
||||
let daemon_path = get_daemon_path().ok_or_else(|| {
|
||||
format!(
|
||||
"Could not find daemon binary. Current exe: {:?}",
|
||||
std::env::current_exe().ok()
|
||||
)
|
||||
})?;
|
||||
|
||||
log::info!("Spawning daemon from: {:?}", daemon_path);
|
||||
|
||||
Command::new(&daemon_path)
|
||||
.arg("run")
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.creation_flags(DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP)
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to spawn daemon: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn ensure_daemon_running() -> Result<(), String> {
|
||||
if !is_daemon_running() {
|
||||
spawn_daemon()?;
|
||||
|
||||
@@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Mutex;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::api_client::ApiClient;
|
||||
use crate::browser::{create_browser, BrowserType};
|
||||
@@ -13,6 +14,8 @@ use crate::events;
|
||||
lazy_static::lazy_static! {
|
||||
static ref DOWNLOADING_BROWSERS: std::sync::Arc<Mutex<std::collections::HashSet<String>>> =
|
||||
std::sync::Arc::new(Mutex::new(std::collections::HashSet::new()));
|
||||
static ref DOWNLOAD_CANCELLATION_TOKENS: std::sync::Arc<Mutex<std::collections::HashMap<String, CancellationToken>>> =
|
||||
std::sync::Arc::new(Mutex::new(std::collections::HashMap::new()));
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
@@ -438,6 +441,7 @@ impl Downloader {
|
||||
version: &str,
|
||||
download_info: &DownloadInfo,
|
||||
dest_path: &Path,
|
||||
cancel_token: Option<&CancellationToken>,
|
||||
) -> Result<PathBuf, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let file_path = dest_path.join(&download_info.filename);
|
||||
|
||||
@@ -573,6 +577,13 @@ impl Downloader {
|
||||
|
||||
use futures_util::StreamExt;
|
||||
while let Some(chunk) = stream.next().await {
|
||||
if let Some(token) = cancel_token {
|
||||
if token.is_cancelled() {
|
||||
drop(file);
|
||||
let _ = std::fs::remove_file(&file_path);
|
||||
return Err("Download cancelled".into());
|
||||
}
|
||||
}
|
||||
let chunk = chunk?;
|
||||
io::copy(&mut chunk.as_ref(), &mut file)?;
|
||||
downloaded += chunk.len() as u64;
|
||||
@@ -635,21 +646,27 @@ impl Downloader {
|
||||
browser_str: String,
|
||||
version: String,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Check if Wayfern terms have been accepted before allowing any browser downloads
|
||||
if !crate::wayfern_terms::WayfernTermsManager::instance().is_terms_accepted() {
|
||||
// Only check Wayfern terms if Wayfern is already downloaded
|
||||
let terms_manager = crate::wayfern_terms::WayfernTermsManager::instance();
|
||||
if terms_manager.is_wayfern_downloaded() && !terms_manager.is_terms_accepted() {
|
||||
return Err("Please accept Wayfern Terms and Conditions before downloading browsers".into());
|
||||
}
|
||||
|
||||
// Check if this browser-version pair is already being downloaded
|
||||
let download_key = format!("{browser_str}-{version}");
|
||||
{
|
||||
let cancel_token = {
|
||||
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
|
||||
if downloading.contains(&download_key) {
|
||||
return Err(format!("Browser '{browser_str}' version '{version}' is already being downloaded. Please wait for the current download to complete.").into());
|
||||
}
|
||||
// Mark this browser-version pair as being downloaded
|
||||
downloading.insert(download_key.clone());
|
||||
}
|
||||
|
||||
let token = CancellationToken::new();
|
||||
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
|
||||
tokens.insert(download_key.clone(), token.clone());
|
||||
token
|
||||
};
|
||||
|
||||
let browser_type =
|
||||
BrowserType::from_str(&browser_str).map_err(|e| format!("Invalid browser type: {e}"))?;
|
||||
@@ -681,6 +698,9 @@ impl Downloader {
|
||||
// Remove from downloading set since it's already downloaded
|
||||
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
|
||||
downloading.remove(&download_key);
|
||||
drop(downloading);
|
||||
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
|
||||
tokens.remove(&download_key);
|
||||
return Ok(version);
|
||||
} else {
|
||||
// Registry says it's downloaded but files don't exist - clean up registry
|
||||
@@ -702,6 +722,9 @@ impl Downloader {
|
||||
// Remove from downloading set on error
|
||||
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
|
||||
downloading.remove(&download_key);
|
||||
drop(downloading);
|
||||
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
|
||||
tokens.remove(&download_key);
|
||||
return Err(
|
||||
format!(
|
||||
"Browser '{}' is not supported on your platform ({} {}). Supported browsers: {}",
|
||||
@@ -741,6 +764,7 @@ impl Downloader {
|
||||
&version,
|
||||
&download_info,
|
||||
&browser_dir,
|
||||
Some(&cancel_token),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -752,6 +776,25 @@ impl Downloader {
|
||||
let _ = self.registry.save();
|
||||
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
|
||||
downloading.remove(&download_key);
|
||||
drop(downloading);
|
||||
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
|
||||
tokens.remove(&download_key);
|
||||
|
||||
// Emit cancelled stage if the download was cancelled by user
|
||||
if cancel_token.is_cancelled() {
|
||||
let progress = DownloadProgress {
|
||||
browser: browser_str.clone(),
|
||||
version: version.clone(),
|
||||
downloaded_bytes: 0,
|
||||
total_bytes: None,
|
||||
percentage: 0.0,
|
||||
speed_bytes_per_sec: 0.0,
|
||||
eta_seconds: None,
|
||||
stage: "cancelled".to_string(),
|
||||
};
|
||||
let _ = events::emit("download-progress", &progress);
|
||||
}
|
||||
|
||||
return Err(format!("Failed to download browser: {e}").into());
|
||||
}
|
||||
};
|
||||
@@ -782,6 +825,10 @@ impl Downloader {
|
||||
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
|
||||
downloading.remove(&download_key);
|
||||
}
|
||||
{
|
||||
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
|
||||
tokens.remove(&download_key);
|
||||
}
|
||||
return Err(format!("Failed to extract browser: {e}").into());
|
||||
}
|
||||
}
|
||||
@@ -869,6 +916,10 @@ impl Downloader {
|
||||
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
|
||||
downloading.remove(&download_key);
|
||||
}
|
||||
{
|
||||
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
|
||||
tokens.remove(&download_key);
|
||||
}
|
||||
return Err(error_details.into());
|
||||
}
|
||||
|
||||
@@ -941,11 +992,15 @@ impl Downloader {
|
||||
};
|
||||
let _ = events::emit("download-progress", &progress);
|
||||
|
||||
// Remove browser-version pair from downloading set
|
||||
// Remove browser-version pair from downloading set and cancel token
|
||||
{
|
||||
let mut downloading = DOWNLOADING_BROWSERS.lock().unwrap();
|
||||
downloading.remove(&download_key);
|
||||
}
|
||||
{
|
||||
let mut tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
|
||||
tokens.remove(&download_key);
|
||||
}
|
||||
|
||||
Ok(version)
|
||||
}
|
||||
@@ -964,6 +1019,24 @@ pub async fn download_browser(
|
||||
.map_err(|e| format!("Failed to download browser: {e}"))
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn cancel_download(browser_str: String, version: String) -> Result<(), String> {
|
||||
let download_key = format!("{browser_str}-{version}");
|
||||
let token = {
|
||||
let tokens = DOWNLOAD_CANCELLATION_TOKENS.lock().unwrap();
|
||||
tokens.get(&download_key).cloned()
|
||||
};
|
||||
|
||||
if let Some(token) = token {
|
||||
token.cancel();
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!(
|
||||
"No active download found for {browser_str} {version}"
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -1074,6 +1147,7 @@ mod tests {
|
||||
"139.0",
|
||||
&download_info,
|
||||
dest_path,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1118,6 +1192,7 @@ mod tests {
|
||||
"139.0",
|
||||
&download_info,
|
||||
dest_path,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1163,6 +1238,7 @@ mod tests {
|
||||
"1465660",
|
||||
&download_info,
|
||||
dest_path,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
|
||||
+248
-4
@@ -47,6 +47,7 @@ pub mod events;
|
||||
mod mcp_server;
|
||||
mod tag_manager;
|
||||
mod version_updater;
|
||||
pub mod vpn;
|
||||
|
||||
use browser_runner::{
|
||||
check_browser_exists, kill_browser_profile, launch_browser_profile, open_url_with_profile,
|
||||
@@ -68,12 +69,12 @@ use downloaded_browsers_registry::{
|
||||
check_missing_binaries, ensure_all_binaries_exist, get_downloaded_browser_versions,
|
||||
};
|
||||
|
||||
use downloader::download_browser;
|
||||
use downloader::{cancel_download, download_browser};
|
||||
|
||||
use settings_manager::{
|
||||
decline_launch_on_login, enable_launch_on_login, get_app_settings, get_sync_settings,
|
||||
get_table_sorting_settings, save_app_settings, save_sync_settings, save_table_sorting_settings,
|
||||
should_show_launch_on_login_prompt,
|
||||
get_system_language, get_table_sorting_settings, save_app_settings, save_sync_settings,
|
||||
save_table_sorting_settings, should_show_launch_on_login_prompt,
|
||||
};
|
||||
|
||||
use sync::{
|
||||
@@ -232,6 +233,41 @@ fn get_cached_proxy_check(proxy_id: String) -> Option<crate::proxy_manager::Prox
|
||||
crate::proxy_manager::PROXY_MANAGER.get_cached_proxy_check(&proxy_id)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn export_proxies(format: String) -> Result<String, String> {
|
||||
match format.as_str() {
|
||||
"json" => crate::proxy_manager::PROXY_MANAGER.export_proxies_json(),
|
||||
"txt" => Ok(crate::proxy_manager::PROXY_MANAGER.export_proxies_txt()),
|
||||
_ => Err(format!("Unsupported export format: {format}")),
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn import_proxies_json(
|
||||
app_handle: tauri::AppHandle,
|
||||
content: String,
|
||||
) -> Result<crate::proxy_manager::ProxyImportResult, String> {
|
||||
crate::proxy_manager::PROXY_MANAGER
|
||||
.import_proxies_json(&app_handle, &content)
|
||||
.map_err(|e| format!("Failed to import proxies: {e}"))
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn parse_txt_proxies(content: String) -> Vec<crate::proxy_manager::ProxyParseResult> {
|
||||
crate::proxy_manager::ProxyManager::parse_txt_proxies(&content)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn import_proxies_from_parsed(
|
||||
app_handle: tauri::AppHandle,
|
||||
parsed_proxies: Vec<crate::proxy_manager::ParsedProxyLine>,
|
||||
name_prefix: Option<String>,
|
||||
) -> Result<crate::proxy_manager::ProxyImportResult, String> {
|
||||
crate::proxy_manager::PROXY_MANAGER
|
||||
.import_proxies_from_parsed(&app_handle, parsed_proxies, name_prefix)
|
||||
.map_err(|e| format!("Failed to import proxies: {e}"))
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn read_profile_cookies(profile_id: String) -> Result<cookie_manager::CookieReadResult, String> {
|
||||
cookie_manager::CookieManager::read_cookies(&profile_id)
|
||||
@@ -250,6 +286,11 @@ fn check_wayfern_terms_accepted() -> bool {
|
||||
wayfern_terms::WayfernTermsManager::instance().is_terms_accepted()
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn check_wayfern_downloaded() -> bool {
|
||||
wayfern_terms::WayfernTermsManager::instance().is_wayfern_downloaded()
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn accept_wayfern_terms() -> Result<(), String> {
|
||||
wayfern_terms::WayfernTermsManager::instance()
|
||||
@@ -374,6 +415,172 @@ async fn download_geoip_database(app_handle: tauri::AppHandle) -> Result<(), Str
|
||||
.map_err(|e| format!("Failed to download GeoIP database: {e}"))
|
||||
}
|
||||
|
||||
// VPN commands
|
||||
#[tauri::command]
|
||||
async fn import_vpn_config(
|
||||
content: String,
|
||||
filename: String,
|
||||
name: Option<String>,
|
||||
) -> Result<vpn::VpnImportResult, String> {
|
||||
let storage = vpn::VPN_STORAGE
|
||||
.lock()
|
||||
.map_err(|e| format!("Failed to lock VPN storage: {e}"))?;
|
||||
|
||||
match storage.import_config(&content, &filename, name.clone()) {
|
||||
Ok(config) => Ok(vpn::VpnImportResult {
|
||||
success: true,
|
||||
vpn_id: Some(config.id),
|
||||
vpn_type: Some(config.vpn_type),
|
||||
name: config.name,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(vpn::VpnImportResult {
|
||||
success: false,
|
||||
vpn_id: None,
|
||||
vpn_type: None,
|
||||
name: name.unwrap_or_else(|| filename.clone()),
|
||||
error: Some(e.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn list_vpn_configs() -> Result<Vec<vpn::VpnConfig>, String> {
|
||||
let storage = vpn::VPN_STORAGE
|
||||
.lock()
|
||||
.map_err(|e| format!("Failed to lock VPN storage: {e}"))?;
|
||||
|
||||
storage
|
||||
.list_configs()
|
||||
.map_err(|e| format!("Failed to list VPN configs: {e}"))
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn get_vpn_config(vpn_id: String) -> Result<vpn::VpnConfig, String> {
|
||||
let storage = vpn::VPN_STORAGE
|
||||
.lock()
|
||||
.map_err(|e| format!("Failed to lock VPN storage: {e}"))?;
|
||||
|
||||
storage
|
||||
.load_config(&vpn_id)
|
||||
.map_err(|e| format!("Failed to load VPN config: {e}"))
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn delete_vpn_config(vpn_id: String) -> Result<(), String> {
|
||||
// First disconnect if connected
|
||||
{
|
||||
let mut manager = vpn::TUNNEL_MANAGER.lock().await;
|
||||
if manager.is_tunnel_active(&vpn_id) {
|
||||
if let Some(tunnel) = manager.get_tunnel_mut(&vpn_id) {
|
||||
let _ = tunnel.disconnect().await;
|
||||
}
|
||||
manager.remove_tunnel(&vpn_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Then delete from storage
|
||||
let storage = vpn::VPN_STORAGE
|
||||
.lock()
|
||||
.map_err(|e| format!("Failed to lock VPN storage: {e}"))?;
|
||||
|
||||
storage
|
||||
.delete_config(&vpn_id)
|
||||
.map_err(|e| format!("Failed to delete VPN config: {e}"))
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn connect_vpn(vpn_id: String) -> Result<(), String> {
|
||||
// Load config from storage
|
||||
let config = {
|
||||
let storage = vpn::VPN_STORAGE
|
||||
.lock()
|
||||
.map_err(|e| format!("Failed to lock VPN storage: {e}"))?;
|
||||
|
||||
storage
|
||||
.load_config(&vpn_id)
|
||||
.map_err(|e| format!("Failed to load VPN config: {e}"))?
|
||||
};
|
||||
|
||||
// Create and connect the appropriate tunnel
|
||||
let mut manager = vpn::TUNNEL_MANAGER.lock().await;
|
||||
|
||||
// Check if already connected
|
||||
if manager.is_tunnel_active(&vpn_id) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut tunnel: Box<dyn vpn::VpnTunnel> = match config.vpn_type {
|
||||
vpn::VpnType::WireGuard => {
|
||||
let wg_config = vpn::parse_wireguard_config(&config.config_data)
|
||||
.map_err(|e| format!("Invalid WireGuard config: {e}"))?;
|
||||
Box::new(vpn::WireGuardTunnel::new(vpn_id.clone(), wg_config))
|
||||
}
|
||||
vpn::VpnType::OpenVPN => {
|
||||
let ovpn_config = vpn::parse_openvpn_config(&config.config_data)
|
||||
.map_err(|e| format!("Invalid OpenVPN config: {e}"))?;
|
||||
Box::new(vpn::OpenVpnTunnel::new(vpn_id.clone(), ovpn_config))
|
||||
}
|
||||
};
|
||||
|
||||
tunnel
|
||||
.connect()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to connect VPN: {e}"))?;
|
||||
|
||||
manager.register_tunnel(vpn_id.clone(), tunnel);
|
||||
|
||||
// Update last_used timestamp
|
||||
{
|
||||
let storage = vpn::VPN_STORAGE
|
||||
.lock()
|
||||
.map_err(|e| format!("Failed to lock VPN storage: {e}"))?;
|
||||
let _ = storage.update_last_used(&vpn_id);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn disconnect_vpn(vpn_id: String) -> Result<(), String> {
|
||||
let mut manager = vpn::TUNNEL_MANAGER.lock().await;
|
||||
|
||||
if let Some(tunnel) = manager.get_tunnel_mut(&vpn_id) {
|
||||
tunnel
|
||||
.disconnect()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to disconnect VPN: {e}"))?;
|
||||
}
|
||||
|
||||
manager.remove_tunnel(&vpn_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn get_vpn_status(vpn_id: String) -> Result<vpn::VpnStatus, String> {
|
||||
let manager = vpn::TUNNEL_MANAGER.lock().await;
|
||||
|
||||
if let Some(tunnel) = manager.get_tunnel(&vpn_id) {
|
||||
Ok(tunnel.get_status())
|
||||
} else {
|
||||
// Not connected
|
||||
Ok(vpn::VpnStatus {
|
||||
connected: false,
|
||||
vpn_id,
|
||||
connected_at: None,
|
||||
bytes_sent: None,
|
||||
bytes_received: None,
|
||||
last_handshake: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn list_active_vpn_connections() -> Result<Vec<vpn::VpnStatus>, String> {
|
||||
let manager = vpn::TUNNEL_MANAGER.lock().await;
|
||||
Ok(manager.get_all_statuses())
|
||||
}
|
||||
|
||||
#[cfg_attr(mobile, tauri::mobile_entry_point)]
|
||||
pub fn run() {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
@@ -883,6 +1090,7 @@ pub fn run() {
|
||||
get_supported_browsers,
|
||||
is_browser_supported_on_platform,
|
||||
download_browser,
|
||||
cancel_download,
|
||||
delete_profile,
|
||||
check_browser_exists,
|
||||
create_browser_profile_new,
|
||||
@@ -907,6 +1115,7 @@ pub fn run() {
|
||||
decline_launch_on_login,
|
||||
get_table_sorting_settings,
|
||||
save_table_sorting_settings,
|
||||
get_system_language,
|
||||
clear_all_version_cache_and_refetch,
|
||||
is_default_browser,
|
||||
open_url_with_profile,
|
||||
@@ -931,6 +1140,10 @@ pub fn run() {
|
||||
delete_stored_proxy,
|
||||
check_proxy_validity,
|
||||
get_cached_proxy_check,
|
||||
export_proxies,
|
||||
import_proxies_json,
|
||||
parse_txt_proxies,
|
||||
import_proxies_from_parsed,
|
||||
update_camoufox_config,
|
||||
update_wayfern_config,
|
||||
get_profile_groups,
|
||||
@@ -959,6 +1172,7 @@ pub fn run() {
|
||||
read_profile_cookies,
|
||||
copy_profile_cookies,
|
||||
check_wayfern_terms_accepted,
|
||||
check_wayfern_downloaded,
|
||||
accept_wayfern_terms,
|
||||
get_commercial_trial_status,
|
||||
acknowledge_trial_expiration,
|
||||
@@ -966,7 +1180,16 @@ pub fn run() {
|
||||
start_mcp_server,
|
||||
stop_mcp_server,
|
||||
get_mcp_server_status,
|
||||
get_mcp_config
|
||||
get_mcp_config,
|
||||
// VPN commands
|
||||
import_vpn_config,
|
||||
list_vpn_configs,
|
||||
get_vpn_config,
|
||||
delete_vpn_config,
|
||||
connect_vpn,
|
||||
disconnect_vpn,
|
||||
get_vpn_status,
|
||||
list_active_vpn_connections
|
||||
])
|
||||
.run(tauri::generate_context!())
|
||||
.expect("error while running tauri application");
|
||||
@@ -987,6 +1210,18 @@ mod tests {
|
||||
}
|
||||
|
||||
fn check_unused_commands(verbose: bool) {
|
||||
// Commands that are intentionally not used in the frontend
|
||||
// but are used via MCP server or other programmatic APIs
|
||||
let mcp_only_commands = [
|
||||
"list_vpn_configs",
|
||||
"get_vpn_config",
|
||||
"delete_vpn_config",
|
||||
"connect_vpn",
|
||||
"disconnect_vpn",
|
||||
"get_vpn_status",
|
||||
"list_active_vpn_connections",
|
||||
];
|
||||
|
||||
// Extract command names from the generate_handler! macro in this file
|
||||
let lib_rs_content = fs::read_to_string("src/lib.rs").expect("Failed to read lib.rs");
|
||||
let commands = extract_tauri_commands(&lib_rs_content);
|
||||
@@ -999,6 +1234,15 @@ mod tests {
|
||||
let mut used_commands = Vec::new();
|
||||
|
||||
for command in &commands {
|
||||
// Skip commands that are intentionally MCP-only
|
||||
if mcp_only_commands.contains(&command.as_str()) {
|
||||
used_commands.push(command.clone());
|
||||
if verbose {
|
||||
println!("✅ {command} (MCP-only)");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut is_used = false;
|
||||
|
||||
for file_content in &frontend_files {
|
||||
|
||||
+533
-2
@@ -553,6 +553,132 @@ impl McpServer {
|
||||
"required": ["proxy_id"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "export_proxies".to_string(),
|
||||
description: "Export all proxy configurations".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["json", "txt"],
|
||||
"description": "Export format (json for structured data, txt for URL format)"
|
||||
}
|
||||
},
|
||||
"required": ["format"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "import_proxies".to_string(),
|
||||
description: "Import proxy configurations from JSON or TXT content".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The proxy configuration content to import"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["json", "txt"],
|
||||
"description": "Import format (json or txt)"
|
||||
},
|
||||
"name_prefix": {
|
||||
"type": "string",
|
||||
"description": "Optional prefix for imported proxy names (default: 'Imported')"
|
||||
}
|
||||
},
|
||||
"required": ["content", "format"]
|
||||
}),
|
||||
},
|
||||
// VPN management tools
|
||||
McpTool {
|
||||
name: "import_vpn".to_string(),
|
||||
description: "Import a WireGuard (.conf) or OpenVPN (.ovpn) configuration".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Raw VPN config file content"
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Original filename (.conf or .ovpn) for type detection"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Optional display name for the VPN config"
|
||||
}
|
||||
},
|
||||
"required": ["content", "filename"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "list_vpn_configs".to_string(),
|
||||
description: "List all stored VPN configurations".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "delete_vpn".to_string(),
|
||||
description: "Delete a VPN configuration".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"vpn_id": {
|
||||
"type": "string",
|
||||
"description": "The UUID of the VPN config to delete"
|
||||
}
|
||||
},
|
||||
"required": ["vpn_id"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "connect_vpn".to_string(),
|
||||
description: "Connect to a VPN configuration".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"vpn_id": {
|
||||
"type": "string",
|
||||
"description": "The UUID of the VPN config to connect"
|
||||
}
|
||||
},
|
||||
"required": ["vpn_id"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "disconnect_vpn".to_string(),
|
||||
description: "Disconnect from a VPN".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"vpn_id": {
|
||||
"type": "string",
|
||||
"description": "The UUID of the VPN to disconnect"
|
||||
}
|
||||
},
|
||||
"required": ["vpn_id"]
|
||||
}),
|
||||
},
|
||||
McpTool {
|
||||
name: "get_vpn_status".to_string(),
|
||||
description: "Get the connection status of a VPN".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"vpn_id": {
|
||||
"type": "string",
|
||||
"description": "The UUID of the VPN to check"
|
||||
}
|
||||
},
|
||||
"required": ["vpn_id"]
|
||||
}),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
@@ -641,6 +767,16 @@ impl McpServer {
|
||||
"create_proxy" => self.handle_create_proxy(&arguments).await,
|
||||
"update_proxy" => self.handle_update_proxy(&arguments).await,
|
||||
"delete_proxy" => self.handle_delete_proxy(&arguments).await,
|
||||
// Proxy import/export
|
||||
"export_proxies" => self.handle_export_proxies(&arguments).await,
|
||||
"import_proxies" => self.handle_import_proxies(&arguments).await,
|
||||
// VPN management
|
||||
"import_vpn" => self.handle_import_vpn(&arguments).await,
|
||||
"list_vpn_configs" => self.handle_list_vpn_configs().await,
|
||||
"delete_vpn" => self.handle_delete_vpn(&arguments).await,
|
||||
"connect_vpn" => self.handle_connect_vpn(&arguments).await,
|
||||
"disconnect_vpn" => self.handle_disconnect_vpn(&arguments).await,
|
||||
"get_vpn_status" => self.handle_get_vpn_status(&arguments).await,
|
||||
_ => Err(McpError {
|
||||
code: -32602,
|
||||
message: format!("Unknown tool: {tool_name}"),
|
||||
@@ -1361,6 +1497,391 @@ impl McpServer {
|
||||
}]
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_export_proxies(
|
||||
&self,
|
||||
arguments: &serde_json::Value,
|
||||
) -> Result<serde_json::Value, McpError> {
|
||||
let format = arguments
|
||||
.get("format")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| McpError {
|
||||
code: -32602,
|
||||
message: "Missing format".to_string(),
|
||||
})?;
|
||||
|
||||
let content = match format {
|
||||
"json" => PROXY_MANAGER.export_proxies_json().map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to export proxies: {e}"),
|
||||
})?,
|
||||
"txt" => PROXY_MANAGER.export_proxies_txt(),
|
||||
_ => {
|
||||
return Err(McpError {
|
||||
code: -32602,
|
||||
message: format!("Invalid format '{}', must be 'json' or 'txt'", format),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": content
|
||||
}]
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_import_proxies(
|
||||
&self,
|
||||
arguments: &serde_json::Value,
|
||||
) -> Result<serde_json::Value, McpError> {
|
||||
let content = arguments
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| McpError {
|
||||
code: -32602,
|
||||
message: "Missing content".to_string(),
|
||||
})?;
|
||||
|
||||
let format = arguments
|
||||
.get("format")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| McpError {
|
||||
code: -32602,
|
||||
message: "Missing format".to_string(),
|
||||
})?;
|
||||
|
||||
let name_prefix = arguments
|
||||
.get("name_prefix")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let inner = self.inner.lock().await;
|
||||
let app_handle = inner.app_handle.as_ref().ok_or_else(|| McpError {
|
||||
code: -32000,
|
||||
message: "MCP server not properly initialized".to_string(),
|
||||
})?;
|
||||
|
||||
let result = match format {
|
||||
"json" => PROXY_MANAGER
|
||||
.import_proxies_json(app_handle, content)
|
||||
.map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to import proxies: {e}"),
|
||||
})?,
|
||||
"txt" => {
|
||||
use crate::proxy_manager::{ProxyManager, ProxyParseResult};
|
||||
|
||||
let parse_results = ProxyManager::parse_txt_proxies(content);
|
||||
let parsed: Vec<_> = parse_results
|
||||
.into_iter()
|
||||
.filter_map(|r| {
|
||||
if let ProxyParseResult::Parsed(p) = r {
|
||||
Some(p)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if parsed.is_empty() {
|
||||
return Err(McpError {
|
||||
code: -32000,
|
||||
message: "No valid proxies found in content".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
PROXY_MANAGER
|
||||
.import_proxies_from_parsed(app_handle, parsed, name_prefix)
|
||||
.map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to import proxies: {e}"),
|
||||
})?
|
||||
}
|
||||
_ => {
|
||||
return Err(McpError {
|
||||
code: -32602,
|
||||
message: format!("Invalid format '{}', must be 'json' or 'txt'", format),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": format!(
|
||||
"Import complete: {} imported, {} skipped, {} errors",
|
||||
result.imported_count,
|
||||
result.skipped_count,
|
||||
result.errors.len()
|
||||
)
|
||||
}]
|
||||
}))
|
||||
}
|
||||
|
||||
// VPN management handlers
|
||||
async fn handle_import_vpn(
|
||||
&self,
|
||||
arguments: &serde_json::Value,
|
||||
) -> Result<serde_json::Value, McpError> {
|
||||
let content = arguments
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| McpError {
|
||||
code: -32602,
|
||||
message: "Missing content".to_string(),
|
||||
})?;
|
||||
|
||||
let filename = arguments
|
||||
.get("filename")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| McpError {
|
||||
code: -32602,
|
||||
message: "Missing filename".to_string(),
|
||||
})?;
|
||||
|
||||
let name = arguments
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let storage = crate::vpn::VPN_STORAGE.lock().map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to lock VPN storage: {e}"),
|
||||
})?;
|
||||
|
||||
let config = storage
|
||||
.import_config(content, filename, name)
|
||||
.map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to import VPN config: {e}"),
|
||||
})?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": format!(
|
||||
"VPN '{}' ({}) imported successfully with ID: {}",
|
||||
config.name,
|
||||
config.vpn_type,
|
||||
config.id
|
||||
)
|
||||
}]
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_list_vpn_configs(&self) -> Result<serde_json::Value, McpError> {
|
||||
let storage = crate::vpn::VPN_STORAGE.lock().map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to lock VPN storage: {e}"),
|
||||
})?;
|
||||
|
||||
let configs = storage.list_configs().map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to list VPN configs: {e}"),
|
||||
})?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": serde_json::to_string_pretty(&configs).unwrap_or_default()
|
||||
}]
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_delete_vpn(
|
||||
&self,
|
||||
arguments: &serde_json::Value,
|
||||
) -> Result<serde_json::Value, McpError> {
|
||||
let vpn_id = arguments
|
||||
.get("vpn_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| McpError {
|
||||
code: -32602,
|
||||
message: "Missing vpn_id".to_string(),
|
||||
})?;
|
||||
|
||||
// First disconnect if connected
|
||||
{
|
||||
let mut manager = crate::vpn::TUNNEL_MANAGER.lock().await;
|
||||
if manager.is_tunnel_active(vpn_id) {
|
||||
if let Some(tunnel) = manager.get_tunnel_mut(vpn_id) {
|
||||
let _ = tunnel.disconnect().await;
|
||||
}
|
||||
manager.remove_tunnel(vpn_id);
|
||||
}
|
||||
}
|
||||
|
||||
let storage = crate::vpn::VPN_STORAGE.lock().map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to lock VPN storage: {e}"),
|
||||
})?;
|
||||
|
||||
storage.delete_config(vpn_id).map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to delete VPN config: {e}"),
|
||||
})?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": format!("VPN '{}' deleted successfully", vpn_id)
|
||||
}]
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_connect_vpn(
|
||||
&self,
|
||||
arguments: &serde_json::Value,
|
||||
) -> Result<serde_json::Value, McpError> {
|
||||
let vpn_id = arguments
|
||||
.get("vpn_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| McpError {
|
||||
code: -32602,
|
||||
message: "Missing vpn_id".to_string(),
|
||||
})?;
|
||||
|
||||
// Load config from storage
|
||||
let config = {
|
||||
let storage = crate::vpn::VPN_STORAGE.lock().map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to lock VPN storage: {e}"),
|
||||
})?;
|
||||
|
||||
storage.load_config(vpn_id).map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to load VPN config: {e}"),
|
||||
})?
|
||||
};
|
||||
|
||||
let mut manager = crate::vpn::TUNNEL_MANAGER.lock().await;
|
||||
|
||||
// Check if already connected
|
||||
if manager.is_tunnel_active(vpn_id) {
|
||||
return Ok(serde_json::json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": format!("VPN '{}' is already connected", config.name)
|
||||
}]
|
||||
}));
|
||||
}
|
||||
|
||||
let mut tunnel: Box<dyn crate::vpn::VpnTunnel> = match config.vpn_type {
|
||||
crate::vpn::VpnType::WireGuard => {
|
||||
let wg_config =
|
||||
crate::vpn::parse_wireguard_config(&config.config_data).map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Invalid WireGuard config: {e}"),
|
||||
})?;
|
||||
Box::new(crate::vpn::WireGuardTunnel::new(
|
||||
vpn_id.to_string(),
|
||||
wg_config,
|
||||
))
|
||||
}
|
||||
crate::vpn::VpnType::OpenVPN => {
|
||||
let ovpn_config =
|
||||
crate::vpn::parse_openvpn_config(&config.config_data).map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Invalid OpenVPN config: {e}"),
|
||||
})?;
|
||||
Box::new(crate::vpn::OpenVpnTunnel::new(
|
||||
vpn_id.to_string(),
|
||||
ovpn_config,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
tunnel.connect().await.map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to connect VPN: {e}"),
|
||||
})?;
|
||||
|
||||
manager.register_tunnel(vpn_id.to_string(), tunnel);
|
||||
|
||||
// Update last_used timestamp
|
||||
{
|
||||
let storage = crate::vpn::VPN_STORAGE.lock().map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to lock VPN storage: {e}"),
|
||||
})?;
|
||||
let _ = storage.update_last_used(vpn_id);
|
||||
}
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": format!("VPN '{}' connected successfully", config.name)
|
||||
}]
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_disconnect_vpn(
|
||||
&self,
|
||||
arguments: &serde_json::Value,
|
||||
) -> Result<serde_json::Value, McpError> {
|
||||
let vpn_id = arguments
|
||||
.get("vpn_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| McpError {
|
||||
code: -32602,
|
||||
message: "Missing vpn_id".to_string(),
|
||||
})?;
|
||||
|
||||
let mut manager = crate::vpn::TUNNEL_MANAGER.lock().await;
|
||||
|
||||
if let Some(tunnel) = manager.get_tunnel_mut(vpn_id) {
|
||||
tunnel.disconnect().await.map_err(|e| McpError {
|
||||
code: -32000,
|
||||
message: format!("Failed to disconnect VPN: {e}"),
|
||||
})?;
|
||||
}
|
||||
|
||||
manager.remove_tunnel(vpn_id);
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": format!("VPN '{}' disconnected successfully", vpn_id)
|
||||
}]
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_get_vpn_status(
|
||||
&self,
|
||||
arguments: &serde_json::Value,
|
||||
) -> Result<serde_json::Value, McpError> {
|
||||
let vpn_id = arguments
|
||||
.get("vpn_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| McpError {
|
||||
code: -32602,
|
||||
message: "Missing vpn_id".to_string(),
|
||||
})?;
|
||||
|
||||
let manager = crate::vpn::TUNNEL_MANAGER.lock().await;
|
||||
|
||||
let status = if let Some(tunnel) = manager.get_tunnel(vpn_id) {
|
||||
tunnel.get_status()
|
||||
} else {
|
||||
crate::vpn::VpnStatus {
|
||||
connected: false,
|
||||
vpn_id: vpn_id.to_string(),
|
||||
connected_at: None,
|
||||
bytes_sent: None,
|
||||
bytes_received: None,
|
||||
last_handshake: None,
|
||||
}
|
||||
};
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": serde_json::to_string_pretty(&status).unwrap_or_default()
|
||||
}]
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
@@ -1376,8 +1897,8 @@ mod tests {
|
||||
let server = McpServer::new();
|
||||
let tools = server.get_tools();
|
||||
|
||||
// Should have all 16 tools
|
||||
assert!(tools.len() >= 16);
|
||||
// Should have at least 24 tools (18 + 6 VPN tools)
|
||||
assert!(tools.len() >= 24);
|
||||
|
||||
// Check tool names
|
||||
let tool_names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
|
||||
@@ -1400,6 +1921,16 @@ mod tests {
|
||||
assert!(tool_names.contains(&"create_proxy"));
|
||||
assert!(tool_names.contains(&"update_proxy"));
|
||||
assert!(tool_names.contains(&"delete_proxy"));
|
||||
// Proxy import/export tools
|
||||
assert!(tool_names.contains(&"export_proxies"));
|
||||
assert!(tool_names.contains(&"import_proxies"));
|
||||
// VPN tools
|
||||
assert!(tool_names.contains(&"import_vpn"));
|
||||
assert!(tool_names.contains(&"list_vpn_configs"));
|
||||
assert!(tool_names.contains(&"delete_vpn"));
|
||||
assert!(tool_names.contains(&"connect_vpn"));
|
||||
assert!(tool_names.contains(&"disconnect_vpn"));
|
||||
assert!(tool_names.contains(&"get_vpn_status"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use chrono::Utc;
|
||||
use directories::BaseDirs;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
@@ -12,6 +13,60 @@ use crate::browser::ProxySettings;
|
||||
use crate::events;
|
||||
use crate::ip_utils;
|
||||
|
||||
// Export data format for JSON export
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProxyExportData {
|
||||
pub version: String,
|
||||
pub proxies: Vec<ExportedProxy>,
|
||||
pub exported_at: String,
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExportedProxy {
|
||||
pub name: String,
|
||||
#[serde(rename = "type")]
|
||||
pub proxy_type: String,
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub username: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub password: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProxyImportResult {
|
||||
pub imported_count: usize,
|
||||
pub skipped_count: usize,
|
||||
pub errors: Vec<String>,
|
||||
pub proxies: Vec<StoredProxy>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ParsedProxyLine {
|
||||
pub proxy_type: String,
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub username: Option<String>,
|
||||
pub password: Option<String>,
|
||||
pub original_line: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "status")]
|
||||
pub enum ProxyParseResult {
|
||||
#[serde(rename = "parsed")]
|
||||
Parsed(ParsedProxyLine),
|
||||
#[serde(rename = "ambiguous")]
|
||||
Ambiguous {
|
||||
line: String,
|
||||
possible_formats: Vec<String>,
|
||||
},
|
||||
#[serde(rename = "invalid")]
|
||||
Invalid { line: String, reason: String },
|
||||
}
|
||||
|
||||
// Store active proxy information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProxyInfo {
|
||||
@@ -541,6 +596,331 @@ impl ProxyManager {
|
||||
self.load_proxy_check_cache(proxy_id)
|
||||
}
|
||||
|
||||
// Export all proxies as JSON
|
||||
pub fn export_proxies_json(&self) -> Result<String, String> {
|
||||
let stored_proxies = self.stored_proxies.lock().unwrap();
|
||||
let proxies: Vec<ExportedProxy> = stored_proxies
|
||||
.values()
|
||||
.map(|p| ExportedProxy {
|
||||
name: p.name.clone(),
|
||||
proxy_type: p.proxy_settings.proxy_type.clone(),
|
||||
host: p.proxy_settings.host.clone(),
|
||||
port: p.proxy_settings.port,
|
||||
username: p.proxy_settings.username.clone(),
|
||||
password: p.proxy_settings.password.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let export_data = ProxyExportData {
|
||||
version: "1.0".to_string(),
|
||||
proxies,
|
||||
exported_at: Utc::now().to_rfc3339(),
|
||||
source: "DonutBrowser".to_string(),
|
||||
};
|
||||
|
||||
serde_json::to_string_pretty(&export_data).map_err(|e| format!("Failed to serialize: {e}"))
|
||||
}
|
||||
|
||||
// Export all proxies as TXT (one per line: protocol://user:pass@host:port)
|
||||
pub fn export_proxies_txt(&self) -> String {
|
||||
let stored_proxies = self.stored_proxies.lock().unwrap();
|
||||
stored_proxies
|
||||
.values()
|
||||
.map(|p| Self::build_proxy_url(&p.proxy_settings))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
// Parse TXT content with auto-detection of formats
|
||||
pub fn parse_txt_proxies(content: &str) -> Vec<ProxyParseResult> {
|
||||
content
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty() && !line.trim().starts_with('#'))
|
||||
.map(|line| Self::parse_single_proxy_line(line.trim()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Parse a single proxy line with format auto-detection
|
||||
fn parse_single_proxy_line(line: &str) -> ProxyParseResult {
|
||||
// Format 1: protocol://username:password@host:port (full URL)
|
||||
if let Some(result) = Self::try_parse_url_format(line) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Try colon-separated formats
|
||||
let parts: Vec<&str> = line.split(':').collect();
|
||||
|
||||
match parts.len() {
|
||||
// host:port (no auth)
|
||||
2 => {
|
||||
if let Ok(port) = parts[1].parse::<u16>() {
|
||||
return ProxyParseResult::Parsed(ParsedProxyLine {
|
||||
proxy_type: "http".to_string(),
|
||||
host: parts[0].to_string(),
|
||||
port,
|
||||
username: None,
|
||||
password: None,
|
||||
original_line: line.to_string(),
|
||||
});
|
||||
}
|
||||
ProxyParseResult::Invalid {
|
||||
line: line.to_string(),
|
||||
reason: "Invalid port number".to_string(),
|
||||
}
|
||||
}
|
||||
// Could be: host:port:user or user:pass@host (with @ in the middle)
|
||||
3 => {
|
||||
// Try username:password@host:port first
|
||||
if let Some(result) = Self::try_parse_user_pass_at_host_port(line) {
|
||||
return result;
|
||||
}
|
||||
ProxyParseResult::Invalid {
|
||||
line: line.to_string(),
|
||||
reason: "Could not determine format with 3 parts".to_string(),
|
||||
}
|
||||
}
|
||||
// 4 parts: could be host:port:user:pass OR user:pass:host:port
|
||||
4 => {
|
||||
// Try to detect which format
|
||||
let port_at_1 = parts[1].parse::<u16>().is_ok();
|
||||
let port_at_3 = parts[3].parse::<u16>().is_ok();
|
||||
|
||||
match (port_at_1, port_at_3) {
|
||||
// host:port:user:pass
|
||||
(true, false) => {
|
||||
let port = parts[1].parse::<u16>().unwrap();
|
||||
ProxyParseResult::Parsed(ParsedProxyLine {
|
||||
proxy_type: "http".to_string(),
|
||||
host: parts[0].to_string(),
|
||||
port,
|
||||
username: Some(parts[2].to_string()),
|
||||
password: Some(parts[3].to_string()),
|
||||
original_line: line.to_string(),
|
||||
})
|
||||
}
|
||||
// user:pass:host:port
|
||||
(false, true) => {
|
||||
let port = parts[3].parse::<u16>().unwrap();
|
||||
ProxyParseResult::Parsed(ParsedProxyLine {
|
||||
proxy_type: "http".to_string(),
|
||||
host: parts[2].to_string(),
|
||||
port,
|
||||
username: Some(parts[0].to_string()),
|
||||
password: Some(parts[1].to_string()),
|
||||
original_line: line.to_string(),
|
||||
})
|
||||
}
|
||||
// Both could be ports - ambiguous
|
||||
(true, true) => ProxyParseResult::Ambiguous {
|
||||
line: line.to_string(),
|
||||
possible_formats: vec![
|
||||
"host:port:username:password".to_string(),
|
||||
"username:password:host:port".to_string(),
|
||||
],
|
||||
},
|
||||
// Neither is a valid port
|
||||
(false, false) => ProxyParseResult::Invalid {
|
||||
line: line.to_string(),
|
||||
reason: "No valid port number found".to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
_ => ProxyParseResult::Invalid {
|
||||
line: line.to_string(),
|
||||
reason: format!("Unexpected format with {} parts", parts.len()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Try to parse URL format: protocol://username:password@host:port
|
||||
fn try_parse_url_format(line: &str) -> Option<ProxyParseResult> {
|
||||
// Check for protocol prefix using strip_prefix
|
||||
let (protocol, rest) = if let Some(rest) = line.strip_prefix("http://") {
|
||||
("http", rest)
|
||||
} else if let Some(rest) = line.strip_prefix("https://") {
|
||||
("https", rest)
|
||||
} else if let Some(rest) = line.strip_prefix("socks4://") {
|
||||
("socks4", rest)
|
||||
} else if let Some(rest) = line.strip_prefix("socks5://") {
|
||||
("socks5", rest)
|
||||
} else if let Some(rest) = line.strip_prefix("socks://") {
|
||||
("socks5", rest) // Default socks to socks5
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
// Check if there's auth (contains @)
|
||||
if let Some(at_pos) = rest.rfind('@') {
|
||||
let auth = &rest[..at_pos];
|
||||
let host_port = &rest[at_pos + 1..];
|
||||
|
||||
// Parse auth (user:pass)
|
||||
let (username, password) = if let Some(colon_pos) = auth.find(':') {
|
||||
let user = urlencoding::decode(&auth[..colon_pos]).unwrap_or_default();
|
||||
let pass = urlencoding::decode(&auth[colon_pos + 1..]).unwrap_or_default();
|
||||
(Some(user.to_string()), Some(pass.to_string()))
|
||||
} else {
|
||||
(
|
||||
Some(urlencoding::decode(auth).unwrap_or_default().to_string()),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
// Parse host:port
|
||||
if let Some(colon_pos) = host_port.rfind(':') {
|
||||
let host = &host_port[..colon_pos];
|
||||
if let Ok(port) = host_port[colon_pos + 1..].parse::<u16>() {
|
||||
return Some(ProxyParseResult::Parsed(ParsedProxyLine {
|
||||
proxy_type: protocol.to_string(),
|
||||
host: host.to_string(),
|
||||
port,
|
||||
username,
|
||||
password,
|
||||
original_line: line.to_string(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No auth, just host:port
|
||||
if let Some(colon_pos) = rest.rfind(':') {
|
||||
let host = &rest[..colon_pos];
|
||||
if let Ok(port) = rest[colon_pos + 1..].parse::<u16>() {
|
||||
return Some(ProxyParseResult::Parsed(ParsedProxyLine {
|
||||
proxy_type: protocol.to_string(),
|
||||
host: host.to_string(),
|
||||
port,
|
||||
username: None,
|
||||
password: None,
|
||||
original_line: line.to_string(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(ProxyParseResult::Invalid {
|
||||
line: line.to_string(),
|
||||
reason: "Invalid URL format".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
// Try to parse: username:password@host:port format (no protocol)
|
||||
fn try_parse_user_pass_at_host_port(line: &str) -> Option<ProxyParseResult> {
|
||||
if let Some(at_pos) = line.rfind('@') {
|
||||
let auth = &line[..at_pos];
|
||||
let host_port = &line[at_pos + 1..];
|
||||
|
||||
// Parse auth
|
||||
let (username, password) = if let Some(colon_pos) = auth.find(':') {
|
||||
(
|
||||
Some(auth[..colon_pos].to_string()),
|
||||
Some(auth[colon_pos + 1..].to_string()),
|
||||
)
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
// Parse host:port
|
||||
if let Some(colon_pos) = host_port.rfind(':') {
|
||||
let host = &host_port[..colon_pos];
|
||||
if let Ok(port) = host_port[colon_pos + 1..].parse::<u16>() {
|
||||
return Some(ProxyParseResult::Parsed(ParsedProxyLine {
|
||||
proxy_type: "http".to_string(),
|
||||
host: host.to_string(),
|
||||
port,
|
||||
username,
|
||||
password,
|
||||
original_line: line.to_string(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// Import proxies from JSON content
|
||||
pub fn import_proxies_json(
|
||||
&self,
|
||||
app_handle: &tauri::AppHandle,
|
||||
content: &str,
|
||||
) -> Result<ProxyImportResult, String> {
|
||||
let export_data: ProxyExportData =
|
||||
serde_json::from_str(content).map_err(|e| format!("Invalid JSON format: {e}"))?;
|
||||
|
||||
let mut imported = Vec::new();
|
||||
let mut skipped = 0;
|
||||
let mut errors = Vec::new();
|
||||
|
||||
for exported in export_data.proxies {
|
||||
let proxy_settings = ProxySettings {
|
||||
proxy_type: exported.proxy_type,
|
||||
host: exported.host,
|
||||
port: exported.port,
|
||||
username: exported.username,
|
||||
password: exported.password,
|
||||
};
|
||||
|
||||
match self.create_stored_proxy(app_handle, exported.name.clone(), proxy_settings) {
|
||||
Ok(proxy) => imported.push(proxy),
|
||||
Err(e) => {
|
||||
if e.contains("already exists") {
|
||||
skipped += 1;
|
||||
} else {
|
||||
errors.push(format!("Failed to import '{}': {}", exported.name, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ProxyImportResult {
|
||||
imported_count: imported.len(),
|
||||
skipped_count: skipped,
|
||||
errors,
|
||||
proxies: imported,
|
||||
})
|
||||
}
|
||||
|
||||
// Import proxies from already parsed proxy lines
|
||||
pub fn import_proxies_from_parsed(
|
||||
&self,
|
||||
app_handle: &tauri::AppHandle,
|
||||
parsed_proxies: Vec<ParsedProxyLine>,
|
||||
name_prefix: Option<String>,
|
||||
) -> Result<ProxyImportResult, String> {
|
||||
let mut imported = Vec::new();
|
||||
let mut skipped = 0;
|
||||
let mut errors = Vec::new();
|
||||
let prefix = name_prefix.unwrap_or_else(|| "Imported".to_string());
|
||||
|
||||
for (i, parsed) in parsed_proxies.into_iter().enumerate() {
|
||||
let proxy_name = format!("{} Proxy {}", prefix, i + 1);
|
||||
let proxy_settings = ProxySettings {
|
||||
proxy_type: parsed.proxy_type,
|
||||
host: parsed.host,
|
||||
port: parsed.port,
|
||||
username: parsed.username,
|
||||
password: parsed.password,
|
||||
};
|
||||
|
||||
match self.create_stored_proxy(app_handle, proxy_name.clone(), proxy_settings) {
|
||||
Ok(proxy) => imported.push(proxy),
|
||||
Err(e) => {
|
||||
if e.contains("already exists") {
|
||||
skipped += 1;
|
||||
} else {
|
||||
errors.push(format!("Failed to import '{}': {}", proxy_name, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ProxyImportResult {
|
||||
imported_count: imported.len(),
|
||||
skipped_count: skipped,
|
||||
errors,
|
||||
proxies: imported,
|
||||
})
|
||||
}
|
||||
|
||||
// Start a proxy for given proxy settings and associate it with a browser process ID
|
||||
// If proxy_settings is None, starts a direct proxy for traffic monitoring
|
||||
pub async fn start_proxy(
|
||||
|
||||
@@ -52,6 +52,8 @@ pub struct AppSettings {
|
||||
pub mcp_token: Option<String>, // Displayed token for user to copy (not persisted, loaded from encrypted file)
|
||||
#[serde(default)]
|
||||
pub launch_on_login_declined: bool, // User permanently declined the launch-on-login prompt
|
||||
#[serde(default)]
|
||||
pub language: Option<String>, // ISO 639-1: "en", "es", "pt", "fr", "zh", "ja", "ru", or None for system default
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
|
||||
@@ -84,6 +86,7 @@ impl Default for AppSettings {
|
||||
mcp_port: None,
|
||||
mcp_token: None,
|
||||
launch_on_login_declined: false,
|
||||
language: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -809,6 +812,17 @@ pub async fn save_app_settings(
|
||||
let mut persist_settings = settings.clone();
|
||||
persist_settings.api_token = None;
|
||||
persist_settings.mcp_token = None;
|
||||
|
||||
log::info!(
|
||||
"[settings] Saving settings: theme={}, custom_theme_keys={}",
|
||||
persist_settings.theme,
|
||||
persist_settings
|
||||
.custom_theme
|
||||
.as_ref()
|
||||
.map(|t| t.len())
|
||||
.unwrap_or(0)
|
||||
);
|
||||
|
||||
manager
|
||||
.save_settings(&persist_settings)
|
||||
.map_err(|e| format!("Failed to save settings: {e}"))?;
|
||||
@@ -899,6 +913,20 @@ pub async fn save_sync_settings(
|
||||
})
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub fn get_system_language() -> String {
|
||||
sys_locale::get_locale()
|
||||
.map(|locale| {
|
||||
// Extract just the language code (e.g., "en" from "en-US")
|
||||
locale
|
||||
.split(['-', '_'])
|
||||
.next()
|
||||
.unwrap_or("en")
|
||||
.to_lowercase()
|
||||
})
|
||||
.unwrap_or_else(|| "en".to_string())
|
||||
}
|
||||
|
||||
// Global singleton instance
|
||||
lazy_static::lazy_static! {
|
||||
static ref SETTINGS_MANAGER: SettingsManager = SettingsManager::new();
|
||||
@@ -985,6 +1013,7 @@ mod tests {
|
||||
mcp_port: None,
|
||||
mcp_token: None,
|
||||
launch_on_login_declined: false,
|
||||
language: None,
|
||||
};
|
||||
|
||||
// Save settings
|
||||
|
||||
@@ -0,0 +1,489 @@
|
||||
//! VPN configuration types and parsing.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
|
||||
/// VPN-related errors
|
||||
#[derive(Error, Debug)]
|
||||
pub enum VpnError {
|
||||
#[error("Unknown VPN config format")]
|
||||
UnknownFormat,
|
||||
#[error("Invalid WireGuard config: {0}")]
|
||||
InvalidWireGuard(String),
|
||||
#[error("Invalid OpenVPN config: {0}")]
|
||||
InvalidOpenVpn(String),
|
||||
#[error("Storage error: {0}")]
|
||||
Storage(String),
|
||||
#[error("Connection error: {0}")]
|
||||
Connection(String),
|
||||
#[error("Encryption error: {0}")]
|
||||
Encryption(String),
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("VPN not found: {0}")]
|
||||
NotFound(String),
|
||||
#[error("Tunnel error: {0}")]
|
||||
Tunnel(String),
|
||||
}
|
||||
|
||||
/// The type of VPN configuration
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum VpnType {
|
||||
WireGuard,
|
||||
OpenVPN,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for VpnType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
VpnType::WireGuard => write!(f, "WireGuard"),
|
||||
VpnType::OpenVPN => write!(f, "OpenVPN"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A stored VPN configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VpnConfig {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub vpn_type: VpnType,
|
||||
pub config_data: String, // Raw config content (encrypted at rest)
|
||||
pub created_at: i64,
|
||||
pub last_used: Option<i64>,
|
||||
}
|
||||
|
||||
/// Parsed WireGuard configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WireGuardConfig {
|
||||
pub private_key: String,
|
||||
pub address: String,
|
||||
pub dns: Option<String>,
|
||||
pub mtu: Option<u16>,
|
||||
pub peer_public_key: String,
|
||||
pub peer_endpoint: String,
|
||||
pub allowed_ips: Vec<String>,
|
||||
pub persistent_keepalive: Option<u16>,
|
||||
pub preshared_key: Option<String>,
|
||||
}
|
||||
|
||||
/// Parsed OpenVPN configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenVpnConfig {
|
||||
pub raw_config: String,
|
||||
pub remote_host: String,
|
||||
pub remote_port: u16,
|
||||
pub protocol: String, // "udp" or "tcp"
|
||||
pub dev_type: String, // "tun" or "tap"
|
||||
pub has_inline_ca: bool,
|
||||
pub has_inline_cert: bool,
|
||||
pub has_inline_key: bool,
|
||||
}
|
||||
|
||||
/// Result of importing a VPN configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VpnImportResult {
|
||||
pub success: bool,
|
||||
pub vpn_id: Option<String>,
|
||||
pub vpn_type: Option<VpnType>,
|
||||
pub name: String,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// VPN connection status
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VpnStatus {
|
||||
pub connected: bool,
|
||||
pub vpn_id: String,
|
||||
pub connected_at: Option<i64>,
|
||||
pub bytes_sent: Option<u64>,
|
||||
pub bytes_received: Option<u64>,
|
||||
pub last_handshake: Option<i64>,
|
||||
}
|
||||
|
||||
/// Detect the VPN type from file content and filename
|
||||
pub fn detect_vpn_type(content: &str, filename: &str) -> Result<VpnType, VpnError> {
|
||||
let filename_lower = filename.to_lowercase();
|
||||
|
||||
// Check file extension first
|
||||
if filename_lower.ends_with(".conf") {
|
||||
// .conf could be WireGuard - check content
|
||||
if content.contains("[Interface]") && content.contains("[Peer]") {
|
||||
return Ok(VpnType::WireGuard);
|
||||
}
|
||||
}
|
||||
|
||||
if filename_lower.ends_with(".ovpn") {
|
||||
return Ok(VpnType::OpenVPN);
|
||||
}
|
||||
|
||||
// Check content patterns
|
||||
if content.contains("[Interface]") && content.contains("PrivateKey") && content.contains("[Peer]")
|
||||
{
|
||||
return Ok(VpnType::WireGuard);
|
||||
}
|
||||
|
||||
if content.contains("remote ") && (content.contains("client") || content.contains("dev tun")) {
|
||||
return Ok(VpnType::OpenVPN);
|
||||
}
|
||||
|
||||
Err(VpnError::UnknownFormat)
|
||||
}
|
||||
|
||||
/// Parse a WireGuard configuration file
|
||||
pub fn parse_wireguard_config(content: &str) -> Result<WireGuardConfig, VpnError> {
|
||||
let mut interface: HashMap<String, String> = HashMap::new();
|
||||
let mut peer: HashMap<String, String> = HashMap::new();
|
||||
let mut current_section: Option<&str> = None;
|
||||
|
||||
for line in content.lines() {
|
||||
let line = line.trim();
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for section headers
|
||||
if line == "[Interface]" {
|
||||
current_section = Some("interface");
|
||||
continue;
|
||||
}
|
||||
if line == "[Peer]" {
|
||||
current_section = Some("peer");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse key-value pairs
|
||||
if let Some((key, value)) = line.split_once('=') {
|
||||
let key = key.trim().to_string();
|
||||
let value = value.trim().to_string();
|
||||
|
||||
match current_section {
|
||||
Some("interface") => {
|
||||
interface.insert(key, value);
|
||||
}
|
||||
Some("peer") => {
|
||||
peer.insert(key, value);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
let private_key = interface
|
||||
.get("PrivateKey")
|
||||
.ok_or_else(|| VpnError::InvalidWireGuard("Missing PrivateKey in [Interface]".to_string()))?
|
||||
.clone();
|
||||
|
||||
let address = interface
|
||||
.get("Address")
|
||||
.ok_or_else(|| VpnError::InvalidWireGuard("Missing Address in [Interface]".to_string()))?
|
||||
.clone();
|
||||
|
||||
let peer_public_key = peer
|
||||
.get("PublicKey")
|
||||
.ok_or_else(|| VpnError::InvalidWireGuard("Missing PublicKey in [Peer]".to_string()))?
|
||||
.clone();
|
||||
|
||||
let peer_endpoint = peer
|
||||
.get("Endpoint")
|
||||
.ok_or_else(|| VpnError::InvalidWireGuard("Missing Endpoint in [Peer]".to_string()))?
|
||||
.clone();
|
||||
|
||||
let allowed_ips = peer
|
||||
.get("AllowedIPs")
|
||||
.map(|s| s.split(',').map(|ip| ip.trim().to_string()).collect())
|
||||
.unwrap_or_else(|| vec!["0.0.0.0/0".to_string()]);
|
||||
|
||||
let persistent_keepalive = peer.get("PersistentKeepalive").and_then(|s| s.parse().ok());
|
||||
|
||||
let dns = interface.get("DNS").cloned();
|
||||
let mtu = interface.get("MTU").and_then(|s| s.parse().ok());
|
||||
let preshared_key = peer.get("PresharedKey").cloned();
|
||||
|
||||
Ok(WireGuardConfig {
|
||||
private_key,
|
||||
address,
|
||||
dns,
|
||||
mtu,
|
||||
peer_public_key,
|
||||
peer_endpoint,
|
||||
allowed_ips,
|
||||
persistent_keepalive,
|
||||
preshared_key,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse an OpenVPN configuration file
|
||||
pub fn parse_openvpn_config(content: &str) -> Result<OpenVpnConfig, VpnError> {
|
||||
let mut remote_host = String::new();
|
||||
let mut remote_port: u16 = 1194; // Default OpenVPN port
|
||||
let mut protocol = "udp".to_string();
|
||||
let mut dev_type = "tun".to_string();
|
||||
|
||||
let has_inline_ca = content.contains("<ca>") && content.contains("</ca>");
|
||||
let has_inline_cert = content.contains("<cert>") && content.contains("</cert>");
|
||||
let has_inline_key = content.contains("<key>") && content.contains("</key>");
|
||||
|
||||
for line in content.lines() {
|
||||
let line = line.trim();
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line.is_empty() || line.starts_with('#') || line.starts_with(';') {
|
||||
continue;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = line.split_whitespace().collect();
|
||||
if parts.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match parts[0] {
|
||||
"remote" => {
|
||||
if parts.len() >= 2 {
|
||||
remote_host = parts[1].to_string();
|
||||
}
|
||||
if parts.len() >= 3 {
|
||||
if let Ok(port) = parts[2].parse() {
|
||||
remote_port = port;
|
||||
}
|
||||
}
|
||||
if parts.len() >= 4 {
|
||||
protocol = parts[3].to_string();
|
||||
}
|
||||
}
|
||||
"proto" => {
|
||||
if parts.len() >= 2 {
|
||||
protocol = parts[1].to_string();
|
||||
}
|
||||
}
|
||||
"port" => {
|
||||
if parts.len() >= 2 {
|
||||
if let Ok(port) = parts[1].parse() {
|
||||
remote_port = port;
|
||||
}
|
||||
}
|
||||
}
|
||||
"dev" => {
|
||||
if parts.len() >= 2 {
|
||||
dev_type = parts[1].to_string();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if remote_host.is_empty() {
|
||||
return Err(VpnError::InvalidOpenVpn(
|
||||
"Missing 'remote' directive".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(OpenVpnConfig {
|
||||
raw_config: content.to_string(),
|
||||
remote_host,
|
||||
remote_port,
|
||||
protocol,
|
||||
dev_type,
|
||||
has_inline_ca,
|
||||
has_inline_cert,
|
||||
has_inline_key,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_detect_wireguard_by_extension() {
|
||||
let content = "[Interface]\nPrivateKey = test\n[Peer]\nPublicKey = test";
|
||||
assert_eq!(
|
||||
detect_vpn_type(content, "test.conf").unwrap(),
|
||||
VpnType::WireGuard
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openvpn_by_extension() {
|
||||
let content = "client\nremote vpn.example.com 1194";
|
||||
assert_eq!(
|
||||
detect_vpn_type(content, "test.ovpn").unwrap(),
|
||||
VpnType::OpenVPN
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_wireguard_by_content() {
|
||||
let content = "[Interface]\nPrivateKey = testkey123\nAddress = 10.0.0.2/24\n\n[Peer]\nPublicKey = peerkey456\nEndpoint = vpn.example.com:51820";
|
||||
assert_eq!(
|
||||
detect_vpn_type(content, "config").unwrap(),
|
||||
VpnType::WireGuard
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_openvpn_by_content() {
|
||||
let content = "client\ndev tun\nproto udp\nremote vpn.example.com 1194";
|
||||
assert_eq!(
|
||||
detect_vpn_type(content, "config").unwrap(),
|
||||
VpnType::OpenVPN
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_unknown_format() {
|
||||
let content = "random text that is not a vpn config";
|
||||
assert!(detect_vpn_type(content, "random.txt").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_wireguard_config() {
|
||||
let content = r#"
|
||||
[Interface]
|
||||
PrivateKey = WGTestPrivateKey123456789012345678901234567890
|
||||
Address = 10.0.0.2/24
|
||||
DNS = 1.1.1.1
|
||||
MTU = 1420
|
||||
|
||||
[Peer]
|
||||
PublicKey = WGTestPublicKey1234567890123456789012345678901
|
||||
Endpoint = vpn.example.com:51820
|
||||
AllowedIPs = 0.0.0.0/0, ::/0
|
||||
PersistentKeepalive = 25
|
||||
"#;
|
||||
|
||||
let config = parse_wireguard_config(content).unwrap();
|
||||
assert_eq!(
|
||||
config.private_key,
|
||||
"WGTestPrivateKey123456789012345678901234567890"
|
||||
);
|
||||
assert_eq!(config.address, "10.0.0.2/24");
|
||||
assert_eq!(config.dns, Some("1.1.1.1".to_string()));
|
||||
assert_eq!(config.mtu, Some(1420));
|
||||
assert_eq!(
|
||||
config.peer_public_key,
|
||||
"WGTestPublicKey1234567890123456789012345678901"
|
||||
);
|
||||
assert_eq!(config.peer_endpoint, "vpn.example.com:51820");
|
||||
assert_eq!(config.allowed_ips, vec!["0.0.0.0/0", "::/0"]);
|
||||
assert_eq!(config.persistent_keepalive, Some(25));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_wireguard_config_minimal() {
|
||||
let content = r#"
|
||||
[Interface]
|
||||
PrivateKey = minimalkey
|
||||
Address = 10.0.0.2/32
|
||||
|
||||
[Peer]
|
||||
PublicKey = peerpubkey
|
||||
Endpoint = 1.2.3.4:51820
|
||||
"#;
|
||||
|
||||
let config = parse_wireguard_config(content).unwrap();
|
||||
assert_eq!(config.private_key, "minimalkey");
|
||||
assert_eq!(config.address, "10.0.0.2/32");
|
||||
assert!(config.dns.is_none());
|
||||
assert!(config.mtu.is_none());
|
||||
assert_eq!(config.peer_public_key, "peerpubkey");
|
||||
assert_eq!(config.peer_endpoint, "1.2.3.4:51820");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_wireguard_missing_private_key() {
|
||||
let content = r#"
|
||||
[Interface]
|
||||
Address = 10.0.0.2/24
|
||||
|
||||
[Peer]
|
||||
PublicKey = key
|
||||
Endpoint = 1.2.3.4:51820
|
||||
"#;
|
||||
|
||||
let result = parse_wireguard_config(content);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("PrivateKey"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_openvpn_config() {
|
||||
let content = r#"
|
||||
client
|
||||
dev tun
|
||||
proto udp
|
||||
remote vpn.example.com 1194
|
||||
resolv-retry infinite
|
||||
nobind
|
||||
persist-key
|
||||
persist-tun
|
||||
<ca>
|
||||
-----BEGIN CERTIFICATE-----
|
||||
...certificate data...
|
||||
-----END CERTIFICATE-----
|
||||
</ca>
|
||||
<cert>
|
||||
-----BEGIN CERTIFICATE-----
|
||||
...cert data...
|
||||
-----END CERTIFICATE-----
|
||||
</cert>
|
||||
<key>
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
...key data...
|
||||
-----END PRIVATE KEY-----
|
||||
</key>
|
||||
"#;
|
||||
|
||||
let config = parse_openvpn_config(content).unwrap();
|
||||
assert_eq!(config.remote_host, "vpn.example.com");
|
||||
assert_eq!(config.remote_port, 1194);
|
||||
assert_eq!(config.protocol, "udp");
|
||||
assert_eq!(config.dev_type, "tun");
|
||||
assert!(config.has_inline_ca);
|
||||
assert!(config.has_inline_cert);
|
||||
assert!(config.has_inline_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_openvpn_config_minimal() {
|
||||
let content = r#"
|
||||
client
|
||||
remote vpn.example.com
|
||||
"#;
|
||||
|
||||
let config = parse_openvpn_config(content).unwrap();
|
||||
assert_eq!(config.remote_host, "vpn.example.com");
|
||||
assert_eq!(config.remote_port, 1194); // Default
|
||||
assert_eq!(config.protocol, "udp"); // Default
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_openvpn_config_with_port_and_proto() {
|
||||
let content = r#"
|
||||
client
|
||||
remote vpn.example.com 443 tcp
|
||||
"#;
|
||||
|
||||
let config = parse_openvpn_config(content).unwrap();
|
||||
assert_eq!(config.remote_host, "vpn.example.com");
|
||||
assert_eq!(config.remote_port, 443);
|
||||
assert_eq!(config.protocol, "tcp");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_openvpn_missing_remote() {
|
||||
let content = r#"
|
||||
client
|
||||
dev tun
|
||||
proto udp
|
||||
"#;
|
||||
|
||||
let result = parse_openvpn_config(content);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("remote"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
//! VPN support module for WireGuard and OpenVPN configurations.
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - VPN config parsing (WireGuard .conf and OpenVPN .ovpn files)
|
||||
//! - Encrypted storage for VPN configurations
|
||||
//! - Tunnel management with userspace WireGuard (boringtun) and OpenVPN process management
|
||||
|
||||
mod config;
|
||||
mod openvpn;
|
||||
mod storage;
|
||||
mod tunnel;
|
||||
mod wireguard;
|
||||
|
||||
pub use config::{
|
||||
detect_vpn_type, parse_openvpn_config, parse_wireguard_config, OpenVpnConfig, VpnConfig,
|
||||
VpnError, VpnImportResult, VpnStatus, VpnType, WireGuardConfig,
|
||||
};
|
||||
pub use openvpn::OpenVpnTunnel;
|
||||
pub use storage::VpnStorage;
|
||||
pub use tunnel::{TunnelManager, VpnTunnel};
|
||||
pub use wireguard::WireGuardTunnel;
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Global VPN storage instance
|
||||
pub static VPN_STORAGE: Lazy<Mutex<VpnStorage>> = Lazy::new(|| Mutex::new(VpnStorage::new()));
|
||||
|
||||
/// Global tunnel manager instance
|
||||
pub static TUNNEL_MANAGER: Lazy<tokio::sync::Mutex<TunnelManager>> =
|
||||
Lazy::new(|| tokio::sync::Mutex::new(TunnelManager::new()));
|
||||
@@ -0,0 +1,343 @@
|
||||
//! OpenVPN tunnel implementation using system openvpn binary.
|
||||
|
||||
use super::config::{OpenVpnConfig, VpnError, VpnStatus};
|
||||
use super::tunnel::VpnTunnel;
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Child, Command, Stdio};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tempfile::NamedTempFile;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// OpenVPN tunnel implementation
|
||||
pub struct OpenVpnTunnel {
|
||||
vpn_id: String,
|
||||
config: OpenVpnConfig,
|
||||
process: Arc<Mutex<Option<Child>>>,
|
||||
config_file: Option<NamedTempFile>,
|
||||
connected: AtomicBool,
|
||||
connected_at: Option<i64>,
|
||||
bytes_sent: AtomicU64,
|
||||
bytes_received: AtomicU64,
|
||||
}
|
||||
|
||||
impl OpenVpnTunnel {
|
||||
/// Create a new OpenVPN tunnel
|
||||
pub fn new(vpn_id: String, config: OpenVpnConfig) -> Self {
|
||||
Self {
|
||||
vpn_id,
|
||||
config,
|
||||
process: Arc::new(Mutex::new(None)),
|
||||
config_file: None,
|
||||
connected: AtomicBool::new(false),
|
||||
connected_at: None,
|
||||
bytes_sent: AtomicU64::new(0),
|
||||
bytes_received: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the openvpn binary
|
||||
fn find_openvpn_binary() -> Result<PathBuf, VpnError> {
|
||||
// Check common locations
|
||||
let locations = [
|
||||
"/usr/sbin/openvpn",
|
||||
"/usr/local/sbin/openvpn",
|
||||
"/opt/homebrew/bin/openvpn",
|
||||
"/usr/bin/openvpn",
|
||||
"C:\\Program Files\\OpenVPN\\bin\\openvpn.exe",
|
||||
"C:\\Program Files (x86)\\OpenVPN\\bin\\openvpn.exe",
|
||||
];
|
||||
|
||||
for loc in &locations {
|
||||
let path = PathBuf::from(loc);
|
||||
if path.exists() {
|
||||
return Ok(path);
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find via which/where command
|
||||
#[cfg(unix)]
|
||||
{
|
||||
if let Ok(output) = Command::new("which").arg("openvpn").output() {
|
||||
if output.status.success() {
|
||||
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if !path.is_empty() {
|
||||
return Ok(PathBuf::from(path));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
{
|
||||
if let Ok(output) = Command::new("where").arg("openvpn").output() {
|
||||
if output.status.success() {
|
||||
let path = String::from_utf8_lossy(&output.stdout)
|
||||
.lines()
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.trim()
|
||||
.to_string();
|
||||
if !path.is_empty() {
|
||||
return Ok(PathBuf::from(path));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(VpnError::Connection(
|
||||
"OpenVPN binary not found. Please install OpenVPN.".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Write config to temporary file
|
||||
fn write_config_file(&mut self) -> Result<PathBuf, VpnError> {
|
||||
let temp_file =
|
||||
NamedTempFile::new().map_err(|e| VpnError::Io(std::io::Error::other(e.to_string())))?;
|
||||
|
||||
std::fs::write(temp_file.path(), &self.config.raw_config).map_err(VpnError::Io)?;
|
||||
|
||||
let path = temp_file.path().to_path_buf();
|
||||
self.config_file = Some(temp_file);
|
||||
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
/// Start the OpenVPN process
|
||||
async fn start_process(&mut self) -> Result<(), VpnError> {
|
||||
let openvpn_bin = Self::find_openvpn_binary()?;
|
||||
let config_path = self.write_config_file()?;
|
||||
|
||||
log::info!(
|
||||
"[vpn] Starting OpenVPN with config: {}",
|
||||
config_path.display()
|
||||
);
|
||||
|
||||
// Build command with common options
|
||||
let mut cmd = Command::new(&openvpn_bin);
|
||||
cmd
|
||||
.arg("--config")
|
||||
.arg(&config_path)
|
||||
.arg("--verb")
|
||||
.arg("3") // Verbosity level
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
// On Unix, try to avoid requiring root if possible
|
||||
#[cfg(unix)]
|
||||
{
|
||||
cmd.arg("--script-security").arg("2");
|
||||
}
|
||||
|
||||
let child = cmd
|
||||
.spawn()
|
||||
.map_err(|e| VpnError::Connection(format!("Failed to start OpenVPN: {e}")))?;
|
||||
|
||||
*self.process.lock().await = Some(child);
|
||||
|
||||
// Wait a bit and check if process is still running
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||
|
||||
let mut process_guard = self.process.lock().await;
|
||||
if let Some(ref mut child) = *process_guard {
|
||||
match child.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
// Process exited early
|
||||
let mut error_msg = format!("OpenVPN exited with status: {status}");
|
||||
|
||||
// Try to get stderr output
|
||||
if let Some(stderr) = child.stderr.take() {
|
||||
let reader = BufReader::new(stderr);
|
||||
let lines: Vec<String> = reader.lines().map_while(Result::ok).take(5).collect();
|
||||
if !lines.is_empty() {
|
||||
error_msg.push_str(&format!("\nError: {}", lines.join("\n")));
|
||||
}
|
||||
}
|
||||
|
||||
return Err(VpnError::Connection(error_msg));
|
||||
}
|
||||
Ok(None) => {
|
||||
// Still running, good
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(VpnError::Connection(format!(
|
||||
"Failed to check process status: {e}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Kill the OpenVPN process
|
||||
async fn kill_process(&mut self) -> Result<(), VpnError> {
|
||||
let mut process_guard = self.process.lock().await;
|
||||
|
||||
if let Some(mut child) = process_guard.take() {
|
||||
// Try graceful shutdown first
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use nix::sys::signal::{kill, Signal};
|
||||
use nix::unistd::Pid;
|
||||
|
||||
if let Ok(pid) = child.id().try_into() {
|
||||
let _ = kill(Pid::from_raw(pid), Signal::SIGTERM);
|
||||
// Wait a bit for graceful shutdown
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Force kill if still running
|
||||
let _ = child.kill();
|
||||
let _ = child.wait();
|
||||
}
|
||||
|
||||
// Clean up config file
|
||||
self.config_file = None;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VpnTunnel for OpenVpnTunnel {
|
||||
async fn connect(&mut self) -> Result<(), VpnError> {
|
||||
if self.connected.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Start OpenVPN process
|
||||
self.start_process().await?;
|
||||
|
||||
// Wait for connection to be established
|
||||
// Note: In a real implementation, we'd monitor the OpenVPN management interface
|
||||
// For now, we assume success if the process starts and runs for a bit
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
||||
|
||||
// Check if process is still running
|
||||
let process_guard = self.process.lock().await;
|
||||
if let Some(ref child) = *process_guard {
|
||||
let id = child.id();
|
||||
if id > 0 {
|
||||
self.connected.store(true, Ordering::Release);
|
||||
self.connected_at = Some(Utc::now().timestamp());
|
||||
log::info!("[vpn] OpenVPN tunnel {} connected (PID: {id})", self.vpn_id);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Err(VpnError::Connection(
|
||||
"Failed to establish OpenVPN connection".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn disconnect(&mut self) -> Result<(), VpnError> {
|
||||
if !self.connected.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.kill_process().await?;
|
||||
|
||||
self.connected.store(false, Ordering::Release);
|
||||
self.connected_at = None;
|
||||
|
||||
log::info!("[vpn] OpenVPN tunnel {} disconnected", self.vpn_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_connected(&self) -> bool {
|
||||
self.connected.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
fn vpn_id(&self) -> &str {
|
||||
&self.vpn_id
|
||||
}
|
||||
|
||||
fn get_status(&self) -> VpnStatus {
|
||||
VpnStatus {
|
||||
connected: self.is_connected(),
|
||||
vpn_id: self.vpn_id.clone(),
|
||||
connected_at: self.connected_at,
|
||||
bytes_sent: Some(self.bytes_sent.load(Ordering::Relaxed)),
|
||||
bytes_received: Some(self.bytes_received.load(Ordering::Relaxed)),
|
||||
last_handshake: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_sent(&self) -> u64 {
|
||||
self.bytes_sent.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn bytes_received(&self) -> u64 {
|
||||
self.bytes_received.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for OpenVpnTunnel {
|
||||
fn drop(&mut self) {
|
||||
// Clean up process on drop (synchronously)
|
||||
if let Ok(mut guard) = self.process.try_lock() {
|
||||
if let Some(mut child) = guard.take() {
|
||||
let _ = child.kill();
|
||||
let _ = child.wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_config() -> OpenVpnConfig {
|
||||
OpenVpnConfig {
|
||||
raw_config: "client\nremote test.example.com 1194\ndev tun".to_string(),
|
||||
remote_host: "test.example.com".to_string(),
|
||||
remote_port: 1194,
|
||||
protocol: "udp".to_string(),
|
||||
dev_type: "tun".to_string(),
|
||||
has_inline_ca: false,
|
||||
has_inline_cert: false,
|
||||
has_inline_key: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openvpn_tunnel_creation() {
|
||||
let config = create_test_config();
|
||||
let tunnel = OpenVpnTunnel::new("test-ovpn-1".to_string(), config);
|
||||
|
||||
assert_eq!(tunnel.vpn_id(), "test-ovpn-1");
|
||||
assert!(!tunnel.is_connected());
|
||||
assert_eq!(tunnel.bytes_sent(), 0);
|
||||
assert_eq!(tunnel.bytes_received(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openvpn_status() {
|
||||
let config = create_test_config();
|
||||
let tunnel = OpenVpnTunnel::new("test-ovpn-2".to_string(), config);
|
||||
|
||||
let status = tunnel.get_status();
|
||||
assert!(!status.connected);
|
||||
assert_eq!(status.vpn_id, "test-ovpn-2");
|
||||
assert!(status.connected_at.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_openvpn_binary_format() {
|
||||
// This test just checks that the function doesn't panic
|
||||
// It may or may not find openvpn depending on the system
|
||||
let result = OpenVpnTunnel::find_openvpn_binary();
|
||||
// Just check that it returns a valid Result
|
||||
match result {
|
||||
Ok(path) => assert!(!path.as_os_str().is_empty()),
|
||||
Err(e) => assert!(e.to_string().contains("not found")),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,415 @@
|
||||
//! Encrypted storage for VPN configurations.
|
||||
|
||||
use super::config::{VpnConfig, VpnError, VpnType};
|
||||
use aes_gcm::{
|
||||
aead::{Aead, KeyInit},
|
||||
Aes256Gcm, Nonce,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Storage format version for migration support
|
||||
const STORAGE_VERSION: u32 = 1;
|
||||
|
||||
/// Stored VPN configs container
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct VpnStorageData {
|
||||
version: u32,
|
||||
configs: Vec<StoredVpnConfig>,
|
||||
}
|
||||
|
||||
/// Encrypted VPN config as stored on disk
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct StoredVpnConfig {
|
||||
id: String,
|
||||
name: String,
|
||||
vpn_type: VpnType,
|
||||
encrypted_data: String, // Base64 encoded encrypted config
|
||||
nonce: String, // Base64 encoded nonce
|
||||
created_at: i64,
|
||||
last_used: Option<i64>,
|
||||
}
|
||||
|
||||
/// VPN storage manager with encryption
|
||||
pub struct VpnStorage {
|
||||
storage_path: PathBuf,
|
||||
encryption_key: [u8; 32],
|
||||
}
|
||||
|
||||
impl Default for VpnStorage {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl VpnStorage {
|
||||
/// Create a new VPN storage manager
|
||||
pub fn new() -> Self {
|
||||
let storage_path = Self::get_storage_path();
|
||||
let encryption_key = Self::get_or_create_key();
|
||||
|
||||
Self {
|
||||
storage_path,
|
||||
encryption_key,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the storage file path
|
||||
fn get_storage_path() -> PathBuf {
|
||||
let data_dir = directories::ProjectDirs::from("com", "donut", "donutbrowser")
|
||||
.map(|dirs| dirs.data_local_dir().to_path_buf())
|
||||
.unwrap_or_else(|| PathBuf::from("."));
|
||||
|
||||
if !data_dir.exists() {
|
||||
let _ = fs::create_dir_all(&data_dir);
|
||||
}
|
||||
|
||||
data_dir.join("vpn_configs.json")
|
||||
}
|
||||
|
||||
/// Get or create the encryption key
|
||||
fn get_or_create_key() -> [u8; 32] {
|
||||
let key_path = directories::ProjectDirs::from("com", "donut", "donutbrowser")
|
||||
.map(|dirs| dirs.data_local_dir().join(".vpn_key"))
|
||||
.unwrap_or_else(|| PathBuf::from(".vpn_key"));
|
||||
|
||||
if key_path.exists() {
|
||||
if let Ok(key_data) = fs::read(&key_path) {
|
||||
if key_data.len() == 32 {
|
||||
let mut key = [0u8; 32];
|
||||
key.copy_from_slice(&key_data);
|
||||
return key;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a new key
|
||||
let key: [u8; 32] = rand::rng().random();
|
||||
let _ = fs::write(&key_path, key);
|
||||
|
||||
// Set restrictive permissions on Unix
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = fs::set_permissions(&key_path, fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
|
||||
key
|
||||
}
|
||||
|
||||
/// Load storage data from disk
|
||||
fn load_storage(&self) -> Result<VpnStorageData, VpnError> {
|
||||
if !self.storage_path.exists() {
|
||||
return Ok(VpnStorageData {
|
||||
version: STORAGE_VERSION,
|
||||
configs: Vec::new(),
|
||||
});
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&self.storage_path)
|
||||
.map_err(|e| VpnError::Storage(format!("Failed to read storage file: {e}")))?;
|
||||
|
||||
serde_json::from_str(&content)
|
||||
.map_err(|e| VpnError::Storage(format!("Failed to parse storage file: {e}")))
|
||||
}
|
||||
|
||||
/// Save storage data to disk
|
||||
fn save_storage(&self, data: &VpnStorageData) -> Result<(), VpnError> {
|
||||
let content = serde_json::to_string_pretty(data)
|
||||
.map_err(|e| VpnError::Storage(format!("Failed to serialize storage: {e}")))?;
|
||||
|
||||
fs::write(&self.storage_path, content)
|
||||
.map_err(|e| VpnError::Storage(format!("Failed to write storage file: {e}")))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Encrypt config data
|
||||
fn encrypt(&self, data: &str) -> Result<(String, String), VpnError> {
|
||||
let cipher = Aes256Gcm::new_from_slice(&self.encryption_key)
|
||||
.map_err(|e| VpnError::Encryption(format!("Failed to create cipher: {e}")))?;
|
||||
|
||||
let nonce_bytes: [u8; 12] = rand::rng().random();
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
let ciphertext = cipher
|
||||
.encrypt(nonce, data.as_bytes())
|
||||
.map_err(|e| VpnError::Encryption(format!("Encryption failed: {e}")))?;
|
||||
|
||||
Ok((
|
||||
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &ciphertext),
|
||||
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, nonce_bytes),
|
||||
))
|
||||
}
|
||||
|
||||
/// Decrypt config data
|
||||
fn decrypt(&self, encrypted_data: &str, nonce_str: &str) -> Result<String, VpnError> {
|
||||
let cipher = Aes256Gcm::new_from_slice(&self.encryption_key)
|
||||
.map_err(|e| VpnError::Encryption(format!("Failed to create cipher: {e}")))?;
|
||||
|
||||
let ciphertext =
|
||||
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, encrypted_data)
|
||||
.map_err(|e| VpnError::Encryption(format!("Failed to decode ciphertext: {e}")))?;
|
||||
|
||||
let nonce_bytes = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, nonce_str)
|
||||
.map_err(|e| VpnError::Encryption(format!("Failed to decode nonce: {e}")))?;
|
||||
|
||||
if nonce_bytes.len() != 12 {
|
||||
return Err(VpnError::Encryption("Invalid nonce length".to_string()));
|
||||
}
|
||||
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, ciphertext.as_ref())
|
||||
.map_err(|e| VpnError::Encryption(format!("Decryption failed: {e}")))?;
|
||||
|
||||
String::from_utf8(plaintext)
|
||||
.map_err(|e| VpnError::Encryption(format!("Failed to decode plaintext: {e}")))
|
||||
}
|
||||
|
||||
/// Save a VPN configuration
|
||||
pub fn save_config(&self, config: &VpnConfig) -> Result<(), VpnError> {
|
||||
let mut storage = self.load_storage()?;
|
||||
|
||||
// Encrypt the config data
|
||||
let (encrypted_data, nonce) = self.encrypt(&config.config_data)?;
|
||||
|
||||
let stored = StoredVpnConfig {
|
||||
id: config.id.clone(),
|
||||
name: config.name.clone(),
|
||||
vpn_type: config.vpn_type,
|
||||
encrypted_data,
|
||||
nonce,
|
||||
created_at: config.created_at,
|
||||
last_used: config.last_used,
|
||||
};
|
||||
|
||||
// Update existing or add new
|
||||
if let Some(pos) = storage.configs.iter().position(|c| c.id == config.id) {
|
||||
storage.configs[pos] = stored;
|
||||
} else {
|
||||
storage.configs.push(stored);
|
||||
}
|
||||
|
||||
self.save_storage(&storage)
|
||||
}
|
||||
|
||||
/// Load a VPN configuration by ID
|
||||
pub fn load_config(&self, id: &str) -> Result<VpnConfig, VpnError> {
|
||||
let storage = self.load_storage()?;
|
||||
|
||||
let stored = storage
|
||||
.configs
|
||||
.iter()
|
||||
.find(|c| c.id == id)
|
||||
.ok_or_else(|| VpnError::NotFound(id.to_string()))?;
|
||||
|
||||
let config_data = self.decrypt(&stored.encrypted_data, &stored.nonce)?;
|
||||
|
||||
Ok(VpnConfig {
|
||||
id: stored.id.clone(),
|
||||
name: stored.name.clone(),
|
||||
vpn_type: stored.vpn_type,
|
||||
config_data,
|
||||
created_at: stored.created_at,
|
||||
last_used: stored.last_used,
|
||||
})
|
||||
}
|
||||
|
||||
/// List all VPN configurations (without decrypted config data)
|
||||
pub fn list_configs(&self) -> Result<Vec<VpnConfig>, VpnError> {
|
||||
let storage = self.load_storage()?;
|
||||
|
||||
Ok(
|
||||
storage
|
||||
.configs
|
||||
.iter()
|
||||
.map(|stored| VpnConfig {
|
||||
id: stored.id.clone(),
|
||||
name: stored.name.clone(),
|
||||
vpn_type: stored.vpn_type,
|
||||
config_data: String::new(), // Don't include config data in list
|
||||
created_at: stored.created_at,
|
||||
last_used: stored.last_used,
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Delete a VPN configuration
|
||||
pub fn delete_config(&self, id: &str) -> Result<(), VpnError> {
|
||||
let mut storage = self.load_storage()?;
|
||||
|
||||
let initial_len = storage.configs.len();
|
||||
storage.configs.retain(|c| c.id != id);
|
||||
|
||||
if storage.configs.len() == initial_len {
|
||||
return Err(VpnError::NotFound(id.to_string()));
|
||||
}
|
||||
|
||||
self.save_storage(&storage)
|
||||
}
|
||||
|
||||
/// Update last_used timestamp
|
||||
pub fn update_last_used(&self, id: &str) -> Result<(), VpnError> {
|
||||
let mut storage = self.load_storage()?;
|
||||
|
||||
if let Some(config) = storage.configs.iter_mut().find(|c| c.id == id) {
|
||||
config.last_used = Some(Utc::now().timestamp());
|
||||
self.save_storage(&storage)
|
||||
} else {
|
||||
Err(VpnError::NotFound(id.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Import a VPN config from raw content
|
||||
pub fn import_config(
|
||||
&self,
|
||||
content: &str,
|
||||
filename: &str,
|
||||
name: Option<String>,
|
||||
) -> Result<VpnConfig, VpnError> {
|
||||
let vpn_type = super::detect_vpn_type(content, filename)?;
|
||||
|
||||
// Validate the config by parsing it
|
||||
match vpn_type {
|
||||
VpnType::WireGuard => {
|
||||
super::parse_wireguard_config(content)?;
|
||||
}
|
||||
VpnType::OpenVPN => {
|
||||
super::parse_openvpn_config(content)?;
|
||||
}
|
||||
}
|
||||
|
||||
let id = Uuid::new_v4().to_string();
|
||||
let display_name = name.unwrap_or_else(|| {
|
||||
// Generate name from filename
|
||||
let base = filename.trim_end_matches(".conf").trim_end_matches(".ovpn");
|
||||
format!("{} ({})", base, vpn_type)
|
||||
});
|
||||
|
||||
let config = VpnConfig {
|
||||
id,
|
||||
name: display_name,
|
||||
vpn_type,
|
||||
config_data: content.to_string(),
|
||||
created_at: Utc::now().timestamp(),
|
||||
last_used: None,
|
||||
};
|
||||
|
||||
self.save_config(&config)?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn create_test_storage() -> (VpnStorage, TempDir) {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let mut storage = VpnStorage::new();
|
||||
storage.storage_path = temp_dir.path().join("test_vpn_configs.json");
|
||||
(storage, temp_dir)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_decrypt_roundtrip() {
|
||||
let (storage, _temp) = create_test_storage();
|
||||
let original = "This is a secret VPN configuration";
|
||||
|
||||
let (encrypted, nonce) = storage.encrypt(original).unwrap();
|
||||
let decrypted = storage.decrypt(&encrypted, &nonce).unwrap();
|
||||
|
||||
assert_eq!(original, decrypted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load_config() {
|
||||
let (storage, _temp) = create_test_storage();
|
||||
|
||||
let config = VpnConfig {
|
||||
id: "test-id-123".to_string(),
|
||||
name: "Test VPN".to_string(),
|
||||
vpn_type: VpnType::WireGuard,
|
||||
config_data: "[Interface]\nPrivateKey = test\n[Peer]\nPublicKey = peer".to_string(),
|
||||
created_at: 1234567890,
|
||||
last_used: None,
|
||||
};
|
||||
|
||||
storage.save_config(&config).unwrap();
|
||||
let loaded = storage.load_config("test-id-123").unwrap();
|
||||
|
||||
assert_eq!(loaded.id, config.id);
|
||||
assert_eq!(loaded.name, config.name);
|
||||
assert_eq!(loaded.vpn_type, config.vpn_type);
|
||||
assert_eq!(loaded.config_data, config.config_data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_configs() {
|
||||
let (storage, _temp) = create_test_storage();
|
||||
|
||||
let config1 = VpnConfig {
|
||||
id: "id-1".to_string(),
|
||||
name: "VPN 1".to_string(),
|
||||
vpn_type: VpnType::WireGuard,
|
||||
config_data: "secret1".to_string(),
|
||||
created_at: 1000,
|
||||
last_used: None,
|
||||
};
|
||||
|
||||
let config2 = VpnConfig {
|
||||
id: "id-2".to_string(),
|
||||
name: "VPN 2".to_string(),
|
||||
vpn_type: VpnType::OpenVPN,
|
||||
config_data: "secret2".to_string(),
|
||||
created_at: 2000,
|
||||
last_used: Some(3000),
|
||||
};
|
||||
|
||||
storage.save_config(&config1).unwrap();
|
||||
storage.save_config(&config2).unwrap();
|
||||
|
||||
let configs = storage.list_configs().unwrap();
|
||||
assert_eq!(configs.len(), 2);
|
||||
|
||||
// Config data should be empty in listing
|
||||
assert!(configs[0].config_data.is_empty());
|
||||
assert!(configs[1].config_data.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete_config() {
|
||||
let (storage, _temp) = create_test_storage();
|
||||
|
||||
let config = VpnConfig {
|
||||
id: "delete-me".to_string(),
|
||||
name: "To Delete".to_string(),
|
||||
vpn_type: VpnType::WireGuard,
|
||||
config_data: "data".to_string(),
|
||||
created_at: 1000,
|
||||
last_used: None,
|
||||
};
|
||||
|
||||
storage.save_config(&config).unwrap();
|
||||
assert!(storage.load_config("delete-me").is_ok());
|
||||
|
||||
storage.delete_config("delete-me").unwrap();
|
||||
assert!(storage.load_config("delete-me").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_nonexistent_config() {
|
||||
let (storage, _temp) = create_test_storage();
|
||||
let result = storage.load_config("nonexistent");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
//! VPN tunnel trait and management.
|
||||
|
||||
use super::config::{VpnError, VpnStatus};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Trait for VPN tunnel implementations
|
||||
#[async_trait]
|
||||
pub trait VpnTunnel: Send + Sync {
|
||||
/// Connect the VPN tunnel
|
||||
async fn connect(&mut self) -> Result<(), VpnError>;
|
||||
|
||||
/// Disconnect the VPN tunnel
|
||||
async fn disconnect(&mut self) -> Result<(), VpnError>;
|
||||
|
||||
/// Check if the tunnel is connected
|
||||
fn is_connected(&self) -> bool;
|
||||
|
||||
/// Get the VPN config ID
|
||||
fn vpn_id(&self) -> &str;
|
||||
|
||||
/// Get the current status of the tunnel
|
||||
fn get_status(&self) -> VpnStatus;
|
||||
|
||||
/// Get bytes sent through the tunnel
|
||||
fn bytes_sent(&self) -> u64;
|
||||
|
||||
/// Get bytes received through the tunnel
|
||||
fn bytes_received(&self) -> u64;
|
||||
}
|
||||
|
||||
/// Manager for active VPN tunnels
|
||||
pub struct TunnelManager {
|
||||
active_tunnels: HashMap<String, Box<dyn VpnTunnel>>,
|
||||
}
|
||||
|
||||
impl Default for TunnelManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelManager {
|
||||
/// Create a new tunnel manager
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
active_tunnels: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register an active tunnel
|
||||
pub fn register_tunnel(&mut self, vpn_id: String, tunnel: Box<dyn VpnTunnel>) {
|
||||
self.active_tunnels.insert(vpn_id, tunnel);
|
||||
}
|
||||
|
||||
/// Remove a tunnel from management
|
||||
pub fn remove_tunnel(&mut self, vpn_id: &str) -> Option<Box<dyn VpnTunnel>> {
|
||||
self.active_tunnels.remove(vpn_id)
|
||||
}
|
||||
|
||||
/// Get a reference to an active tunnel
|
||||
pub fn get_tunnel(&self, vpn_id: &str) -> Option<&dyn VpnTunnel> {
|
||||
self.active_tunnels.get(vpn_id).map(|t| t.as_ref())
|
||||
}
|
||||
|
||||
/// Get a mutable reference to an active tunnel
|
||||
pub fn get_tunnel_mut(&mut self, vpn_id: &str) -> Option<&mut Box<dyn VpnTunnel>> {
|
||||
self.active_tunnels.get_mut(vpn_id)
|
||||
}
|
||||
|
||||
/// Check if a tunnel is active
|
||||
pub fn is_tunnel_active(&self, vpn_id: &str) -> bool {
|
||||
self
|
||||
.active_tunnels
|
||||
.get(vpn_id)
|
||||
.is_some_and(|t| t.is_connected())
|
||||
}
|
||||
|
||||
/// Get status of all active tunnels
|
||||
pub fn get_all_statuses(&self) -> Vec<VpnStatus> {
|
||||
self
|
||||
.active_tunnels
|
||||
.values()
|
||||
.map(|t| t.get_status())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Disconnect all active tunnels
|
||||
pub async fn disconnect_all(&mut self) -> Vec<Result<(), VpnError>> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for tunnel in self.active_tunnels.values_mut() {
|
||||
results.push(tunnel.disconnect().await);
|
||||
}
|
||||
|
||||
self.active_tunnels.clear();
|
||||
results
|
||||
}
|
||||
|
||||
/// Get the number of active tunnels
|
||||
pub fn active_count(&self) -> usize {
|
||||
self
|
||||
.active_tunnels
|
||||
.values()
|
||||
.filter(|t| t.is_connected())
|
||||
.count()
|
||||
}
|
||||
|
||||
/// List IDs of all active VPN connections
|
||||
pub fn list_active_ids(&self) -> Vec<String> {
|
||||
self
|
||||
.active_tunnels
|
||||
.iter()
|
||||
.filter(|(_, t)| t.is_connected())
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
struct MockTunnel {
|
||||
id: String,
|
||||
connected: bool,
|
||||
bytes_sent: u64,
|
||||
bytes_received: u64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VpnTunnel for MockTunnel {
|
||||
async fn connect(&mut self) -> Result<(), VpnError> {
|
||||
self.connected = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn disconnect(&mut self) -> Result<(), VpnError> {
|
||||
self.connected = false;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_connected(&self) -> bool {
|
||||
self.connected
|
||||
}
|
||||
|
||||
fn vpn_id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
|
||||
fn get_status(&self) -> VpnStatus {
|
||||
VpnStatus {
|
||||
connected: self.connected,
|
||||
vpn_id: self.id.clone(),
|
||||
connected_at: if self.connected { Some(1000) } else { None },
|
||||
bytes_sent: Some(self.bytes_sent),
|
||||
bytes_received: Some(self.bytes_received),
|
||||
last_handshake: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_sent(&self) -> u64 {
|
||||
self.bytes_sent
|
||||
}
|
||||
|
||||
fn bytes_received(&self) -> u64 {
|
||||
self.bytes_received
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tunnel_manager_register() {
|
||||
let mut manager = TunnelManager::new();
|
||||
let tunnel = Box::new(MockTunnel {
|
||||
id: "test-1".to_string(),
|
||||
connected: true,
|
||||
bytes_sent: 100,
|
||||
bytes_received: 200,
|
||||
});
|
||||
|
||||
manager.register_tunnel("test-1".to_string(), tunnel);
|
||||
assert!(manager.is_tunnel_active("test-1"));
|
||||
assert!(!manager.is_tunnel_active("test-2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tunnel_manager_remove() {
|
||||
let mut manager = TunnelManager::new();
|
||||
let tunnel = Box::new(MockTunnel {
|
||||
id: "test-1".to_string(),
|
||||
connected: true,
|
||||
bytes_sent: 0,
|
||||
bytes_received: 0,
|
||||
});
|
||||
|
||||
manager.register_tunnel("test-1".to_string(), tunnel);
|
||||
assert!(manager.is_tunnel_active("test-1"));
|
||||
|
||||
let removed = manager.remove_tunnel("test-1");
|
||||
assert!(removed.is_some());
|
||||
assert!(!manager.is_tunnel_active("test-1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tunnel_manager_active_count() {
|
||||
let mut manager = TunnelManager::new();
|
||||
|
||||
let tunnel1 = Box::new(MockTunnel {
|
||||
id: "t1".to_string(),
|
||||
connected: true,
|
||||
bytes_sent: 0,
|
||||
bytes_received: 0,
|
||||
});
|
||||
|
||||
let tunnel2 = Box::new(MockTunnel {
|
||||
id: "t2".to_string(),
|
||||
connected: false,
|
||||
bytes_sent: 0,
|
||||
bytes_received: 0,
|
||||
});
|
||||
|
||||
manager.register_tunnel("t1".to_string(), tunnel1);
|
||||
manager.register_tunnel("t2".to_string(), tunnel2);
|
||||
|
||||
assert_eq!(manager.active_count(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tunnel_manager_disconnect_all() {
|
||||
let mut manager = TunnelManager::new();
|
||||
|
||||
let tunnel1 = Box::new(MockTunnel {
|
||||
id: "t1".to_string(),
|
||||
connected: true,
|
||||
bytes_sent: 0,
|
||||
bytes_received: 0,
|
||||
});
|
||||
|
||||
let tunnel2 = Box::new(MockTunnel {
|
||||
id: "t2".to_string(),
|
||||
connected: true,
|
||||
bytes_sent: 0,
|
||||
bytes_received: 0,
|
||||
});
|
||||
|
||||
manager.register_tunnel("t1".to_string(), tunnel1);
|
||||
manager.register_tunnel("t2".to_string(), tunnel2);
|
||||
|
||||
assert_eq!(manager.active_count(), 2);
|
||||
|
||||
let results = manager.disconnect_all().await;
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results.iter().all(|r| r.is_ok()));
|
||||
assert_eq!(manager.active_count(), 0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,413 @@
|
||||
//! WireGuard tunnel implementation using boringtun.
|
||||
|
||||
use super::config::{VpnError, VpnStatus, WireGuardConfig};
|
||||
use super::tunnel::VpnTunnel;
|
||||
use async_trait::async_trait;
|
||||
use boringtun::noise::{Tunn, TunnResult};
|
||||
use boringtun::x25519::{PublicKey, StaticSecret};
|
||||
use chrono::Utc;
|
||||
use std::net::{SocketAddr, ToSocketAddrs, UdpSocket};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// WireGuard tunnel implementation
|
||||
pub struct WireGuardTunnel {
|
||||
vpn_id: String,
|
||||
config: WireGuardConfig,
|
||||
tunnel: Option<Arc<Mutex<Box<Tunn>>>>,
|
||||
socket: Option<Arc<UdpSocket>>,
|
||||
connected: AtomicBool,
|
||||
connected_at: Option<i64>,
|
||||
bytes_sent: AtomicU64,
|
||||
bytes_received: AtomicU64,
|
||||
last_handshake: Option<i64>,
|
||||
peer_addr: Option<SocketAddr>,
|
||||
}
|
||||
|
||||
impl WireGuardTunnel {
|
||||
/// Create a new WireGuard tunnel
|
||||
pub fn new(vpn_id: String, config: WireGuardConfig) -> Self {
|
||||
Self {
|
||||
vpn_id,
|
||||
config,
|
||||
tunnel: None,
|
||||
socket: None,
|
||||
connected: AtomicBool::new(false),
|
||||
connected_at: None,
|
||||
bytes_sent: AtomicU64::new(0),
|
||||
bytes_received: AtomicU64::new(0),
|
||||
last_handshake: None,
|
||||
peer_addr: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse base64 key to bytes
|
||||
fn parse_key(key: &str) -> Result<[u8; 32], VpnError> {
|
||||
let decoded = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, key)
|
||||
.map_err(|e| VpnError::InvalidWireGuard(format!("Invalid key encoding: {e}")))?;
|
||||
|
||||
if decoded.len() != 32 {
|
||||
return Err(VpnError::InvalidWireGuard(format!(
|
||||
"Invalid key length: {} (expected 32)",
|
||||
decoded.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut key_bytes = [0u8; 32];
|
||||
key_bytes.copy_from_slice(&decoded);
|
||||
Ok(key_bytes)
|
||||
}
|
||||
|
||||
/// Initialize the WireGuard tunnel
|
||||
fn init_tunnel(&mut self) -> Result<(), VpnError> {
|
||||
// Parse private key
|
||||
let private_key_bytes = Self::parse_key(&self.config.private_key)?;
|
||||
let static_private = StaticSecret::from(private_key_bytes);
|
||||
|
||||
// Parse peer public key
|
||||
let peer_public_bytes = Self::parse_key(&self.config.peer_public_key)?;
|
||||
let peer_public = PublicKey::from(peer_public_bytes);
|
||||
|
||||
// Parse optional preshared key
|
||||
let preshared_key = if let Some(ref psk) = self.config.preshared_key {
|
||||
Some(Self::parse_key(psk)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Create the boringtun tunnel
|
||||
let tunn = Tunn::new(
|
||||
static_private,
|
||||
peer_public,
|
||||
preshared_key,
|
||||
self.config.persistent_keepalive,
|
||||
0, // index
|
||||
None,
|
||||
)
|
||||
.map_err(|e| VpnError::Tunnel(format!("Failed to create tunnel: {e}")))?;
|
||||
|
||||
self.tunnel = Some(Arc::new(Mutex::new(Box::new(tunn))));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Resolve peer endpoint to socket address
|
||||
fn resolve_endpoint(&mut self) -> Result<SocketAddr, VpnError> {
|
||||
let endpoint = &self.config.peer_endpoint;
|
||||
|
||||
// Try to resolve the endpoint
|
||||
let addrs: Vec<SocketAddr> = endpoint
|
||||
.to_socket_addrs()
|
||||
.map_err(|e| VpnError::Connection(format!("Failed to resolve endpoint '{endpoint}': {e}")))?
|
||||
.collect();
|
||||
|
||||
addrs
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| VpnError::Connection(format!("No addresses found for endpoint: {endpoint}")))
|
||||
}
|
||||
|
||||
/// Perform WireGuard handshake
|
||||
async fn handshake(&mut self) -> Result<(), VpnError> {
|
||||
let tunnel = self
|
||||
.tunnel
|
||||
.as_ref()
|
||||
.ok_or_else(|| VpnError::Tunnel("Tunnel not initialized".to_string()))?;
|
||||
|
||||
let socket = self
|
||||
.socket
|
||||
.as_ref()
|
||||
.ok_or_else(|| VpnError::Tunnel("Socket not initialized".to_string()))?;
|
||||
|
||||
let peer_addr = self
|
||||
.peer_addr
|
||||
.ok_or_else(|| VpnError::Tunnel("Peer address not resolved".to_string()))?;
|
||||
|
||||
let mut tunnel_guard = tunnel.lock().await;
|
||||
|
||||
// Generate handshake initiation
|
||||
let mut dst = vec![0u8; 2048];
|
||||
let result = tunnel_guard.format_handshake_initiation(&mut dst, false);
|
||||
|
||||
match result {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
socket
|
||||
.send_to(packet, peer_addr)
|
||||
.map_err(|e| VpnError::Connection(format!("Failed to send handshake: {e}")))?;
|
||||
|
||||
self
|
||||
.bytes_sent
|
||||
.fetch_add(packet.len() as u64, Ordering::Relaxed);
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
return Err(VpnError::Tunnel(format!(
|
||||
"Handshake initiation failed: {e:?}"
|
||||
)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Wait for handshake response (with timeout)
|
||||
socket
|
||||
.set_read_timeout(Some(std::time::Duration::from_secs(10)))
|
||||
.map_err(|e| VpnError::Connection(format!("Failed to set timeout: {e}")))?;
|
||||
|
||||
let mut recv_buf = vec![0u8; 2048];
|
||||
|
||||
match socket.recv_from(&mut recv_buf) {
|
||||
Ok((len, _from)) => {
|
||||
self.bytes_received.fetch_add(len as u64, Ordering::Relaxed);
|
||||
|
||||
let result = tunnel_guard.decapsulate(None, &recv_buf[..len], &mut dst);
|
||||
|
||||
match result {
|
||||
TunnResult::WriteToNetwork(response) => {
|
||||
socket
|
||||
.send_to(response, peer_addr)
|
||||
.map_err(|e| VpnError::Connection(format!("Failed to send response: {e}")))?;
|
||||
|
||||
self
|
||||
.bytes_sent
|
||||
.fetch_add(response.len() as u64, Ordering::Relaxed);
|
||||
self.last_handshake = Some(Utc::now().timestamp());
|
||||
}
|
||||
TunnResult::Done => {
|
||||
self.last_handshake = Some(Utc::now().timestamp());
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
return Err(VpnError::Tunnel(format!(
|
||||
"Handshake response failed: {e:?}"
|
||||
)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(VpnError::Connection(format!(
|
||||
"Handshake timeout or error: {e}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Encrypt and send data through the tunnel
|
||||
pub async fn send(&self, data: &[u8]) -> Result<(), VpnError> {
|
||||
let tunnel = self
|
||||
.tunnel
|
||||
.as_ref()
|
||||
.ok_or_else(|| VpnError::Tunnel("Tunnel not initialized".to_string()))?;
|
||||
|
||||
let socket = self
|
||||
.socket
|
||||
.as_ref()
|
||||
.ok_or_else(|| VpnError::Tunnel("Socket not initialized".to_string()))?;
|
||||
|
||||
let peer_addr = self
|
||||
.peer_addr
|
||||
.ok_or_else(|| VpnError::Tunnel("Peer address not resolved".to_string()))?;
|
||||
|
||||
let mut tunnel_guard = tunnel.lock().await;
|
||||
let mut dst = vec![0u8; data.len() + 256]; // Extra space for WireGuard overhead
|
||||
|
||||
let result = tunnel_guard.encapsulate(data, &mut dst);
|
||||
|
||||
match result {
|
||||
TunnResult::WriteToNetwork(packet) => {
|
||||
socket
|
||||
.send_to(packet, peer_addr)
|
||||
.map_err(|e| VpnError::Connection(format!("Failed to send data: {e}")))?;
|
||||
|
||||
self
|
||||
.bytes_sent
|
||||
.fetch_add(packet.len() as u64, Ordering::Relaxed);
|
||||
}
|
||||
TunnResult::Err(e) => {
|
||||
return Err(VpnError::Tunnel(format!("Encryption failed: {e:?}")));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive and decrypt data from the tunnel
|
||||
pub async fn receive(&self, buf: &mut [u8]) -> Result<usize, VpnError> {
|
||||
let tunnel = self
|
||||
.tunnel
|
||||
.as_ref()
|
||||
.ok_or_else(|| VpnError::Tunnel("Tunnel not initialized".to_string()))?;
|
||||
|
||||
let socket = self
|
||||
.socket
|
||||
.as_ref()
|
||||
.ok_or_else(|| VpnError::Tunnel("Socket not initialized".to_string()))?;
|
||||
|
||||
let mut recv_buf = vec![0u8; 2048];
|
||||
|
||||
let (len, _from) = socket
|
||||
.recv_from(&mut recv_buf)
|
||||
.map_err(|e| VpnError::Connection(format!("Receive failed: {e}")))?;
|
||||
|
||||
self.bytes_received.fetch_add(len as u64, Ordering::Relaxed);
|
||||
|
||||
let mut tunnel_guard = tunnel.lock().await;
|
||||
// decapsulate writes decrypted data directly to buf and returns a slice pointing to it
|
||||
let result = tunnel_guard.decapsulate(None, &recv_buf[..len], buf);
|
||||
|
||||
match result {
|
||||
// Data is already written to buf by decapsulate, just return the length
|
||||
TunnResult::WriteToTunnelV4(decrypted, _) => Ok(decrypted.len()),
|
||||
TunnResult::WriteToTunnelV6(decrypted, _) => Ok(decrypted.len()),
|
||||
TunnResult::Done => Ok(0),
|
||||
TunnResult::Err(e) => Err(VpnError::Tunnel(format!("Decryption failed: {e:?}"))),
|
||||
_ => Ok(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl VpnTunnel for WireGuardTunnel {
|
||||
async fn connect(&mut self) -> Result<(), VpnError> {
|
||||
if self.connected.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Initialize the tunnel
|
||||
self.init_tunnel()?;
|
||||
|
||||
// Resolve endpoint
|
||||
self.peer_addr = Some(self.resolve_endpoint()?);
|
||||
|
||||
// Create UDP socket
|
||||
let socket = UdpSocket::bind("0.0.0.0:0")
|
||||
.map_err(|e| VpnError::Connection(format!("Failed to create socket: {e}")))?;
|
||||
|
||||
socket
|
||||
.set_nonblocking(false)
|
||||
.map_err(|e| VpnError::Connection(format!("Failed to set socket options: {e}")))?;
|
||||
|
||||
self.socket = Some(Arc::new(socket));
|
||||
|
||||
// Perform handshake
|
||||
self.handshake().await?;
|
||||
|
||||
self.connected.store(true, Ordering::Release);
|
||||
self.connected_at = Some(Utc::now().timestamp());
|
||||
|
||||
log::info!("[vpn] WireGuard tunnel {} connected", self.vpn_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn disconnect(&mut self) -> Result<(), VpnError> {
|
||||
if !self.connected.load(Ordering::Relaxed) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.connected.store(false, Ordering::Release);
|
||||
self.tunnel = None;
|
||||
self.socket = None;
|
||||
self.connected_at = None;
|
||||
|
||||
log::info!("[vpn] WireGuard tunnel {} disconnected", self.vpn_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_connected(&self) -> bool {
|
||||
self.connected.load(Ordering::Acquire)
|
||||
}
|
||||
|
||||
fn vpn_id(&self) -> &str {
|
||||
&self.vpn_id
|
||||
}
|
||||
|
||||
fn get_status(&self) -> VpnStatus {
|
||||
VpnStatus {
|
||||
connected: self.is_connected(),
|
||||
vpn_id: self.vpn_id.clone(),
|
||||
connected_at: self.connected_at,
|
||||
bytes_sent: Some(self.bytes_sent.load(Ordering::Relaxed)),
|
||||
bytes_received: Some(self.bytes_received.load(Ordering::Relaxed)),
|
||||
last_handshake: self.last_handshake,
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_sent(&self) -> u64 {
|
||||
self.bytes_sent.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn bytes_received(&self) -> u64 {
|
||||
self.bytes_received.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_config() -> WireGuardConfig {
|
||||
WireGuardConfig {
|
||||
// These are test keys, not real ones
|
||||
private_key: "YEocP0e2o1WT5GlvBvQzVF7EeR6z9aCk+ZdZ5NKEuXA=".to_string(),
|
||||
address: "10.0.0.2/24".to_string(),
|
||||
dns: Some("1.1.1.1".to_string()),
|
||||
mtu: Some(1420),
|
||||
peer_public_key: "aGnF7JlG+U5t0BqB1PVf1yOuELHrWLGGcUJb0eCK9Aw=".to_string(),
|
||||
peer_endpoint: "127.0.0.1:51820".to_string(),
|
||||
allowed_ips: vec!["0.0.0.0/0".to_string()],
|
||||
persistent_keepalive: Some(25),
|
||||
preshared_key: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wireguard_tunnel_creation() {
|
||||
let config = create_test_config();
|
||||
let tunnel = WireGuardTunnel::new("test-wg-1".to_string(), config);
|
||||
|
||||
assert_eq!(tunnel.vpn_id(), "test-wg-1");
|
||||
assert!(!tunnel.is_connected());
|
||||
assert_eq!(tunnel.bytes_sent(), 0);
|
||||
assert_eq!(tunnel.bytes_received(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_key_valid() {
|
||||
// Valid base64-encoded 32-byte key
|
||||
let key = "YEocP0e2o1WT5GlvBvQzVF7EeR6z9aCk+ZdZ5NKEuXA=";
|
||||
let result = WireGuardTunnel::parse_key(key);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_key_invalid_base64() {
|
||||
let key = "not-valid-base64!!!";
|
||||
let result = WireGuardTunnel::parse_key(key);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_key_wrong_length() {
|
||||
// Valid base64 but wrong length
|
||||
let key = "YWJjZA=="; // "abcd" in base64
|
||||
let result = WireGuardTunnel::parse_key(key);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wireguard_status() {
|
||||
let config = create_test_config();
|
||||
let tunnel = WireGuardTunnel::new("test-wg-2".to_string(), config);
|
||||
|
||||
let status = tunnel.get_status();
|
||||
assert!(!status.connected);
|
||||
assert_eq!(status.vpn_id, "test-wg-2");
|
||||
assert!(status.connected_at.is_none());
|
||||
assert_eq!(status.bytes_sent, Some(0));
|
||||
assert_eq!(status.bytes_received, Some(0));
|
||||
}
|
||||
}
|
||||
@@ -97,6 +97,12 @@ impl WayfernTermsManager {
|
||||
timestamp >= MIN_VALID_TIMESTAMP
|
||||
}
|
||||
|
||||
pub fn is_wayfern_downloaded(&self) -> bool {
|
||||
let registry = DownloadedBrowsersRegistry::instance();
|
||||
let versions = registry.get_downloaded_versions("wayfern");
|
||||
!versions.is_empty()
|
||||
}
|
||||
|
||||
fn get_any_wayfern_executable(&self) -> Option<PathBuf> {
|
||||
// First try to get executable from any downloaded Wayfern version
|
||||
let registry = DownloadedBrowsersRegistry::instance();
|
||||
|
||||
Reference in New Issue
Block a user