mirror of
https://github.com/aloshdenny/reverse-SynthID.git
synced 2026-04-30 02:27:49 +02:00
sparsified spectral codebook
This commit is contained in:
@@ -1975,44 +1975,114 @@ class SpectralCodebook:
|
||||
# Save / Load
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# rfft symmetry helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _rfft_to_full_sym(rfft_half, H, W):
|
||||
"""(H, W//2+1, C) → (H, W, C) via conjugate symmetry."""
|
||||
rw = W // 2 + 1
|
||||
full = np.zeros((H, W) + rfft_half.shape[2:], dtype=rfft_half.dtype)
|
||||
full[:, :rw] = rfft_half
|
||||
if W > 2:
|
||||
sky = (H - np.arange(H)) % H
|
||||
skx = W - np.arange(rw, W)
|
||||
full[:, rw:] = full[sky[:, None], skx[None, :]]
|
||||
return full
|
||||
|
||||
@staticmethod
|
||||
def _rfft_to_full_anti(rfft_half, H, W):
|
||||
"""Anti-symmetric variant (for phase)."""
|
||||
rw = W // 2 + 1
|
||||
full = np.zeros((H, W) + rfft_half.shape[2:], dtype=rfft_half.dtype)
|
||||
full[:, :rw] = rfft_half
|
||||
if W > 2:
|
||||
sky = (H - np.arange(H)) % H
|
||||
skx = W - np.arange(rw, W)
|
||||
full[:, rw:] = -full[sky[:, None], skx[None, :]]
|
||||
return full
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Save / Load — compact format
|
||||
#
|
||||
# Reduces codebook size ~25x through:
|
||||
# 1. rfft2 half-spectrum (conjugate symmetry → 2x)
|
||||
# 2. Drop derivable arrays (content_baseline, white_phase_*)
|
||||
# 3. float16 for magnitudes (log2-encoded) and phase
|
||||
# 4. uint8 for consistency and agreement ([0,1] → 0..255)
|
||||
# 5. Sparse storage for profiles where <50% of bins are active
|
||||
# (bins with consistency < 0.15 have zero bypass contribution)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_CONS_THRESHOLD = 0.15 # minimum cons_floor across all strength levels
|
||||
|
||||
def save(self, path: str):
|
||||
"""Save all profiles to a single .npz file."""
|
||||
data = {'resolutions': np.array(list(self.profiles.keys()))}
|
||||
"""Save in compact format (~20 MB for typical dual-resolution codebook)."""
|
||||
data = {
|
||||
'format_version': np.array(2),
|
||||
'resolutions': np.array(list(self.profiles.keys())),
|
||||
}
|
||||
|
||||
for (h, w), prof in self.profiles.items():
|
||||
pfx = f'{h}x{w}/'
|
||||
for key in self._PROFILE_ARRAYS:
|
||||
if prof.get(key) is not None:
|
||||
data[pfx + key] = prof[key]
|
||||
rw = w // 2 + 1
|
||||
|
||||
for key in self._PROFILE_SCALARS:
|
||||
data[pfx + key] = np.array(prof.get(key, 0))
|
||||
np.savez_compressed(path, **data)
|
||||
|
||||
mag = prof['magnitude_profile']
|
||||
phase = prof['phase_template']
|
||||
cons = prof['phase_consistency']
|
||||
|
||||
mag_r = mag[:, :rw, :]
|
||||
phase_r = phase[:, :rw, :]
|
||||
cons_r = cons[:, :rw, :]
|
||||
|
||||
active_frac = float(np.mean(cons_r > self._CONS_THRESHOLD))
|
||||
use_sparse = active_frac < 0.50
|
||||
data[pfx + 'sparse'] = np.array(int(use_sparse))
|
||||
|
||||
if use_sparse:
|
||||
for ch in range(3):
|
||||
mask = cons_r[:, :, ch] >= self._CONS_THRESHOLD
|
||||
idx = np.where(mask.ravel())[0].astype(np.uint32)
|
||||
vals_m = mag_r[:, :, ch].ravel()[idx]
|
||||
vals_p = phase_r[:, :, ch].ravel()[idx]
|
||||
vals_c = cons_r[:, :, ch].ravel()[idx]
|
||||
data[pfx + f'idx_{ch}'] = idx
|
||||
data[pfx + f'mag_{ch}'] = np.log2(1.0 + vals_m).astype(np.float16)
|
||||
data[pfx + f'phase_{ch}'] = vals_p.astype(np.float16)
|
||||
data[pfx + f'cons_{ch}'] = np.round(vals_c * 255).clip(0, 255).astype(np.uint8)
|
||||
else:
|
||||
data[pfx + 'mag'] = np.log2(1.0 + mag_r).astype(np.float16)
|
||||
data[pfx + 'phase'] = phase_r.astype(np.float16)
|
||||
data[pfx + 'cons'] = np.round(cons_r * 255).clip(0, 255).astype(np.uint8)
|
||||
|
||||
wmag = prof.get('white_magnitude_profile')
|
||||
if wmag is not None:
|
||||
data[pfx + 'wmag'] = np.log2(1.0 + wmag[:, :rw, :]).astype(np.float16)
|
||||
|
||||
agree = prof.get('black_white_agreement')
|
||||
if agree is not None:
|
||||
data[pfx + 'agree'] = np.round(agree[:, :rw, :] * 255).clip(0, 255).astype(np.uint8)
|
||||
|
||||
np.savez(path, **data)
|
||||
sz = os.path.getsize(path) / 1e6
|
||||
res_str = ', '.join(f'{h}x{w}' for h, w in self.profiles)
|
||||
print(f"Codebook saved → {path} [{res_str}]")
|
||||
print(f"Codebook saved → {path} [{res_str}] {sz:.1f} MB")
|
||||
|
||||
def load(self, path: str):
|
||||
"""Load profiles. Supports multi-res and legacy single-res files."""
|
||||
"""Load codebook (auto-detects format version)."""
|
||||
d = np.load(path)
|
||||
if 'resolutions' in d:
|
||||
for res in d['resolutions']:
|
||||
h, w = int(res[0]), int(res[1])
|
||||
pfx = f'{h}x{w}/'
|
||||
prof = {}
|
||||
for key in self._PROFILE_ARRAYS:
|
||||
fk = pfx + key
|
||||
prof[key] = d[fk] if fk in d else None
|
||||
for key in self._PROFILE_SCALARS:
|
||||
fk = pfx + key
|
||||
prof[key] = int(d[fk]) if fk in d else 0
|
||||
self.profiles[(h, w)] = prof
|
||||
fmt = int(d['format_version']) if 'format_version' in d else 0
|
||||
|
||||
if fmt >= 2:
|
||||
self._load_compact(d)
|
||||
elif 'resolutions' in d:
|
||||
self._load_v1(d)
|
||||
else:
|
||||
h, w = int(d['ref_shape'][0]), int(d['ref_shape'][1])
|
||||
prof = {}
|
||||
for key in self._PROFILE_ARRAYS:
|
||||
prof[key] = d[key] if key in d else None
|
||||
prof['n_black_refs'] = int(d['n_black_refs']) if 'n_black_refs' in d else 0
|
||||
prof['n_white_refs'] = int(d['n_white_refs']) if 'n_white_refs' in d else 0
|
||||
prof['n_random_refs'] = int(d['n_random_refs']) if 'n_random_refs' in d else 0
|
||||
self.profiles[(h, w)] = prof
|
||||
self._load_legacy(d)
|
||||
|
||||
res_str = ', '.join(f'{h}x{w}' for h, w in self.profiles)
|
||||
print(f"Codebook loaded: [{res_str}]")
|
||||
@@ -2020,6 +2090,89 @@ class SpectralCodebook:
|
||||
nb, nw, nr = prof['n_black_refs'], prof['n_white_refs'], prof['n_random_refs']
|
||||
print(f" {h}x{w}: {nb}b+{nw}w+{nr}r")
|
||||
|
||||
def _load_compact(self, d):
|
||||
"""Load format_version >= 2 (rfft + mixed precision + optional sparse)."""
|
||||
for res in d['resolutions']:
|
||||
h, w = int(res[0]), int(res[1])
|
||||
pfx = f'{h}x{w}/'
|
||||
rw = w // 2 + 1
|
||||
use_sparse = bool(int(d[pfx + 'sparse']))
|
||||
|
||||
if use_sparse:
|
||||
mag_r = np.zeros((h, rw, 3), dtype=np.float64)
|
||||
phase_r = np.zeros((h, rw, 3), dtype=np.float64)
|
||||
cons_r = np.zeros((h, rw, 3), dtype=np.float64)
|
||||
for ch in range(3):
|
||||
idx = d[pfx + f'idx_{ch}']
|
||||
rows, cols = np.unravel_index(idx, (h, rw))
|
||||
mag_r[rows, cols, ch] = np.power(2.0, d[pfx + f'mag_{ch}'].astype(np.float64)) - 1.0
|
||||
phase_r[rows, cols, ch] = d[pfx + f'phase_{ch}'].astype(np.float64)
|
||||
cons_r[rows, cols, ch] = d[pfx + f'cons_{ch}'].astype(np.float64) / 255.0
|
||||
else:
|
||||
mag_r = np.power(2.0, d[pfx + 'mag'].astype(np.float64)) - 1.0
|
||||
phase_r = d[pfx + 'phase'].astype(np.float64)
|
||||
cons_r = d[pfx + 'cons'].astype(np.float64) / 255.0
|
||||
|
||||
mag_full = self._rfft_to_full_sym(mag_r, h, w)
|
||||
phase_full = self._rfft_to_full_anti(phase_r, h, w)
|
||||
cons_full = self._rfft_to_full_sym(cons_r, h, w)
|
||||
|
||||
wmag_full = None
|
||||
if pfx + 'wmag' in d:
|
||||
wmag_r = np.power(2.0, d[pfx + 'wmag'].astype(np.float64)) - 1.0
|
||||
wmag_full = self._rfft_to_full_sym(wmag_r, h, w)
|
||||
|
||||
agree_full = None
|
||||
if pfx + 'agree' in d:
|
||||
agree_r = d[pfx + 'agree'].astype(np.float64) / 255.0
|
||||
agree_full = self._rfft_to_full_sym(agree_r, h, w)
|
||||
|
||||
# Reconstruct content_baseline for watermarked-only profiles
|
||||
content_base = None
|
||||
nb = int(d.get(pfx + 'n_black_refs', 0))
|
||||
if nb == 0:
|
||||
safe = np.maximum(cons_full ** 2, 1e-10)
|
||||
content_base = mag_full / safe
|
||||
|
||||
self.profiles[(h, w)] = {
|
||||
'magnitude_profile': mag_full,
|
||||
'phase_template': phase_full,
|
||||
'phase_consistency': cons_full,
|
||||
'content_magnitude_baseline': content_base,
|
||||
'white_magnitude_profile': wmag_full,
|
||||
'white_phase_template': None,
|
||||
'white_phase_consistency': None,
|
||||
'black_white_agreement': agree_full,
|
||||
'n_black_refs': nb,
|
||||
'n_white_refs': int(d.get(pfx + 'n_white_refs', 0)),
|
||||
'n_random_refs': int(d.get(pfx + 'n_random_refs', 0)),
|
||||
}
|
||||
|
||||
def _load_v1(self, d):
|
||||
"""Load v1 multi-resolution format (full-precision arrays)."""
|
||||
for res in d['resolutions']:
|
||||
h, w = int(res[0]), int(res[1])
|
||||
pfx = f'{h}x{w}/'
|
||||
prof = {}
|
||||
for key in self._PROFILE_ARRAYS:
|
||||
fk = pfx + key
|
||||
prof[key] = d[fk] if fk in d else None
|
||||
for key in self._PROFILE_SCALARS:
|
||||
fk = pfx + key
|
||||
prof[key] = int(d[fk]) if fk in d else 0
|
||||
self.profiles[(h, w)] = prof
|
||||
|
||||
def _load_legacy(self, d):
|
||||
"""Load original single-resolution format."""
|
||||
h, w = int(d['ref_shape'][0]), int(d['ref_shape'][1])
|
||||
prof = {}
|
||||
for key in self._PROFILE_ARRAYS:
|
||||
prof[key] = d[key] if key in d else None
|
||||
prof['n_black_refs'] = int(d['n_black_refs']) if 'n_black_refs' in d else 0
|
||||
prof['n_white_refs'] = int(d['n_white_refs']) if 'n_white_refs' in d else 0
|
||||
prof['n_random_refs'] = int(d['n_random_refs']) if 'n_random_refs' in d else 0
|
||||
self.profiles[(h, w)] = prof
|
||||
|
||||
|
||||
# ================================================================
|
||||
# CLI INTERFACE
|
||||
|
||||
Reference in New Issue
Block a user