mirror of
https://github.com/aloshdenny/reverse-SynthID.git
synced 2026-04-30 18:47:53 +02:00
fix cv2.resize scale bug
This commit is contained in:
@@ -0,0 +1,781 @@
|
||||
"""
|
||||
Robust SynthID Watermark Extractor
|
||||
|
||||
A comprehensive multi-stage watermark extraction pipeline that combines:
|
||||
1. Multi-scale analysis (256, 512, 1024 pixels)
|
||||
2. Multi-denoiser fusion (wavelet, bilateral, non-local means)
|
||||
3. Ensemble carrier frequency detection with voting
|
||||
4. ICA/PCA-based watermark separation
|
||||
5. Adaptive thresholding based on image content
|
||||
|
||||
This provides significantly more robust detection than single-scale approaches.
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from scipy.fft import fft2, ifft2, fftshift, ifftshift
|
||||
from scipy import ndimage
|
||||
from scipy.stats import pearsonr
|
||||
from collections import defaultdict
|
||||
import pywt
|
||||
import pickle
|
||||
from typing import Optional, Dict, List, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from sklearn.decomposition import PCA, FastICA
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Result of watermark detection."""
|
||||
is_watermarked: bool
|
||||
confidence: float
|
||||
correlation: float
|
||||
phase_match: float
|
||||
structure_ratio: float
|
||||
carrier_strength: float
|
||||
multi_scale_consistency: float
|
||||
details: Dict
|
||||
|
||||
|
||||
class RobustSynthIDExtractor:
|
||||
"""
|
||||
Robust SynthID watermark extractor using multi-stage analysis.
|
||||
|
||||
Features:
|
||||
- Multi-scale processing for comprehensive watermark detection
|
||||
- Multiple denoising methods with intelligent fusion
|
||||
- Ensemble carrier frequency detection
|
||||
- Adaptive thresholds based on image content
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scales: List[int] = [256, 512, 1024],
|
||||
wavelets: List[str] = ['db4', 'sym8', 'coif3'],
|
||||
n_carriers: int = 100,
|
||||
codebook_path: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the robust extractor.
|
||||
|
||||
Args:
|
||||
scales: Image scales for multi-scale analysis
|
||||
wavelets: Wavelet families for denoising
|
||||
n_carriers: Number of carrier frequencies to track
|
||||
codebook_path: Path to pre-extracted codebook
|
||||
"""
|
||||
self.scales = scales
|
||||
self.wavelets = wavelets
|
||||
self.n_carriers = n_carriers
|
||||
self.codebook = None
|
||||
|
||||
# Known SynthID carrier frequencies (from previous analysis)
|
||||
self.known_carriers = [
|
||||
(14, 14), (-14, -14),
|
||||
(126, 14), (-126, -14),
|
||||
(98, -14), (-98, 14),
|
||||
(128, 128), (-128, -128),
|
||||
(210, -14), (-210, 14),
|
||||
(238, 14), (-238, -14),
|
||||
]
|
||||
|
||||
if codebook_path and os.path.exists(codebook_path):
|
||||
self.load_codebook(codebook_path)
|
||||
|
||||
def load_codebook(self, path: str) -> None:
|
||||
"""Load pre-extracted codebook."""
|
||||
with open(path, 'rb') as f:
|
||||
self.codebook = pickle.load(f)
|
||||
|
||||
def save_codebook(self, path: str) -> None:
|
||||
"""Save extracted codebook."""
|
||||
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True)
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(self.codebook, f)
|
||||
|
||||
# ================================================================
|
||||
# DENOISING METHODS
|
||||
# ================================================================
|
||||
|
||||
def wavelet_denoise(
|
||||
self,
|
||||
channel: np.ndarray,
|
||||
wavelet: str = 'db4',
|
||||
level: int = 3
|
||||
) -> np.ndarray:
|
||||
"""Wavelet-based denoising using soft thresholding."""
|
||||
coeffs = pywt.wavedec2(channel, wavelet, level=level)
|
||||
|
||||
# Estimate noise from finest detail coefficients
|
||||
detail = coeffs[-1][0]
|
||||
sigma = np.median(np.abs(detail)) / 0.6745
|
||||
threshold = sigma * np.sqrt(2 * np.log(channel.size))
|
||||
|
||||
# Apply soft thresholding to detail coefficients
|
||||
new_coeffs = [coeffs[0]]
|
||||
for details in coeffs[1:]:
|
||||
new_details = tuple(
|
||||
pywt.threshold(d, threshold, mode='soft') for d in details
|
||||
)
|
||||
new_coeffs.append(new_details)
|
||||
|
||||
denoised = pywt.waverec2(new_coeffs, wavelet)
|
||||
return denoised[:channel.shape[0], :channel.shape[1]]
|
||||
|
||||
def bilateral_denoise(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
d: int = 9,
|
||||
sigma_color: float = 75,
|
||||
sigma_space: float = 75
|
||||
) -> np.ndarray:
|
||||
"""Bilateral filter denoising (edge-preserving)."""
|
||||
if len(image.shape) == 2:
|
||||
return cv2.bilateralFilter(image.astype(np.float32), d, sigma_color, sigma_space)
|
||||
else:
|
||||
result = np.zeros_like(image)
|
||||
for c in range(image.shape[2]):
|
||||
result[:, :, c] = cv2.bilateralFilter(
|
||||
image[:, :, c].astype(np.float32), d, sigma_color, sigma_space
|
||||
)
|
||||
return result
|
||||
|
||||
def nlm_denoise(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
h: float = 10,
|
||||
template_size: int = 7,
|
||||
search_size: int = 21
|
||||
) -> np.ndarray:
|
||||
"""Non-local means denoising."""
|
||||
img_uint8 = (image * 255).clip(0, 255).astype(np.uint8)
|
||||
|
||||
if len(image.shape) == 2:
|
||||
denoised = cv2.fastNlMeansDenoising(
|
||||
img_uint8, None, h, template_size, search_size
|
||||
)
|
||||
else:
|
||||
denoised = cv2.fastNlMeansDenoisingColored(
|
||||
img_uint8, None, h, h, template_size, search_size
|
||||
)
|
||||
|
||||
return denoised.astype(np.float32) / 255.0
|
||||
|
||||
def wiener_filter(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
noise_variance: Optional[float] = None
|
||||
) -> np.ndarray:
|
||||
"""Wiener filter for optimal noise estimation."""
|
||||
if noise_variance is None:
|
||||
# Estimate noise variance from high-frequency components
|
||||
noise_variance = np.var(image - ndimage.gaussian_filter(image, sigma=2))
|
||||
|
||||
# Simple Wiener filter in Fourier domain
|
||||
f = fft2(image)
|
||||
power = np.abs(f) ** 2
|
||||
signal_power = np.maximum(power - noise_variance, 0)
|
||||
wiener_ratio = signal_power / (signal_power + noise_variance + 1e-10)
|
||||
|
||||
denoised = np.real(ifft2(f * wiener_ratio))
|
||||
return denoised
|
||||
|
||||
# ================================================================
|
||||
# NOISE EXTRACTION
|
||||
# ================================================================
|
||||
|
||||
def extract_noise_single(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
method: str = 'wavelet',
|
||||
**kwargs
|
||||
) -> np.ndarray:
|
||||
"""Extract noise using a single denoising method."""
|
||||
img_f = image.astype(np.float32)
|
||||
if img_f.max() > 1:
|
||||
img_f = img_f / 255.0
|
||||
|
||||
if method == 'wavelet':
|
||||
wavelet = kwargs.get('wavelet', 'db4')
|
||||
if len(img_f.shape) == 2:
|
||||
denoised = self.wavelet_denoise(img_f, wavelet)
|
||||
else:
|
||||
denoised = np.zeros_like(img_f)
|
||||
for c in range(img_f.shape[2]):
|
||||
denoised[:, :, c] = self.wavelet_denoise(img_f[:, :, c], wavelet)
|
||||
|
||||
elif method == 'bilateral':
|
||||
denoised = self.bilateral_denoise(img_f)
|
||||
|
||||
elif method == 'nlm':
|
||||
denoised = self.nlm_denoise(img_f)
|
||||
|
||||
elif method == 'wiener':
|
||||
if len(img_f.shape) == 2:
|
||||
denoised = self.wiener_filter(img_f)
|
||||
else:
|
||||
denoised = np.zeros_like(img_f)
|
||||
for c in range(img_f.shape[2]):
|
||||
denoised[:, :, c] = self.wiener_filter(img_f[:, :, c])
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown denoising method: {method}")
|
||||
|
||||
return img_f - denoised
|
||||
|
||||
def extract_noise_fused(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Extract noise using multiple methods and fuse results.
|
||||
|
||||
Uses weighted median fusion for robustness against outliers.
|
||||
"""
|
||||
noises = []
|
||||
weights = []
|
||||
|
||||
# Wavelet denoising with multiple families
|
||||
for wavelet in self.wavelets:
|
||||
noise = self.extract_noise_single(image, 'wavelet', wavelet=wavelet)
|
||||
noises.append(noise)
|
||||
weights.append(1.0)
|
||||
|
||||
# Bilateral filter
|
||||
noise = self.extract_noise_single(image, 'bilateral')
|
||||
noises.append(noise)
|
||||
weights.append(0.8)
|
||||
|
||||
# Non-local means
|
||||
noise = self.extract_noise_single(image, 'nlm')
|
||||
noises.append(noise)
|
||||
weights.append(0.7)
|
||||
|
||||
# Wiener filter
|
||||
noise = self.extract_noise_single(image, 'wiener')
|
||||
noises.append(noise)
|
||||
weights.append(0.6)
|
||||
|
||||
# Weighted fusion
|
||||
noises = np.array(noises)
|
||||
weights = np.array(weights) / sum(weights)
|
||||
|
||||
# Use weighted average (more stable than weighted median)
|
||||
fused = np.tensordot(weights, noises, axes=([0], [0]))
|
||||
|
||||
return fused
|
||||
|
||||
# ================================================================
|
||||
# CARRIER FREQUENCY DETECTION
|
||||
# ================================================================
|
||||
|
||||
def find_carrier_peaks(
|
||||
self,
|
||||
magnitude: np.ndarray,
|
||||
phase_coherence: np.ndarray,
|
||||
n_peaks: int = 100
|
||||
) -> List[Tuple[int, int, float]]:
|
||||
"""Find carrier frequency peaks using combined magnitude and coherence."""
|
||||
center = magnitude.shape[0] // 2
|
||||
|
||||
# Combined score
|
||||
log_mag = np.log1p(magnitude)
|
||||
combined = log_mag * phase_coherence
|
||||
|
||||
# Find peaks (excluding DC region)
|
||||
dc_mask = np.ones_like(combined, dtype=bool)
|
||||
y_coords, x_coords = np.ogrid[:combined.shape[0], :combined.shape[1]]
|
||||
dc_mask[((y_coords - center) ** 2 + (x_coords - center) ** 2) < 25] = False
|
||||
|
||||
# Threshold and find peaks
|
||||
threshold = np.percentile(combined[dc_mask], 99)
|
||||
peak_mask = (combined > threshold) & dc_mask
|
||||
|
||||
# Get peak locations with scores
|
||||
peak_locs = np.where(peak_mask)
|
||||
peaks = []
|
||||
for y, x in zip(peak_locs[0], peak_locs[1]):
|
||||
freq_y, freq_x = y - center, x - center
|
||||
score = combined[y, x]
|
||||
peaks.append((freq_y, freq_x, score))
|
||||
|
||||
# Sort by score and return top N
|
||||
peaks.sort(key=lambda p: p[2], reverse=True)
|
||||
return peaks[:n_peaks]
|
||||
|
||||
def detect_carriers_single_scale(
|
||||
self,
|
||||
images: List[np.ndarray],
|
||||
size: int
|
||||
) -> Dict[Tuple[int, int], Dict]:
|
||||
"""Detect carriers at a single scale."""
|
||||
magnitude_sum = None
|
||||
phase_sum = None
|
||||
n_images = 0
|
||||
|
||||
for img in images:
|
||||
# Resize to target scale
|
||||
s = int(size)
|
||||
img_resized = cv2.resize(img, (s, s))
|
||||
if len(img_resized.shape) == 3:
|
||||
gray = cv2.cvtColor(img_resized, cv2.COLOR_RGB2GRAY).astype(np.float32)
|
||||
else:
|
||||
gray = img_resized.astype(np.float32)
|
||||
|
||||
# FFT
|
||||
f = fft2(gray)
|
||||
fshift = fftshift(f)
|
||||
|
||||
if magnitude_sum is None:
|
||||
magnitude_sum = np.abs(fshift)
|
||||
phase_sum = np.exp(1j * np.angle(fshift))
|
||||
else:
|
||||
magnitude_sum += np.abs(fshift)
|
||||
phase_sum += np.exp(1j * np.angle(fshift))
|
||||
|
||||
n_images += 1
|
||||
|
||||
avg_magnitude = magnitude_sum / n_images
|
||||
phase_coherence = np.abs(phase_sum) / n_images
|
||||
avg_phase = np.angle(phase_sum)
|
||||
|
||||
# Find peaks
|
||||
peaks = self.find_carrier_peaks(avg_magnitude, phase_coherence, self.n_carriers)
|
||||
|
||||
# Build carrier dictionary
|
||||
carriers = {}
|
||||
center = size // 2
|
||||
for freq_y, freq_x, score in peaks:
|
||||
y, x = freq_y + center, freq_x + center
|
||||
carriers[(freq_y, freq_x)] = {
|
||||
'position': (y, x),
|
||||
'magnitude': float(avg_magnitude[y, x]),
|
||||
'phase': float(avg_phase[y, x]),
|
||||
'coherence': float(phase_coherence[y, x]),
|
||||
'score': float(score)
|
||||
}
|
||||
|
||||
return carriers
|
||||
|
||||
def detect_carriers_multi_scale(
|
||||
self,
|
||||
images: List[np.ndarray]
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Detect carriers using multi-scale analysis with voting.
|
||||
|
||||
Carriers that appear consistently across scales are more reliable.
|
||||
Falls back to known carriers if voting doesn't find reliable ones.
|
||||
"""
|
||||
all_carriers = defaultdict(lambda: {'votes': 0, 'total_score': 0, 'scales': [], 'infos': []})
|
||||
base_scale = 512
|
||||
|
||||
for scale in self.scales:
|
||||
carriers = self.detect_carriers_single_scale(images, scale)
|
||||
|
||||
for freq, info in carriers.items():
|
||||
# Normalize frequency to base scale (512) with tolerance
|
||||
norm_freq_y = int(round(freq[0] * base_scale / scale))
|
||||
norm_freq_x = int(round(freq[1] * base_scale / scale))
|
||||
|
||||
# Use tolerance-based binning (frequencies within ±2 are considered the same)
|
||||
bin_freq = (norm_freq_y // 2 * 2, norm_freq_x // 2 * 2)
|
||||
|
||||
all_carriers[bin_freq]['votes'] += 1
|
||||
all_carriers[bin_freq]['total_score'] += info['score']
|
||||
all_carriers[bin_freq]['scales'].append(scale)
|
||||
all_carriers[bin_freq]['infos'].append(info)
|
||||
|
||||
# Filter carriers with multiple votes OR high score
|
||||
reliable_carriers = []
|
||||
for freq, info in all_carriers.items():
|
||||
# Accept if appears in 2+ scales OR has very high score
|
||||
if info['votes'] >= 2 or (info['votes'] >= 1 and info['total_score'] > 100):
|
||||
# Average the info from all scales
|
||||
avg_coherence = np.mean([i.get('coherence', 0) for i in info['infos']])
|
||||
avg_phase = np.mean([i.get('phase', 0) for i in info['infos']])
|
||||
avg_magnitude = np.mean([i.get('magnitude', 0) for i in info['infos']])
|
||||
|
||||
carrier = {
|
||||
'frequency': freq,
|
||||
'votes': info['votes'],
|
||||
'avg_score': info['total_score'] / info['votes'],
|
||||
'scales': info['scales'],
|
||||
'coherence': float(avg_coherence),
|
||||
'phase': float(avg_phase),
|
||||
'magnitude': float(avg_magnitude)
|
||||
}
|
||||
reliable_carriers.append(carrier)
|
||||
|
||||
# Sort by votes then score
|
||||
reliable_carriers.sort(key=lambda c: (c['votes'], c['avg_score']), reverse=True)
|
||||
|
||||
# FALLBACK: If no reliable carriers found, use known carriers
|
||||
if len(reliable_carriers) < 5:
|
||||
print(f" Warning: Only {len(reliable_carriers)} carriers found, using known carriers as fallback")
|
||||
# Add known carriers with default values
|
||||
for freq in self.known_carriers:
|
||||
if freq not in [c['frequency'] for c in reliable_carriers]:
|
||||
reliable_carriers.append({
|
||||
'frequency': freq,
|
||||
'votes': 0,
|
||||
'avg_score': 50,
|
||||
'scales': [],
|
||||
'coherence': 0.99,
|
||||
'phase': 0.0, # Will be computed during detection
|
||||
'magnitude': 1000
|
||||
})
|
||||
|
||||
return reliable_carriers[:self.n_carriers]
|
||||
|
||||
# ================================================================
|
||||
# ICA/PCA SEPARATION
|
||||
# ================================================================
|
||||
|
||||
def extract_watermark_ica(
|
||||
self,
|
||||
images: List[np.ndarray],
|
||||
n_components: int = 5
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Use ICA to separate watermark pattern from image content.
|
||||
|
||||
The watermark should appear as a consistent component across images.
|
||||
"""
|
||||
# Extract noise from all images
|
||||
noise_vectors = []
|
||||
target_size = 512
|
||||
|
||||
for img in images[:50]: # Limit for performance
|
||||
img_resized = cv2.resize(img, (target_size, target_size))
|
||||
noise = self.extract_noise_fused(img_resized)
|
||||
|
||||
if len(noise.shape) == 3:
|
||||
noise = np.mean(noise, axis=2)
|
||||
|
||||
noise_vectors.append(noise.flatten())
|
||||
|
||||
noise_matrix = np.array(noise_vectors)
|
||||
|
||||
# Apply ICA
|
||||
ica = FastICA(n_components=n_components, random_state=42, max_iter=500)
|
||||
try:
|
||||
sources = ica.fit_transform(noise_matrix)
|
||||
components = ica.components_
|
||||
except Exception:
|
||||
# Fall back to PCA if ICA fails to converge
|
||||
pca = PCA(n_components=n_components)
|
||||
sources = pca.fit_transform(noise_matrix)
|
||||
components = pca.components_
|
||||
|
||||
# Find the most consistent component (watermark)
|
||||
consistencies = []
|
||||
for i in range(n_components):
|
||||
component = components[i].reshape(target_size, target_size)
|
||||
# Watermark should have specific frequency structure
|
||||
f = fftshift(fft2(component))
|
||||
|
||||
# Check energy at known carrier frequencies
|
||||
center = target_size // 2
|
||||
carrier_energy = 0
|
||||
for freq_y, freq_x in self.known_carriers:
|
||||
y = freq_y + center
|
||||
x = freq_x + center
|
||||
if 0 <= y < target_size and 0 <= x < target_size:
|
||||
carrier_energy += np.abs(f[y, x])
|
||||
|
||||
consistencies.append(carrier_energy)
|
||||
|
||||
# Return the component with highest carrier energy
|
||||
best_idx = np.argmax(consistencies)
|
||||
watermark = components[best_idx].reshape(target_size, target_size)
|
||||
|
||||
return watermark
|
||||
|
||||
# ================================================================
|
||||
# CODEBOOK EXTRACTION
|
||||
# ================================================================
|
||||
|
||||
def extract_codebook(
|
||||
self,
|
||||
image_dir: str,
|
||||
max_images: int = 250,
|
||||
save_path: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Extract comprehensive codebook from watermarked images.
|
||||
|
||||
Uses multi-scale analysis and ensemble methods for robustness.
|
||||
"""
|
||||
print(f"Loading images from {image_dir}...")
|
||||
|
||||
# Load images
|
||||
extensions = {'.png', '.jpg', '.jpeg', '.webp'}
|
||||
images = []
|
||||
|
||||
for fname in sorted(os.listdir(image_dir)):
|
||||
if os.path.splitext(fname)[1].lower() in extensions:
|
||||
path = os.path.join(image_dir, fname)
|
||||
img = cv2.imread(path)
|
||||
if img is not None:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
images.append(img)
|
||||
if len(images) >= max_images:
|
||||
break
|
||||
|
||||
print(f"Loaded {len(images)} images")
|
||||
|
||||
# Multi-scale carrier detection
|
||||
print("Detecting carriers (multi-scale)...")
|
||||
carriers = self.detect_carriers_multi_scale(images)
|
||||
|
||||
# Extract reference noise pattern
|
||||
print("Extracting reference noise pattern...")
|
||||
target_size = 512
|
||||
noise_sum = np.zeros((target_size, target_size, 3), dtype=np.float64)
|
||||
|
||||
for img in images:
|
||||
img_resized = cv2.resize(img, (target_size, target_size))
|
||||
noise = self.extract_noise_fused(img_resized)
|
||||
noise_sum += noise
|
||||
|
||||
reference_noise = noise_sum / len(images)
|
||||
|
||||
# ICA-based watermark extraction
|
||||
print("Extracting watermark pattern via ICA...")
|
||||
watermark_pattern = self.extract_watermark_ica(images)
|
||||
|
||||
# Compute correlation statistics
|
||||
print("Computing correlation statistics...")
|
||||
correlations = []
|
||||
sample_images = images[:min(50, len(images))]
|
||||
|
||||
for i, img1 in enumerate(sample_images):
|
||||
for j, img2 in enumerate(sample_images):
|
||||
if i < j:
|
||||
img1_resized = cv2.resize(img1, (target_size, target_size))
|
||||
img2_resized = cv2.resize(img2, (target_size, target_size))
|
||||
|
||||
noise1 = self.extract_noise_fused(img1_resized)
|
||||
noise2 = self.extract_noise_fused(img2_resized)
|
||||
|
||||
corr = np.corrcoef(noise1.ravel(), noise2.ravel())[0, 1]
|
||||
correlations.append(corr)
|
||||
|
||||
correlation_mean = float(np.mean(correlations))
|
||||
correlation_std = float(np.std(correlations))
|
||||
detection_threshold = correlation_mean - 2.5 * correlation_std
|
||||
|
||||
# Compute FFT statistics
|
||||
print("Computing FFT reference...")
|
||||
ref_gray = np.mean(reference_noise, axis=2)
|
||||
ref_fft = fftshift(fft2(ref_gray))
|
||||
ref_magnitude = np.abs(ref_fft)
|
||||
ref_phase = np.angle(ref_fft)
|
||||
|
||||
# Build codebook
|
||||
self.codebook = {
|
||||
'version': '2.0',
|
||||
'source': 'Gemini/SynthID',
|
||||
'extractor': 'RobustSynthIDExtractor',
|
||||
'n_images_analyzed': len(images),
|
||||
'image_size': target_size,
|
||||
'scales_used': self.scales,
|
||||
|
||||
# Reference patterns
|
||||
'reference_noise': reference_noise,
|
||||
'watermark_pattern': watermark_pattern,
|
||||
'reference_magnitude': ref_magnitude,
|
||||
'reference_phase': ref_phase,
|
||||
|
||||
# Carriers
|
||||
'carriers': carriers,
|
||||
'known_carriers': self.known_carriers,
|
||||
|
||||
# Detection thresholds
|
||||
'correlation_mean': correlation_mean,
|
||||
'correlation_std': correlation_std,
|
||||
'detection_threshold': detection_threshold,
|
||||
'noise_structure_ratio': 1.32,
|
||||
}
|
||||
|
||||
if save_path:
|
||||
self.save_codebook(save_path)
|
||||
print(f"Codebook saved to {save_path}")
|
||||
|
||||
return self.codebook
|
||||
|
||||
# ================================================================
|
||||
# DETECTION
|
||||
# ================================================================
|
||||
|
||||
def detect(self, image_path: str) -> DetectionResult:
|
||||
"""Detect SynthID watermark in an image file."""
|
||||
img = cv2.imread(image_path)
|
||||
if img is None:
|
||||
raise ValueError(f"Could not load image: {image_path}")
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return self.detect_array(img)
|
||||
|
||||
def detect_array(self, image: np.ndarray) -> DetectionResult:
|
||||
"""Detect SynthID watermark in a numpy array image."""
|
||||
if self.codebook is None:
|
||||
raise ValueError("No codebook loaded. Call extract_codebook() or load_codebook() first.")
|
||||
|
||||
target_size = self.codebook['image_size']
|
||||
img_resized = cv2.resize(image, (target_size, target_size))
|
||||
|
||||
# Extract noise pattern
|
||||
noise = self.extract_noise_fused(img_resized)
|
||||
|
||||
# Method 1: Correlation with reference noise
|
||||
ref_noise = self.codebook['reference_noise']
|
||||
correlation = float(np.corrcoef(noise.ravel(), ref_noise.ravel())[0, 1])
|
||||
|
||||
# Method 2: Carrier frequency analysis using known carriers + extracted carriers
|
||||
gray = np.mean(img_resized, axis=2) if len(img_resized.shape) == 3 else img_resized
|
||||
gray = gray.astype(np.float32)
|
||||
f = fftshift(fft2(gray))
|
||||
magnitude = np.abs(f)
|
||||
phase = np.angle(f)
|
||||
|
||||
center = target_size // 2
|
||||
carrier_scores = []
|
||||
carrier_strengths = []
|
||||
|
||||
# Use extracted carriers if available, otherwise use known carriers
|
||||
carriers_to_check = self.codebook['carriers'][:30] if self.codebook['carriers'] else []
|
||||
|
||||
# Always also check known carriers for reliability
|
||||
known_carrier_dicts = [{'frequency': freq, 'phase': 0} for freq in self.codebook.get('known_carriers', self.known_carriers)]
|
||||
carriers_to_check = carriers_to_check + known_carrier_dicts
|
||||
|
||||
# Use reference phase from codebook if available
|
||||
ref_phase = self.codebook.get('reference_phase')
|
||||
|
||||
for carrier in carriers_to_check:
|
||||
freq = carrier['frequency']
|
||||
y = freq[0] + center
|
||||
x = freq[1] + center
|
||||
|
||||
if 0 <= y < target_size and 0 <= x < target_size:
|
||||
actual_phase = phase[y, x]
|
||||
|
||||
# Get expected phase from codebook reference if available
|
||||
if ref_phase is not None:
|
||||
expected_phase = ref_phase[y, x]
|
||||
else:
|
||||
expected_phase = carrier.get('phase', 0)
|
||||
|
||||
# Phase match (accounting for wrap-around)
|
||||
phase_diff = np.abs(np.angle(np.exp(1j * (actual_phase - expected_phase))))
|
||||
phase_match = 1 - phase_diff / np.pi
|
||||
carrier_scores.append(phase_match)
|
||||
|
||||
# Carrier strength
|
||||
carrier_strengths.append(magnitude[y, x])
|
||||
|
||||
avg_phase_match = float(np.mean(carrier_scores)) if carrier_scores else 0
|
||||
avg_carrier_strength = float(np.mean(carrier_strengths)) if carrier_strengths else 0
|
||||
|
||||
# Method 3: Noise structure ratio
|
||||
noise_gray = np.mean(noise, axis=2) if len(noise.shape) == 3 else noise
|
||||
structure_ratio = float(np.std(noise_gray) / (np.mean(np.abs(noise_gray)) + 1e-10))
|
||||
|
||||
# Method 4: Multi-scale consistency
|
||||
scale_scores = []
|
||||
for scale in self.scales:
|
||||
img_scaled = cv2.resize(image, (scale, scale))
|
||||
noise_scaled = self.extract_noise_single(img_scaled, 'wavelet')
|
||||
ref_scaled = cv2.resize(ref_noise, (scale, scale))
|
||||
|
||||
corr = np.corrcoef(noise_scaled.ravel(), ref_scaled.ravel())[0, 1]
|
||||
scale_scores.append(corr)
|
||||
|
||||
multi_scale_consistency = float(np.std(scale_scores)) # Lower is more consistent
|
||||
|
||||
# Detection decision
|
||||
threshold = self.codebook['detection_threshold']
|
||||
is_watermarked = (
|
||||
correlation > threshold and
|
||||
avg_phase_match > 0.45 and
|
||||
0.7 < structure_ratio < 2.0
|
||||
)
|
||||
|
||||
# Confidence score (Bayesian combination)
|
||||
corr_score = max(0, (correlation - threshold) / (self.codebook['correlation_mean'] - threshold + 1e-10))
|
||||
phase_score = avg_phase_match
|
||||
structure_score = max(0, 1 - abs(structure_ratio - 1.32) / 0.6)
|
||||
consistency_score = max(0, 1 - multi_scale_consistency * 5)
|
||||
|
||||
confidence = min(1.0, (
|
||||
0.35 * corr_score +
|
||||
0.35 * phase_score +
|
||||
0.15 * structure_score +
|
||||
0.15 * consistency_score
|
||||
))
|
||||
|
||||
return DetectionResult(
|
||||
is_watermarked=bool(is_watermarked),
|
||||
confidence=float(confidence),
|
||||
correlation=correlation,
|
||||
phase_match=avg_phase_match,
|
||||
structure_ratio=structure_ratio,
|
||||
carrier_strength=avg_carrier_strength,
|
||||
multi_scale_consistency=multi_scale_consistency,
|
||||
details={
|
||||
'threshold': threshold,
|
||||
'corr_score': corr_score,
|
||||
'phase_score': phase_score,
|
||||
'structure_score': structure_score,
|
||||
'consistency_score': consistency_score,
|
||||
'scale_correlations': scale_scores
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# CLI INTERFACE
|
||||
# ================================================================
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Robust SynthID Watermark Extractor')
|
||||
subparsers = parser.add_subparsers(dest='command', help='Commands')
|
||||
|
||||
# Extract command
|
||||
extract_parser = subparsers.add_parser('extract', help='Extract codebook from images')
|
||||
extract_parser.add_argument('image_dir', type=str, help='Directory with watermarked images')
|
||||
extract_parser.add_argument('--output', type=str, default='./robust_codebook.pkl', help='Output path')
|
||||
extract_parser.add_argument('--max-images', type=int, default=250, help='Max images to process')
|
||||
|
||||
# Detect command
|
||||
detect_parser = subparsers.add_parser('detect', help='Detect watermark in image')
|
||||
detect_parser.add_argument('image', type=str, help='Image to check')
|
||||
detect_parser.add_argument('--codebook', type=str, required=True, help='Codebook path')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
extractor = RobustSynthIDExtractor()
|
||||
|
||||
if args.command == 'extract':
|
||||
extractor.extract_codebook(args.image_dir, args.max_images, args.output)
|
||||
|
||||
elif args.command == 'detect':
|
||||
extractor.load_codebook(args.codebook)
|
||||
result = extractor.detect(args.image)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("ROBUST SYNTHID DETECTION RESULTS")
|
||||
print("=" * 50)
|
||||
print(f" Watermarked: {result.is_watermarked}")
|
||||
print(f" Confidence: {result.confidence:.4f}")
|
||||
print(f" Correlation: {result.correlation:.4f}")
|
||||
print(f" Phase Match: {result.phase_match:.4f}")
|
||||
print(f" Structure: {result.structure_ratio:.4f}")
|
||||
print(f" Carrier Str: {result.carrier_strength:.2f}")
|
||||
print(f" Multi-Scale: {result.multi_scale_consistency:.4f}")
|
||||
print("=" * 50)
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
Reference in New Issue
Block a user