mirror of
https://github.com/aloshdenny/reverse-SynthID.git
synced 2026-04-30 18:47:53 +02:00
Merge pull request #18 from mrbeandev/fix/codebook-loading-npz-compat
Fix codebook loading error when .npz path is passed to extractor
This commit is contained in:
@@ -92,7 +92,24 @@ class RobustSynthIDExtractor:
|
||||
self.load_codebook(codebook_path)
|
||||
|
||||
def load_codebook(self, path: str) -> None:
|
||||
"""Load pre-extracted codebook."""
|
||||
"""Load pre-extracted codebook. Supports .pkl files and auto-discovers
|
||||
.pkl when given a .npz path (common user mistake)."""
|
||||
if path.endswith('.npz'):
|
||||
pkl_candidates = [
|
||||
os.path.join(os.path.dirname(path), 'codebook', 'robust_codebook.pkl'),
|
||||
os.path.join(os.path.dirname(path), 'robust_codebook.pkl'),
|
||||
path.replace('.npz', '.pkl'),
|
||||
]
|
||||
for pkl_path in pkl_candidates:
|
||||
if os.path.exists(pkl_path):
|
||||
with open(pkl_path, 'rb') as f:
|
||||
self.codebook = pickle.load(f)
|
||||
return
|
||||
raise FileNotFoundError(
|
||||
f"Cannot load .npz as pickle codebook. "
|
||||
f"Provide a .pkl file instead, e.g.: "
|
||||
f"--detector artifacts/codebook/robust_codebook.pkl"
|
||||
)
|
||||
with open(path, 'rb') as f:
|
||||
self.codebook = pickle.load(f)
|
||||
|
||||
@@ -100,7 +117,7 @@ class RobustSynthIDExtractor:
|
||||
"""Save extracted codebook."""
|
||||
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True)
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(self.codebook, f)
|
||||
pickle.dump(self.codebook, f, protocol=4)
|
||||
|
||||
# ================================================================
|
||||
# DENOISING METHODS
|
||||
|
||||
@@ -201,9 +201,9 @@ def extract_codebook(image_dir, output_path, max_images=250, size=512):
|
||||
# Save codebook
|
||||
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
|
||||
|
||||
# Save as pickle (includes numpy arrays)
|
||||
# Save as pickle (includes numpy arrays, protocol=4 for compatibility)
|
||||
with open(output_path, 'wb') as f:
|
||||
pickle.dump(codebook, f)
|
||||
pickle.dump(codebook, f, protocol=4)
|
||||
|
||||
# Save metadata as JSON
|
||||
json_path = output_path.replace('.pkl', '_meta.json')
|
||||
|
||||
Reference in New Issue
Block a user