From 51f621d0a2b8719a9c6ff569ff0e5433dee28d56 Mon Sep 17 00:00:00 2001 From: Stella Biderman Date: Fri, 13 Mar 2026 17:13:50 -0400 Subject: [PATCH] Save model snapshot to CPU to avoid OOM on multi-GPU setups The snapshot() deepcopy was cloning tensors on their original GPU devices, doubling VRAM usage. For a 234GB model sharded across 6 A100-80GB GPUs (~39GB each), this left no room for the copy. Now snapshot stores tensors on CPU and restore() moves them back to each parameter's current device. --- obliteratus/models/loader.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/obliteratus/models/loader.py b/obliteratus/models/loader.py index 5e98115..ed7af53 100644 --- a/obliteratus/models/loader.py +++ b/obliteratus/models/loader.py @@ -312,14 +312,27 @@ class ModelHandle: ) def snapshot(self): - """Save a deep copy of the model state dict so we can restore after ablation.""" - self._original_state = copy.deepcopy(self.model.state_dict()) + """Save a copy of the model state dict so we can restore after ablation. + + Tensors are moved to CPU to avoid doubling GPU memory usage on + multi-GPU (device_map) setups. + """ + self._original_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()} def restore(self): - """Restore the model to the snapshot state.""" + """Restore the model to the snapshot state. + + Moves CPU-saved tensors back to each parameter's current device. + """ if self._original_state is None: raise RuntimeError("No snapshot to restore — call .snapshot() first.") - self.model.load_state_dict(self._original_state) + # Map each key to the device where the model currently holds it + current_state = self.model.state_dict() + restored = {} + for k, v in self._original_state.items(): + target = current_state[k].device if k in current_state else None + restored[k] = v.to(target) if target is not None else v + self.model.load_state_dict(restored) def cleanup(self): """Remove temporary offload directory if one was auto-created."""