diff --git a/.gitignore b/.gitignore index ecf49ac1..0bb6fc8d 100644 --- a/.gitignore +++ b/.gitignore @@ -60,7 +60,9 @@ ios/Flutter/Flutter.framework/ ios/Flutter/Flutter.podspec # Extension folder -extension/ +extension/* +extension/v2/ +extension/v2/** # Agent instructions AGENTS.md diff --git a/android/app/src/main/AndroidManifest.xml b/android/app/src/main/AndroidManifest.xml index 4a652acd..d70b105e 100644 --- a/android/app/src/main/AndroidManifest.xml +++ b/android/app/src/main/AndroidManifest.xml @@ -100,6 +100,12 @@ + + + + + + diff --git a/android/app/src/main/kotlin/com/zarz/spotiflac/MainActivity.kt b/android/app/src/main/kotlin/com/zarz/spotiflac/MainActivity.kt index 06c4c889..341ce349 100644 --- a/android/app/src/main/kotlin/com/zarz/spotiflac/MainActivity.kt +++ b/android/app/src/main/kotlin/com/zarz/spotiflac/MainActivity.kt @@ -47,6 +47,8 @@ class MainActivity: FlutterFragmentActivity() { private val LARGE_JSON_RESULT_FILE_KEY = "__json_file" private val LARGE_JSON_RESULT_FILE_THRESHOLD_BYTES = 256 * 1024 private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Main) + private var backendChannel: MethodChannel? = null + private val pendingSessionGrantEvents = mutableListOf>() private var pendingSafTreeResult: MethodChannel.Result? = null private val safScanLock = Any() private val safDirLock = Any() @@ -2073,14 +2075,22 @@ class MainActivity: FlutterFragmentActivity() { } val host = (uri.host ?: "").lowercase(Locale.US) val path = (uri.path ?: "").lowercase(Locale.US) + val isSessionGrant = host == "session-grant" val isCallback = - host == "callback" || + isSessionGrant || + host == "callback" || host == "spotify-callback" || path.contains("callback") if (!isCallback) { return } - val code = uri.getQueryParameter("code")?.trim().orEmpty() + val code = ( + if (isSessionGrant) { + uri.getQueryParameter("grant") ?: uri.getQueryParameter("code") + } else { + uri.getQueryParameter("code") + } + )?.trim().orEmpty() if (code.isEmpty()) { return } @@ -2092,15 +2102,38 @@ class MainActivity: FlutterFragmentActivity() { intent.data = null scope.launch(Dispatchers.IO) { try { - Gobackend.setExtensionAuthCodeByID(extId, code) - val json = Gobackend.invokeExtensionActionJSON(extId, "completeSpotifyLogin") - android.util.Log.i("SpotiFLAC", "Extension OAuth complete for $extId: $json") + val json = if (isSessionGrant) { + Gobackend.setExtensionSessionGrantByID(extId, code) + Gobackend.invokeExtensionActionJSON(extId, "completeGrant") + } else { + Gobackend.setExtensionAuthCodeByID(extId, code) + Gobackend.invokeExtensionActionJSON(extId, "completeSpotifyLogin") + } + android.util.Log.i("SpotiFLAC", "Extension callback complete for $extId: $json") + if (isSessionGrant) { + withContext(Dispatchers.Main) { + notifySessionGrantCompleted(extId) + } + } } catch (e: Exception) { - android.util.Log.w("SpotiFLAC", "Extension OAuth failed: ${e.message}") + android.util.Log.w("SpotiFLAC", "Extension callback failed: ${e.message}") } } } + private fun notifySessionGrantCompleted(extensionId: String) { + val payload = mapOf( + "extension_id" to extensionId, + "success" to true, + ) + val channel = backendChannel + if (channel == null) { + pendingSessionGrantEvents.add(payload) + return + } + channel.invokeMethod("extensionSessionGrantCompleted", payload) + } + override fun onDestroy() { try { Gobackend.cleanupExtensions() @@ -2164,7 +2197,17 @@ class MainActivity: FlutterFragmentActivity() { }, ) - MethodChannel(messenger, CHANNEL).setMethodCallHandler { call, result -> + val channel = MethodChannel(messenger, CHANNEL) + backendChannel = channel + if (pendingSessionGrantEvents.isNotEmpty()) { + val events = pendingSessionGrantEvents.toList() + pendingSessionGrantEvents.clear() + for (event in events) { + channel.invokeMethod("extensionSessionGrantCompleted", event) + } + } + + channel.setMethodCallHandler { call, result -> scope.launch { try { when (call.method) { diff --git a/go_backend/exports.go b/go_backend/exports.go index f1d62f1b..ea9038c5 100644 --- a/go_backend/exports.go +++ b/go_backend/exports.go @@ -3263,6 +3263,10 @@ func SetExtensionAuthCodeByID(extensionID, authCode string) { SetExtensionAuthCode(extensionID, authCode) } +func SetExtensionSessionGrantByID(extensionID, grant string) { + setPendingSignedSessionGrant(extensionID, grant) +} + func SetExtensionTokensByID(extensionID, accessToken, refreshToken string, expiresIn int) { var expiresAt time.Time if expiresIn > 0 { @@ -3935,9 +3939,12 @@ func callExtensionFunctionJSONWithRequestID(extensionID, functionName string, ti if (typeof extension !== 'undefined' && typeof extension.%s === 'function') { return extension.%s(); } + if (typeof %s === 'function') { + return %s(); + } return null; })() - `, functionName, functionName) + `, functionName, functionName, functionName, functionName) jsStartedAt := time.Now() result, err := RunWithTimeoutContextAndRecover(requestCtx, vm, script, timeout) diff --git a/go_backend/extension_manager.go b/go_backend/extension_manager.go index 091e3d32..86347fda 100644 --- a/go_backend/extension_manager.go +++ b/go_backend/extension_manager.go @@ -44,6 +44,11 @@ func compareVersions(v1, v2 string) int { return 0 } +func isExtensionPackagePath(filePath string) bool { + lowerPath := strings.ToLower(filePath) + return strings.HasSuffix(lowerPath, ".spotiflac-ext") || strings.HasSuffix(lowerPath, ".sflx") +} + type loadedExtension struct { ID string `json:"id"` Manifest *ExtensionManifest `json:"manifest"` @@ -166,8 +171,8 @@ func (m *extensionManager) LoadExtensionFromFile(filePath string) (*loadedExtens } func (m *extensionManager) loadExtensionFromFileLocked(filePath string) (*loadedExtension, error) { - if !strings.HasSuffix(strings.ToLower(filePath), ".spotiflac-ext") { - return nil, fmt.Errorf("invalid file format: please select a .spotiflac-ext file") + if !isExtensionPackagePath(filePath) { + return nil, fmt.Errorf("invalid file format: please select a .spotiflac-ext or .sflx file") } zipReader, err := zip.OpenReader(filePath) @@ -673,7 +678,7 @@ func (m *extensionManager) LoadExtensionsFromDirectory(dirPath string) ([]string loaded = append(loaded, ext.ID) } } - } else if strings.HasSuffix(strings.ToLower(entry.Name()), ".spotiflac-ext") { + } else if isExtensionPackagePath(entry.Name()) { ext, err := m.LoadExtensionFromFile(filepath.Join(dirPath, entry.Name())) if err != nil { GoLog("[Extension] Failed to load %s: %v\n", entry.Name(), err) @@ -775,8 +780,8 @@ func (m *extensionManager) UpgradeExtension(filePath string) (*loadedExtension, } func (m *extensionManager) upgradeExtensionLocked(filePath string) (*loadedExtension, error) { - if !strings.HasSuffix(strings.ToLower(filePath), ".spotiflac-ext") { - return nil, fmt.Errorf("invalid file format: please select a .spotiflac-ext file") + if !isExtensionPackagePath(filePath) { + return nil, fmt.Errorf("invalid file format: please select a .spotiflac-ext or .sflx file") } zipReader, err := zip.OpenReader(filePath) @@ -924,8 +929,8 @@ type ExtensionUpgradeInfo struct { } func (m *extensionManager) checkExtensionUpgradeInternal(filePath string) (*ExtensionUpgradeInfo, error) { - if !strings.HasSuffix(strings.ToLower(filePath), ".spotiflac-ext") { - return nil, fmt.Errorf("invalid file format: please select a .spotiflac-ext file") + if !isExtensionPackagePath(filePath) { + return nil, fmt.Errorf("invalid file format: please select a .spotiflac-ext or .sflx file") } zipReader, err := zip.OpenReader(filePath) diff --git a/go_backend/extension_manifest.go b/go_backend/extension_manifest.go index d3b6cd64..740a3085 100644 --- a/go_backend/extension_manifest.go +++ b/go_backend/extension_manifest.go @@ -3,6 +3,7 @@ package gobackend import ( "encoding/json" "fmt" + "net/url" "strings" ) @@ -113,28 +114,49 @@ type ExtensionHealthCheck struct { Required bool `json:"required,omitempty"` } +type SignedSessionEndpoints struct { + Bootstrap string `json:"bootstrap,omitempty"` + Challenge string `json:"challenge,omitempty"` + Exchange string `json:"exchange,omitempty"` + Refresh string `json:"refresh,omitempty"` +} + +type SignedSessionConfig struct { + Namespace string `json:"namespace"` + BaseURL string `json:"baseUrl"` + AppVersion string `json:"appVersion,omitempty"` + Platform string `json:"platform,omitempty"` + CallbackURL string `json:"callbackUrl,omitempty"` + SchemeLabel string `json:"schemeLabel,omitempty"` + HeaderPrefix string `json:"headerPrefix,omitempty"` + TimeWindowSeconds int `json:"timeWindowSeconds,omitempty"` + Endpoints SignedSessionEndpoints `json:"endpoints,omitempty"` +} + type ExtensionManifest struct { - Name string `json:"name"` - DisplayName string `json:"displayName"` - Version string `json:"version"` - Description string `json:"description"` - Homepage string `json:"homepage,omitempty"` - Icon string `json:"icon,omitempty"` - Types []ExtensionType `json:"type"` - Permissions ExtensionPermissions `json:"permissions"` - Settings []ExtensionSetting `json:"settings,omitempty"` - QualityOptions []QualityOption `json:"qualityOptions,omitempty"` - MinAppVersion string `json:"minAppVersion,omitempty"` - SkipMetadataEnrichment bool `json:"skipMetadataEnrichment,omitempty"` - SkipLyrics bool `json:"skipLyrics,omitempty"` - StopProviderFallback bool `json:"stopProviderFallback,omitempty"` - SkipBuiltInFallback bool `json:"skipBuiltInFallback,omitempty"` - SearchBehavior *SearchBehaviorConfig `json:"searchBehavior,omitempty"` - URLHandler *URLHandlerConfig `json:"urlHandler,omitempty"` - TrackMatching *TrackMatchingConfig `json:"trackMatching,omitempty"` - PostProcessing *PostProcessingConfig `json:"postProcessing,omitempty"` - ServiceHealth []ExtensionHealthCheck `json:"serviceHealth,omitempty"` - Capabilities map[string]interface{} `json:"capabilities,omitempty"` + Name string `json:"name"` + DisplayName string `json:"displayName"` + Version string `json:"version"` + Description string `json:"description"` + Homepage string `json:"homepage,omitempty"` + Icon string `json:"icon,omitempty"` + Types []ExtensionType `json:"type"` + Permissions ExtensionPermissions `json:"permissions"` + Settings []ExtensionSetting `json:"settings,omitempty"` + QualityOptions []QualityOption `json:"qualityOptions,omitempty"` + MinAppVersion string `json:"minAppVersion,omitempty"` + SkipMetadataEnrichment bool `json:"skipMetadataEnrichment,omitempty"` + SkipLyrics bool `json:"skipLyrics,omitempty"` + StopProviderFallback bool `json:"stopProviderFallback,omitempty"` + SkipBuiltInFallback bool `json:"skipBuiltInFallback,omitempty"` + SearchBehavior *SearchBehaviorConfig `json:"searchBehavior,omitempty"` + URLHandler *URLHandlerConfig `json:"urlHandler,omitempty"` + TrackMatching *TrackMatchingConfig `json:"trackMatching,omitempty"` + PostProcessing *PostProcessingConfig `json:"postProcessing,omitempty"` + ServiceHealth []ExtensionHealthCheck `json:"serviceHealth,omitempty"` + SignedSession *SignedSessionConfig `json:"signedSession,omitempty"` + RequiredRuntimeFeatures []string `json:"requiredRuntimeFeatures,omitempty"` + Capabilities map[string]interface{} `json:"capabilities,omitempty"` } type ManifestValidationError struct { @@ -238,6 +260,26 @@ func (m *ExtensionManifest) Validate() error { } } + if m.SignedSession != nil { + if strings.TrimSpace(m.SignedSession.Namespace) == "" { + return &ManifestValidationError{Field: "signedSession.namespace", Message: "namespace is required"} + } + baseURL := strings.TrimSpace(m.SignedSession.BaseURL) + if baseURL == "" { + return &ManifestValidationError{Field: "signedSession.baseUrl", Message: "baseUrl is required"} + } + if !strings.HasPrefix(strings.ToLower(baseURL), "https://") { + return &ManifestValidationError{Field: "signedSession.baseUrl", Message: "baseUrl must use https"} + } + parsed, err := url.Parse(baseURL) + if err != nil || parsed.Hostname() == "" { + return &ManifestValidationError{Field: "signedSession.baseUrl", Message: "baseUrl is invalid"} + } + if !m.IsDomainAllowed(parsed.Hostname()) { + return &ManifestValidationError{Field: "signedSession.baseUrl", Message: "baseUrl host must be listed in permissions.network"} + } + } + return nil } diff --git a/go_backend/extension_providers.go b/go_backend/extension_providers.go index 81d53af8..b8842b89 100644 --- a/go_backend/extension_providers.go +++ b/go_backend/extension_providers.go @@ -2135,6 +2135,7 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro } var lastErr error + var lastErrType string var stopProviderFallback bool var sourceExtensionLocked bool var sourceExtensionAvailability *ExtAvailabilityResult @@ -2449,11 +2450,23 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro }, nil } lastErr = err + lastErrType = "" } else if result.ErrorMessage != "" { lastErr = fmt.Errorf("%s", result.ErrorMessage) + lastErrType = strings.TrimSpace(result.ErrorType) } GoLog("[DownloadWithExtensionFallback] Source extension %s failed: %v\n", req.Source, lastErr) + if strings.EqualFold(lastErrType, "verification_required") { + GoLog("[DownloadWithExtensionFallback] Source extension %s requires verification, not trying other providers\n", req.Source) + return &DownloadResponse{ + Success: false, + Error: "Download failed: " + lastErr.Error(), + ErrorType: "verification_required", + Service: req.Source, + }, nil + } + if stopProviderFallback || sourceExtensionLocked { if sourceExtensionLocked { GoLog("[DownloadWithExtensionFallback] Source extension %s requested skip_fallback, not trying other providers\n", req.Source) @@ -2463,7 +2476,7 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro return &DownloadResponse{ Success: false, Error: "Download failed: " + lastErr.Error(), - ErrorType: "extension_error", + ErrorType: firstNonEmptyString(lastErrType, "extension_error"), Service: req.Source, }, nil } @@ -2632,8 +2645,10 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro }, nil } lastErr = err + lastErrType = "" } else if result.ErrorMessage != "" { lastErr = fmt.Errorf("%s", result.ErrorMessage) + lastErrType = strings.TrimSpace(result.ErrorType) } GoLog("[DownloadWithExtensionFallback] %s failed: %v\n", providerID, lastErr) if terminalAvailability { @@ -2644,7 +2659,7 @@ func DownloadWithExtensionFallback(req DownloadRequest) (*DownloadResponse, erro } if lastErr != nil { - errorType := classifyDownloadErrorType(lastErr.Error()) + errorType := firstNonEmptyString(lastErrType, classifyDownloadErrorType(lastErr.Error())) if errorType == "unknown" { errorType = "not_found" } diff --git a/go_backend/extension_runtime.go b/go_backend/extension_runtime.go index 3eaf6eda..3f0b97a0 100644 --- a/go_backend/extension_runtime.go +++ b/go_backend/extension_runtime.go @@ -465,6 +465,15 @@ func (r *extensionRuntime) RegisterAPIs(vm *goja.Runtime) { authObj.Set("exchangeCodeWithPKCE", r.authExchangeCodeWithPKCE) vm.Set("auth", authObj) + if r.manifest != nil && r.manifest.SignedSession != nil { + sessionObj := vm.NewObject() + sessionObj.Set("signedFetch", r.signedSessionFetch) + sessionObj.Set("completeGrant", r.signedSessionCompleteGrant) + sessionObj.Set("status", r.signedSessionStatus) + sessionObj.Set("clear", r.signedSessionClear) + vm.Set("session", sessionObj) + } + fileObj := vm.NewObject() fileObj.Set("download", r.fileDownload) fileObj.Set("exists", r.fileExists) diff --git a/go_backend/extension_signed_session.go b/go_backend/extension_signed_session.go new file mode 100644 index 00000000..edad1914 --- /dev/null +++ b/go_backend/extension_signed_session.go @@ -0,0 +1,595 @@ +package gobackend + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/dop251/goja" +) + +const signedSessionRefreshSkew = time.Hour + +var ( + pendingSignedSessionGrants = make(map[string]string) + pendingSignedSessionGrantsMu sync.Mutex +) + +type signedSessionRecord struct { + InstallID string `json:"install_id"` + SessionID string `json:"session_id,omitempty"` + SessionSecret string `json:"session_secret,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` +} + +type signedSessionExchangeResponse struct { + SessionID string `json:"session_id,omitempty"` + SessionSecret string `json:"session_secret,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` + ChallengeID string `json:"challenge_id,omitempty"` + ChallengeURL string `json:"challenge_url,omitempty"` + AuthURL string `json:"auth_url,omitempty"` +} + +func signedSessionConfigWithDefaults(config *SignedSessionConfig) SignedSessionConfig { + if config == nil { + return SignedSessionConfig{} + } + resolved := *config + if resolved.AppVersion == "" { + resolved.AppVersion = "ext-1.0" + } + if resolved.Platform == "" { + resolved.Platform = "extension" + } + if resolved.CallbackURL == "" { + resolved.CallbackURL = "spotiflac://session-grant" + } + if resolved.SchemeLabel == "" { + resolved.SchemeLabel = "SPOTIFLAC-HMAC-V1" + } + if resolved.HeaderPrefix == "" { + resolved.HeaderPrefix = "X-Sig-" + } + if resolved.TimeWindowSeconds <= 0 { + resolved.TimeWindowSeconds = 300 + } + if resolved.Endpoints.Bootstrap == "" { + resolved.Endpoints.Bootstrap = "/bootstrap" + } + if resolved.Endpoints.Challenge == "" { + resolved.Endpoints.Challenge = "/challenge" + } + if resolved.Endpoints.Exchange == "" { + resolved.Endpoints.Exchange = "/session/exchange" + } + return resolved +} + +func (r *extensionRuntime) signedSessionFilePath(config SignedSessionConfig) (string, error) { + namespace := sanitizeSignedSessionNamespace(config.Namespace) + if namespace == "" { + return "", fmt.Errorf("signed session namespace is empty") + } + baseDir := filepath.Dir(r.dataDir) + if baseDir == "." || baseDir == "" { + baseDir = r.dataDir + } + dir := filepath.Join(baseDir, "signed_sessions") + if err := os.MkdirAll(dir, 0700); err != nil { + return "", err + } + return filepath.Join(dir, namespace+".json"), nil +} + +func sanitizeSignedSessionNamespace(namespace string) string { + namespace = strings.TrimSpace(strings.ToLower(namespace)) + var b strings.Builder + for _, ch := range namespace { + if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '-' || ch == '_' || ch == '.' { + b.WriteRune(ch) + } + } + return strings.Trim(b.String(), ".-_") +} + +func (r *extensionRuntime) loadSignedSession(config SignedSessionConfig) (*signedSessionRecord, error) { + path, err := r.signedSessionFilePath(config) + if err != nil { + return nil, err + } + record := &signedSessionRecord{} + if data, err := os.ReadFile(path); err == nil { + _ = json.Unmarshal(data, record) + } + if strings.TrimSpace(record.InstallID) == "" { + record.InstallID = randomHex(16) + if err := r.saveSignedSession(config, record); err != nil { + return nil, err + } + } + return record, nil +} + +func (r *extensionRuntime) saveSignedSession(config SignedSessionConfig, record *signedSessionRecord) error { + path, err := r.signedSessionFilePath(config) + if err != nil { + return err + } + data, err := json.MarshalIndent(record, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0600) +} + +func randomHex(bytesLen int) string { + buf := make([]byte, bytesLen) + if _, err := rand.Read(buf); err != nil { + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return hex.EncodeToString(buf) +} + +func parseSignedSessionTime(value string) (time.Time, bool) { + value = strings.TrimSpace(value) + if value == "" { + return time.Time{}, false + } + layouts := []string{ + time.RFC3339Nano, + time.RFC3339, + "2006-01-02T15:04:05.000Z", + } + for _, layout := range layouts { + if parsed, err := time.Parse(layout, value); err == nil { + return parsed, true + } + } + return time.Time{}, false +} + +func (r *extensionRuntime) signedSessionStatus(call goja.FunctionCall) goja.Value { + config := signedSessionConfigWithDefaults(r.manifest.SignedSession) + if config.Namespace == "" || config.BaseURL == "" { + return r.vm.ToValue(map[string]interface{}{"authenticated": false, "error": "signedSession is not configured"}) + } + record, err := r.loadSignedSession(config) + if err != nil { + return r.vm.ToValue(map[string]interface{}{"authenticated": false, "error": err.Error()}) + } + authenticated := record.SessionID != "" && record.SessionSecret != "" + if expiresAt, ok := parseSignedSessionTime(record.ExpiresAt); ok && time.Now().After(expiresAt) { + authenticated = false + } + return r.vm.ToValue(map[string]interface{}{ + "authenticated": authenticated, + "expires_at": record.ExpiresAt, + "install_id": record.InstallID, + }) +} + +func (r *extensionRuntime) signedSessionClear(call goja.FunctionCall) goja.Value { + config := signedSessionConfigWithDefaults(r.manifest.SignedSession) + record, err := r.loadSignedSession(config) + if err != nil { + return r.vm.ToValue(map[string]interface{}{"success": false, "error": err.Error()}) + } + record.SessionID = "" + record.SessionSecret = "" + record.ExpiresAt = "" + if err := r.saveSignedSession(config, record); err != nil { + return r.vm.ToValue(map[string]interface{}{"success": false, "error": err.Error()}) + } + return r.vm.ToValue(map[string]interface{}{"success": true}) +} + +func (r *extensionRuntime) signedSessionCompleteGrant(call goja.FunctionCall) goja.Value { + grant := "" + if len(call.Arguments) > 0 { + grant = strings.TrimSpace(call.Arguments[0].String()) + } + if grant == "" { + pendingSignedSessionGrantsMu.Lock() + grant = pendingSignedSessionGrants[r.extensionID] + delete(pendingSignedSessionGrants, r.extensionID) + pendingSignedSessionGrantsMu.Unlock() + } + if grant == "" { + return r.vm.ToValue(map[string]interface{}{"success": false, "error": "no pending grant"}) + } + if err := r.exchangeSignedSessionGrant(grant); err != nil { + return r.vm.ToValue(map[string]interface{}{"success": false, "error": err.Error()}) + } + return r.vm.ToValue(map[string]interface{}{"success": true}) +} + +func (r *extensionRuntime) exchangeSignedSessionGrant(grant string) error { + config := signedSessionConfigWithDefaults(r.manifest.SignedSession) + record, err := r.loadSignedSession(config) + if err != nil { + return err + } + endpoint, err := signedSessionURL(config, config.Endpoints.Exchange) + if err != nil { + return err + } + payload := map[string]interface{}{ + "grant": grant, + "install_id": record.InstallID, + "app_version": config.AppVersion, + "platform": config.Platform, + } + body, _ := json.Marshal(payload) + req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "SpotiFLAC-Mobile/"+config.AppVersion) + resp, err := r.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + respBody, err := readExtensionHTTPResponseBody(resp) + if err != nil { + return err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("session exchange failed: HTTP %d", resp.StatusCode) + } + var exchanged signedSessionExchangeResponse + if err := json.Unmarshal(respBody, &exchanged); err != nil { + return fmt.Errorf("invalid session exchange response: %w", err) + } + if exchanged.SessionID == "" || exchanged.SessionSecret == "" || exchanged.ExpiresAt == "" { + return fmt.Errorf("session exchange response missing session fields") + } + record.SessionID = exchanged.SessionID + record.SessionSecret = exchanged.SessionSecret + record.ExpiresAt = exchanged.ExpiresAt + return r.saveSignedSession(config, record) +} + +func (r *extensionRuntime) signedSessionFetch(call goja.FunctionCall) goja.Value { + if len(call.Arguments) < 2 { + return r.vm.ToValue(map[string]interface{}{"ok": false, "error": "method and path are required"}) + } + config := signedSessionConfigWithDefaults(r.manifest.SignedSession) + if config.Namespace == "" || config.BaseURL == "" { + return r.vm.ToValue(map[string]interface{}{"ok": false, "error": "signedSession is not configured"}) + } + method := strings.ToUpper(strings.TrimSpace(call.Arguments[0].String())) + requestPath := call.Arguments[1].String() + body := []byte{} + if len(call.Arguments) > 2 && !goja.IsUndefined(call.Arguments[2]) && !goja.IsNull(call.Arguments[2]) { + switch v := call.Arguments[2].Export().(type) { + case string: + body = []byte(v) + case map[string]interface{}, []interface{}: + encoded, err := json.Marshal(v) + if err != nil { + return r.vm.ToValue(map[string]interface{}{"ok": false, "error": err.Error()}) + } + body = encoded + default: + body = []byte(call.Arguments[2].String()) + } + } + extraHeaders := map[string]string{} + if len(call.Arguments) > 3 && !goja.IsUndefined(call.Arguments[3]) && !goja.IsNull(call.Arguments[3]) { + if h, ok := call.Arguments[3].Export().(map[string]interface{}); ok { + for k, v := range h { + extraHeaders[k] = fmt.Sprintf("%v", v) + } + } + } + + record, err := r.ensureSignedSession(config) + if err != nil { + if authURL := r.startSignedSessionVerification(config, ""); authURL != "" { + return r.signedSessionVerificationRequiredValue(authURL) + } + return r.vm.ToValue(map[string]interface{}{"ok": false, "error": err.Error()}) + } + + resp, respBody, respHeaders, err := r.doSignedSessionRequest(config, record, method, requestPath, body, extraHeaders) + if err != nil { + return r.vm.ToValue(map[string]interface{}{"ok": false, "error": err.Error()}) + } + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusPreconditionRequired { + record.SessionID = "" + record.SessionSecret = "" + record.ExpiresAt = "" + _ = r.saveSignedSession(config, record) + if authURL := r.startSignedSessionVerification(config, ""); authURL != "" { + return r.signedSessionVerificationRequiredValue(authURL) + } + } + return r.vm.ToValue(map[string]interface{}{ + "statusCode": resp.StatusCode, + "status": resp.StatusCode, + "ok": resp.StatusCode >= 200 && resp.StatusCode < 300, + "url": resp.Request.URL.String(), + "body": string(respBody), + "headers": respHeaders, + }) +} + +func (r *extensionRuntime) signedSessionVerificationRequiredValue(authURL string) goja.Value { + return r.vm.ToValue(map[string]interface{}{ + "ok": false, + "needsVerification": true, + "error": "VERIFY_REQUIRED", + "open_auth_url": authURL, + "auth_url": authURL, + }) +} + +func (r *extensionRuntime) ensureSignedSession(config SignedSessionConfig) (*signedSessionRecord, error) { + record, err := r.loadSignedSession(config) + if err != nil { + return nil, err + } + if record.SessionID == "" || record.SessionSecret == "" { + return nil, fmt.Errorf("signed session is not authenticated") + } + if expiresAt, ok := parseSignedSessionTime(record.ExpiresAt); ok { + if time.Now().After(expiresAt) { + record.SessionID = "" + record.SessionSecret = "" + record.ExpiresAt = "" + _ = r.saveSignedSession(config, record) + return nil, fmt.Errorf("signed session expired") + } + if config.Endpoints.Refresh != "" && time.Until(expiresAt) <= signedSessionRefreshSkew { + _ = r.refreshSignedSession(config, record) + } + } + return record, nil +} + +func (r *extensionRuntime) refreshSignedSession(config SignedSessionConfig, record *signedSessionRecord) error { + body, _ := json.Marshal(map[string]string{"install_id": record.InstallID}) + resp, respBody, _, err := r.doSignedSessionRequest(config, record, http.MethodPost, config.Endpoints.Refresh, body, nil) + if err != nil { + return err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("session refresh failed: HTTP %d", resp.StatusCode) + } + var refreshed signedSessionExchangeResponse + if err := json.Unmarshal(respBody, &refreshed); err != nil { + return err + } + changed := false + if refreshed.SessionID != "" { + record.SessionID = refreshed.SessionID + changed = true + } + if refreshed.SessionSecret != "" { + record.SessionSecret = refreshed.SessionSecret + changed = true + } + if refreshed.ExpiresAt != "" && refreshed.ExpiresAt != record.ExpiresAt { + record.ExpiresAt = refreshed.ExpiresAt + changed = true + } + if changed { + return r.saveSignedSession(config, record) + } + return nil +} + +func (r *extensionRuntime) startSignedSessionVerification(config SignedSessionConfig, reason string) string { + record, err := r.loadSignedSession(config) + if err != nil { + return "" + } + bootstrapURL, err := signedSessionURL(config, config.Endpoints.Bootstrap) + if err != nil { + return "" + } + parsed, _ := url.Parse(bootstrapURL) + query := parsed.Query() + query.Set("app_version", config.AppVersion) + query.Set("install_id", record.InstallID) + parsed.RawQuery = query.Encode() + req, err := http.NewRequest(http.MethodGet, parsed.String(), nil) + if err != nil { + return "" + } + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "SpotiFLAC-Mobile/"+config.AppVersion) + resp, err := r.httpClient.Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + body, err := io.ReadAll(io.LimitReader(resp.Body, maxExtensionHTTPResponseBytes)) + if err != nil || resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "" + } + var boot signedSessionExchangeResponse + if err := json.Unmarshal(body, &boot); err != nil { + return "" + } + if boot.SessionID != "" && boot.SessionSecret != "" && boot.ExpiresAt != "" { + record.SessionID = boot.SessionID + record.SessionSecret = boot.SessionSecret + record.ExpiresAt = boot.ExpiresAt + _ = r.saveSignedSession(config, record) + return "" + } + authURL := boot.AuthURL + if authURL == "" && boot.ChallengeURL != "" { + authURL = boot.ChallengeURL + } + if authURL == "" && boot.ChallengeID != "" { + authURL = r.buildSignedSessionChallengeURL(config, boot.ChallengeID) + } + if authURL != "" { + pendingAuthRequestsMu.Lock() + pendingAuthRequests[r.extensionID] = &PendingAuthRequest{ + ExtensionID: r.extensionID, + AuthURL: authURL, + CallbackURL: config.CallbackURL, + } + pendingAuthRequestsMu.Unlock() + } + return authURL +} + +func (r *extensionRuntime) buildSignedSessionChallengeURL(config SignedSessionConfig, challengeID string) string { + challengeURL, err := signedSessionURL(config, config.Endpoints.Challenge) + if err != nil { + return "" + } + parsed, err := url.Parse(challengeURL) + if err != nil { + return "" + } + callback, err := url.Parse(config.CallbackURL) + if err != nil { + return "" + } + q := callback.Query() + q.Set("cb_version", "v2grant") + q.Set("state", r.extensionID) + callback.RawQuery = q.Encode() + + query := parsed.Query() + query.Set("id", challengeID) + query.Set("cb", callback.String()) + parsed.RawQuery = query.Encode() + return parsed.String() +} + +func signedSessionURL(config SignedSessionConfig, endpoint string) (string, error) { + base, err := url.Parse(strings.TrimRight(config.BaseURL, "/") + "/") + if err != nil || base.Scheme != "https" || base.Host == "" { + return "", fmt.Errorf("invalid signed session baseUrl") + } + endpoint = strings.TrimSpace(endpoint) + if endpoint == "" { + return "", fmt.Errorf("signed session endpoint is empty") + } + if strings.HasPrefix(endpoint, "https://") { + return endpoint, nil + } + endpoint = strings.TrimLeft(endpoint, "/") + ref, _ := url.Parse(endpoint) + return base.ResolveReference(ref).String(), nil +} + +func (r *extensionRuntime) doSignedSessionRequest( + config SignedSessionConfig, + record *signedSessionRecord, + method string, + requestPath string, + body []byte, + extraHeaders map[string]string, +) (*http.Response, []byte, map[string]interface{}, error) { + fullURL, err := signedSessionURL(config, requestPath) + if err != nil { + return nil, nil, nil, err + } + parsed, err := url.Parse(fullURL) + if err != nil { + return nil, nil, nil, err + } + ts := time.Now().UTC().Format("2006-01-02T15:04:05.000Z") + nonce := randomHex(12) + bodyHashBytes := sha256.Sum256(body) + bodyHash := hex.EncodeToString(bodyHashBytes[:]) + parsedTs, _ := time.Parse("2006-01-02T15:04:05.000Z", ts) + window := parsedTs.Unix() / int64(config.TimeWindowSeconds) + rollingInput := fmt.Sprintf("%d:%s", window, record.SessionID) + rk := base64.RawURLEncoding.EncodeToString(hmacSHA256Bytes([]byte(record.SessionSecret), []byte(rollingInput))) + signingInput := strings.Join([]string{ + config.SchemeLabel, + method, + parsed.EscapedPath(), + "", + bodyHash, + ts, + nonce, + record.SessionID, + config.AppVersion, + config.Platform, + }, "\n") + sig := base64.RawURLEncoding.EncodeToString(hmacSHA256Bytes([]byte(rk), []byte(signingInput))) + + req, err := http.NewRequest(method, fullURL, bytes.NewReader(body)) + if err != nil { + return nil, nil, nil, err + } + req = r.bindDownloadCancelContext(req) + req.Header.Set("Accept", "application/json") + if len(body) > 0 { + req.Header.Set("Content-Type", "application/json") + } + req.Header.Set("User-Agent", "SpotiFLAC-Mobile/"+config.AppVersion) + prefix := config.HeaderPrefix + req.Header.Set(prefix+"Session", record.SessionID) + req.Header.Set(prefix+"Timestamp", ts) + req.Header.Set(prefix+"Nonce", nonce) + req.Header.Set(prefix+"Body-SHA256", bodyHash) + req.Header.Set(prefix+"Signature", sig) + req.Header.Set(prefix+"App-Version", config.AppVersion) + req.Header.Set(prefix+"Platform", config.Platform) + for k, v := range extraHeaders { + req.Header.Set(k, v) + } + + resp, err := r.httpClient.Do(req) + if err != nil { + return nil, nil, nil, err + } + defer resp.Body.Close() + respBody, err := readExtensionHTTPResponseBody(resp) + if err != nil { + return nil, nil, nil, err + } + headers := make(map[string]interface{}) + for k, v := range resp.Header { + if len(v) == 1 { + headers[k] = v[0] + } else { + headers[k] = v + } + } + return resp, respBody, headers, nil +} + +func hmacSHA256Bytes(key, message []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write(message) + return mac.Sum(nil) +} + +func setPendingSignedSessionGrant(extensionID, grant string) { + extensionID = strings.TrimSpace(extensionID) + grant = strings.TrimSpace(grant) + if extensionID == "" || grant == "" { + return + } + pendingSignedSessionGrantsMu.Lock() + pendingSignedSessionGrants[extensionID] = grant + pendingSignedSessionGrantsMu.Unlock() +} diff --git a/ios/Runner/AppDelegate.swift b/ios/Runner/AppDelegate.swift index 09a0ba72..65f30cc5 100644 --- a/ios/Runner/AppDelegate.swift +++ b/ios/Runner/AppDelegate.swift @@ -17,6 +17,8 @@ import Gobackend // Import Go framework private var libraryScanProgressTimer: DispatchSourceTimer? private var libraryScanProgressEventSink: FlutterEventSink? private var lastLibraryScanProgressPayload: String? + private var backendChannel: FlutterMethodChannel? + private var pendingSessionGrantEvents: [[String: Any]] = [] /// Currently accessed security-scoped URL for library folder private var activeSecurityScopedURL: URL? @@ -39,6 +41,14 @@ import Gobackend // Import Go framework name: CHANNEL, binaryMessenger: controller.binaryMessenger ) + backendChannel = channel + if !pendingSessionGrantEvents.isEmpty { + let events = pendingSessionGrantEvents + pendingSessionGrantEvents.removeAll() + for event in events { + channel.invokeMethod("extensionSessionGrantCompleted", arguments: event) + } + } let downloadProgressEvents = FlutterEventChannel( name: DOWNLOAD_PROGRESS_STREAM_CHANNEL, binaryMessenger: controller.binaryMessenger @@ -83,20 +93,25 @@ import Gobackend // Import Go framework return super.application(application, didFinishLaunchingWithOptions: launchOptions) } - /// PKCE OAuth return URL: spotiflac://callback?code=...&state= + /// Extension return URLs: + /// - OAuth: spotiflac://callback?code=...&state= + /// - Signed session: spotiflac://session-grant?grant=...&state= @discardableResult private func handleExtensionOAuthRedirect(url: URL) -> Bool { guard let scheme = url.scheme?.lowercased(), scheme == "spotiflac" else { return false } let host = (url.host ?? "").lowercased() let path = url.path.lowercased() + let isSessionGrant = host == "session-grant" let ok = - host == "callback" || host == "spotify-callback" || path.contains("callback") + isSessionGrant || host == "callback" || host == "spotify-callback" || path.contains("callback") guard ok else { return false } guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { return false } let q = components.queryItems ?? [] let code = + q.first { $0.name == (isSessionGrant ? "grant" : "code") }?.value?.trimmingCharacters( + in: .whitespacesAndNewlines) ?? q.first { $0.name == "code" }?.value?.trimmingCharacters( in: .whitespacesAndNewlines) ?? "" let state = @@ -109,16 +124,37 @@ import Gobackend // Import Go framework } streamQueue.async { var err: NSError? - GobackendSetExtensionAuthCodeByID(state, code) - _ = GobackendInvokeExtensionActionJSON(state, "completeSpotifyLogin", &err) + if isSessionGrant { + GobackendSetExtensionSessionGrantByID(state, code) + _ = GobackendInvokeExtensionActionJSON(state, "completeGrant", &err) + } else { + GobackendSetExtensionAuthCodeByID(state, code) + _ = GobackendInvokeExtensionActionJSON(state, "completeSpotifyLogin", &err) + } if let err = err { NSLog( - "SpotiFLAC: Extension OAuth complete failed: \(err.localizedDescription)") + "SpotiFLAC: Extension callback complete failed: \(err.localizedDescription)") + } else if isSessionGrant { + DispatchQueue.main.async { [weak self] in + self?.notifySessionGrantCompleted(extensionId: state) + } } } return true } + private func notifySessionGrantCompleted(extensionId: String) { + let payload: [String: Any] = [ + "extension_id": extensionId, + "success": true, + ] + if let channel = backendChannel { + channel.invokeMethod("extensionSessionGrantCompleted", arguments: payload) + } else { + pendingSessionGrantEvents.append(payload) + } + } + override func application( _ app: UIApplication, open url: URL, diff --git a/lib/models/download_item.dart b/lib/models/download_item.dart index b4662360..8c9534f9 100644 --- a/lib/models/download_item.dart +++ b/lib/models/download_item.dart @@ -12,7 +12,14 @@ enum DownloadStatus { skipped, } -enum DownloadErrorType { unknown, notFound, rateLimit, network, permission } +enum DownloadErrorType { + unknown, + notFound, + rateLimit, + network, + permission, + verificationRequired, +} @JsonSerializable() class DownloadItem { @@ -94,6 +101,8 @@ class DownloadItem { return 'Connection failed, check your internet'; case DownloadErrorType.permission: return 'Cannot write to folder, check storage permission'; + case DownloadErrorType.verificationRequired: + return 'Verification required. Open the extension and complete the security check.'; default: return error ?? 'An error occurred'; } diff --git a/lib/models/download_item.g.dart b/lib/models/download_item.g.dart index 7aee835c..0cf5e622 100644 --- a/lib/models/download_item.g.dart +++ b/lib/models/download_item.g.dart @@ -58,4 +58,5 @@ const _$DownloadErrorTypeEnumMap = { DownloadErrorType.rateLimit: 'rateLimit', DownloadErrorType.network: 'network', DownloadErrorType.permission: 'permission', + DownloadErrorType.verificationRequired: 'verificationRequired', }; diff --git a/lib/providers/download_queue_provider.dart b/lib/providers/download_queue_provider.dart index 39602995..44058f77 100644 --- a/lib/providers/download_queue_provider.dart +++ b/lib/providers/download_queue_provider.dart @@ -22,6 +22,7 @@ import 'package:spotiflac_android/utils/file_access.dart'; import 'package:spotiflac_android/utils/string_utils.dart'; import 'package:spotiflac_android/utils/artist_utils.dart'; import 'package:spotiflac_android/utils/int_utils.dart'; +import 'package:spotiflac_android/utils/extension_auth_launcher.dart'; export 'package:spotiflac_android/services/history_database.dart' show HistoryLookupRequest, HistoryBatchLookupRequest; @@ -4902,7 +4903,8 @@ class DownloadQueueNotifier extends Notifier { 'totalDiscs': track.totalDiscs!.toString(), if (track.isrc != null) 'isrc': track.isrc!, if (label != null && label.isNotEmpty) 'label': label, - if (copyright != null && copyright.isNotEmpty) 'copyright': copyright, + if (copyright != null && copyright.isNotEmpty) + 'copyright': copyright, if (shouldEmbedLyrics) 'lyrics': ?lrcContent, }; final ac4Result = await PlatformBridge.writeAC4Metadata( @@ -6636,6 +6638,8 @@ class DownloadQueueNotifier extends Notifier { return DownloadErrorType.network; case 'permission': return DownloadErrorType.permission; + case 'verification_required': + return DownloadErrorType.verificationRequired; default: return DownloadErrorType.unknown; } @@ -6643,6 +6647,9 @@ class DownloadQueueNotifier extends Notifier { DownloadErrorType _downloadErrorTypeFromMessage(String errorMsg) { final lowerMsg = errorMsg.toLowerCase(); + if (isExtensionVerificationRequired(errorMsg)) { + return DownloadErrorType.verificationRequired; + } if (errorMsg.contains('429') || lowerMsg.contains('rate limit') || lowerMsg.contains('too many requests')) { @@ -7609,7 +7616,10 @@ class DownloadQueueNotifier extends Notifier { // Repair AC-4 (dac4 + ISO MP4) using the still-present encrypted // source. No-op for other codecs. try { - await PlatformBridge.ensureAC4Config(decryptedTempPath, tempPath); + await PlatformBridge.ensureAC4Config( + decryptedTempPath, + tempPath, + ); } catch (e) { _log.w('AC-4 container repair skipped: $e'); } @@ -7688,7 +7698,10 @@ class DownloadQueueNotifier extends Notifier { // Repair AC-4 (dac4 + ISO MP4) using the still-present encrypted // source before discarding it. No-op for other codecs. try { - await PlatformBridge.ensureAC4Config(decryptedPath, encryptedSource); + await PlatformBridge.ensureAC4Config( + decryptedPath, + encryptedSource, + ); } catch (e) { _log.w('AC-4 container repair skipped: $e'); } @@ -8862,6 +8875,9 @@ class DownloadQueueNotifier extends Notifier { case 'permission': errorType = DownloadErrorType.permission; break; + case 'verification_required': + errorType = DownloadErrorType.verificationRequired; + break; default: errorType = _downloadErrorTypeFromMessage(errorMsg); } @@ -8873,6 +8889,9 @@ class DownloadQueueNotifier extends Notifier { error: errorMsg, errorType: errorType, ); + if (errorType == DownloadErrorType.verificationRequired) { + unawaited(openPendingExtensionVerification(item.service)); + } _failedInSession++; try { @@ -8927,6 +8946,9 @@ class DownloadQueueNotifier extends Notifier { error: errorMsg, errorType: errorType, ); + if (errorType == DownloadErrorType.verificationRequired) { + unawaited(openPendingExtensionVerification(item.service)); + } _failedInSession++; try { diff --git a/lib/providers/track_provider.dart b/lib/providers/track_provider.dart index 63957a62..68daa7b0 100644 --- a/lib/providers/track_provider.dart +++ b/lib/providers/track_provider.dart @@ -1,8 +1,11 @@ +import 'dart:async'; + import 'package:flutter_riverpod/flutter_riverpod.dart'; import 'package:spotiflac_android/models/track.dart'; import 'package:spotiflac_android/services/platform_bridge.dart'; import 'package:spotiflac_android/utils/logger.dart'; import 'package:spotiflac_android/utils/string_utils.dart'; +import 'package:spotiflac_android/utils/extension_auth_launcher.dart'; import 'package:spotiflac_android/providers/settings_provider.dart'; import 'package:spotiflac_android/providers/extension_provider.dart'; @@ -195,9 +198,20 @@ class SearchPlaylist { class TrackNotifier extends Notifier { int _currentRequestId = 0; + StreamSubscription? _sessionGrantSub; + _PendingVerificationSearch? _pendingVerificationSearch; + bool _retryingPendingVerificationSearch = false; @override TrackState build() { + _sessionGrantSub ??= PlatformBridge.extensionSessionGrantEvents().listen( + _handleExtensionSessionGrantCompleted, + ); + ref.onDispose(() { + _sessionGrantSub?.cancel(); + _sessionGrantSub = null; + _pendingVerificationSearch = null; + }); return const TrackState(); } @@ -314,7 +328,8 @@ class TrackNotifier extends Notifier { .map((a) => _parseArtistAlbum(a as Map)) .toList(); - final topTracksList = artistData['top_tracks'] as List? ?? []; + final topTracksList = + artistData['top_tracks'] as List? ?? []; final topTracks = topTracksList .map( (t) => _parseSearchTrack( @@ -359,10 +374,7 @@ class TrackNotifier extends Notifier { } } - Future search( - String query, { - String? filterOverride, - }) async { + Future search(String query, {String? filterOverride}) async { final requestId = ++_currentRequestId; final currentFilter = filterOverride ?? state.selectedSearchFilter; final requestFilter = currentFilter == 'all' ? null : currentFilter; @@ -601,6 +613,7 @@ class TrackNotifier extends Notifier { _log.i( 'Custom search complete: ${tracks.length} tracks parsed (source=$extensionId)', ); + _clearPendingVerificationSearch(extensionId, query, currentFilter); state = TrackState( tracks: tracks, @@ -614,6 +627,18 @@ class TrackNotifier extends Notifier { } catch (e, stackTrace) { if (!_isRequestValid(requestId)) return; _log.e('Custom search failed: $e', e, stackTrace); + if (isExtensionVerificationRequired(e)) { + _pendingVerificationSearch = _PendingVerificationSearch( + extensionId: extensionId, + query: query, + options: Map.from( + options ?? const {}, + ), + selectedFilter: currentFilter, + createdAt: DateTime.now(), + ); + await openPendingExtensionVerification(extensionId); + } state = TrackState( isLoading: false, error: e.toString(), @@ -624,6 +649,49 @@ class TrackNotifier extends Notifier { } } + void _clearPendingVerificationSearch( + String extensionId, + String query, + String? selectedFilter, + ) { + final pending = _pendingVerificationSearch; + if (pending == null) return; + if (pending.extensionId == extensionId && + pending.query == query && + pending.selectedFilter == selectedFilter) { + _pendingVerificationSearch = null; + } + } + + void _handleExtensionSessionGrantCompleted(ExtensionSessionGrantEvent event) { + if (!event.success || _retryingPendingVerificationSearch) return; + final pending = _pendingVerificationSearch; + if (pending == null || pending.extensionId != event.extensionId) return; + if (DateTime.now().difference(pending.createdAt) > + const Duration(minutes: 10)) { + _pendingVerificationSearch = null; + return; + } + + _pendingVerificationSearch = null; + _retryingPendingVerificationSearch = true; + Future.delayed(const Duration(milliseconds: 300), () async { + try { + _log.i( + 'Retrying custom search after verification: extension=${pending.extensionId}', + ); + await customSearch( + pending.extensionId, + pending.query, + options: pending.options, + selectedFilter: pending.selectedFilter, + ); + } finally { + _retryingPendingVerificationSearch = false; + } + }); + } + Future checkAvailability(int index) async { if (index < 0 || index >= state.tracks.length) return; @@ -826,7 +894,22 @@ class TrackNotifier extends Notifier { totalTracks: data['total_tracks'] as int? ?? 0, ); } +} +class _PendingVerificationSearch { + final String extensionId; + final String query; + final Map options; + final String? selectedFilter; + final DateTime createdAt; + + const _PendingVerificationSearch({ + required this.extensionId, + required this.query, + required this.options, + required this.selectedFilter, + required this.createdAt, + }); } final trackProvider = NotifierProvider( diff --git a/lib/screens/settings/extensions_page.dart b/lib/screens/settings/extensions_page.dart index 89cd193a..e48958c3 100644 --- a/lib/screens/settings/extensions_page.dart +++ b/lib/screens/settings/extensions_page.dart @@ -293,9 +293,11 @@ class _ExtensionsPageState extends ConsumerState { .map((file) => file.path) .whereType() .toList(); - final extensionPaths = selectedPaths - .where((path) => path.toLowerCase().endsWith('.spotiflac-ext')) - .toList(); + final extensionPaths = selectedPaths.where((path) { + final lowerPath = path.toLowerCase(); + return lowerPath.endsWith('.spotiflac-ext') || + lowerPath.endsWith('.sflx'); + }).toList(); if (extensionPaths.length != selectedPaths.length) { if (mounted) { diff --git a/lib/services/platform_bridge.dart b/lib/services/platform_bridge.dart index d96bcfaf..29758ac1 100644 --- a/lib/services/platform_bridge.dart +++ b/lib/services/platform_bridge.dart @@ -12,6 +12,16 @@ final _log = AppLogger('PlatformBridge'); Object? _decodeJsonInBackground(String json) => jsonDecode(json); +class ExtensionSessionGrantEvent { + final String extensionId; + final bool success; + + const ExtensionSessionGrantEvent({ + required this.extensionId, + required this.success, + }); +} + class _BridgeCacheEntry { final Map value; final DateTime expiresAt; @@ -76,12 +86,46 @@ class PlatformBridge { static Future? _persistentLookupCacheLoadFuture; static int _lookupCacheGeneration = 0; static int _extensionRequestSequence = 0; + static final StreamController + _extensionSessionGrantEvents = + StreamController.broadcast(); + static bool _backendEventHandlerInstalled = false; static bool get supportsCoreBackend => Platform.isAndroid || Platform.isIOS; static bool get supportsExtensionSystem => Platform.isAndroid || Platform.isIOS; + static Stream extensionSessionGrantEvents() { + _ensureBackendEventHandler(); + return _extensionSessionGrantEvents.stream; + } + + static void _ensureBackendEventHandler() { + if (_backendEventHandlerInstalled) return; + _backendEventHandlerInstalled = true; + _channel.setMethodCallHandler((call) async { + switch (call.method) { + case 'extensionSessionGrantCompleted': + final args = call.arguments; + if (args is Map) { + final extensionId = args['extension_id']?.toString().trim() ?? ''; + if (extensionId.isNotEmpty) { + _extensionSessionGrantEvents.add( + ExtensionSessionGrantEvent( + extensionId: extensionId, + success: args['success'] != false, + ), + ); + } + } + return null; + default: + return null; + } + }); + } + static Future> checkAvailability( String spotifyId, String isrc, diff --git a/lib/utils/extension_auth_launcher.dart b/lib/utils/extension_auth_launcher.dart new file mode 100644 index 00000000..ed7b0ba8 --- /dev/null +++ b/lib/utils/extension_auth_launcher.dart @@ -0,0 +1,42 @@ +import 'package:spotiflac_android/services/platform_bridge.dart'; +import 'package:spotiflac_android/utils/logger.dart'; +import 'package:url_launcher/url_launcher.dart'; + +final _log = AppLogger('ExtensionAuthLauncher'); + +bool isExtensionVerificationRequired(Object error) { + final message = error.toString().toLowerCase(); + return message.contains('verify_required') || + message.contains('verification_required') || + message.contains('needsverification') || + message.contains('needs verification'); +} + +Future openPendingExtensionVerification(String extensionId) async { + final normalizedExtensionId = extensionId.trim(); + if (normalizedExtensionId.isEmpty) return; + + try { + final pending = await PlatformBridge.getExtensionPendingAuth( + normalizedExtensionId, + ); + final authUrl = pending?['auth_url']?.toString().trim() ?? ''; + if (authUrl.isEmpty) return; + + final uri = Uri.tryParse(authUrl); + if (uri == null) return; + + final launched = await launchUrl(uri, mode: LaunchMode.externalApplication); + if (launched) { + _log.i('Opened verification challenge for $normalizedExtensionId'); + } else { + _log.w( + 'Could not open verification challenge for $normalizedExtensionId', + ); + } + } catch (e) { + _log.w( + 'Failed to open verification challenge for $normalizedExtensionId: $e', + ); + } +}