mirror of
https://github.com/tauri-apps/plugins-workspace.git
synced 2026-04-21 11:26:15 +02:00
feat(websocket): Add proxy configuration (#1536)
Co-authored-by: FabianLars <github@fabianlars.de>
This commit is contained in:
Generated
+44
-14
@@ -376,9 +376,9 @@ checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.0"
|
||||
version = "0.22.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51"
|
||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
||||
|
||||
[[package]]
|
||||
name = "base64ct"
|
||||
@@ -1989,9 +1989,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.2.0"
|
||||
version = "1.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a"
|
||||
checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
@@ -2025,9 +2025,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.3"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa"
|
||||
checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
@@ -3644,7 +3644,7 @@ version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3e6cc1e89e689536eb5aeede61520e874df5a4707df811cd5da4aa5fbb2aae19"
|
||||
dependencies = [
|
||||
"base64 0.22.0",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-core",
|
||||
@@ -3797,13 +3797,26 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a"
|
||||
dependencies = [
|
||||
"openssl-probe",
|
||||
"rustls-pemfile",
|
||||
"rustls-pki-types",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "2.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d"
|
||||
dependencies = [
|
||||
"base64 0.22.0",
|
||||
"base64 0.22.1",
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
@@ -4367,7 +4380,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "936cac0ab331b14cb3921c62156d913e4c15b74fb6ec0f3146bd4ef6e4fb3c12"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64 0.22.0",
|
||||
"base64 0.22.1",
|
||||
"bitflags 2.4.1",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
@@ -4410,7 +4423,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9734dbce698c67ecf67c442f768a5e90a49b2a4d61a9f1d59f73874bd4cf0710"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64 0.22.0",
|
||||
"base64 0.22.1",
|
||||
"bitflags 2.4.1",
|
||||
"byteorder",
|
||||
"crc",
|
||||
@@ -5021,8 +5034,11 @@ dependencies = [
|
||||
name = "tauri-plugin-websocket"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"futures-util",
|
||||
"http 1.0.0",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"log",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
@@ -5294,6 +5310,17 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.26.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.14"
|
||||
@@ -5313,10 +5340,13 @@ checksum = "489a59b6730eda1b0171fcfda8b121f4bee2b35cba8645ca35c5f7ba3eb736c1"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"native-tls",
|
||||
"rustls",
|
||||
"rustls-native-certs",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-rustls",
|
||||
"tungstenite",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5389,7 +5419,6 @@ dependencies = [
|
||||
"tokio",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5484,8 +5513,9 @@ dependencies = [
|
||||
"http 1.0.0",
|
||||
"httparse",
|
||||
"log",
|
||||
"native-tls",
|
||||
"rand 0.9.1",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"sha1",
|
||||
"thiserror 2.0.9",
|
||||
"utf-8",
|
||||
|
||||
@@ -19,4 +19,10 @@ http = "1"
|
||||
rand = "0.8"
|
||||
futures-util = "0.3"
|
||||
tokio = { version = "1", features = ["net", "sync"] }
|
||||
tokio-tungstenite = { version = "0.27", features = ["native-tls"] }
|
||||
tokio-tungstenite = { version = "0.27", features = ["rustls-tls-webpki-roots"] }
|
||||
hyper = { version = "1", features = ["client"] }
|
||||
hyper-util = { version = "0.1", features = ["tokio", "http1"] }
|
||||
base64 = "0.22"
|
||||
|
||||
[features]
|
||||
rustls-tls-native-roots = ["tokio-tungstenite/rustls-tls-native-roots"]
|
||||
+156
-13
@@ -1,16 +1,23 @@
|
||||
use base64::prelude::{Engine, BASE64_STANDARD};
|
||||
use futures_util::{stream::SplitSink, SinkExt, StreamExt};
|
||||
use http::header::{HeaderName, HeaderValue};
|
||||
use http::{
|
||||
header::{HeaderName, HeaderValue},
|
||||
Request,
|
||||
};
|
||||
use hyper::client::conn;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use serde::{ser::Serializer, Deserialize, Serialize};
|
||||
use tauri::{
|
||||
api::ipc::{format_callback, CallbackFn},
|
||||
plugin::{Builder as PluginBuilder, TauriPlugin},
|
||||
Manager, Runtime, State, Window,
|
||||
AppHandle, Manager, Runtime, State, Window,
|
||||
};
|
||||
use tokio::{net::TcpStream, sync::Mutex};
|
||||
use tokio_tungstenite::{
|
||||
connect_async_tls_with_config,
|
||||
client_async_tls_with_config, connect_async_tls_with_config,
|
||||
tungstenite::{
|
||||
client::IntoClientRequest,
|
||||
error::UrlError,
|
||||
protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig},
|
||||
Message,
|
||||
},
|
||||
@@ -19,10 +26,12 @@ use tokio_tungstenite::{
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
|
||||
type Id = u32;
|
||||
type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
type WebSocketWriter = SplitSink<WebSocket, Message>;
|
||||
type WebSocketWriter =
|
||||
SplitSink<WebSocketStream<tokio_tungstenite::MaybeTlsStream<TcpStream>>, Message>;
|
||||
type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -35,6 +44,16 @@ enum Error {
|
||||
InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue),
|
||||
#[error(transparent)]
|
||||
InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName),
|
||||
#[error(transparent)]
|
||||
ProxyConnection(#[from] hyper::Error),
|
||||
#[error("proxy returned status code: {0}")]
|
||||
ProxyStatus(u16),
|
||||
#[error(transparent)]
|
||||
ProxyIo(#[from] std::io::Error),
|
||||
#[error(transparent)]
|
||||
ProxyHttp(#[from] http::Error),
|
||||
#[error(transparent)]
|
||||
ProxyJoinHandle(#[from] tokio::task::JoinError),
|
||||
}
|
||||
|
||||
impl Serialize for Error {
|
||||
@@ -49,7 +68,27 @@ impl Serialize for Error {
|
||||
#[derive(Default)]
|
||||
struct ConnectionManager(Mutex<HashMap<Id, WebSocketWriter>>);
|
||||
|
||||
struct TlsConnector(Mutex<Option<Connector>>);
|
||||
struct TlsConnector(StdMutex<Option<Connector>>);
|
||||
struct ProxyConfigurationInternal(StdMutex<Option<ProxyConfiguration>>);
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProxyAuth {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl ProxyAuth {
|
||||
pub fn encode(&self) -> String {
|
||||
BASE64_STANDARD.encode(format!("{}:{}", self.username, self.password))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProxyConfiguration {
|
||||
pub proxy_url: String,
|
||||
pub proxy_port: u16,
|
||||
pub auth: Option<ProxyAuth>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged, rename_all = "camelCase")]
|
||||
@@ -133,10 +172,6 @@ async fn connect<R: Runtime>(
|
||||
) -> Result<Id> {
|
||||
let id = rand::random();
|
||||
let mut request = url.into_client_request()?;
|
||||
let tls_connector = match window.try_state::<TlsConnector>() {
|
||||
Some(tls_connector) => tls_connector.0.lock().await.clone(),
|
||||
None => None,
|
||||
};
|
||||
|
||||
if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) {
|
||||
for (k, v) in headers {
|
||||
@@ -146,9 +181,21 @@ async fn connect<R: Runtime>(
|
||||
}
|
||||
}
|
||||
|
||||
let (ws_stream, _) =
|
||||
let tls_connector = window
|
||||
.try_state::<TlsConnector>()
|
||||
.and_then(|c| c.0.lock().unwrap().clone());
|
||||
|
||||
let proxy_config = window
|
||||
.try_state::<ProxyConfigurationInternal>()
|
||||
.and_then(|c| c.0.lock().unwrap().clone());
|
||||
|
||||
let ws_stream = if let Some(proxy_config) = proxy_config {
|
||||
connect_using_proxy(request, config, proxy_config, tls_connector).await?
|
||||
} else {
|
||||
connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector)
|
||||
.await?;
|
||||
.await?
|
||||
.0
|
||||
};
|
||||
|
||||
tauri::async_runtime::spawn(async move {
|
||||
let (write, read) = ws_stream.split();
|
||||
@@ -182,7 +229,7 @@ async fn connect<R: Runtime>(
|
||||
})))
|
||||
.unwrap()
|
||||
}
|
||||
Ok(Message::Frame(_)) => serde_json::Value::Null, // This value can't be recieved.
|
||||
Ok(Message::Frame(_)) => serde_json::Value::Null, // This value can't be received.
|
||||
Err(e) => serde_json::to_value(Error::from(e)).unwrap(),
|
||||
};
|
||||
let js = format_callback(callback_function, &response)
|
||||
@@ -196,6 +243,62 @@ async fn connect<R: Runtime>(
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
async fn connect_using_proxy(
|
||||
request: Request<()>,
|
||||
config: Option<ConnectionConfig>,
|
||||
proxy_config: ProxyConfiguration,
|
||||
tls_connector: Option<Connector>,
|
||||
) -> Result<WebSocket> {
|
||||
let domain = domain(&request)?;
|
||||
let port = request
|
||||
.uri()
|
||||
.port_u16()
|
||||
.or_else(|| match request.uri().scheme_str() {
|
||||
Some("wss") => Some(443),
|
||||
Some("ws") => Some(80),
|
||||
_ => None,
|
||||
})
|
||||
.ok_or(Error::Websocket(
|
||||
tokio_tungstenite::tungstenite::Error::Url(UrlError::UnsupportedUrlScheme),
|
||||
))?;
|
||||
|
||||
let tcp = TcpStream::connect(format!(
|
||||
"{}:{}",
|
||||
proxy_config.proxy_url, proxy_config.proxy_port
|
||||
))
|
||||
.await?;
|
||||
let io = TokioIo::new(tcp);
|
||||
|
||||
let (mut request_sender, proxy_connection) =
|
||||
conn::http1::handshake::<TokioIo<tokio::net::TcpStream>, String>(io).await?;
|
||||
let proxy_connection_task = tokio::spawn(proxy_connection.without_shutdown());
|
||||
|
||||
let addr = format!("{domain}:{port}");
|
||||
let mut req_builder = Request::connect(addr);
|
||||
|
||||
if let Some(auth) = proxy_config.auth {
|
||||
req_builder = req_builder.header("Proxy-Authorization", format!("Basic {}", auth.encode()));
|
||||
}
|
||||
|
||||
// TODO: This looks super fishy
|
||||
let req = req_builder.body("".to_string())?;
|
||||
let res = request_sender.send_request(req).await?;
|
||||
if !res.status().is_success() {
|
||||
return Err(Error::ProxyStatus(res.status().as_u16()));
|
||||
}
|
||||
|
||||
let proxied_tcp_socket = proxy_connection_task.await??.io.into_inner();
|
||||
let (ws_stream, _) = client_async_tls_with_config(
|
||||
request,
|
||||
proxied_tcp_socket,
|
||||
config.map(Into::into),
|
||||
tls_connector,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(ws_stream)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn send(
|
||||
manager: State<'_, ConnectionManager>,
|
||||
@@ -228,12 +331,14 @@ pub fn init<R: Runtime>() -> TauriPlugin<R> {
|
||||
#[derive(Default)]
|
||||
pub struct Builder {
|
||||
tls_connector: Option<Connector>,
|
||||
proxy_configuration: Option<ProxyConfiguration>,
|
||||
}
|
||||
|
||||
impl Builder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tls_connector: None,
|
||||
proxy_configuration: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,14 +347,52 @@ impl Builder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn proxy_configuration(mut self, proxy_configuration: ProxyConfiguration) -> Self {
|
||||
self.proxy_configuration.replace(proxy_configuration);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
|
||||
PluginBuilder::new("websocket")
|
||||
.invoke_handler(tauri::generate_handler![connect, send])
|
||||
.setup(|app| {
|
||||
app.manage(ConnectionManager::default());
|
||||
app.manage(TlsConnector(Mutex::new(self.tls_connector)));
|
||||
app.manage(TlsConnector(StdMutex::new(self.tls_connector)));
|
||||
app.manage(ProxyConfigurationInternal(StdMutex::new(
|
||||
self.proxy_configuration,
|
||||
)));
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.build()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn reconfigure_proxy(app: &AppHandle, proxy_config: Option<ProxyConfiguration>) {
|
||||
if let Some(state) = app.try_state::<ProxyConfigurationInternal>() {
|
||||
*state.0.lock().unwrap() = proxy_config;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn reconfigure_tls_connector(app: &AppHandle, tls_connector: Option<Connector>) {
|
||||
if let Some(state) = app.try_state::<TlsConnector>() {
|
||||
*state.0.lock().unwrap() = tls_connector;
|
||||
}
|
||||
}
|
||||
|
||||
// Copied from tokio-tungstenite internal function (tokio-tungstenite/src/lib.rs) with the same name
|
||||
// Get a domain from an URL.
|
||||
#[allow(clippy::result_large_err)]
|
||||
#[inline]
|
||||
fn domain(
|
||||
request: &tokio_tungstenite::tungstenite::handshake::client::Request,
|
||||
) -> tokio_tungstenite::tungstenite::Result<String, tokio_tungstenite::tungstenite::Error> {
|
||||
match request.uri().host() {
|
||||
// rustls expects IPv6 addresses without the surrounding [] brackets
|
||||
Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
|
||||
Some(d) => Ok(d.to_string()),
|
||||
None => Err(tokio_tungstenite::tungstenite::Error::Url(
|
||||
tokio_tungstenite::tungstenite::error::UrlError::NoHostName,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user