mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-06-07 14:53:53 +02:00
Save model snapshot to CPU to avoid OOM on multi-GPU setups
The snapshot() deepcopy was cloning tensors on their original GPU devices, doubling VRAM usage. For a 234GB model sharded across 6 A100-80GB GPUs (~39GB each), this left no room for the copy. Now snapshot stores tensors on CPU and restore() moves them back to each parameter's current device.
This commit is contained in:
@@ -312,14 +312,27 @@ class ModelHandle:
|
||||
)
|
||||
|
||||
def snapshot(self):
|
||||
"""Save a deep copy of the model state dict so we can restore after ablation."""
|
||||
self._original_state = copy.deepcopy(self.model.state_dict())
|
||||
"""Save a copy of the model state dict so we can restore after ablation.
|
||||
|
||||
Tensors are moved to CPU to avoid doubling GPU memory usage on
|
||||
multi-GPU (device_map) setups.
|
||||
"""
|
||||
self._original_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
|
||||
|
||||
def restore(self):
|
||||
"""Restore the model to the snapshot state."""
|
||||
"""Restore the model to the snapshot state.
|
||||
|
||||
Moves CPU-saved tensors back to each parameter's current device.
|
||||
"""
|
||||
if self._original_state is None:
|
||||
raise RuntimeError("No snapshot to restore — call .snapshot() first.")
|
||||
self.model.load_state_dict(self._original_state)
|
||||
# Map each key to the device where the model currently holds it
|
||||
current_state = self.model.state_dict()
|
||||
restored = {}
|
||||
for k, v in self._original_state.items():
|
||||
target = current_state[k].device if k in current_state else None
|
||||
restored[k] = v.to(target) if target is not None else v
|
||||
self.model.load_state_dict(restored)
|
||||
|
||||
def cleanup(self):
|
||||
"""Remove temporary offload directory if one was auto-created."""
|
||||
|
||||
Reference in New Issue
Block a user