sparsified spectral codebook

This commit is contained in:
Alosh Denny
2026-03-28 18:40:00 +05:30
parent 9eadacaced
commit 4e6a9987bb
+181 -28
View File
@@ -1975,44 +1975,114 @@ class SpectralCodebook:
# Save / Load
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# rfft symmetry helpers
# ------------------------------------------------------------------
@staticmethod
def _rfft_to_full_sym(rfft_half, H, W):
"""(H, W//2+1, C) → (H, W, C) via conjugate symmetry."""
rw = W // 2 + 1
full = np.zeros((H, W) + rfft_half.shape[2:], dtype=rfft_half.dtype)
full[:, :rw] = rfft_half
if W > 2:
sky = (H - np.arange(H)) % H
skx = W - np.arange(rw, W)
full[:, rw:] = full[sky[:, None], skx[None, :]]
return full
@staticmethod
def _rfft_to_full_anti(rfft_half, H, W):
"""Anti-symmetric variant (for phase)."""
rw = W // 2 + 1
full = np.zeros((H, W) + rfft_half.shape[2:], dtype=rfft_half.dtype)
full[:, :rw] = rfft_half
if W > 2:
sky = (H - np.arange(H)) % H
skx = W - np.arange(rw, W)
full[:, rw:] = -full[sky[:, None], skx[None, :]]
return full
# ------------------------------------------------------------------
# Save / Load — compact format
#
# Reduces codebook size ~25x through:
# 1. rfft2 half-spectrum (conjugate symmetry → 2x)
# 2. Drop derivable arrays (content_baseline, white_phase_*)
# 3. float16 for magnitudes (log2-encoded) and phase
# 4. uint8 for consistency and agreement ([0,1] → 0..255)
# 5. Sparse storage for profiles where <50% of bins are active
# (bins with consistency < 0.15 have zero bypass contribution)
# ------------------------------------------------------------------
_CONS_THRESHOLD = 0.15 # minimum cons_floor across all strength levels
def save(self, path: str):
"""Save all profiles to a single .npz file."""
data = {'resolutions': np.array(list(self.profiles.keys()))}
"""Save in compact format (~20 MB for typical dual-resolution codebook)."""
data = {
'format_version': np.array(2),
'resolutions': np.array(list(self.profiles.keys())),
}
for (h, w), prof in self.profiles.items():
pfx = f'{h}x{w}/'
for key in self._PROFILE_ARRAYS:
if prof.get(key) is not None:
data[pfx + key] = prof[key]
rw = w // 2 + 1
for key in self._PROFILE_SCALARS:
data[pfx + key] = np.array(prof.get(key, 0))
np.savez_compressed(path, **data)
mag = prof['magnitude_profile']
phase = prof['phase_template']
cons = prof['phase_consistency']
mag_r = mag[:, :rw, :]
phase_r = phase[:, :rw, :]
cons_r = cons[:, :rw, :]
active_frac = float(np.mean(cons_r > self._CONS_THRESHOLD))
use_sparse = active_frac < 0.50
data[pfx + 'sparse'] = np.array(int(use_sparse))
if use_sparse:
for ch in range(3):
mask = cons_r[:, :, ch] >= self._CONS_THRESHOLD
idx = np.where(mask.ravel())[0].astype(np.uint32)
vals_m = mag_r[:, :, ch].ravel()[idx]
vals_p = phase_r[:, :, ch].ravel()[idx]
vals_c = cons_r[:, :, ch].ravel()[idx]
data[pfx + f'idx_{ch}'] = idx
data[pfx + f'mag_{ch}'] = np.log2(1.0 + vals_m).astype(np.float16)
data[pfx + f'phase_{ch}'] = vals_p.astype(np.float16)
data[pfx + f'cons_{ch}'] = np.round(vals_c * 255).clip(0, 255).astype(np.uint8)
else:
data[pfx + 'mag'] = np.log2(1.0 + mag_r).astype(np.float16)
data[pfx + 'phase'] = phase_r.astype(np.float16)
data[pfx + 'cons'] = np.round(cons_r * 255).clip(0, 255).astype(np.uint8)
wmag = prof.get('white_magnitude_profile')
if wmag is not None:
data[pfx + 'wmag'] = np.log2(1.0 + wmag[:, :rw, :]).astype(np.float16)
agree = prof.get('black_white_agreement')
if agree is not None:
data[pfx + 'agree'] = np.round(agree[:, :rw, :] * 255).clip(0, 255).astype(np.uint8)
np.savez(path, **data)
sz = os.path.getsize(path) / 1e6
res_str = ', '.join(f'{h}x{w}' for h, w in self.profiles)
print(f"Codebook saved → {path} [{res_str}]")
print(f"Codebook saved → {path} [{res_str}] {sz:.1f} MB")
def load(self, path: str):
"""Load profiles. Supports multi-res and legacy single-res files."""
"""Load codebook (auto-detects format version)."""
d = np.load(path)
if 'resolutions' in d:
for res in d['resolutions']:
h, w = int(res[0]), int(res[1])
pfx = f'{h}x{w}/'
prof = {}
for key in self._PROFILE_ARRAYS:
fk = pfx + key
prof[key] = d[fk] if fk in d else None
for key in self._PROFILE_SCALARS:
fk = pfx + key
prof[key] = int(d[fk]) if fk in d else 0
self.profiles[(h, w)] = prof
fmt = int(d['format_version']) if 'format_version' in d else 0
if fmt >= 2:
self._load_compact(d)
elif 'resolutions' in d:
self._load_v1(d)
else:
h, w = int(d['ref_shape'][0]), int(d['ref_shape'][1])
prof = {}
for key in self._PROFILE_ARRAYS:
prof[key] = d[key] if key in d else None
prof['n_black_refs'] = int(d['n_black_refs']) if 'n_black_refs' in d else 0
prof['n_white_refs'] = int(d['n_white_refs']) if 'n_white_refs' in d else 0
prof['n_random_refs'] = int(d['n_random_refs']) if 'n_random_refs' in d else 0
self.profiles[(h, w)] = prof
self._load_legacy(d)
res_str = ', '.join(f'{h}x{w}' for h, w in self.profiles)
print(f"Codebook loaded: [{res_str}]")
@@ -2020,6 +2090,89 @@ class SpectralCodebook:
nb, nw, nr = prof['n_black_refs'], prof['n_white_refs'], prof['n_random_refs']
print(f" {h}x{w}: {nb}b+{nw}w+{nr}r")
def _load_compact(self, d):
"""Load format_version >= 2 (rfft + mixed precision + optional sparse)."""
for res in d['resolutions']:
h, w = int(res[0]), int(res[1])
pfx = f'{h}x{w}/'
rw = w // 2 + 1
use_sparse = bool(int(d[pfx + 'sparse']))
if use_sparse:
mag_r = np.zeros((h, rw, 3), dtype=np.float64)
phase_r = np.zeros((h, rw, 3), dtype=np.float64)
cons_r = np.zeros((h, rw, 3), dtype=np.float64)
for ch in range(3):
idx = d[pfx + f'idx_{ch}']
rows, cols = np.unravel_index(idx, (h, rw))
mag_r[rows, cols, ch] = np.power(2.0, d[pfx + f'mag_{ch}'].astype(np.float64)) - 1.0
phase_r[rows, cols, ch] = d[pfx + f'phase_{ch}'].astype(np.float64)
cons_r[rows, cols, ch] = d[pfx + f'cons_{ch}'].astype(np.float64) / 255.0
else:
mag_r = np.power(2.0, d[pfx + 'mag'].astype(np.float64)) - 1.0
phase_r = d[pfx + 'phase'].astype(np.float64)
cons_r = d[pfx + 'cons'].astype(np.float64) / 255.0
mag_full = self._rfft_to_full_sym(mag_r, h, w)
phase_full = self._rfft_to_full_anti(phase_r, h, w)
cons_full = self._rfft_to_full_sym(cons_r, h, w)
wmag_full = None
if pfx + 'wmag' in d:
wmag_r = np.power(2.0, d[pfx + 'wmag'].astype(np.float64)) - 1.0
wmag_full = self._rfft_to_full_sym(wmag_r, h, w)
agree_full = None
if pfx + 'agree' in d:
agree_r = d[pfx + 'agree'].astype(np.float64) / 255.0
agree_full = self._rfft_to_full_sym(agree_r, h, w)
# Reconstruct content_baseline for watermarked-only profiles
content_base = None
nb = int(d.get(pfx + 'n_black_refs', 0))
if nb == 0:
safe = np.maximum(cons_full ** 2, 1e-10)
content_base = mag_full / safe
self.profiles[(h, w)] = {
'magnitude_profile': mag_full,
'phase_template': phase_full,
'phase_consistency': cons_full,
'content_magnitude_baseline': content_base,
'white_magnitude_profile': wmag_full,
'white_phase_template': None,
'white_phase_consistency': None,
'black_white_agreement': agree_full,
'n_black_refs': nb,
'n_white_refs': int(d.get(pfx + 'n_white_refs', 0)),
'n_random_refs': int(d.get(pfx + 'n_random_refs', 0)),
}
def _load_v1(self, d):
"""Load v1 multi-resolution format (full-precision arrays)."""
for res in d['resolutions']:
h, w = int(res[0]), int(res[1])
pfx = f'{h}x{w}/'
prof = {}
for key in self._PROFILE_ARRAYS:
fk = pfx + key
prof[key] = d[fk] if fk in d else None
for key in self._PROFILE_SCALARS:
fk = pfx + key
prof[key] = int(d[fk]) if fk in d else 0
self.profiles[(h, w)] = prof
def _load_legacy(self, d):
"""Load original single-resolution format."""
h, w = int(d['ref_shape'][0]), int(d['ref_shape'][1])
prof = {}
for key in self._PROFILE_ARRAYS:
prof[key] = d[key] if key in d else None
prof['n_black_refs'] = int(d['n_black_refs']) if 'n_black_refs' in d else 0
prof['n_white_refs'] = int(d['n_white_refs']) if 'n_white_refs' in d else 0
prof['n_random_refs'] = int(d['n_random_refs']) if 'n_random_refs' in d else 0
self.profiles[(h, w)] = prof
# ================================================================
# CLI INTERFACE