diff --git a/src/extraction/synthid_bypass.py b/src/extraction/synthid_bypass.py index f3180aa..5ac2b2f 100644 --- a/src/extraction/synthid_bypass.py +++ b/src/extraction/synthid_bypass.py @@ -1975,44 +1975,114 @@ class SpectralCodebook: # Save / Load # ------------------------------------------------------------------ + # ------------------------------------------------------------------ + # rfft symmetry helpers + # ------------------------------------------------------------------ + + @staticmethod + def _rfft_to_full_sym(rfft_half, H, W): + """(H, W//2+1, C) → (H, W, C) via conjugate symmetry.""" + rw = W // 2 + 1 + full = np.zeros((H, W) + rfft_half.shape[2:], dtype=rfft_half.dtype) + full[:, :rw] = rfft_half + if W > 2: + sky = (H - np.arange(H)) % H + skx = W - np.arange(rw, W) + full[:, rw:] = full[sky[:, None], skx[None, :]] + return full + + @staticmethod + def _rfft_to_full_anti(rfft_half, H, W): + """Anti-symmetric variant (for phase).""" + rw = W // 2 + 1 + full = np.zeros((H, W) + rfft_half.shape[2:], dtype=rfft_half.dtype) + full[:, :rw] = rfft_half + if W > 2: + sky = (H - np.arange(H)) % H + skx = W - np.arange(rw, W) + full[:, rw:] = -full[sky[:, None], skx[None, :]] + return full + + # ------------------------------------------------------------------ + # Save / Load — compact format + # + # Reduces codebook size ~25x through: + # 1. rfft2 half-spectrum (conjugate symmetry → 2x) + # 2. Drop derivable arrays (content_baseline, white_phase_*) + # 3. float16 for magnitudes (log2-encoded) and phase + # 4. uint8 for consistency and agreement ([0,1] → 0..255) + # 5. Sparse storage for profiles where <50% of bins are active + # (bins with consistency < 0.15 have zero bypass contribution) + # ------------------------------------------------------------------ + + _CONS_THRESHOLD = 0.15 # minimum cons_floor across all strength levels + def save(self, path: str): - """Save all profiles to a single .npz file.""" - data = {'resolutions': np.array(list(self.profiles.keys()))} + """Save in compact format (~20 MB for typical dual-resolution codebook).""" + data = { + 'format_version': np.array(2), + 'resolutions': np.array(list(self.profiles.keys())), + } + for (h, w), prof in self.profiles.items(): pfx = f'{h}x{w}/' - for key in self._PROFILE_ARRAYS: - if prof.get(key) is not None: - data[pfx + key] = prof[key] + rw = w // 2 + 1 + for key in self._PROFILE_SCALARS: data[pfx + key] = np.array(prof.get(key, 0)) - np.savez_compressed(path, **data) + + mag = prof['magnitude_profile'] + phase = prof['phase_template'] + cons = prof['phase_consistency'] + + mag_r = mag[:, :rw, :] + phase_r = phase[:, :rw, :] + cons_r = cons[:, :rw, :] + + active_frac = float(np.mean(cons_r > self._CONS_THRESHOLD)) + use_sparse = active_frac < 0.50 + data[pfx + 'sparse'] = np.array(int(use_sparse)) + + if use_sparse: + for ch in range(3): + mask = cons_r[:, :, ch] >= self._CONS_THRESHOLD + idx = np.where(mask.ravel())[0].astype(np.uint32) + vals_m = mag_r[:, :, ch].ravel()[idx] + vals_p = phase_r[:, :, ch].ravel()[idx] + vals_c = cons_r[:, :, ch].ravel()[idx] + data[pfx + f'idx_{ch}'] = idx + data[pfx + f'mag_{ch}'] = np.log2(1.0 + vals_m).astype(np.float16) + data[pfx + f'phase_{ch}'] = vals_p.astype(np.float16) + data[pfx + f'cons_{ch}'] = np.round(vals_c * 255).clip(0, 255).astype(np.uint8) + else: + data[pfx + 'mag'] = np.log2(1.0 + mag_r).astype(np.float16) + data[pfx + 'phase'] = phase_r.astype(np.float16) + data[pfx + 'cons'] = np.round(cons_r * 255).clip(0, 255).astype(np.uint8) + + wmag = prof.get('white_magnitude_profile') + if wmag is not None: + data[pfx + 'wmag'] = np.log2(1.0 + wmag[:, :rw, :]).astype(np.float16) + + agree = prof.get('black_white_agreement') + if agree is not None: + data[pfx + 'agree'] = np.round(agree[:, :rw, :] * 255).clip(0, 255).astype(np.uint8) + + np.savez(path, **data) + sz = os.path.getsize(path) / 1e6 res_str = ', '.join(f'{h}x{w}' for h, w in self.profiles) - print(f"Codebook saved → {path} [{res_str}]") + print(f"Codebook saved → {path} [{res_str}] {sz:.1f} MB") def load(self, path: str): - """Load profiles. Supports multi-res and legacy single-res files.""" + """Load codebook (auto-detects format version).""" d = np.load(path) - if 'resolutions' in d: - for res in d['resolutions']: - h, w = int(res[0]), int(res[1]) - pfx = f'{h}x{w}/' - prof = {} - for key in self._PROFILE_ARRAYS: - fk = pfx + key - prof[key] = d[fk] if fk in d else None - for key in self._PROFILE_SCALARS: - fk = pfx + key - prof[key] = int(d[fk]) if fk in d else 0 - self.profiles[(h, w)] = prof + fmt = int(d['format_version']) if 'format_version' in d else 0 + + if fmt >= 2: + self._load_compact(d) + elif 'resolutions' in d: + self._load_v1(d) else: - h, w = int(d['ref_shape'][0]), int(d['ref_shape'][1]) - prof = {} - for key in self._PROFILE_ARRAYS: - prof[key] = d[key] if key in d else None - prof['n_black_refs'] = int(d['n_black_refs']) if 'n_black_refs' in d else 0 - prof['n_white_refs'] = int(d['n_white_refs']) if 'n_white_refs' in d else 0 - prof['n_random_refs'] = int(d['n_random_refs']) if 'n_random_refs' in d else 0 - self.profiles[(h, w)] = prof + self._load_legacy(d) res_str = ', '.join(f'{h}x{w}' for h, w in self.profiles) print(f"Codebook loaded: [{res_str}]") @@ -2020,6 +2090,89 @@ class SpectralCodebook: nb, nw, nr = prof['n_black_refs'], prof['n_white_refs'], prof['n_random_refs'] print(f" {h}x{w}: {nb}b+{nw}w+{nr}r") + def _load_compact(self, d): + """Load format_version >= 2 (rfft + mixed precision + optional sparse).""" + for res in d['resolutions']: + h, w = int(res[0]), int(res[1]) + pfx = f'{h}x{w}/' + rw = w // 2 + 1 + use_sparse = bool(int(d[pfx + 'sparse'])) + + if use_sparse: + mag_r = np.zeros((h, rw, 3), dtype=np.float64) + phase_r = np.zeros((h, rw, 3), dtype=np.float64) + cons_r = np.zeros((h, rw, 3), dtype=np.float64) + for ch in range(3): + idx = d[pfx + f'idx_{ch}'] + rows, cols = np.unravel_index(idx, (h, rw)) + mag_r[rows, cols, ch] = np.power(2.0, d[pfx + f'mag_{ch}'].astype(np.float64)) - 1.0 + phase_r[rows, cols, ch] = d[pfx + f'phase_{ch}'].astype(np.float64) + cons_r[rows, cols, ch] = d[pfx + f'cons_{ch}'].astype(np.float64) / 255.0 + else: + mag_r = np.power(2.0, d[pfx + 'mag'].astype(np.float64)) - 1.0 + phase_r = d[pfx + 'phase'].astype(np.float64) + cons_r = d[pfx + 'cons'].astype(np.float64) / 255.0 + + mag_full = self._rfft_to_full_sym(mag_r, h, w) + phase_full = self._rfft_to_full_anti(phase_r, h, w) + cons_full = self._rfft_to_full_sym(cons_r, h, w) + + wmag_full = None + if pfx + 'wmag' in d: + wmag_r = np.power(2.0, d[pfx + 'wmag'].astype(np.float64)) - 1.0 + wmag_full = self._rfft_to_full_sym(wmag_r, h, w) + + agree_full = None + if pfx + 'agree' in d: + agree_r = d[pfx + 'agree'].astype(np.float64) / 255.0 + agree_full = self._rfft_to_full_sym(agree_r, h, w) + + # Reconstruct content_baseline for watermarked-only profiles + content_base = None + nb = int(d.get(pfx + 'n_black_refs', 0)) + if nb == 0: + safe = np.maximum(cons_full ** 2, 1e-10) + content_base = mag_full / safe + + self.profiles[(h, w)] = { + 'magnitude_profile': mag_full, + 'phase_template': phase_full, + 'phase_consistency': cons_full, + 'content_magnitude_baseline': content_base, + 'white_magnitude_profile': wmag_full, + 'white_phase_template': None, + 'white_phase_consistency': None, + 'black_white_agreement': agree_full, + 'n_black_refs': nb, + 'n_white_refs': int(d.get(pfx + 'n_white_refs', 0)), + 'n_random_refs': int(d.get(pfx + 'n_random_refs', 0)), + } + + def _load_v1(self, d): + """Load v1 multi-resolution format (full-precision arrays).""" + for res in d['resolutions']: + h, w = int(res[0]), int(res[1]) + pfx = f'{h}x{w}/' + prof = {} + for key in self._PROFILE_ARRAYS: + fk = pfx + key + prof[key] = d[fk] if fk in d else None + for key in self._PROFILE_SCALARS: + fk = pfx + key + prof[key] = int(d[fk]) if fk in d else 0 + self.profiles[(h, w)] = prof + + def _load_legacy(self, d): + """Load original single-resolution format.""" + h, w = int(d['ref_shape'][0]), int(d['ref_shape'][1]) + prof = {} + for key in self._PROFILE_ARRAYS: + prof[key] = d[key] if key in d else None + prof['n_black_refs'] = int(d['n_black_refs']) if 'n_black_refs' in d else 0 + prof['n_white_refs'] = int(d['n_white_refs']) if 'n_white_refs' in d else 0 + prof['n_random_refs'] = int(d['n_random_refs']) if 'n_random_refs' in d else 0 + self.profiles[(h, w)] = prof + # ================================================================ # CLI INTERFACE