mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-29 22:47:50 +02:00
306 lines
9.7 KiB
Python
306 lines
9.7 KiB
Python
"""Unified device abstraction for CUDA, MPS (Apple Silicon), and CPU.
|
|
|
|
All device-specific queries (availability, memory, cache management) go through
|
|
this module so the rest of the codebase never calls ``torch.cuda.*`` directly.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import gc
|
|
import logging
|
|
import os
|
|
import platform
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Device detection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def is_cuda() -> bool:
|
|
"""True when at least one NVIDIA CUDA GPU is visible."""
|
|
return torch.cuda.is_available()
|
|
|
|
|
|
def is_mps() -> bool:
|
|
"""True when Apple Metal Performance Shaders backend is usable."""
|
|
return (
|
|
hasattr(torch.backends, "mps")
|
|
and torch.backends.mps.is_available()
|
|
and torch.backends.mps.is_built()
|
|
)
|
|
|
|
|
|
def is_gpu_available() -> bool:
|
|
"""True if *any* GPU backend (CUDA or MPS) is available."""
|
|
return is_cuda() or is_mps()
|
|
|
|
|
|
def get_device(preference: str = "auto") -> str:
|
|
"""Resolve a device string.
|
|
|
|
Parameters
|
|
----------
|
|
preference : str
|
|
``"auto"`` picks the best GPU, ``"cuda"``/``"mps"``/``"cpu"`` forces.
|
|
|
|
Returns
|
|
-------
|
|
str
|
|
A PyTorch device string (``"cuda"``, ``"mps"``, or ``"cpu"``).
|
|
"""
|
|
if preference == "auto":
|
|
if is_cuda():
|
|
return "cuda"
|
|
if is_mps():
|
|
return "mps"
|
|
return "cpu"
|
|
return preference
|
|
|
|
|
|
def get_device_name() -> str:
|
|
"""Human-readable name of the current accelerator."""
|
|
if is_cuda():
|
|
return torch.cuda.get_device_name(0)
|
|
if is_mps():
|
|
# Apple doesn't expose a per-chip name via MPS; use platform info.
|
|
chip = platform.processor() or "Apple Silicon"
|
|
return f"Apple {chip} (MPS)"
|
|
return "CPU"
|
|
|
|
|
|
def device_count() -> int:
|
|
"""Number of accelerator devices (GPUs or MPS slots)."""
|
|
if is_cuda():
|
|
return torch.cuda.device_count()
|
|
if is_mps():
|
|
return 1 # MPS always exposes a single unified device
|
|
return 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Memory information
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class MemoryInfo:
|
|
"""Snapshot of accelerator memory (in GB)."""
|
|
|
|
used_gb: float = 0.0
|
|
reserved_gb: float = 0.0
|
|
total_gb: float = 0.0
|
|
free_gb: float = 0.0
|
|
device_name: str = "CPU"
|
|
|
|
|
|
def _system_memory_gb() -> tuple[float, float]:
|
|
"""Return (total_gb, available_gb) of system RAM."""
|
|
try:
|
|
import psutil
|
|
vm = psutil.virtual_memory()
|
|
return vm.total / 1024 ** 3, vm.available / 1024 ** 3
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
total = os.sysconf("SC_PHYS_PAGES") * os.sysconf("SC_PAGE_SIZE") / 1024 ** 3
|
|
# Rough estimate: assume 60 % available if we can't query
|
|
return total, total * 0.6
|
|
except (AttributeError, ValueError):
|
|
return 16.0, 8.0 # conservative fallback
|
|
|
|
|
|
def get_memory_info(device_index: int = 0) -> MemoryInfo:
|
|
"""Query memory for the given accelerator (or system RAM for MPS/CPU)."""
|
|
name = get_device_name()
|
|
|
|
if is_cuda():
|
|
try:
|
|
free, total = torch.cuda.mem_get_info(device_index)
|
|
used = torch.cuda.memory_allocated(device_index)
|
|
reserved = torch.cuda.memory_reserved(device_index)
|
|
total_gb = total / 1024 ** 3
|
|
return MemoryInfo(
|
|
used_gb=used / 1024 ** 3,
|
|
reserved_gb=reserved / 1024 ** 3,
|
|
total_gb=total_gb,
|
|
free_gb=free / 1024 ** 3,
|
|
device_name=name,
|
|
)
|
|
except Exception:
|
|
props = torch.cuda.get_device_properties(device_index)
|
|
total_gb = props.total_memory / 1024 ** 3
|
|
return MemoryInfo(total_gb=total_gb, free_gb=total_gb, device_name=name)
|
|
|
|
if is_mps():
|
|
# MPS uses unified memory — report system RAM as a proxy.
|
|
total, avail = _system_memory_gb()
|
|
# Apple's unified memory is shared with the OS, so usable fraction
|
|
# is typically ~65-75 % of total.
|
|
usable = total * 0.70
|
|
return MemoryInfo(
|
|
used_gb=max(usable - avail, 0.0),
|
|
reserved_gb=0.0,
|
|
total_gb=usable,
|
|
free_gb=min(avail, usable),
|
|
device_name=name,
|
|
)
|
|
|
|
# CPU-only
|
|
total, avail = _system_memory_gb()
|
|
return MemoryInfo(total_gb=total, free_gb=avail, device_name=name)
|
|
|
|
|
|
def get_total_free_gb() -> float:
|
|
"""Sum of free memory across all accelerator devices, in GB."""
|
|
if is_cuda():
|
|
total_free = 0.0
|
|
for i in range(torch.cuda.device_count()):
|
|
try:
|
|
free, _ = torch.cuda.mem_get_info(i)
|
|
total_free += free / 1024 ** 3
|
|
except Exception:
|
|
props = torch.cuda.get_device_properties(i)
|
|
total_free += props.total_memory / 1024 ** 3
|
|
return total_free
|
|
if is_mps():
|
|
_, avail = _system_memory_gb()
|
|
return avail * 0.70 # usable fraction
|
|
return 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Cache / memory management
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def empty_cache() -> None:
|
|
"""Release cached allocations on the current accelerator."""
|
|
if is_cuda():
|
|
torch.cuda.empty_cache()
|
|
elif is_mps():
|
|
# torch.mps.empty_cache() available since PyTorch 2.1
|
|
if hasattr(torch.mps, "empty_cache"):
|
|
torch.mps.empty_cache()
|
|
|
|
|
|
def free_gpu_memory() -> None:
|
|
"""Aggressive memory cleanup: GC + accelerator cache flush."""
|
|
gc.collect()
|
|
if is_cuda():
|
|
try:
|
|
torch.cuda.empty_cache()
|
|
except Exception:
|
|
try:
|
|
torch.cuda.synchronize()
|
|
except Exception:
|
|
pass
|
|
try:
|
|
torch.cuda.reset_peak_memory_stats()
|
|
except Exception:
|
|
pass
|
|
elif is_mps():
|
|
if hasattr(torch.mps, "empty_cache"):
|
|
try:
|
|
torch.mps.empty_cache()
|
|
except Exception:
|
|
pass
|
|
if hasattr(torch.mps, "synchronize"):
|
|
try:
|
|
torch.mps.synchronize()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def set_seed_all(seed: int) -> None:
|
|
"""Set random seed on all available accelerators."""
|
|
torch.manual_seed(seed)
|
|
if is_cuda():
|
|
torch.cuda.manual_seed_all(seed)
|
|
# MPS shares the CPU random state — no separate seed call needed.
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dtype helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def default_dtype(device: str | None = None) -> torch.dtype:
|
|
"""Sensible default dtype for the given device."""
|
|
dev = device or get_device()
|
|
if dev == "cpu":
|
|
return torch.float32
|
|
return torch.float16
|
|
|
|
|
|
def supports_bfloat16(device: str | None = None) -> bool:
|
|
"""Whether *bfloat16* is natively supported on the target device."""
|
|
dev = device or get_device()
|
|
if dev.startswith("cuda"):
|
|
if is_cuda():
|
|
major, _ = torch.cuda.get_device_capability(0)
|
|
return major >= 8 # Ampere+
|
|
return False
|
|
if dev == "mps":
|
|
# MPS added bfloat16 support in PyTorch 2.3+
|
|
return hasattr(torch, "__version__") and tuple(
|
|
int(x) for x in torch.__version__.split(".")[:2]
|
|
) >= (2, 3)
|
|
return True # CPU supports bfloat16 on most modern platforms
|
|
|
|
|
|
def supports_float64(device: str | None = None) -> bool:
|
|
"""Whether *float64* is supported (MPS does NOT support it)."""
|
|
dev = device or get_device()
|
|
return dev != "mps"
|
|
|
|
|
|
def safe_svd_dtype(tensor: torch.Tensor) -> torch.dtype:
|
|
"""Return a dtype safe for SVD on the tensor's device.
|
|
|
|
MPS does not support float64, so we cap at float32.
|
|
"""
|
|
if tensor.device.type == "mps":
|
|
return torch.float32
|
|
return torch.float64 if tensor.dtype in (torch.float64, torch.float32) else torch.float32
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# OOM exception matching
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def is_oom_error(exc: BaseException) -> bool:
|
|
"""Return True if *exc* is an out-of-memory error on any backend."""
|
|
if isinstance(exc, torch.cuda.OutOfMemoryError):
|
|
return True
|
|
# MPS raises a generic RuntimeError containing "out of memory"
|
|
if isinstance(exc, RuntimeError) and "out of memory" in str(exc).lower():
|
|
return True
|
|
return False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Quantization compatibility
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def supports_bitsandbytes(device: str | None = None) -> bool:
|
|
"""BitsAndBytes requires NVIDIA CUDA — check that."""
|
|
dev = device or get_device()
|
|
return dev.startswith("cuda")
|
|
|
|
|
|
def supports_device_map_auto(device: str | None = None) -> bool:
|
|
"""Accelerate's device_map='auto' is only reliable on CUDA."""
|
|
dev = device or get_device()
|
|
return dev.startswith("cuda")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CUDA env setup (called once at import time of abliterate.py)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def configure_cuda_alloc() -> None:
|
|
"""Set expandable_segments for CUDA if available."""
|
|
if is_cuda() and "PYTORCH_CUDA_ALLOC_CONF" not in os.environ:
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|