diff --git a/go_backend/extension_providers.go b/go_backend/extension_providers.go index 5adc06f5..d832215c 100644 --- a/go_backend/extension_providers.go +++ b/go_backend/extension_providers.go @@ -83,9 +83,10 @@ type ExtSearchResult struct { } type ExtAvailabilityResult struct { - Available bool `json:"available"` - Reason string `json:"reason,omitempty"` - TrackID string `json:"track_id,omitempty"` + Available bool `json:"available"` + Reason string `json:"reason,omitempty"` + TrackID string `json:"track_id,omitempty"` + SkipFallback bool `json:"skip_fallback,omitempty"` } type ExtDownloadURLResult struct { @@ -95,6 +96,32 @@ type ExtDownloadURLResult struct { SampleRate int `json:"sample_rate,omitempty"` } +func shouldStopProviderFallback(availability *ExtAvailabilityResult) bool { + return availability != nil && availability.SkipFallback +} + +func resolveExtensionAvailabilityReason(availability *ExtAvailabilityResult, err error) string { + if availability != nil { + if reason := strings.TrimSpace(availability.Reason); reason != "" { + return reason + } + } + if err != nil { + return err.Error() + } + return "extension requested no further fallback" +} + +func buildExtensionFallbackStoppedResponse(providerID string, availability *ExtAvailabilityResult, err error) *DownloadResponse { + reason := resolveExtensionAvailabilityReason(availability, err) + return &DownloadResponse{ + Success: false, + Error: fmt.Sprintf("Fallback stopped by %s: %s", providerID, reason), + ErrorType: "extension_error", + Service: providerID, + } +} + type DownloadDecryptionInfo struct { Strategy string `json:"strategy,omitempty"` Key string `json:"key,omitempty"` @@ -1286,6 +1313,31 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro var lastErr error var skipBuiltIn bool + var sourceExtensionLocked bool + var sourceExtensionAvailability *ExtAvailabilityResult + var sourceExtensionTrackID string + + if req.Source != "" && + !isBuiltInProvider(strings.ToLower(req.Source)) && + selectedProvider != req.Source { + ext, err := extManager.GetExtension(req.Source) + if err == nil && ext.Enabled && ext.Error == "" && ext.Manifest.IsDownloadProvider() { + provider := newExtensionProviderWrapper(ext) + availability, availErr := provider.CheckAvailabilityForItemID(req.ISRC, req.TrackName, req.ArtistName, req.SpotifyID, req.DeezerID, req.ItemID) + if errors.Is(availErr, ErrDownloadCancelled) { + return nil, ErrDownloadCancelled + } + if availErr != nil { + GoLog("[DownloadWithExtensionFallback] Source extension %s preflight failed (non-fatal): %v\n", req.Source, availErr) + } else if shouldStopProviderFallback(availability) { + sourceExtensionLocked = true + sourceExtensionAvailability = availability + sourceExtensionTrackID = strings.TrimSpace(availability.TrackID) + selectedProvider = req.Source + GoLog("[DownloadWithExtensionFallback] Source extension %s requested skip_fallback (available=%v), locking download to source extension\n", req.Source, availability.Available) + } + } + } if req.Source != "" && !isBuiltInProvider(strings.ToLower(req.Source)) { ext, err := extManager.GetExtension(req.Source) @@ -1466,6 +1518,11 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro return nil, ErrDownloadCancelled } + if sourceExtensionLocked && (sourceExtensionAvailability == nil || !sourceExtensionAvailability.Available) { + GoLog("[DownloadWithExtensionFallback] Source extension %s stopped fallback before download (reason: %s)\n", req.Source, resolveExtensionAvailabilityReason(sourceExtensionAvailability, nil)) + return buildExtensionFallbackStoppedResponse(req.Source, sourceExtensionAvailability, nil), nil + } + GoLog("[DownloadWithExtensionFallback] Track source is extension '%s' matching selected provider, trying it first\n", req.Source) ext, err := extManager.GetExtension(req.Source) @@ -1475,6 +1532,9 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro provider := newExtensionProviderWrapper(ext) trackID := req.SpotifyID + if sourceExtensionTrackID != "" { + trackID = sourceExtensionTrackID + } GoLog("[DownloadWithExtensionFallback] Downloading from source extension with trackID: %s (skipBuiltInFallback: %v)\n", trackID, skipBuiltIn) @@ -1635,7 +1695,11 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro } GoLog("[DownloadWithExtensionFallback] Source extension %s failed: %v\n", req.Source, lastErr) - if skipBuiltIn { + if skipBuiltIn || sourceExtensionLocked { + if sourceExtensionLocked { + GoLog("[DownloadWithExtensionFallback] Source extension %s requested skip_fallback, not trying other providers\n", req.Source) + return buildExtensionFallbackStoppedResponse(req.Source, sourceExtensionAvailability, lastErr), nil + } GoLog("[DownloadWithExtensionFallback] skipBuiltInFallback is true, not trying other providers\n") return &DownloadResponse{ Success: false, @@ -1735,11 +1799,16 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro if errors.Is(err, ErrDownloadCancelled) { return nil, ErrDownloadCancelled } + terminalAvailability := shouldStopProviderFallback(availability) if err != nil || !availability.Available { GoLog("[DownloadWithExtensionFallback] %s: not available\n", providerID) if err != nil { lastErr = err } + if terminalAvailability { + GoLog("[DownloadWithExtensionFallback] %s requested skip_fallback after availability check\n", providerID) + return buildExtensionFallbackStoppedResponse(providerID, availability, err), nil + } continue } @@ -1864,6 +1933,10 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro lastErr = fmt.Errorf("%s", result.ErrorMessage) } GoLog("[DownloadWithExtensionFallback] %s failed: %v\n", providerID, lastErr) + if terminalAvailability { + GoLog("[DownloadWithExtensionFallback] %s requested skip_fallback after download failure\n", providerID) + return buildExtensionFallbackStoppedResponse(providerID, availability, lastErr), nil + } } } diff --git a/go_backend/extension_providers_test.go b/go_backend/extension_providers_test.go index 1d30be4b..cdc1621d 100644 --- a/go_backend/extension_providers_test.go +++ b/go_backend/extension_providers_test.go @@ -1,6 +1,7 @@ package gobackend import ( + "errors" "os" "path/filepath" "testing" @@ -180,6 +181,45 @@ func TestBuildOutputPathForExtensionUsesTempDirForFDOutput(t *testing.T) { } } +func TestShouldStopProviderFallback(t *testing.T) { + if shouldStopProviderFallback(nil) { + t.Fatal("nil availability should not stop fallback") + } + if shouldStopProviderFallback(&ExtAvailabilityResult{Available: false}) { + t.Fatal("availability without skip_fallback should not stop fallback") + } + if !shouldStopProviderFallback(&ExtAvailabilityResult{Available: false, SkipFallback: true}) { + t.Fatal("skip_fallback availability should stop fallback") + } +} + +func TestBuildExtensionFallbackStoppedResponsePrefersAvailabilityReason(t *testing.T) { + resp := buildExtensionFallbackStoppedResponse("soundcloud", &ExtAvailabilityResult{ + Reason: "direct SoundCloud track ID", + SkipFallback: true, + }, errors.New("ignored")) + + if resp.Service != "soundcloud" { + t.Fatalf("service = %q", resp.Service) + } + if resp.Error != "Fallback stopped by soundcloud: direct SoundCloud track ID" { + t.Fatalf("unexpected error message: %q", resp.Error) + } + if resp.ErrorType != "extension_error" { + t.Fatalf("error type = %q", resp.ErrorType) + } +} + +func TestBuildExtensionFallbackStoppedResponseFallsBackToError(t *testing.T) { + resp := buildExtensionFallbackStoppedResponse("soundcloud", &ExtAvailabilityResult{ + SkipFallback: true, + }, errors.New("lookup failed")) + + if resp.Error != "Fallback stopped by soundcloud: lookup failed" { + t.Fatalf("unexpected error message: %q", resp.Error) + } +} + func TestCanEmbedGenreLabelRequiresExistingAbsoluteLocalFile(t *testing.T) { tempFile := filepath.Join(t.TempDir(), "track.flac") if err := os.WriteFile(tempFile, []byte("fLaC"), 0644); err != nil {