Save model snapshot to CPU to avoid OOM on multi-GPU setups

The snapshot() deepcopy was cloning tensors on their original GPU
devices, doubling VRAM usage. For a 234GB model sharded across 6
A100-80GB GPUs (~39GB each), this left no room for the copy.

Now snapshot stores tensors on CPU and restore() moves them back
to each parameter's current device.
This commit is contained in:
Stella Biderman
2026-03-13 17:13:50 -04:00
parent a2bb748f1b
commit 51f621d0a2
+17 -4
View File
@@ -312,14 +312,27 @@ class ModelHandle:
)
def snapshot(self):
"""Save a deep copy of the model state dict so we can restore after ablation."""
self._original_state = copy.deepcopy(self.model.state_dict())
"""Save a copy of the model state dict so we can restore after ablation.
Tensors are moved to CPU to avoid doubling GPU memory usage on
multi-GPU (device_map) setups.
"""
self._original_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
def restore(self):
"""Restore the model to the snapshot state."""
"""Restore the model to the snapshot state.
Moves CPU-saved tensors back to each parameter's current device.
"""
if self._original_state is None:
raise RuntimeError("No snapshot to restore — call .snapshot() first.")
self.model.load_state_dict(self._original_state)
# Map each key to the device where the model currently holds it
current_state = self.model.state_dict()
restored = {}
for k, v in self._original_state.items():
target = current_state[k].device if k in current_state else None
restored[k] = v.to(target) if target is not None else v
self.model.load_state_dict(restored)
def cleanup(self):
"""Remove temporary offload directory if one was auto-created."""