diff --git a/go_backend/extension_providers.go b/go_backend/extension_providers.go index ecddfbcf..df7467f9 100644 --- a/go_backend/extension_providers.go +++ b/go_backend/extension_providers.go @@ -510,7 +510,7 @@ func (p *ExtensionProviderWrapper) GetDownloadURL(trackID, quality string) (*Ext const ExtDownloadTimeout = DownloadTimeout -func (p *ExtensionProviderWrapper) Download(trackID, quality, outputPath string, onProgress func(percent int)) (*ExtDownloadResult, error) { +func (p *ExtensionProviderWrapper) Download(trackID, quality, outputPath, itemID string, onProgress func(percent int)) (*ExtDownloadResult, error) { if !p.extension.Manifest.IsDownloadProvider() { return nil, fmt.Errorf("extension '%s' is not a download provider", p.extension.ID) } @@ -526,6 +526,10 @@ func (p *ExtensionProviderWrapper) Download(trackID, quality, outputPath string, }, nil } defer p.extension.VMMu.Unlock() + if p.extension.runtime != nil { + p.extension.runtime.setActiveDownloadItemID(itemID) + defer p.extension.runtime.clearActiveDownloadItemID() + } p.vm.Set("__onProgress", func(call goja.FunctionCall) goja.Value { if len(call.Arguments) > 0 { @@ -1128,7 +1132,7 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro StartItemProgress(req.ItemID) } - result, err := provider.Download(trackID, req.Quality, outputPath, func(percent int) { + result, err := provider.Download(trackID, req.Quality, outputPath, req.ItemID, func(percent int) { if req.ItemID != "" { normalized := float64(percent) / 100.0 if normalized < 0 { @@ -1356,7 +1360,7 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro StartItemProgress(req.ItemID) } - result, err := provider.Download(availability.TrackID, req.Quality, outputPath, func(percent int) { + result, err := provider.Download(availability.TrackID, req.Quality, outputPath, req.ItemID, func(percent int) { if req.ItemID != "" { normalized := float64(percent) / 100.0 if normalized < 0 { diff --git a/go_backend/extension_runtime.go b/go_backend/extension_runtime.go index 58068a42..7f2c0848 100644 --- a/go_backend/extension_runtime.go +++ b/go_backend/extension_runtime.go @@ -90,6 +90,9 @@ type ExtensionRuntime struct { dataDir string vm *goja.Runtime + activeDownloadMu sync.RWMutex + activeDownloadItemID string + storageMu sync.RWMutex storageCache map[string]interface{} storageLoaded bool @@ -139,6 +142,24 @@ func NewExtensionRuntime(ext *LoadedExtension) *ExtensionRuntime { return runtime } +func (r *ExtensionRuntime) setActiveDownloadItemID(itemID string) { + r.activeDownloadMu.Lock() + defer r.activeDownloadMu.Unlock() + r.activeDownloadItemID = strings.TrimSpace(itemID) +} + +func (r *ExtensionRuntime) clearActiveDownloadItemID() { + r.activeDownloadMu.Lock() + defer r.activeDownloadMu.Unlock() + r.activeDownloadItemID = "" +} + +func (r *ExtensionRuntime) getActiveDownloadItemID() string { + r.activeDownloadMu.RLock() + defer r.activeDownloadMu.RUnlock() + return r.activeDownloadItemID +} + func newExtensionHTTPClient(ext *LoadedExtension, jar http.CookieJar, timeout time.Duration) *http.Client { // Extension sandbox enforces HTTPS-only domains. Do not apply global // allow_http scheme downgrade here, because some extension APIs (e.g. diff --git a/go_backend/extension_runtime_file.go b/go_backend/extension_runtime_file.go index 92312c98..9e442aea 100644 --- a/go_backend/extension_runtime_file.go +++ b/go_backend/extension_runtime_file.go @@ -205,13 +205,22 @@ func (r *ExtensionRuntime) fileDownload(call goja.FunctionCall) goja.Value { defer out.Close() contentLength := resp.ContentLength + activeItemID := r.getActiveDownloadItemID() + if activeItemID != "" && contentLength > 0 { + SetItemBytesTotal(activeItemID, contentLength) + } + + var progressWriter interface{ Write([]byte) (int, error) } = out + if activeItemID != "" { + progressWriter = NewItemProgressWriter(out, activeItemID) + } var written int64 buf := make([]byte, 32*1024) for { nr, er := resp.Body.Read(buf) if nr > 0 { - nw, ew := out.Write(buf[0:nr]) + nw, ew := progressWriter.Write(buf[0:nr]) if nw < 0 || nr < nw { nw = 0 if ew == nil { @@ -220,6 +229,12 @@ func (r *ExtensionRuntime) fileDownload(call goja.FunctionCall) goja.Value { } written += int64(nw) if ew != nil { + if ew == ErrDownloadCancelled { + return r.vm.ToValue(map[string]interface{}{ + "success": false, + "error": "download cancelled", + }) + } return r.vm.ToValue(map[string]interface{}{ "success": false, "error": fmt.Sprintf("failed to write file: %v", ew), diff --git a/lib/models/download_item.dart b/lib/models/download_item.dart index db55029b..d8ac97cb 100644 --- a/lib/models/download_item.dart +++ b/lib/models/download_item.dart @@ -12,13 +12,7 @@ enum DownloadStatus { skipped, } -enum DownloadErrorType { - unknown, - notFound, - rateLimit, - network, - permission, -} +enum DownloadErrorType { unknown, notFound, rateLimit, network, permission } @JsonSerializable() class DownloadItem { @@ -28,7 +22,8 @@ class DownloadItem { final DownloadStatus status; final double progress; final double speedMBps; - final int bytesReceived; // Bytes downloaded so far (for unknown size downloads) + final int bytesReceived; // Bytes downloaded so far + final int bytesTotal; // Total bytes when the server provides content length final String? filePath; final String? error; final DownloadErrorType? errorType; @@ -44,6 +39,7 @@ class DownloadItem { this.progress = 0.0, this.speedMBps = 0.0, this.bytesReceived = 0, + this.bytesTotal = 0, this.filePath, this.error, this.errorType, @@ -60,6 +56,7 @@ class DownloadItem { double? progress, double? speedMBps, int? bytesReceived, + int? bytesTotal, String? filePath, String? error, DownloadErrorType? errorType, @@ -75,6 +72,7 @@ class DownloadItem { progress: progress ?? this.progress, speedMBps: speedMBps ?? this.speedMBps, bytesReceived: bytesReceived ?? this.bytesReceived, + bytesTotal: bytesTotal ?? this.bytesTotal, filePath: filePath ?? this.filePath, error: error ?? this.error, errorType: errorType ?? this.errorType, @@ -86,7 +84,7 @@ class DownloadItem { String get errorMessage { if (error == null) return ''; - + switch (errorType) { case DownloadErrorType.notFound: return 'Song not found on any service'; diff --git a/lib/models/download_item.g.dart b/lib/models/download_item.g.dart index 961e6d6d..7aee835c 100644 --- a/lib/models/download_item.g.dart +++ b/lib/models/download_item.g.dart @@ -16,6 +16,7 @@ DownloadItem _$DownloadItemFromJson(Map json) => DownloadItem( progress: (json['progress'] as num?)?.toDouble() ?? 0.0, speedMBps: (json['speedMBps'] as num?)?.toDouble() ?? 0.0, bytesReceived: (json['bytesReceived'] as num?)?.toInt() ?? 0, + bytesTotal: (json['bytesTotal'] as num?)?.toInt() ?? 0, filePath: json['filePath'] as String?, error: json['error'] as String?, errorType: $enumDecodeNullable(_$DownloadErrorTypeEnumMap, json['errorType']), @@ -33,6 +34,7 @@ Map _$DownloadItemToJson(DownloadItem instance) => 'progress': instance.progress, 'speedMBps': instance.speedMBps, 'bytesReceived': instance.bytesReceived, + 'bytesTotal': instance.bytesTotal, 'filePath': instance.filePath, 'error': instance.error, 'errorType': _$DownloadErrorTypeEnumMap[instance.errorType], diff --git a/lib/providers/download_queue_provider.dart b/lib/providers/download_queue_provider.dart index e157bec6..123a0c97 100644 --- a/lib/providers/download_queue_provider.dart +++ b/lib/providers/download_queue_provider.dart @@ -1166,12 +1166,14 @@ class _ProgressUpdate { final double progress; final double? speedMBps; final int? bytesReceived; + final int? bytesTotal; const _ProgressUpdate({ required this.status, required this.progress, this.speedMBps, this.bytesReceived, + this.bytesTotal, }); } @@ -1587,6 +1589,7 @@ class DownloadQueueNotifier extends Notifier { progress: normalizedProgress, speedMBps: normalizedSpeed, bytesReceived: normalizedBytes, + bytesTotal: bytesTotal, ); if (LogBuffer.loggingEnabled) { @@ -1624,11 +1627,13 @@ class DownloadQueueNotifier extends Notifier { progress: update.progress, speedMBps: update.speedMBps ?? current.speedMBps, bytesReceived: update.bytesReceived ?? current.bytesReceived, + bytesTotal: update.bytesTotal ?? current.bytesTotal, ); if (current.status != next.status || current.progress != next.progress || current.speedMBps != next.speedMBps || - current.bytesReceived != next.bytesReceived) { + current.bytesReceived != next.bytesReceived || + current.bytesTotal != next.bytesTotal) { if (!changed) { updatedItems = List.from(updatedItems); changed = true; @@ -2408,6 +2413,7 @@ class DownloadQueueNotifier extends Notifier { progress: 0, speedMBps: 0, bytesReceived: 0, + bytesTotal: 0, ); }) .toList(growable: false); diff --git a/lib/screens/queue_tab.dart b/lib/screens/queue_tab.dart index 5e3f9fcb..7a4465ed 100644 --- a/lib/screens/queue_tab.dart +++ b/lib/screens/queue_tab.dart @@ -5537,17 +5537,29 @@ class _QueueTabState extends ConsumerState { ), const SizedBox(width: 8), Text( - // When progress is 0 (unknown size, e.g. YouTube tunnel mode), - // show bytes downloaded instead of percentage - item.progress > 0 - ? (item.speedMBps > 0 - ? '${(item.progress * 100).toStringAsFixed(0)}% • ${item.speedMBps.toStringAsFixed(1)} MB/s' - : '${(item.progress * 100).toStringAsFixed(0)}%') + item.bytesTotal > 0 && item.bytesReceived > 0 + ? (() { + final receivedMB = + item.bytesReceived / (1024 * 1024); + final totalMB = + item.bytesTotal / (1024 * 1024); + final progressLabel = item.progress > 0 + ? '${(item.progress * 100).toStringAsFixed(0)}% • ' + : ''; + final speedLabel = item.speedMBps > 0 + ? ' • ${item.speedMBps.toStringAsFixed(1)} MB/s' + : ''; + return '$progressLabel${receivedMB.toStringAsFixed(1)} / ${totalMB.toStringAsFixed(1)} MB$speedLabel'; + })() : (item.bytesReceived > 0 - ? '${(item.bytesReceived / (1024 * 1024)).toStringAsFixed(1)} MB • ${item.speedMBps.toStringAsFixed(1)} MB/s' - : (item.speedMBps > 0 - ? 'Downloading • ${item.speedMBps.toStringAsFixed(1)} MB/s' - : 'Starting...')), + ? '${(item.bytesReceived / (1024 * 1024)).toStringAsFixed(1)} MB${item.speedMBps > 0 ? ' • ${item.speedMBps.toStringAsFixed(1)} MB/s' : ''}' + : (item.progress > 0 + ? (item.speedMBps > 0 + ? '${(item.progress * 100).toStringAsFixed(0)}% • ${item.speedMBps.toStringAsFixed(1)} MB/s' + : '${(item.progress * 100).toStringAsFixed(0)}%') + : (item.speedMBps > 0 + ? 'Downloading • ${item.speedMBps.toStringAsFixed(1)} MB/s' + : 'Starting...'))), style: Theme.of(context).textTheme.labelSmall ?.copyWith( color: colorScheme.primary,