feat: track byte-level download progress for extension downloads

Pass active download item ID through extension download pipeline so
fileDownload can report bytes received/total via ItemProgressWriter.
Add bytesTotal field to DownloadItem model and show X/Y MB progress
in queue tab when total size is known.
This commit is contained in:
zarzet
2026-03-27 21:58:01 +07:00
parent f29177216d
commit 8ee2919934
7 changed files with 82 additions and 24 deletions
+7 -3
View File
@@ -510,7 +510,7 @@ func (p *ExtensionProviderWrapper) GetDownloadURL(trackID, quality string) (*Ext
const ExtDownloadTimeout = DownloadTimeout
func (p *ExtensionProviderWrapper) Download(trackID, quality, outputPath string, onProgress func(percent int)) (*ExtDownloadResult, error) {
func (p *ExtensionProviderWrapper) Download(trackID, quality, outputPath, itemID string, onProgress func(percent int)) (*ExtDownloadResult, error) {
if !p.extension.Manifest.IsDownloadProvider() {
return nil, fmt.Errorf("extension '%s' is not a download provider", p.extension.ID)
}
@@ -526,6 +526,10 @@ func (p *ExtensionProviderWrapper) Download(trackID, quality, outputPath string,
}, nil
}
defer p.extension.VMMu.Unlock()
if p.extension.runtime != nil {
p.extension.runtime.setActiveDownloadItemID(itemID)
defer p.extension.runtime.clearActiveDownloadItemID()
}
p.vm.Set("__onProgress", func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) > 0 {
@@ -1128,7 +1132,7 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro
StartItemProgress(req.ItemID)
}
result, err := provider.Download(trackID, req.Quality, outputPath, func(percent int) {
result, err := provider.Download(trackID, req.Quality, outputPath, req.ItemID, func(percent int) {
if req.ItemID != "" {
normalized := float64(percent) / 100.0
if normalized < 0 {
@@ -1356,7 +1360,7 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro
StartItemProgress(req.ItemID)
}
result, err := provider.Download(availability.TrackID, req.Quality, outputPath, func(percent int) {
result, err := provider.Download(availability.TrackID, req.Quality, outputPath, req.ItemID, func(percent int) {
if req.ItemID != "" {
normalized := float64(percent) / 100.0
if normalized < 0 {
+21
View File
@@ -90,6 +90,9 @@ type ExtensionRuntime struct {
dataDir string
vm *goja.Runtime
activeDownloadMu sync.RWMutex
activeDownloadItemID string
storageMu sync.RWMutex
storageCache map[string]interface{}
storageLoaded bool
@@ -139,6 +142,24 @@ func NewExtensionRuntime(ext *LoadedExtension) *ExtensionRuntime {
return runtime
}
func (r *ExtensionRuntime) setActiveDownloadItemID(itemID string) {
r.activeDownloadMu.Lock()
defer r.activeDownloadMu.Unlock()
r.activeDownloadItemID = strings.TrimSpace(itemID)
}
func (r *ExtensionRuntime) clearActiveDownloadItemID() {
r.activeDownloadMu.Lock()
defer r.activeDownloadMu.Unlock()
r.activeDownloadItemID = ""
}
func (r *ExtensionRuntime) getActiveDownloadItemID() string {
r.activeDownloadMu.RLock()
defer r.activeDownloadMu.RUnlock()
return r.activeDownloadItemID
}
func newExtensionHTTPClient(ext *LoadedExtension, jar http.CookieJar, timeout time.Duration) *http.Client {
// Extension sandbox enforces HTTPS-only domains. Do not apply global
// allow_http scheme downgrade here, because some extension APIs (e.g.
+16 -1
View File
@@ -205,13 +205,22 @@ func (r *ExtensionRuntime) fileDownload(call goja.FunctionCall) goja.Value {
defer out.Close()
contentLength := resp.ContentLength
activeItemID := r.getActiveDownloadItemID()
if activeItemID != "" && contentLength > 0 {
SetItemBytesTotal(activeItemID, contentLength)
}
var progressWriter interface{ Write([]byte) (int, error) } = out
if activeItemID != "" {
progressWriter = NewItemProgressWriter(out, activeItemID)
}
var written int64
buf := make([]byte, 32*1024)
for {
nr, er := resp.Body.Read(buf)
if nr > 0 {
nw, ew := out.Write(buf[0:nr])
nw, ew := progressWriter.Write(buf[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
@@ -220,6 +229,12 @@ func (r *ExtensionRuntime) fileDownload(call goja.FunctionCall) goja.Value {
}
written += int64(nw)
if ew != nil {
if ew == ErrDownloadCancelled {
return r.vm.ToValue(map[string]interface{}{
"success": false,
"error": "download cancelled",
})
}
return r.vm.ToValue(map[string]interface{}{
"success": false,
"error": fmt.Sprintf("failed to write file: %v", ew),
+7 -9
View File
@@ -12,13 +12,7 @@ enum DownloadStatus {
skipped,
}
enum DownloadErrorType {
unknown,
notFound,
rateLimit,
network,
permission,
}
enum DownloadErrorType { unknown, notFound, rateLimit, network, permission }
@JsonSerializable()
class DownloadItem {
@@ -28,7 +22,8 @@ class DownloadItem {
final DownloadStatus status;
final double progress;
final double speedMBps;
final int bytesReceived; // Bytes downloaded so far (for unknown size downloads)
final int bytesReceived; // Bytes downloaded so far
final int bytesTotal; // Total bytes when the server provides content length
final String? filePath;
final String? error;
final DownloadErrorType? errorType;
@@ -44,6 +39,7 @@ class DownloadItem {
this.progress = 0.0,
this.speedMBps = 0.0,
this.bytesReceived = 0,
this.bytesTotal = 0,
this.filePath,
this.error,
this.errorType,
@@ -60,6 +56,7 @@ class DownloadItem {
double? progress,
double? speedMBps,
int? bytesReceived,
int? bytesTotal,
String? filePath,
String? error,
DownloadErrorType? errorType,
@@ -75,6 +72,7 @@ class DownloadItem {
progress: progress ?? this.progress,
speedMBps: speedMBps ?? this.speedMBps,
bytesReceived: bytesReceived ?? this.bytesReceived,
bytesTotal: bytesTotal ?? this.bytesTotal,
filePath: filePath ?? this.filePath,
error: error ?? this.error,
errorType: errorType ?? this.errorType,
@@ -86,7 +84,7 @@ class DownloadItem {
String get errorMessage {
if (error == null) return '';
switch (errorType) {
case DownloadErrorType.notFound:
return 'Song not found on any service';
+2
View File
@@ -16,6 +16,7 @@ DownloadItem _$DownloadItemFromJson(Map<String, dynamic> json) => DownloadItem(
progress: (json['progress'] as num?)?.toDouble() ?? 0.0,
speedMBps: (json['speedMBps'] as num?)?.toDouble() ?? 0.0,
bytesReceived: (json['bytesReceived'] as num?)?.toInt() ?? 0,
bytesTotal: (json['bytesTotal'] as num?)?.toInt() ?? 0,
filePath: json['filePath'] as String?,
error: json['error'] as String?,
errorType: $enumDecodeNullable(_$DownloadErrorTypeEnumMap, json['errorType']),
@@ -33,6 +34,7 @@ Map<String, dynamic> _$DownloadItemToJson(DownloadItem instance) =>
'progress': instance.progress,
'speedMBps': instance.speedMBps,
'bytesReceived': instance.bytesReceived,
'bytesTotal': instance.bytesTotal,
'filePath': instance.filePath,
'error': instance.error,
'errorType': _$DownloadErrorTypeEnumMap[instance.errorType],
+7 -1
View File
@@ -1166,12 +1166,14 @@ class _ProgressUpdate {
final double progress;
final double? speedMBps;
final int? bytesReceived;
final int? bytesTotal;
const _ProgressUpdate({
required this.status,
required this.progress,
this.speedMBps,
this.bytesReceived,
this.bytesTotal,
});
}
@@ -1587,6 +1589,7 @@ class DownloadQueueNotifier extends Notifier<DownloadQueueState> {
progress: normalizedProgress,
speedMBps: normalizedSpeed,
bytesReceived: normalizedBytes,
bytesTotal: bytesTotal,
);
if (LogBuffer.loggingEnabled) {
@@ -1624,11 +1627,13 @@ class DownloadQueueNotifier extends Notifier<DownloadQueueState> {
progress: update.progress,
speedMBps: update.speedMBps ?? current.speedMBps,
bytesReceived: update.bytesReceived ?? current.bytesReceived,
bytesTotal: update.bytesTotal ?? current.bytesTotal,
);
if (current.status != next.status ||
current.progress != next.progress ||
current.speedMBps != next.speedMBps ||
current.bytesReceived != next.bytesReceived) {
current.bytesReceived != next.bytesReceived ||
current.bytesTotal != next.bytesTotal) {
if (!changed) {
updatedItems = List<DownloadItem>.from(updatedItems);
changed = true;
@@ -2408,6 +2413,7 @@ class DownloadQueueNotifier extends Notifier<DownloadQueueState> {
progress: 0,
speedMBps: 0,
bytesReceived: 0,
bytesTotal: 0,
);
})
.toList(growable: false);
+22 -10
View File
@@ -5537,17 +5537,29 @@ class _QueueTabState extends ConsumerState<QueueTab> {
),
const SizedBox(width: 8),
Text(
// When progress is 0 (unknown size, e.g. YouTube tunnel mode),
// show bytes downloaded instead of percentage
item.progress > 0
? (item.speedMBps > 0
? '${(item.progress * 100).toStringAsFixed(0)}% • ${item.speedMBps.toStringAsFixed(1)} MB/s'
: '${(item.progress * 100).toStringAsFixed(0)}%')
item.bytesTotal > 0 && item.bytesReceived > 0
? (() {
final receivedMB =
item.bytesReceived / (1024 * 1024);
final totalMB =
item.bytesTotal / (1024 * 1024);
final progressLabel = item.progress > 0
? '${(item.progress * 100).toStringAsFixed(0)}% • '
: '';
final speedLabel = item.speedMBps > 0
? '${item.speedMBps.toStringAsFixed(1)} MB/s'
: '';
return '$progressLabel${receivedMB.toStringAsFixed(1)} / ${totalMB.toStringAsFixed(1)} MB$speedLabel';
})()
: (item.bytesReceived > 0
? '${(item.bytesReceived / (1024 * 1024)).toStringAsFixed(1)} MB${item.speedMBps.toStringAsFixed(1)} MB/s'
: (item.speedMBps > 0
? 'Downloading • ${item.speedMBps.toStringAsFixed(1)} MB/s'
: 'Starting...')),
? '${(item.bytesReceived / (1024 * 1024)).toStringAsFixed(1)} MB${item.speedMBps > 0 ? '${item.speedMBps.toStringAsFixed(1)} MB/s' : ''}'
: (item.progress > 0
? (item.speedMBps > 0
? '${(item.progress * 100).toStringAsFixed(0)}% • ${item.speedMBps.toStringAsFixed(1)} MB/s'
: '${(item.progress * 100).toStringAsFixed(0)}%')
: (item.speedMBps > 0
? 'Downloading • ${item.speedMBps.toStringAsFixed(1)} MB/s'
: 'Starting...'))),
style: Theme.of(context).textTheme.labelSmall
?.copyWith(
color: colorScheme.primary,