diff --git a/src/extraction/robust_extractor.py b/src/extraction/robust_extractor.py index 9119ece..cf06704 100644 --- a/src/extraction/robust_extractor.py +++ b/src/extraction/robust_extractor.py @@ -92,7 +92,24 @@ class RobustSynthIDExtractor: self.load_codebook(codebook_path) def load_codebook(self, path: str) -> None: - """Load pre-extracted codebook.""" + """Load pre-extracted codebook. Supports .pkl files and auto-discovers + .pkl when given a .npz path (common user mistake).""" + if path.endswith('.npz'): + pkl_candidates = [ + os.path.join(os.path.dirname(path), 'codebook', 'robust_codebook.pkl'), + os.path.join(os.path.dirname(path), 'robust_codebook.pkl'), + path.replace('.npz', '.pkl'), + ] + for pkl_path in pkl_candidates: + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as f: + self.codebook = pickle.load(f) + return + raise FileNotFoundError( + f"Cannot load .npz as pickle codebook. " + f"Provide a .pkl file instead, e.g.: " + f"--detector artifacts/codebook/robust_codebook.pkl" + ) with open(path, 'rb') as f: self.codebook = pickle.load(f) @@ -100,7 +117,7 @@ class RobustSynthIDExtractor: """Save extracted codebook.""" os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True) with open(path, 'wb') as f: - pickle.dump(self.codebook, f) + pickle.dump(self.codebook, f, protocol=4) # ================================================================ # DENOISING METHODS diff --git a/src/extraction/synthid_codebook_extractor.py b/src/extraction/synthid_codebook_extractor.py index ae190fa..d88f146 100644 --- a/src/extraction/synthid_codebook_extractor.py +++ b/src/extraction/synthid_codebook_extractor.py @@ -201,9 +201,9 @@ def extract_codebook(image_dir, output_path, max_images=250, size=512): # Save codebook os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) - # Save as pickle (includes numpy arrays) + # Save as pickle (includes numpy arrays, protocol=4 for compatibility) with open(output_path, 'wb') as f: - pickle.dump(codebook, f) + pickle.dump(codebook, f, protocol=4) # Save metadata as JSON json_path = output_path.replace('.pkl', '_meta.json')