fix: channel callback never cleaned up from window (#13136)

* Fix channel cb never cleaned up from `window`

* Should be `_{id}`

* Still need to manually impl clone

* Regenerate bundle.global.js

* Remove current_index from ChannelInner

* Move phantom to `Channel`

* `Channel` not `Self`

* Clean up

* Clean up

* Fix missing end quote

* Add change file

* Rename id to index to match js side

* Improve channel speed on small data

* do the same perf check for IPC responses and raw bytes

---------

Co-authored-by: Lucas Nogueira <lucas@tauri.app>
This commit is contained in:
Tony
2025-04-12 21:31:25 +08:00
committed by GitHub
parent 0aa48fb9e4
commit 66e6325f43
10 changed files with 219 additions and 115 deletions

View File

@@ -0,0 +1,5 @@
---
"@tauri-apps/api": minor:feat
---
Allow passing the callback as the parameter of constructor of `Channel` so you can use it like this `new Channel((message) => console.log(message))`

View File

@@ -0,0 +1,6 @@
---
"tauri": minor:bug
"@tauri-apps/api": minor:bug
---
Fix `Channel`'s callback attached to `window` never cleaned up

View File

@@ -0,0 +1,5 @@
---
"tauri": minor:perf
---
Improve `Channel`'s performance when sending small amount of data (e.g. sending a number)

18
Cargo.lock generated
View File

@@ -1355,7 +1355,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c"
dependencies = [
"lazy_static",
"windows-sys 0.59.0",
"windows-sys 0.48.0",
]
[[package]]
@@ -2347,7 +2347,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
dependencies = [
"libc",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -5792,7 +5792,7 @@ dependencies = [
"aes-gcm",
"aes-kw",
"argon2",
"base64 0.22.1",
"base64 0.21.7",
"bitfield",
"block-padding",
"blowfish",
@@ -6372,7 +6372,7 @@ dependencies = [
"once_cell",
"socket2",
"tracing",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -7057,7 +7057,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.4.15",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -8917,7 +8917,7 @@ dependencies = [
"getrandom 0.2.15",
"once_cell",
"rustix 0.38.43",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -10178,7 +10178,7 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.48.0",
]
[[package]]
@@ -10816,9 +10816,9 @@ checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51"
[[package]]
name = "wry"
version = "0.51.0"
version = "0.51.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c6bf9f41e2585a730fa981b40592ac22d5ca4c3fea549a21f2f97a648ad34f4"
checksum = "48846531c50ee2e209a396ddd24af04ca1584be814e750fb81b395c8e7983ff9"
dependencies = [
"base64 0.22.1",
"block2 0.6.0",

File diff suppressed because one or more lines are too long

View File

@@ -21,8 +21,8 @@
Object.defineProperty(window.__TAURI_INTERNALS__, 'transformCallback', {
value: function transformCallback(callback, once) {
var identifier = uid()
var prop = `_${identifier}`
const identifier = uid()
const prop = `_${identifier}`
Object.defineProperty(window, prop, {
value: (result) => {
@@ -56,15 +56,11 @@
Object.defineProperty(window.__TAURI_INTERNALS__, 'invoke', {
value: function (cmd, payload = {}, options) {
return new Promise(function (resolve, reject) {
const callback = window.__TAURI_INTERNALS__.transformCallback(function (
r
) {
const callback = window.__TAURI_INTERNALS__.transformCallback((r) => {
resolve(r)
delete window[`_${error}`]
}, true)
const error = window.__TAURI_INTERNALS__.transformCallback(function (
e
) {
const error = window.__TAURI_INTERNALS__.transformCallback((e) => {
reject(e)
delete window[`_${callback}`]
}, true)

View File

@@ -6,9 +6,9 @@
(function (message) {
if (
message instanceof ArrayBuffer ||
ArrayBuffer.isView(message) ||
Array.isArray(message)
message instanceof ArrayBuffer
|| ArrayBuffer.isView(message)
|| Array.isArray(message)
) {
return {
contentType: 'application/octet-stream',
@@ -27,15 +27,13 @@
return Array.from(val)
} else if (val instanceof ArrayBuffer) {
return Array.from(new Uint8Array(val))
} else if (typeof val === "object" && val !== null && SERIALIZE_TO_IPC_FN in val) {
return val[SERIALIZE_TO_IPC_FN]()
} else if (
val instanceof Object &&
'__TAURI_CHANNEL_MARKER__' in val &&
typeof val.id === 'number'
typeof val === 'object'
&& val !== null
&& SERIALIZE_TO_IPC_FN in val
) {
return `__CHANNEL__:${val.id}`
} else {
return val[SERIALIZE_TO_IPC_FN]()
} else {
return val
}
})

View File

@@ -27,19 +27,24 @@ pub const IPC_PAYLOAD_PREFIX: &str = "__CHANNEL__:";
pub const CHANNEL_PLUGIN_NAME: &str = "__TAURI_CHANNEL__";
// TODO: Change this to `plugin:channel|fetch` in v3
pub const FETCH_CHANNEL_DATA_COMMAND: &str = "plugin:__TAURI_CHANNEL__|fetch";
pub(crate) const CHANNEL_ID_HEADER_NAME: &str = "Tauri-Channel-Id";
const CHANNEL_ID_HEADER_NAME: &str = "Tauri-Channel-Id";
/// Maximum size a JSON we should send directly without going through the fetch process
// 8192 byte JSON payload runs roughly 2x faster through eval than through fetch on WebView2 v135
const MAX_JSON_DIRECT_EXECUTE_THRESHOLD: usize = 8192;
// 1024 byte payload runs roughly 30% faster through eval than through fetch on macOS
const MAX_RAW_DIRECT_EXECUTE_THRESHOLD: usize = 1024;
static CHANNEL_COUNTER: AtomicU32 = AtomicU32::new(0);
static CHANNEL_DATA_COUNTER: AtomicU32 = AtomicU32::new(0);
/// Maps a channel id to a pending data that must be send to the JavaScript side via the IPC.
#[derive(Default, Clone)]
pub struct ChannelDataIpcQueue(pub(crate) Arc<Mutex<HashMap<u32, InvokeResponseBody>>>);
pub struct ChannelDataIpcQueue(Arc<Mutex<HashMap<u32, InvokeResponseBody>>>);
/// An IPC channel.
pub struct Channel<TSend = InvokeResponseBody> {
id: u32,
on_message: Arc<dyn Fn(InvokeResponseBody) -> crate::Result<()> + Send + Sync>,
inner: Arc<ChannelInner>,
phantom: std::marker::PhantomData<TSend>,
}
@@ -53,9 +58,25 @@ const _: () = {
impl<TSend> Clone for Channel<TSend> {
fn clone(&self) -> Self {
Self {
id: self.id,
on_message: self.on_message.clone(),
phantom: Default::default(),
inner: self.inner.clone(),
phantom: self.phantom,
}
}
}
type OnDropFn = Option<Box<dyn Fn() + Send + Sync + 'static>>;
type OnMessageFn = Box<dyn Fn(InvokeResponseBody) -> crate::Result<()> + Send + Sync>;
struct ChannelInner {
id: u32,
on_message: OnMessageFn,
on_drop: OnDropFn,
}
impl Drop for ChannelInner {
fn drop(&mut self) {
if let Some(on_drop) = &self.on_drop {
on_drop();
}
}
}
@@ -65,7 +86,7 @@ impl<TSend> Serialize for Channel<TSend> {
where
S: Serializer,
{
serializer.serialize_str(&format!("{IPC_PAYLOAD_PREFIX}{}", self.id))
serializer.serialize_str(&format!("{IPC_PAYLOAD_PREFIX}{}", self.inner.id))
}
}
@@ -97,9 +118,9 @@ impl FromStr for JavaScriptChannelId {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.split_once(IPC_PAYLOAD_PREFIX)
s.strip_prefix(IPC_PAYLOAD_PREFIX)
.ok_or("invalid channel string")
.and_then(|(_prefix, id)| id.parse().map_err(|_| "invalid channel ID"))
.and_then(|id| id.parse().map_err(|_| "invalid channel ID"))
.map(|id| Self(CallbackFn(id)))
}
}
@@ -107,34 +128,63 @@ impl FromStr for JavaScriptChannelId {
impl JavaScriptChannelId {
/// Gets a [`Channel`] for this channel ID on the given [`Webview`].
pub fn channel_on<R: Runtime, TSend>(&self, webview: Webview<R>) -> Channel<TSend> {
let callback_id = self.0;
let counter = AtomicUsize::new(0);
let callback_fn = self.0;
let callback_id = callback_fn.0;
Channel::new_with_id(callback_id.0, move |body| {
let i = counter.fetch_add(1, Ordering::Relaxed);
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let webview_clone = webview.clone();
if let Some(interceptor) = &webview.manager.channel_interceptor {
if interceptor(&webview, callback_id, i, &body) {
return Ok(());
Channel::new_with_id(
callback_id,
Box::new(move |body| {
let current_index = counter.fetch_add(1, Ordering::Relaxed);
if let Some(interceptor) = &webview.manager.channel_interceptor {
if interceptor(&webview, callback_fn, current_index, &body) {
return Ok(());
}
}
}
let data_id = CHANNEL_DATA_COUNTER.fetch_add(1, Ordering::Relaxed);
match body {
// Don't go through the fetch process if the payload is small
InvokeResponseBody::Json(string) if string.len() < MAX_JSON_DIRECT_EXECUTE_THRESHOLD => {
webview.eval(format!(
"window['_{callback_id}']({{ message: {string}, index: {current_index} }})"
))?;
}
InvokeResponseBody::Raw(bytes) if bytes.len() < MAX_RAW_DIRECT_EXECUTE_THRESHOLD => {
webview.eval(format!(
"window['_{callback_id}']({{ message: {}, index: {current_index} }})",
serde_json::to_string(&bytes)?,
))?;
}
// use the fetch API to speed up larger response payloads
_ => {
let data_id = CHANNEL_DATA_COUNTER.fetch_add(1, Ordering::Relaxed);
webview
.state::<ChannelDataIpcQueue>()
.0
.lock()
.unwrap()
.insert(data_id, body);
webview
.state::<ChannelDataIpcQueue>()
.0
.lock()
.unwrap()
.insert(data_id, body);
webview.eval(format!(
"window.__TAURI_INTERNALS__.invoke('{FETCH_CHANNEL_DATA_COMMAND}', null, {{ headers: {{ '{CHANNEL_ID_HEADER_NAME}': '{data_id}' }} }}).then((response) => window['_' + {}]({{ message: response, id: {i} }})).catch(console.error)",
callback_id.0
))?;
webview.eval(format!(
"window.__TAURI_INTERNALS__.invoke('{FETCH_CHANNEL_DATA_COMMAND}', null, {{ headers: {{ '{CHANNEL_ID_HEADER_NAME}': '{data_id}' }} }}).then((response) => window['_{callback_id}']({{ message: response, index: {current_index} }})).catch(console.error)",
))?;
}
}
Ok(())
})
Ok(())
}),
Some(Box::new(move || {
let current_index = counter_clone.load(Ordering::Relaxed);
let _ = webview_clone.eval(format!(
"window['_{callback_id}']({{ end: true, index: {current_index} }})",
));
})),
)
}
}
@@ -157,53 +207,76 @@ impl<TSend> Channel<TSend> {
pub fn new<F: Fn(InvokeResponseBody) -> crate::Result<()> + Send + Sync + 'static>(
on_message: F,
) -> Self {
Self::new_with_id(CHANNEL_COUNTER.fetch_add(1, Ordering::Relaxed), on_message)
Self::new_with_id(
CHANNEL_COUNTER.fetch_add(1, Ordering::Relaxed),
Box::new(on_message),
None,
)
}
fn new_with_id<F: Fn(InvokeResponseBody) -> crate::Result<()> + Send + Sync + 'static>(
id: u32,
on_message: F,
) -> Self {
fn new_with_id(id: u32, on_message: OnMessageFn, on_drop: OnDropFn) -> Self {
#[allow(clippy::let_and_return)]
let channel = Self {
id,
on_message: Arc::new(on_message),
inner: Arc::new(ChannelInner {
id,
on_message,
on_drop,
}),
phantom: Default::default(),
};
#[cfg(mobile)]
crate::plugin::mobile::register_channel(Channel {
id,
on_message: channel.on_message.clone(),
inner: channel.inner.clone(),
phantom: Default::default(),
});
channel
}
// This is used from the IPC handler
pub(crate) fn from_callback_fn<R: Runtime>(webview: Webview<R>, callback: CallbackFn) -> Self {
Channel::new_with_id(callback.0, move |body| {
let data_id = CHANNEL_DATA_COUNTER.fetch_add(1, Ordering::Relaxed);
let callback_id = callback.0;
Channel::new_with_id(
callback_id,
Box::new(move |body| {
match body {
// Don't go through the fetch process if the payload is small
InvokeResponseBody::Json(string) if string.len() < MAX_JSON_DIRECT_EXECUTE_THRESHOLD => {
webview.eval(format!("window['_{callback_id}']({string})"))?;
}
InvokeResponseBody::Raw(bytes) if bytes.len() < MAX_RAW_DIRECT_EXECUTE_THRESHOLD => {
webview.eval(format!(
"window['_{callback_id}']({})",
serde_json::to_string(&bytes)?,
))?;
}
// use the fetch API to speed up larger response payloads
_ => {
let data_id = CHANNEL_DATA_COUNTER.fetch_add(1, Ordering::Relaxed);
webview
.state::<ChannelDataIpcQueue>()
.0
.lock()
.unwrap()
.insert(data_id, body);
webview
.state::<ChannelDataIpcQueue>()
.0
.lock()
.unwrap()
.insert(data_id, body);
webview.eval(format!(
"window.__TAURI_INTERNALS__.invoke('{FETCH_CHANNEL_DATA_COMMAND}', null, {{ headers: {{ '{CHANNEL_ID_HEADER_NAME}': '{data_id}' }} }}).then((response) => window['_' + {}](response)).catch(console.error)",
callback.0
))?;
webview.eval(format!(
"window.__TAURI_INTERNALS__.invoke('{FETCH_CHANNEL_DATA_COMMAND}', null, {{ headers: {{ '{CHANNEL_ID_HEADER_NAME}': '{data_id}' }} }}).then((response) => window['_{callback_id}'](response)).catch(console.error)",
))?;
}
}
Ok(())
})
Ok(())
}),
None,
)
}
/// The channel identifier.
pub fn id(&self) -> u32 {
self.id
self.inner.id
}
/// Sends the given data through the channel.
@@ -211,7 +284,7 @@ impl<TSend> Channel<TSend> {
where
TSend: IpcResponse,
{
(self.on_message)(data.body()?)
(self.inner.on_message)(data.body()?)
}
}

View File

@@ -1467,7 +1467,6 @@ fn main() {
let resolver = InvokeResolver::new(
self.clone(),
Arc::new(Mutex::new(Some(Box::new(
#[allow(unused_variables)]
move |webview: Webview<R>, cmd, response, callback, error| {
responder(webview, cmd, response, callback, error);
},

View File

@@ -49,7 +49,6 @@
* }
* }
*
*
* type UserId = UserIdString | UserIdNumber
* ```
*
@@ -75,40 +74,64 @@ function transformCallback<T = unknown>(
}
class Channel<T = unknown> {
/** The callback id returned from {@linkcode transformCallback} */
id: number
// @ts-expect-error field used by the IPC serializer
private readonly __TAURI_CHANNEL_MARKER__ = true
#onmessage: (response: T) => void = () => {
// no-op
}
// the id is used as a mechanism to preserve message order
#nextMessageId = 0
#onmessage: (response: T) => void
// the index is used as a mechanism to preserve message order
#nextMessageIndex = 0
#pendingMessages: T[] = []
#messageEndIndex: number | undefined
constructor() {
this.id = transformCallback(
({ message, id }: { message: T; id: number }) => {
// Process the message if we're at the right order
if (id == this.#nextMessageId) {
this.#onmessage(message)
this.#nextMessageId += 1
constructor(onmessage?: (response: T) => void) {
this.#onmessage = onmessage || (() => {})
// process pending messages
while (this.#nextMessageId in this.#pendingMessages) {
const message = this.#pendingMessages[this.#nextMessageId]
this.#onmessage(message)
// eslint-disable-next-line @typescript-eslint/no-array-delete
delete this.#pendingMessages[this.#nextMessageId]
this.#nextMessageId += 1
}
this.id = transformCallback<
// Normal message
| { message: T; index: number }
// Message when the channel gets dropped in the rust side
| { end: true; index: number }
>((rawMessage) => {
const index = rawMessage.index
if ('end' in rawMessage) {
if (index == this.#nextMessageIndex) {
this.cleanupCallback()
} else {
this.#messageEndIndex = index
}
// Queue the message if we're not
else {
// eslint-disable-next-line security/detect-object-injection
this.#pendingMessages[id] = message
return
}
const message = rawMessage.message
// Process the message if we're at the right order
if (index == this.#nextMessageIndex) {
this.#onmessage(message)
this.#nextMessageIndex += 1
// process pending messages
while (this.#nextMessageIndex in this.#pendingMessages) {
const message = this.#pendingMessages[this.#nextMessageIndex]
this.#onmessage(message)
// eslint-disable-next-line @typescript-eslint/no-array-delete
delete this.#pendingMessages[this.#nextMessageIndex]
this.#nextMessageIndex += 1
}
if (this.#nextMessageIndex === this.#messageEndIndex) {
this.cleanupCallback()
}
}
)
// Queue the message if we're not
else {
// eslint-disable-next-line security/detect-object-injection
this.#pendingMessages[index] = message
}
})
}
private cleanupCallback() {
Reflect.deleteProperty(window, `_${this.id}`)
}
set onmessage(handler: (response: T) => void) {
@@ -160,8 +183,7 @@ async function addPluginListener<T>(
event: string,
cb: (payload: T) => void
): Promise<PluginListener> {
const handler = new Channel<T>()
handler.onmessage = cb
const handler = new Channel<T>(cb)
return invoke(`plugin:${plugin}|registerListener`, { event, handler }).then(
() => new PluginListener(plugin, event, handler.id)
)