Files
2026-03-07 17:53:42 -08:00

306 lines
9.7 KiB
Python

"""Unified device abstraction for CUDA, MPS (Apple Silicon), and CPU.
All device-specific queries (availability, memory, cache management) go through
this module so the rest of the codebase never calls ``torch.cuda.*`` directly.
"""
from __future__ import annotations
import gc
import logging
import os
import platform
from dataclasses import dataclass
import torch
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Device detection
# ---------------------------------------------------------------------------
def is_cuda() -> bool:
"""True when at least one NVIDIA CUDA GPU is visible."""
return torch.cuda.is_available()
def is_mps() -> bool:
"""True when Apple Metal Performance Shaders backend is usable."""
return (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
)
def is_gpu_available() -> bool:
"""True if *any* GPU backend (CUDA or MPS) is available."""
return is_cuda() or is_mps()
def get_device(preference: str = "auto") -> str:
"""Resolve a device string.
Parameters
----------
preference : str
``"auto"`` picks the best GPU, ``"cuda"``/``"mps"``/``"cpu"`` forces.
Returns
-------
str
A PyTorch device string (``"cuda"``, ``"mps"``, or ``"cpu"``).
"""
if preference == "auto":
if is_cuda():
return "cuda"
if is_mps():
return "mps"
return "cpu"
return preference
def get_device_name() -> str:
"""Human-readable name of the current accelerator."""
if is_cuda():
return torch.cuda.get_device_name(0)
if is_mps():
# Apple doesn't expose a per-chip name via MPS; use platform info.
chip = platform.processor() or "Apple Silicon"
return f"Apple {chip} (MPS)"
return "CPU"
def device_count() -> int:
"""Number of accelerator devices (GPUs or MPS slots)."""
if is_cuda():
return torch.cuda.device_count()
if is_mps():
return 1 # MPS always exposes a single unified device
return 0
# ---------------------------------------------------------------------------
# Memory information
# ---------------------------------------------------------------------------
@dataclass
class MemoryInfo:
"""Snapshot of accelerator memory (in GB)."""
used_gb: float = 0.0
reserved_gb: float = 0.0
total_gb: float = 0.0
free_gb: float = 0.0
device_name: str = "CPU"
def _system_memory_gb() -> tuple[float, float]:
"""Return (total_gb, available_gb) of system RAM."""
try:
import psutil
vm = psutil.virtual_memory()
return vm.total / 1024 ** 3, vm.available / 1024 ** 3
except ImportError:
pass
try:
total = os.sysconf("SC_PHYS_PAGES") * os.sysconf("SC_PAGE_SIZE") / 1024 ** 3
# Rough estimate: assume 60 % available if we can't query
return total, total * 0.6
except (AttributeError, ValueError):
return 16.0, 8.0 # conservative fallback
def get_memory_info(device_index: int = 0) -> MemoryInfo:
"""Query memory for the given accelerator (or system RAM for MPS/CPU)."""
name = get_device_name()
if is_cuda():
try:
free, total = torch.cuda.mem_get_info(device_index)
used = torch.cuda.memory_allocated(device_index)
reserved = torch.cuda.memory_reserved(device_index)
total_gb = total / 1024 ** 3
return MemoryInfo(
used_gb=used / 1024 ** 3,
reserved_gb=reserved / 1024 ** 3,
total_gb=total_gb,
free_gb=free / 1024 ** 3,
device_name=name,
)
except Exception:
props = torch.cuda.get_device_properties(device_index)
total_gb = props.total_memory / 1024 ** 3
return MemoryInfo(total_gb=total_gb, free_gb=total_gb, device_name=name)
if is_mps():
# MPS uses unified memory — report system RAM as a proxy.
total, avail = _system_memory_gb()
# Apple's unified memory is shared with the OS, so usable fraction
# is typically ~65-75 % of total.
usable = total * 0.70
return MemoryInfo(
used_gb=max(usable - avail, 0.0),
reserved_gb=0.0,
total_gb=usable,
free_gb=min(avail, usable),
device_name=name,
)
# CPU-only
total, avail = _system_memory_gb()
return MemoryInfo(total_gb=total, free_gb=avail, device_name=name)
def get_total_free_gb() -> float:
"""Sum of free memory across all accelerator devices, in GB."""
if is_cuda():
total_free = 0.0
for i in range(torch.cuda.device_count()):
try:
free, _ = torch.cuda.mem_get_info(i)
total_free += free / 1024 ** 3
except Exception:
props = torch.cuda.get_device_properties(i)
total_free += props.total_memory / 1024 ** 3
return total_free
if is_mps():
_, avail = _system_memory_gb()
return avail * 0.70 # usable fraction
return 0.0
# ---------------------------------------------------------------------------
# Cache / memory management
# ---------------------------------------------------------------------------
def empty_cache() -> None:
"""Release cached allocations on the current accelerator."""
if is_cuda():
torch.cuda.empty_cache()
elif is_mps():
# torch.mps.empty_cache() available since PyTorch 2.1
if hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache()
def free_gpu_memory() -> None:
"""Aggressive memory cleanup: GC + accelerator cache flush."""
gc.collect()
if is_cuda():
try:
torch.cuda.empty_cache()
except Exception:
try:
torch.cuda.synchronize()
except Exception:
pass
try:
torch.cuda.reset_peak_memory_stats()
except Exception:
pass
elif is_mps():
if hasattr(torch.mps, "empty_cache"):
try:
torch.mps.empty_cache()
except Exception:
pass
if hasattr(torch.mps, "synchronize"):
try:
torch.mps.synchronize()
except Exception:
pass
def set_seed_all(seed: int) -> None:
"""Set random seed on all available accelerators."""
torch.manual_seed(seed)
if is_cuda():
torch.cuda.manual_seed_all(seed)
# MPS shares the CPU random state — no separate seed call needed.
# ---------------------------------------------------------------------------
# Dtype helpers
# ---------------------------------------------------------------------------
def default_dtype(device: str | None = None) -> torch.dtype:
"""Sensible default dtype for the given device."""
dev = device or get_device()
if dev == "cpu":
return torch.float32
return torch.float16
def supports_bfloat16(device: str | None = None) -> bool:
"""Whether *bfloat16* is natively supported on the target device."""
dev = device or get_device()
if dev.startswith("cuda"):
if is_cuda():
major, _ = torch.cuda.get_device_capability(0)
return major >= 8 # Ampere+
return False
if dev == "mps":
# MPS added bfloat16 support in PyTorch 2.3+
return hasattr(torch, "__version__") and tuple(
int(x) for x in torch.__version__.split(".")[:2]
) >= (2, 3)
return True # CPU supports bfloat16 on most modern platforms
def supports_float64(device: str | None = None) -> bool:
"""Whether *float64* is supported (MPS does NOT support it)."""
dev = device or get_device()
return dev != "mps"
def safe_svd_dtype(tensor: torch.Tensor) -> torch.dtype:
"""Return a dtype safe for SVD on the tensor's device.
MPS does not support float64, so we cap at float32.
"""
if tensor.device.type == "mps":
return torch.float32
return torch.float64 if tensor.dtype in (torch.float64, torch.float32) else torch.float32
# ---------------------------------------------------------------------------
# OOM exception matching
# ---------------------------------------------------------------------------
def is_oom_error(exc: BaseException) -> bool:
"""Return True if *exc* is an out-of-memory error on any backend."""
if isinstance(exc, torch.cuda.OutOfMemoryError):
return True
# MPS raises a generic RuntimeError containing "out of memory"
if isinstance(exc, RuntimeError) and "out of memory" in str(exc).lower():
return True
return False
# ---------------------------------------------------------------------------
# Quantization compatibility
# ---------------------------------------------------------------------------
def supports_bitsandbytes(device: str | None = None) -> bool:
"""BitsAndBytes requires NVIDIA CUDA — check that."""
dev = device or get_device()
return dev.startswith("cuda")
def supports_device_map_auto(device: str | None = None) -> bool:
"""Accelerate's device_map='auto' is only reliable on CUDA."""
dev = device or get_device()
return dev.startswith("cuda")
# ---------------------------------------------------------------------------
# CUDA env setup (called once at import time of abliterate.py)
# ---------------------------------------------------------------------------
def configure_cuda_alloc() -> None:
"""Set expandable_segments for CUDA if available."""
if is_cuda() and "PYTORCH_CUDA_ALLOC_CONF" not in os.environ:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"