From dbba4d663011158efb9b2a0f8a51c6e7c25171ce Mon Sep 17 00:00:00 2001 From: zarzet Date: Mon, 13 Apr 2026 22:20:17 +0700 Subject: [PATCH] feat: propagate download cancel to extension HTTP requests and fix SAF filename extension mismatch - Bind cancel context to all extension HTTP calls (fetch, httpGet, httpPost, httpRequest, fileDownload, authExchangeCodeWithPKCE) so in-flight requests are aborted when user cancels a download - Make initDownloadCancel idempotent: return existing context if entry already exists and preserve pre-cancelled state - Force SAF output filename to match actual file extension when extension returns a different format than requested (e.g. FLAC requested but M4A produced) - Map ALAC/AAC quality to .m4a instead of falling through to default .flac --- .../kotlin/com/zarz/spotiflac/MainActivity.kt | 33 ++++++++- go_backend/cancel.go | 14 ++++ go_backend/extension_providers.go | 4 ++ go_backend/extension_runtime.go | 13 ++++ go_backend/extension_runtime_auth.go | 1 + go_backend/extension_runtime_file.go | 1 + go_backend/extension_runtime_http.go | 4 ++ go_backend/extension_runtime_polyfills.go | 1 + go_backend/extension_test.go | 72 +++++++++++++++++++ lib/providers/download_queue_provider.dart | 1 + 10 files changed, 141 insertions(+), 3 deletions(-) diff --git a/android/app/src/main/kotlin/com/zarz/spotiflac/MainActivity.kt b/android/app/src/main/kotlin/com/zarz/spotiflac/MainActivity.kt index 5b3e1ae9..520d90cd 100644 --- a/android/app/src/main/kotlin/com/zarz/spotiflac/MainActivity.kt +++ b/android/app/src/main/kotlin/com/zarz/spotiflac/MainActivity.kt @@ -308,6 +308,21 @@ class MainActivity: FlutterFragmentActivity() { } } + private fun forceFilenameExt(name: String, outputExt: String): String { + val normalizedExt = normalizeExt(outputExt) + if (normalizedExt.isBlank()) return sanitizeFilename(name) + + val safeName = sanitizeFilename(name) + val lower = safeName.lowercase(Locale.ROOT) + val knownExts = listOf(".flac", ".m4a", ".mp3", ".opus", ".lrc") + for (knownExt in knownExts) { + if (lower.endsWith(knownExt)) { + return safeName.dropLast(knownExt.length) + normalizedExt + } + } + return safeName + normalizedExt + } + private fun sanitizeFilename(name: String): String { var sanitized = name .replace("/", " ") @@ -617,12 +632,12 @@ class MainActivity: FlutterFragmentActivity() { private fun buildSafFileName(req: JSONObject, outputExt: String): String { val provided = req.optString("saf_file_name", "") - if (provided.isNotBlank()) return sanitizeFilename(provided) + if (provided.isNotBlank()) return forceFilenameExt(provided, outputExt) val trackName = req.optString("track_name", "track") val artistName = req.optString("artist_name", "") val baseName = if (artistName.isNotBlank()) "$artistName - $trackName" else trackName - return sanitizeFilename(baseName) + outputExt + return forceFilenameExt(baseName, outputExt) } private fun errorJson(message: String): String { @@ -937,7 +952,7 @@ class MainActivity: FlutterFragmentActivity() { ?: return errorJson("Failed to access SAF directory") val existingFile = targetDir.findFile(fileName) - val document = existingFile ?: targetDir.createFile(mimeType, fileName) + var document = existingFile ?: targetDir.createFile(mimeType, fileName) ?: return errorJson("Failed to create SAF file") val pfd = contentResolver.openFileDescriptor(document.uri, "rw") @@ -965,6 +980,18 @@ class MainActivity: FlutterFragmentActivity() { if (!srcFile.exists() || srcFile.length() <= 0) { throw IllegalStateException("extension output missing or empty: $goFilePath") } + val actualExt = normalizeExt(srcFile.extension) + if (actualExt.isNotBlank() && actualExt != outputExt) { + val actualFileName = buildSafFileName(req, actualExt) + val actualMimeType = mimeTypeForExt(actualExt) + val replacement = targetDir.findFile(actualFileName) + ?: targetDir.createFile(actualMimeType, actualFileName) + ?: throw IllegalStateException("failed to create SAF output with actual extension") + if (replacement.uri != document.uri) { + document.delete() + document = replacement + } + } contentResolver.openOutputStream(document.uri, "wt")?.use { output -> srcFile.inputStream().use { input -> input.copyTo(output) diff --git a/go_backend/cancel.go b/go_backend/cancel.go index 9dc3c28e..25f69cca 100644 --- a/go_backend/cancel.go +++ b/go_backend/cancel.go @@ -10,6 +10,7 @@ import ( var ErrDownloadCancelled = errors.New("download cancelled") type cancelEntry struct { + ctx context.Context cancel context.CancelFunc canceled bool } @@ -27,8 +28,21 @@ func initDownloadCancel(itemID string) context.Context { cancelMu.Lock() defer cancelMu.Unlock() + if entry, ok := cancelMap[itemID]; ok { + if entry.ctx == nil { + ctx, cancel := context.WithCancel(context.Background()) + entry.ctx = ctx + entry.cancel = cancel + if entry.canceled && entry.cancel != nil { + entry.cancel() + } + } + return entry.ctx + } + ctx, cancel := context.WithCancel(context.Background()) cancelMap[itemID] = &cancelEntry{ + ctx: ctx, cancel: cancel, canceled: false, } diff --git a/go_backend/extension_providers.go b/go_backend/extension_providers.go index 7741ae17..2137de82 100644 --- a/go_backend/extension_providers.go +++ b/go_backend/extension_providers.go @@ -615,6 +615,10 @@ func (p *extensionProviderWrapper) Download(trackID, quality, outputPath, itemID p.extension.runtime.setActiveDownloadItemID(itemID) defer p.extension.runtime.clearActiveDownloadItemID() } + if itemID != "" { + initDownloadCancel(itemID) + defer clearDownloadCancel(itemID) + } p.vm.Set("__onProgress", func(call goja.FunctionCall) goja.Value { if len(call.Arguments) > 0 { diff --git a/go_backend/extension_runtime.go b/go_backend/extension_runtime.go index 7c77a43b..0d8b4fdf 100644 --- a/go_backend/extension_runtime.go +++ b/go_backend/extension_runtime.go @@ -160,6 +160,19 @@ func (r *extensionRuntime) getActiveDownloadItemID() string { return r.activeDownloadItemID } +func (r *extensionRuntime) bindDownloadCancelContext(req *http.Request) *http.Request { + if req == nil { + return nil + } + + itemID := r.getActiveDownloadItemID() + if itemID == "" { + return req + } + + return req.WithContext(initDownloadCancel(itemID)) +} + func newExtensionHTTPClient(ext *loadedExtension, jar http.CookieJar, timeout time.Duration) *http.Client { // Extension sandbox enforces HTTPS-only domains. Do not apply global // allow_http scheme downgrade here, because some extension APIs (e.g. diff --git a/go_backend/extension_runtime_auth.go b/go_backend/extension_runtime_auth.go index 4be66422..afa76d52 100644 --- a/go_backend/extension_runtime_auth.go +++ b/go_backend/extension_runtime_auth.go @@ -458,6 +458,7 @@ func (r *extensionRuntime) authExchangeCodeWithPKCE(call goja.FunctionCall) goja "error": err.Error(), }) } + req = r.bindDownloadCancelContext(req) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("User-Agent", "SpotiFLAC-Extension/1.0") diff --git a/go_backend/extension_runtime_file.go b/go_backend/extension_runtime_file.go index 1ea0f9b3..7a54e26d 100644 --- a/go_backend/extension_runtime_file.go +++ b/go_backend/extension_runtime_file.go @@ -166,6 +166,7 @@ func (r *extensionRuntime) fileDownload(call goja.FunctionCall) goja.Value { "error": err.Error(), }) } + req = r.bindDownloadCancelContext(req) for k, v := range headers { req.Header.Set(k, v) diff --git a/go_backend/extension_runtime_http.go b/go_backend/extension_runtime_http.go index 35552b3d..c78fffea 100644 --- a/go_backend/extension_runtime_http.go +++ b/go_backend/extension_runtime_http.go @@ -81,6 +81,7 @@ func (r *extensionRuntime) httpGet(call goja.FunctionCall) goja.Value { "error": err.Error(), }) } + req = r.bindDownloadCancelContext(req) for k, v := range headers { req.Header.Set(k, v) @@ -175,6 +176,7 @@ func (r *extensionRuntime) httpPost(call goja.FunctionCall) goja.Value { "error": err.Error(), }) } + req = r.bindDownloadCancelContext(req) for k, v := range headers { req.Header.Set(k, v) @@ -284,6 +286,7 @@ func (r *extensionRuntime) httpRequest(call goja.FunctionCall) goja.Value { "error": err.Error(), }) } + req = r.bindDownloadCancelContext(req) for k, v := range headers { req.Header.Set(k, v) @@ -410,6 +413,7 @@ func (r *extensionRuntime) httpMethodShortcut(method string, call goja.FunctionC "error": err.Error(), }) } + req = r.bindDownloadCancelContext(req) for k, v := range headers { req.Header.Set(k, v) diff --git a/go_backend/extension_runtime_polyfills.go b/go_backend/extension_runtime_polyfills.go index d70467aa..5f5cb4a5 100644 --- a/go_backend/extension_runtime_polyfills.go +++ b/go_backend/extension_runtime_polyfills.go @@ -69,6 +69,7 @@ func (r *extensionRuntime) fetchPolyfill(call goja.FunctionCall) goja.Value { if err != nil { return r.createFetchError(err.Error()) } + req = r.bindDownloadCancelContext(req) for k, v := range headers { req.Header.Set(k, v) diff --git a/go_backend/extension_test.go b/go_backend/extension_test.go index c7d7074d..e80b14e3 100644 --- a/go_backend/extension_test.go +++ b/go_backend/extension_test.go @@ -1,8 +1,10 @@ package gobackend import ( + "net/http" "path/filepath" "testing" + "time" "github.com/dop251/goja" ) @@ -290,6 +292,76 @@ func TestExtensionRuntime_UtilityFunctions(t *testing.T) { } } +func TestExtensionRuntime_BindDownloadCancelContext(t *testing.T) { + ext := &loadedExtension{ + ID: "test-ext", + Manifest: &ExtensionManifest{ + Name: "test-ext", + }, + DataDir: t.TempDir(), + } + + runtime := newExtensionRuntime(ext) + runtime.setActiveDownloadItemID("test-item") + t.Cleanup(func() { + clearDownloadCancel("test-item") + runtime.clearActiveDownloadItemID() + }) + + req, err := http.NewRequest("GET", "https://api.example.com/test", nil) + if err != nil { + t.Fatalf("NewRequest failed: %v", err) + } + + req = runtime.bindDownloadCancelContext(req) + cancelDownload("test-item") + + select { + case <-req.Context().Done(): + case <-time.After(500 * time.Millisecond): + t.Fatal("Expected bound request context to be cancelled") + } + + if req.Context().Err() == nil { + t.Fatal("Expected request context error after cancellation") + } +} + +func TestExtensionRuntime_BindDownloadCancelContextPreservesPreCancelledState(t *testing.T) { + ext := &loadedExtension{ + ID: "test-ext", + Manifest: &ExtensionManifest{ + Name: "test-ext", + }, + DataDir: t.TempDir(), + } + + runtime := newExtensionRuntime(ext) + runtime.setActiveDownloadItemID("test-item") + cancelDownload("test-item") + t.Cleanup(func() { + clearDownloadCancel("test-item") + runtime.clearActiveDownloadItemID() + }) + + req, err := http.NewRequest("GET", "https://api.example.com/test", nil) + if err != nil { + t.Fatalf("NewRequest failed: %v", err) + } + + req = runtime.bindDownloadCancelContext(req) + + select { + case <-req.Context().Done(): + case <-time.After(500 * time.Millisecond): + t.Fatal("Expected pre-cancelled request context to stay cancelled") + } + + if req.Context().Err() == nil { + t.Fatal("Expected request context error for pre-cancelled item") + } +} + func TestExtensionRuntime_SSRFProtection(t *testing.T) { // Create extension with limited network permissions ext := &loadedExtension{ diff --git a/lib/providers/download_queue_provider.dart b/lib/providers/download_queue_provider.dart index d2506bf6..63da24db 100644 --- a/lib/providers/download_queue_provider.dart +++ b/lib/providers/download_queue_provider.dart @@ -2378,6 +2378,7 @@ class DownloadQueueNotifier extends Notifier { return '.m4a'; } final q = quality.toLowerCase(); + if (q == 'alac' || q.startsWith('aac')) return '.m4a'; if (q.startsWith('opus')) return '.opus'; if (q.startsWith('mp3')) return '.mp3'; return '.flac';