diff --git a/obliteratus/models/loader.py b/obliteratus/models/loader.py index 5e98115..ed7af53 100644 --- a/obliteratus/models/loader.py +++ b/obliteratus/models/loader.py @@ -312,14 +312,27 @@ class ModelHandle: ) def snapshot(self): - """Save a deep copy of the model state dict so we can restore after ablation.""" - self._original_state = copy.deepcopy(self.model.state_dict()) + """Save a copy of the model state dict so we can restore after ablation. + + Tensors are moved to CPU to avoid doubling GPU memory usage on + multi-GPU (device_map) setups. + """ + self._original_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()} def restore(self): - """Restore the model to the snapshot state.""" + """Restore the model to the snapshot state. + + Moves CPU-saved tensors back to each parameter's current device. + """ if self._original_state is None: raise RuntimeError("No snapshot to restore — call .snapshot() first.") - self.model.load_state_dict(self._original_state) + # Map each key to the device where the model currently holds it + current_state = self.model.state_dict() + restored = {} + for k, v in self._original_state.items(): + target = current_state[k].device if k in current_state else None + restored[k] = v.to(target) if target is not None else v + self.model.load_state_dict(restored) def cleanup(self): """Remove temporary offload directory if one was auto-created."""