mirror of
https://github.com/aloshdenny/reverse-SynthID.git
synced 2026-04-30 10:37:49 +02:00
790 lines
29 KiB
Python
790 lines
29 KiB
Python
"""
|
|
Robust SynthID Watermark Extractor
|
|
|
|
A comprehensive multi-stage watermark extraction pipeline that combines:
|
|
1. Multi-scale analysis (256, 512, 1024 pixels)
|
|
2. Multi-denoiser fusion (wavelet, bilateral, non-local means)
|
|
3. Ensemble carrier frequency detection with voting
|
|
4. ICA/PCA-based watermark separation
|
|
5. Adaptive thresholding based on image content
|
|
|
|
This provides significantly more robust detection than single-scale approaches.
|
|
"""
|
|
|
|
import os
|
|
import numpy as np
|
|
import cv2
|
|
from scipy.fft import fft2, ifft2, fftshift, ifftshift
|
|
from scipy import ndimage
|
|
from scipy.stats import pearsonr
|
|
from collections import defaultdict
|
|
import pywt
|
|
import pickle
|
|
from typing import Optional, Dict, List, Tuple, Union
|
|
from dataclasses import dataclass
|
|
from sklearn.decomposition import PCA, FastICA
|
|
|
|
|
|
@dataclass
|
|
class DetectionResult:
|
|
"""Result of watermark detection."""
|
|
is_watermarked: bool
|
|
confidence: float
|
|
correlation: float
|
|
phase_match: float
|
|
structure_ratio: float
|
|
carrier_strength: float
|
|
multi_scale_consistency: float
|
|
details: Dict
|
|
|
|
|
|
class RobustSynthIDExtractor:
|
|
"""
|
|
Robust SynthID watermark extractor using multi-stage analysis.
|
|
|
|
Features:
|
|
- Multi-scale processing for comprehensive watermark detection
|
|
- Multiple denoising methods with intelligent fusion
|
|
- Ensemble carrier frequency detection
|
|
- Adaptive thresholds based on image content
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
scales: List[int] = [256, 512, 1024],
|
|
wavelets: List[str] = ['db4', 'sym8', 'coif3'],
|
|
n_carriers: int = 100,
|
|
codebook_path: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize the robust extractor.
|
|
|
|
Args:
|
|
scales: Image scales for multi-scale analysis
|
|
wavelets: Wavelet families for denoising
|
|
n_carriers: Number of carrier frequencies to track
|
|
codebook_path: Path to pre-extracted codebook
|
|
"""
|
|
self.scales = scales
|
|
self.wavelets = wavelets
|
|
self.n_carriers = n_carriers
|
|
self.codebook = None
|
|
|
|
# SynthID carriers from 288 Gemini ref images (48,88 grid at 512px)
|
|
self.known_carriers = [
|
|
(48, 0), (-48, 0),
|
|
(96, 0), (-96, 0),
|
|
(192, 0), (-192, 0),
|
|
(210, 0), (-210, 0),
|
|
(238, 0), (-238, 0),
|
|
(0, 88), (0, -88),
|
|
(0, 176), (0, -176),
|
|
(0, 192), (0, -192),
|
|
(48, 88), (-48, -88),
|
|
(48, -88), (-48, 88),
|
|
(96, 88), (-96, -88),
|
|
(96, -88), (-96, 88),
|
|
(96, 176), (-96, -176),
|
|
(96, -176), (-96, 176),
|
|
]
|
|
|
|
if codebook_path and os.path.exists(codebook_path):
|
|
self.load_codebook(codebook_path)
|
|
|
|
def load_codebook(self, path: str) -> None:
|
|
"""Load pre-extracted codebook."""
|
|
with open(path, 'rb') as f:
|
|
self.codebook = pickle.load(f)
|
|
|
|
def save_codebook(self, path: str) -> None:
|
|
"""Save extracted codebook."""
|
|
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True)
|
|
with open(path, 'wb') as f:
|
|
pickle.dump(self.codebook, f)
|
|
|
|
# ================================================================
|
|
# DENOISING METHODS
|
|
# ================================================================
|
|
|
|
def wavelet_denoise(
|
|
self,
|
|
channel: np.ndarray,
|
|
wavelet: str = 'db4',
|
|
level: int = 3
|
|
) -> np.ndarray:
|
|
"""Wavelet-based denoising using soft thresholding."""
|
|
coeffs = pywt.wavedec2(channel, wavelet, level=level)
|
|
|
|
# Estimate noise from finest detail coefficients
|
|
detail = coeffs[-1][0]
|
|
sigma = np.median(np.abs(detail)) / 0.6745
|
|
threshold = sigma * np.sqrt(2 * np.log(channel.size))
|
|
|
|
# Apply soft thresholding to detail coefficients
|
|
new_coeffs = [coeffs[0]]
|
|
for details in coeffs[1:]:
|
|
new_details = tuple(
|
|
pywt.threshold(d, threshold, mode='soft') for d in details
|
|
)
|
|
new_coeffs.append(new_details)
|
|
|
|
denoised = pywt.waverec2(new_coeffs, wavelet)
|
|
return denoised[:channel.shape[0], :channel.shape[1]]
|
|
|
|
def bilateral_denoise(
|
|
self,
|
|
image: np.ndarray,
|
|
d: int = 9,
|
|
sigma_color: float = 75,
|
|
sigma_space: float = 75
|
|
) -> np.ndarray:
|
|
"""Bilateral filter denoising (edge-preserving)."""
|
|
if len(image.shape) == 2:
|
|
return cv2.bilateralFilter(image.astype(np.float32), d, sigma_color, sigma_space)
|
|
else:
|
|
result = np.zeros_like(image)
|
|
for c in range(image.shape[2]):
|
|
result[:, :, c] = cv2.bilateralFilter(
|
|
image[:, :, c].astype(np.float32), d, sigma_color, sigma_space
|
|
)
|
|
return result
|
|
|
|
def nlm_denoise(
|
|
self,
|
|
image: np.ndarray,
|
|
h: float = 10,
|
|
template_size: int = 7,
|
|
search_size: int = 21
|
|
) -> np.ndarray:
|
|
"""Non-local means denoising."""
|
|
img_uint8 = (image * 255).clip(0, 255).astype(np.uint8)
|
|
|
|
if len(image.shape) == 2:
|
|
denoised = cv2.fastNlMeansDenoising(
|
|
img_uint8, None, h, template_size, search_size
|
|
)
|
|
else:
|
|
denoised = cv2.fastNlMeansDenoisingColored(
|
|
img_uint8, None, h, h, template_size, search_size
|
|
)
|
|
|
|
return denoised.astype(np.float32) / 255.0
|
|
|
|
def wiener_filter(
|
|
self,
|
|
image: np.ndarray,
|
|
noise_variance: Optional[float] = None
|
|
) -> np.ndarray:
|
|
"""Wiener filter for optimal noise estimation."""
|
|
if noise_variance is None:
|
|
# Estimate noise variance from high-frequency components
|
|
noise_variance = np.var(image - ndimage.gaussian_filter(image, sigma=2))
|
|
|
|
# Simple Wiener filter in Fourier domain
|
|
f = fft2(image)
|
|
power = np.abs(f) ** 2
|
|
signal_power = np.maximum(power - noise_variance, 0)
|
|
wiener_ratio = signal_power / (signal_power + noise_variance + 1e-10)
|
|
|
|
denoised = np.real(ifft2(f * wiener_ratio))
|
|
return denoised
|
|
|
|
# ================================================================
|
|
# NOISE EXTRACTION
|
|
# ================================================================
|
|
|
|
def extract_noise_single(
|
|
self,
|
|
image: np.ndarray,
|
|
method: str = 'wavelet',
|
|
**kwargs
|
|
) -> np.ndarray:
|
|
"""Extract noise using a single denoising method."""
|
|
img_f = image.astype(np.float32)
|
|
if img_f.max() > 1:
|
|
img_f = img_f / 255.0
|
|
|
|
if method == 'wavelet':
|
|
wavelet = kwargs.get('wavelet', 'db4')
|
|
if len(img_f.shape) == 2:
|
|
denoised = self.wavelet_denoise(img_f, wavelet)
|
|
else:
|
|
denoised = np.zeros_like(img_f)
|
|
for c in range(img_f.shape[2]):
|
|
denoised[:, :, c] = self.wavelet_denoise(img_f[:, :, c], wavelet)
|
|
|
|
elif method == 'bilateral':
|
|
denoised = self.bilateral_denoise(img_f)
|
|
|
|
elif method == 'nlm':
|
|
denoised = self.nlm_denoise(img_f)
|
|
|
|
elif method == 'wiener':
|
|
if len(img_f.shape) == 2:
|
|
denoised = self.wiener_filter(img_f)
|
|
else:
|
|
denoised = np.zeros_like(img_f)
|
|
for c in range(img_f.shape[2]):
|
|
denoised[:, :, c] = self.wiener_filter(img_f[:, :, c])
|
|
|
|
else:
|
|
raise ValueError(f"Unknown denoising method: {method}")
|
|
|
|
return img_f - denoised
|
|
|
|
def extract_noise_fused(self, image: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Extract noise using multiple methods and fuse results.
|
|
|
|
Uses weighted median fusion for robustness against outliers.
|
|
"""
|
|
noises = []
|
|
weights = []
|
|
|
|
# Wavelet denoising with multiple families
|
|
for wavelet in self.wavelets:
|
|
noise = self.extract_noise_single(image, 'wavelet', wavelet=wavelet)
|
|
noises.append(noise)
|
|
weights.append(1.0)
|
|
|
|
# Bilateral filter
|
|
noise = self.extract_noise_single(image, 'bilateral')
|
|
noises.append(noise)
|
|
weights.append(0.8)
|
|
|
|
# Non-local means
|
|
noise = self.extract_noise_single(image, 'nlm')
|
|
noises.append(noise)
|
|
weights.append(0.7)
|
|
|
|
# Wiener filter
|
|
noise = self.extract_noise_single(image, 'wiener')
|
|
noises.append(noise)
|
|
weights.append(0.6)
|
|
|
|
# Weighted fusion
|
|
noises = np.array(noises)
|
|
weights = np.array(weights) / sum(weights)
|
|
|
|
# Use weighted average (more stable than weighted median)
|
|
fused = np.tensordot(weights, noises, axes=([0], [0]))
|
|
|
|
return fused
|
|
|
|
# ================================================================
|
|
# CARRIER FREQUENCY DETECTION
|
|
# ================================================================
|
|
|
|
def find_carrier_peaks(
|
|
self,
|
|
magnitude: np.ndarray,
|
|
phase_coherence: np.ndarray,
|
|
n_peaks: int = 100
|
|
) -> List[Tuple[int, int, float]]:
|
|
"""Find carrier frequency peaks using combined magnitude and coherence."""
|
|
center = magnitude.shape[0] // 2
|
|
|
|
# Combined score
|
|
log_mag = np.log1p(magnitude)
|
|
combined = log_mag * phase_coherence
|
|
|
|
# Find peaks (excluding DC region)
|
|
dc_mask = np.ones_like(combined, dtype=bool)
|
|
y_coords, x_coords = np.ogrid[:combined.shape[0], :combined.shape[1]]
|
|
dc_mask[((y_coords - center) ** 2 + (x_coords - center) ** 2) < 25] = False
|
|
|
|
# Threshold and find peaks
|
|
threshold = np.percentile(combined[dc_mask], 99)
|
|
peak_mask = (combined > threshold) & dc_mask
|
|
|
|
# Get peak locations with scores
|
|
peak_locs = np.where(peak_mask)
|
|
peaks = []
|
|
for y, x in zip(peak_locs[0], peak_locs[1]):
|
|
freq_y, freq_x = y - center, x - center
|
|
score = combined[y, x]
|
|
peaks.append((freq_y, freq_x, score))
|
|
|
|
# Sort by score and return top N
|
|
peaks.sort(key=lambda p: p[2], reverse=True)
|
|
return peaks[:n_peaks]
|
|
|
|
def detect_carriers_single_scale(
|
|
self,
|
|
images: List[np.ndarray],
|
|
size: int
|
|
) -> Dict[Tuple[int, int], Dict]:
|
|
"""Detect carriers at a single scale."""
|
|
magnitude_sum = None
|
|
phase_sum = None
|
|
n_images = 0
|
|
|
|
for img in images:
|
|
# Resize to target scale
|
|
s = int(size)
|
|
img_resized = cv2.resize(img, (s, s))
|
|
if len(img_resized.shape) == 3:
|
|
gray = cv2.cvtColor(img_resized, cv2.COLOR_RGB2GRAY).astype(np.float32)
|
|
else:
|
|
gray = img_resized.astype(np.float32)
|
|
|
|
# FFT
|
|
f = fft2(gray)
|
|
fshift = fftshift(f)
|
|
|
|
if magnitude_sum is None:
|
|
magnitude_sum = np.abs(fshift)
|
|
phase_sum = np.exp(1j * np.angle(fshift))
|
|
else:
|
|
magnitude_sum += np.abs(fshift)
|
|
phase_sum += np.exp(1j * np.angle(fshift))
|
|
|
|
n_images += 1
|
|
|
|
avg_magnitude = magnitude_sum / n_images
|
|
phase_coherence = np.abs(phase_sum) / n_images
|
|
avg_phase = np.angle(phase_sum)
|
|
|
|
# Find peaks
|
|
peaks = self.find_carrier_peaks(avg_magnitude, phase_coherence, self.n_carriers)
|
|
|
|
# Build carrier dictionary
|
|
carriers = {}
|
|
center = size // 2
|
|
for freq_y, freq_x, score in peaks:
|
|
y, x = freq_y + center, freq_x + center
|
|
carriers[(freq_y, freq_x)] = {
|
|
'position': (y, x),
|
|
'magnitude': float(avg_magnitude[y, x]),
|
|
'phase': float(avg_phase[y, x]),
|
|
'coherence': float(phase_coherence[y, x]),
|
|
'score': float(score)
|
|
}
|
|
|
|
return carriers
|
|
|
|
def detect_carriers_multi_scale(
|
|
self,
|
|
images: List[np.ndarray]
|
|
) -> List[Dict]:
|
|
"""
|
|
Detect carriers using multi-scale analysis with voting.
|
|
|
|
Carriers that appear consistently across scales are more reliable.
|
|
Falls back to known carriers if voting doesn't find reliable ones.
|
|
"""
|
|
all_carriers = defaultdict(lambda: {'votes': 0, 'total_score': 0, 'scales': [], 'infos': []})
|
|
base_scale = 512
|
|
|
|
for scale in self.scales:
|
|
carriers = self.detect_carriers_single_scale(images, scale)
|
|
|
|
for freq, info in carriers.items():
|
|
# Normalize frequency to base scale (512) with tolerance
|
|
norm_freq_y = int(round(freq[0] * base_scale / scale))
|
|
norm_freq_x = int(round(freq[1] * base_scale / scale))
|
|
|
|
# Use tolerance-based binning (frequencies within ±2 are considered the same)
|
|
bin_freq = (norm_freq_y // 2 * 2, norm_freq_x // 2 * 2)
|
|
|
|
all_carriers[bin_freq]['votes'] += 1
|
|
all_carriers[bin_freq]['total_score'] += info['score']
|
|
all_carriers[bin_freq]['scales'].append(scale)
|
|
all_carriers[bin_freq]['infos'].append(info)
|
|
|
|
# Filter carriers with multiple votes OR high score
|
|
reliable_carriers = []
|
|
for freq, info in all_carriers.items():
|
|
# Accept if appears in 2+ scales OR has very high score
|
|
if info['votes'] >= 2 or (info['votes'] >= 1 and info['total_score'] > 100):
|
|
# Average the info from all scales
|
|
avg_coherence = np.mean([i.get('coherence', 0) for i in info['infos']])
|
|
avg_phase = np.mean([i.get('phase', 0) for i in info['infos']])
|
|
avg_magnitude = np.mean([i.get('magnitude', 0) for i in info['infos']])
|
|
|
|
carrier = {
|
|
'frequency': freq,
|
|
'votes': info['votes'],
|
|
'avg_score': info['total_score'] / info['votes'],
|
|
'scales': info['scales'],
|
|
'coherence': float(avg_coherence),
|
|
'phase': float(avg_phase),
|
|
'magnitude': float(avg_magnitude)
|
|
}
|
|
reliable_carriers.append(carrier)
|
|
|
|
# Sort by votes then score
|
|
reliable_carriers.sort(key=lambda c: (c['votes'], c['avg_score']), reverse=True)
|
|
|
|
# FALLBACK: If no reliable carriers found, use known carriers
|
|
if len(reliable_carriers) < 5:
|
|
print(f" Warning: Only {len(reliable_carriers)} carriers found, using known carriers as fallback")
|
|
# Add known carriers with default values
|
|
for freq in self.known_carriers:
|
|
if freq not in [c['frequency'] for c in reliable_carriers]:
|
|
reliable_carriers.append({
|
|
'frequency': freq,
|
|
'votes': 0,
|
|
'avg_score': 50,
|
|
'scales': [],
|
|
'coherence': 0.99,
|
|
'phase': 0.0, # Will be computed during detection
|
|
'magnitude': 1000
|
|
})
|
|
|
|
return reliable_carriers[:self.n_carriers]
|
|
|
|
# ================================================================
|
|
# ICA/PCA SEPARATION
|
|
# ================================================================
|
|
|
|
def extract_watermark_ica(
|
|
self,
|
|
images: List[np.ndarray],
|
|
n_components: int = 5
|
|
) -> np.ndarray:
|
|
"""
|
|
Use ICA to separate watermark pattern from image content.
|
|
|
|
The watermark should appear as a consistent component across images.
|
|
"""
|
|
# Extract noise from all images
|
|
noise_vectors = []
|
|
target_size = 512
|
|
|
|
for img in images[:50]: # Limit for performance
|
|
img_resized = cv2.resize(img, (target_size, target_size))
|
|
noise = self.extract_noise_fused(img_resized)
|
|
|
|
if len(noise.shape) == 3:
|
|
noise = np.mean(noise, axis=2)
|
|
|
|
noise_vectors.append(noise.flatten())
|
|
|
|
noise_matrix = np.array(noise_vectors)
|
|
|
|
# Apply ICA
|
|
ica = FastICA(n_components=n_components, random_state=42, max_iter=500)
|
|
try:
|
|
sources = ica.fit_transform(noise_matrix)
|
|
components = ica.components_
|
|
except Exception:
|
|
# Fall back to PCA if ICA fails to converge
|
|
pca = PCA(n_components=n_components)
|
|
sources = pca.fit_transform(noise_matrix)
|
|
components = pca.components_
|
|
|
|
# Find the most consistent component (watermark)
|
|
consistencies = []
|
|
for i in range(n_components):
|
|
component = components[i].reshape(target_size, target_size)
|
|
# Watermark should have specific frequency structure
|
|
f = fftshift(fft2(component))
|
|
|
|
# Check energy at known carrier frequencies
|
|
center = target_size // 2
|
|
carrier_energy = 0
|
|
for freq_y, freq_x in self.known_carriers:
|
|
y = freq_y + center
|
|
x = freq_x + center
|
|
if 0 <= y < target_size and 0 <= x < target_size:
|
|
carrier_energy += np.abs(f[y, x])
|
|
|
|
consistencies.append(carrier_energy)
|
|
|
|
# Return the component with highest carrier energy
|
|
best_idx = np.argmax(consistencies)
|
|
watermark = components[best_idx].reshape(target_size, target_size)
|
|
|
|
return watermark
|
|
|
|
# ================================================================
|
|
# CODEBOOK EXTRACTION
|
|
# ================================================================
|
|
|
|
def extract_codebook(
|
|
self,
|
|
image_dir: str,
|
|
max_images: int = 250,
|
|
save_path: Optional[str] = None
|
|
) -> Dict:
|
|
"""
|
|
Extract comprehensive codebook from watermarked images.
|
|
|
|
Uses multi-scale analysis and ensemble methods for robustness.
|
|
"""
|
|
print(f"Loading images from {image_dir}...")
|
|
|
|
# Load images
|
|
extensions = {'.png', '.jpg', '.jpeg', '.webp'}
|
|
images = []
|
|
|
|
for fname in sorted(os.listdir(image_dir)):
|
|
if os.path.splitext(fname)[1].lower() in extensions:
|
|
path = os.path.join(image_dir, fname)
|
|
img = cv2.imread(path)
|
|
if img is not None:
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
images.append(img)
|
|
if len(images) >= max_images:
|
|
break
|
|
|
|
print(f"Loaded {len(images)} images")
|
|
|
|
# Multi-scale carrier detection
|
|
print("Detecting carriers (multi-scale)...")
|
|
carriers = self.detect_carriers_multi_scale(images)
|
|
|
|
# Extract reference noise pattern
|
|
print("Extracting reference noise pattern...")
|
|
target_size = 512
|
|
noise_sum = np.zeros((target_size, target_size, 3), dtype=np.float64)
|
|
|
|
for img in images:
|
|
img_resized = cv2.resize(img, (target_size, target_size))
|
|
noise = self.extract_noise_fused(img_resized)
|
|
noise_sum += noise
|
|
|
|
reference_noise = noise_sum / len(images)
|
|
|
|
# ICA-based watermark extraction
|
|
print("Extracting watermark pattern via ICA...")
|
|
watermark_pattern = self.extract_watermark_ica(images)
|
|
|
|
# Compute correlation statistics
|
|
print("Computing correlation statistics...")
|
|
correlations = []
|
|
sample_images = images[:min(50, len(images))]
|
|
|
|
for i, img1 in enumerate(sample_images):
|
|
for j, img2 in enumerate(sample_images):
|
|
if i < j:
|
|
img1_resized = cv2.resize(img1, (target_size, target_size))
|
|
img2_resized = cv2.resize(img2, (target_size, target_size))
|
|
|
|
noise1 = self.extract_noise_fused(img1_resized)
|
|
noise2 = self.extract_noise_fused(img2_resized)
|
|
|
|
corr = np.corrcoef(noise1.ravel(), noise2.ravel())[0, 1]
|
|
correlations.append(corr)
|
|
|
|
correlation_mean = float(np.mean(correlations))
|
|
correlation_std = float(np.std(correlations))
|
|
detection_threshold = correlation_mean - 2.5 * correlation_std
|
|
|
|
# Compute FFT statistics
|
|
print("Computing FFT reference...")
|
|
ref_gray = np.mean(reference_noise, axis=2)
|
|
ref_fft = fftshift(fft2(ref_gray))
|
|
ref_magnitude = np.abs(ref_fft)
|
|
ref_phase = np.angle(ref_fft)
|
|
|
|
# Build codebook
|
|
self.codebook = {
|
|
'version': '2.0',
|
|
'source': 'Gemini/SynthID',
|
|
'extractor': 'RobustSynthIDExtractor',
|
|
'n_images_analyzed': len(images),
|
|
'image_size': target_size,
|
|
'scales_used': self.scales,
|
|
|
|
# Reference patterns
|
|
'reference_noise': reference_noise,
|
|
'watermark_pattern': watermark_pattern,
|
|
'reference_magnitude': ref_magnitude,
|
|
'reference_phase': ref_phase,
|
|
|
|
# Carriers
|
|
'carriers': carriers,
|
|
'known_carriers': self.known_carriers,
|
|
|
|
# Detection thresholds
|
|
'correlation_mean': correlation_mean,
|
|
'correlation_std': correlation_std,
|
|
'detection_threshold': detection_threshold,
|
|
'noise_structure_ratio': 1.32,
|
|
}
|
|
|
|
if save_path:
|
|
self.save_codebook(save_path)
|
|
print(f"Codebook saved to {save_path}")
|
|
|
|
return self.codebook
|
|
|
|
# ================================================================
|
|
# DETECTION
|
|
# ================================================================
|
|
|
|
def detect(self, image_path: str) -> DetectionResult:
|
|
"""Detect SynthID watermark in an image file."""
|
|
img = cv2.imread(image_path)
|
|
if img is None:
|
|
raise ValueError(f"Could not load image: {image_path}")
|
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
return self.detect_array(img)
|
|
|
|
def detect_array(self, image: np.ndarray) -> DetectionResult:
|
|
"""Detect SynthID watermark in a numpy array image."""
|
|
if self.codebook is None:
|
|
raise ValueError("No codebook loaded. Call extract_codebook() or load_codebook() first.")
|
|
|
|
target_size = self.codebook['image_size']
|
|
img_resized = cv2.resize(image, (target_size, target_size))
|
|
|
|
# Extract noise pattern
|
|
noise = self.extract_noise_fused(img_resized)
|
|
|
|
# Method 1: Correlation with reference noise
|
|
ref_noise = self.codebook['reference_noise']
|
|
correlation = float(np.corrcoef(noise.ravel(), ref_noise.ravel())[0, 1])
|
|
|
|
# Method 2: Carrier frequency analysis using known carriers + extracted carriers
|
|
gray = np.mean(img_resized, axis=2) if len(img_resized.shape) == 3 else img_resized
|
|
gray = gray.astype(np.float32)
|
|
f = fftshift(fft2(gray))
|
|
magnitude = np.abs(f)
|
|
phase = np.angle(f)
|
|
|
|
center = target_size // 2
|
|
carrier_scores = []
|
|
carrier_strengths = []
|
|
|
|
# Use extracted carriers if available, otherwise use known carriers
|
|
carriers_to_check = self.codebook['carriers'][:30] if self.codebook['carriers'] else []
|
|
|
|
# Always also check known carriers for reliability
|
|
known_carrier_dicts = [{'frequency': freq, 'phase': 0} for freq in self.codebook.get('known_carriers', self.known_carriers)]
|
|
carriers_to_check = carriers_to_check + known_carrier_dicts
|
|
|
|
# Use reference phase from codebook if available
|
|
ref_phase = self.codebook.get('reference_phase')
|
|
|
|
for carrier in carriers_to_check:
|
|
freq = carrier['frequency']
|
|
y = freq[0] + center
|
|
x = freq[1] + center
|
|
|
|
if 0 <= y < target_size and 0 <= x < target_size:
|
|
actual_phase = phase[y, x]
|
|
|
|
# Get expected phase from codebook reference if available
|
|
if ref_phase is not None:
|
|
expected_phase = ref_phase[y, x]
|
|
else:
|
|
expected_phase = carrier.get('phase', 0)
|
|
|
|
# Phase match (accounting for wrap-around)
|
|
phase_diff = np.abs(np.angle(np.exp(1j * (actual_phase - expected_phase))))
|
|
phase_match = 1 - phase_diff / np.pi
|
|
carrier_scores.append(phase_match)
|
|
|
|
# Carrier strength
|
|
carrier_strengths.append(magnitude[y, x])
|
|
|
|
avg_phase_match = float(np.mean(carrier_scores)) if carrier_scores else 0
|
|
avg_carrier_strength = float(np.mean(carrier_strengths)) if carrier_strengths else 0
|
|
|
|
# Method 3: Noise structure ratio
|
|
noise_gray = np.mean(noise, axis=2) if len(noise.shape) == 3 else noise
|
|
structure_ratio = float(np.std(noise_gray) / (np.mean(np.abs(noise_gray)) + 1e-10))
|
|
|
|
# Method 4: Multi-scale consistency
|
|
scale_scores = []
|
|
for scale in self.scales:
|
|
img_scaled = cv2.resize(image, (scale, scale))
|
|
noise_scaled = self.extract_noise_single(img_scaled, 'wavelet')
|
|
ref_scaled = cv2.resize(ref_noise, (scale, scale))
|
|
|
|
corr = np.corrcoef(noise_scaled.ravel(), ref_scaled.ravel())[0, 1]
|
|
scale_scores.append(corr)
|
|
|
|
multi_scale_consistency = float(np.std(scale_scores)) # Lower is more consistent
|
|
|
|
# Detection decision
|
|
threshold = self.codebook['detection_threshold']
|
|
is_watermarked = (
|
|
correlation > threshold and
|
|
avg_phase_match > 0.45 and
|
|
0.7 < structure_ratio < 2.0
|
|
)
|
|
|
|
# Confidence score (Bayesian combination)
|
|
corr_score = max(0, (correlation - threshold) / (self.codebook['correlation_mean'] - threshold + 1e-10))
|
|
phase_score = avg_phase_match
|
|
structure_score = max(0, 1 - abs(structure_ratio - 1.32) / 0.6)
|
|
consistency_score = max(0, 1 - multi_scale_consistency * 5)
|
|
|
|
confidence = min(1.0, (
|
|
0.35 * corr_score +
|
|
0.35 * phase_score +
|
|
0.15 * structure_score +
|
|
0.15 * consistency_score
|
|
))
|
|
|
|
return DetectionResult(
|
|
is_watermarked=bool(is_watermarked),
|
|
confidence=float(confidence),
|
|
correlation=correlation,
|
|
phase_match=avg_phase_match,
|
|
structure_ratio=structure_ratio,
|
|
carrier_strength=avg_carrier_strength,
|
|
multi_scale_consistency=multi_scale_consistency,
|
|
details={
|
|
'threshold': threshold,
|
|
'corr_score': corr_score,
|
|
'phase_score': phase_score,
|
|
'structure_score': structure_score,
|
|
'consistency_score': consistency_score,
|
|
'scale_correlations': scale_scores
|
|
}
|
|
)
|
|
|
|
|
|
# ================================================================
|
|
# CLI INTERFACE
|
|
# ================================================================
|
|
|
|
if __name__ == '__main__':
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='Robust SynthID Watermark Extractor')
|
|
subparsers = parser.add_subparsers(dest='command', help='Commands')
|
|
|
|
# Extract command
|
|
extract_parser = subparsers.add_parser('extract', help='Extract codebook from images')
|
|
extract_parser.add_argument('image_dir', type=str, help='Directory with watermarked images')
|
|
extract_parser.add_argument('--output', type=str, default='./robust_codebook.pkl', help='Output path')
|
|
extract_parser.add_argument('--max-images', type=int, default=250, help='Max images to process')
|
|
|
|
# Detect command
|
|
detect_parser = subparsers.add_parser('detect', help='Detect watermark in image')
|
|
detect_parser.add_argument('image', type=str, help='Image to check')
|
|
detect_parser.add_argument('--codebook', type=str, required=True, help='Codebook path')
|
|
|
|
args = parser.parse_args()
|
|
|
|
extractor = RobustSynthIDExtractor()
|
|
|
|
if args.command == 'extract':
|
|
extractor.extract_codebook(args.image_dir, args.max_images, args.output)
|
|
|
|
elif args.command == 'detect':
|
|
extractor.load_codebook(args.codebook)
|
|
result = extractor.detect(args.image)
|
|
|
|
print("\n" + "=" * 50)
|
|
print("ROBUST SYNTHID DETECTION RESULTS")
|
|
print("=" * 50)
|
|
print(f" Watermarked: {result.is_watermarked}")
|
|
print(f" Confidence: {result.confidence:.4f}")
|
|
print(f" Correlation: {result.correlation:.4f}")
|
|
print(f" Phase Match: {result.phase_match:.4f}")
|
|
print(f" Structure: {result.structure_ratio:.4f}")
|
|
print(f" Carrier Str: {result.carrier_strength:.2f}")
|
|
print(f" Multi-Scale: {result.multi_scale_consistency:.4f}")
|
|
print("=" * 50)
|
|
|
|
else:
|
|
parser.print_help()
|