mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-29 14:46:15 +02:00
51f621d0a2
The snapshot() deepcopy was cloning tensors on their original GPU devices, doubling VRAM usage. For a 234GB model sharded across 6 A100-80GB GPUs (~39GB each), this left no room for the copy. Now snapshot stores tensors on CPU and restore() moves them back to each parameter's current device.
700 lines
29 KiB
Python
700 lines
29 KiB
Python
"""Load HuggingFace models and wrap them for ablation."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import copy
|
||
import logging
|
||
import os
|
||
import tempfile
|
||
from dataclasses import dataclass, field
|
||
from typing import Optional
|
||
|
||
import sys as _sys
|
||
|
||
import torch
|
||
from obliteratus import device as dev
|
||
from transformers import (
|
||
AutoConfig,
|
||
AutoModelForCausalLM,
|
||
AutoModelForSequenceClassification,
|
||
AutoTokenizer,
|
||
PreTrainedModel,
|
||
PreTrainedTokenizerBase,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Compat shims for transformers ≥5.0 breaking changes.
|
||
#
|
||
# Many HuggingFace model repos ship custom modeling code (loaded via
|
||
# trust_remote_code=True) that imports symbols from their pre-5.x locations.
|
||
# We monkey-patch the old module paths so loading works without downgrading.
|
||
#
|
||
# Every section is wrapped in try/except so a failure in one shim never
|
||
# breaks unrelated functionality. Patches are purely additive — we never
|
||
# remove attributes that already exist.
|
||
# ---------------------------------------------------------------------------
|
||
|
||
# ── 1. utils.generic → utils.output_capturing ──────────────────────
|
||
# OutputRecorder, check_model_inputs, _CAN_RECORD_REGISTRY moved.
|
||
# Affected: MiniMax-M2.x, DeepSeek-V3
|
||
try:
|
||
import transformers.utils.generic as _tfu_generic
|
||
try:
|
||
from transformers.utils import output_capturing as _oc
|
||
for _old, _new in [
|
||
("OutputRecorder", "OutputRecorder"),
|
||
("check_model_inputs", "capture_outputs"),
|
||
("_CAN_RECORD_REGISTRY", "_CAN_RECORD_REGISTRY"),
|
||
]:
|
||
if not hasattr(_tfu_generic, _old) and hasattr(_oc, _new):
|
||
setattr(_tfu_generic, _old, getattr(_oc, _new))
|
||
except ImportError:
|
||
pass
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 2. utils.generic.working_or_temp_dir ───────────────────────────
|
||
# Removed in 5.x. Trivial contextmanager replacement.
|
||
# Affected: GLM-4 / ChatGLM custom code
|
||
try:
|
||
import transformers.utils.generic as _tfu_generic # noqa: F811 – may already be imported
|
||
if not hasattr(_tfu_generic, "working_or_temp_dir"):
|
||
import contextlib as _ctxlib
|
||
import tempfile as _tmpmod
|
||
|
||
@_ctxlib.contextmanager
|
||
def _working_or_temp_dir(working_dir=None):
|
||
if working_dir is not None:
|
||
yield working_dir
|
||
else:
|
||
with _tmpmod.TemporaryDirectory() as tmp:
|
||
yield tmp
|
||
|
||
_tfu_generic.working_or_temp_dir = _working_or_temp_dir
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 3. utils.import_utils: removed availability checks ─────────────
|
||
# is_torch_fx_available → removed (torch.fx always present in torch≥2.0)
|
||
# is_tf_available → removed (TF backend dropped in v5)
|
||
# is_flax_available → removed (Flax backend dropped in v5)
|
||
# is_safetensors_available→ removed (safetensors is now mandatory)
|
||
# Affected: various model repos that defensively check backends
|
||
try:
|
||
import transformers.utils.import_utils as _tfu_imports
|
||
_import_shims = {
|
||
"is_torch_fx_available": lambda: True,
|
||
"is_tf_available": lambda: False,
|
||
"is_flax_available": lambda: False,
|
||
"is_safetensors_available": lambda: True,
|
||
}
|
||
for _name, _fn in _import_shims.items():
|
||
if not hasattr(_tfu_imports, _name):
|
||
setattr(_tfu_imports, _name, _fn)
|
||
# Also patch the top-level transformers.utils re-export so both
|
||
# ``from transformers.utils import is_tf_available`` and
|
||
# ``from transformers.utils.import_utils import is_tf_available`` work.
|
||
try:
|
||
import transformers.utils as _tu
|
||
for _name, _fn in _import_shims.items():
|
||
if not hasattr(_tu, _name):
|
||
setattr(_tu, _name, _fn)
|
||
except Exception:
|
||
pass
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 4. pytorch_utils: removed version-check constants ──────────────
|
||
# ``is_torch_greater_or_equal_than_X_Y`` constants removed in v4.48+.
|
||
# Affected: DeepSeek-V3/R1/V2-Lite, MiniCPM3, older custom code
|
||
try:
|
||
import transformers.pytorch_utils as _pt_utils
|
||
# transformers ≥5.0 requires torch ≥2.0, so every historical gate is True.
|
||
for _ver in [
|
||
"is_torch_greater_or_equal_than_2_4",
|
||
"is_torch_greater_or_equal_than_2_3",
|
||
"is_torch_greater_or_equal_than_2_2",
|
||
"is_torch_greater_or_equal_than_2_1",
|
||
"is_torch_greater_or_equal_than_2_0",
|
||
"is_torch_greater_or_equal_than_1_13",
|
||
"is_torch_greater_or_equal_than_1_12",
|
||
"is_torch_greater_or_equal_than_1_11",
|
||
"is_torch_greater_or_equal_than_1_10",
|
||
"is_torch_greater_or_equal_than_1_9",
|
||
"is_torch_greater_or_equal_than_1_8",
|
||
"is_torch_greater_or_equal_than_1_6",
|
||
]:
|
||
if not hasattr(_pt_utils, _ver):
|
||
setattr(_pt_utils, _ver, True)
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 5. generation_utils module → transformers.generation ────────────
|
||
# Entire module removed; old custom code does
|
||
# ``from transformers.generation_utils import GenerationMixin``
|
||
# Affected: older generation-customising model repos
|
||
try:
|
||
import transformers.generation_utils # noqa: F401 – already exists
|
||
except ModuleNotFoundError:
|
||
try:
|
||
import transformers.generation as _gen
|
||
_sys.modules["transformers.generation_utils"] = _gen
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 6. deepspeed module → transformers.integrations.deepspeed ───────
|
||
# Affected: model repos with DeepSpeed training code
|
||
try:
|
||
import transformers.deepspeed # noqa: F401 – already exists
|
||
except ModuleNotFoundError:
|
||
try:
|
||
import transformers.integrations.deepspeed as _ds
|
||
_sys.modules["transformers.deepspeed"] = _ds
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 7. DynamicCache.get_max_length → get_max_cache_shape ───────────
|
||
# Removed in v4.49+. DeepSeek-V3/R1 custom code calls .get_max_length().
|
||
try:
|
||
from transformers.cache_utils import DynamicCache as _DC
|
||
if not hasattr(_DC, "get_max_length") and hasattr(_DC, "get_max_cache_shape"):
|
||
_DC.get_max_length = _DC.get_max_cache_shape
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 8. LogitsWarper → LogitsProcessor ──────────────────────────────
|
||
# LogitsWarper removed in v5.0 (deprecated v4.46). Drop-in alias.
|
||
# Affected: MiniCPM-o custom code
|
||
# NOTE: submodule patch runs here; top-level ``transformers.LogitsWarper``
|
||
# is deferred to _apply_deferred_shims() because the _LazyModule may reset
|
||
# its __dict__ during initial import.
|
||
try:
|
||
import transformers.generation.logits_process as _lp_mod
|
||
if not hasattr(_lp_mod, "LogitsWarper"):
|
||
from transformers.generation.logits_process import LogitsProcessor as _LP
|
||
_lp_mod.LogitsWarper = _LP
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 9. processing_utils._validate_images_text_input_order ──────────
|
||
# Removed in v5.0rc3. Kimi-VL custom code imports it.
|
||
try:
|
||
import transformers.processing_utils as _proc
|
||
if not hasattr(_proc, "_validate_images_text_input_order"):
|
||
def _validate_images_text_input_order(images=None, text=None, **kw):
|
||
return images, text
|
||
_proc._validate_images_text_input_order = _validate_images_text_input_order
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 10. TF/Flax weight constants (removed with TF backend) ─────────
|
||
try:
|
||
import transformers.utils as _tu # noqa: F811
|
||
for _cname, _cval in [
|
||
("TF_WEIGHTS_NAME", "tf_model.h5"),
|
||
("TF2_WEIGHTS_NAME", "tf_model.h5"),
|
||
]:
|
||
if not hasattr(_tu, _cname):
|
||
setattr(_tu, _cname, _cval)
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 11. file_utils.cached_path → huggingface_hub fallback ──────────
|
||
# Removed in v4.22. Very old model repos use it for file download.
|
||
try:
|
||
import transformers.file_utils as _fu
|
||
if not hasattr(_fu, "cached_path"):
|
||
def _cached_path_shim(url_or_filename, cache_dir=None, **kwargs):
|
||
"""Minimal shim: local paths pass through, HF paths download."""
|
||
if os.path.exists(str(url_or_filename)):
|
||
return str(url_or_filename)
|
||
try:
|
||
from huggingface_hub import hf_hub_download
|
||
parts = str(url_or_filename).rsplit("/", 1)
|
||
if len(parts) == 2:
|
||
return hf_hub_download(repo_id=parts[0], filename=parts[1],
|
||
cache_dir=cache_dir)
|
||
except Exception:
|
||
pass
|
||
return str(url_or_filename)
|
||
_fu.cached_path = _cached_path_shim
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
# ── Deferred shims ──────────────────────────────────────────────────
|
||
# Some patches must wait until the _LazyModule has fully initialized
|
||
# (it replaces its __dict__ during bootstrap). We apply these once,
|
||
# lazily, the first time load_model() is called.
|
||
_DEFERRED_SHIMS_APPLIED = False
|
||
|
||
|
||
def _apply_deferred_shims():
|
||
global _DEFERRED_SHIMS_APPLIED
|
||
if _DEFERRED_SHIMS_APPLIED:
|
||
return
|
||
_DEFERRED_SHIMS_APPLIED = True
|
||
|
||
tf_mod = _sys.modules.get("transformers")
|
||
if tf_mod is None:
|
||
return
|
||
|
||
# LogitsWarper → LogitsProcessor on the top-level transformers namespace
|
||
try:
|
||
if not hasattr(tf_mod, "LogitsWarper"):
|
||
from transformers.generation.logits_process import LogitsProcessor
|
||
tf_mod.__dict__["LogitsWarper"] = LogitsProcessor
|
||
if hasattr(tf_mod, "_objects"):
|
||
tf_mod._objects["LogitsWarper"] = LogitsProcessor
|
||
except Exception:
|
||
pass
|
||
|
||
# is_tf_available / is_flax_available / is_safetensors_available
|
||
# on the top-level namespace (complements shim 3 which patches submodules)
|
||
try:
|
||
for name, val in [
|
||
("is_tf_available", lambda: False),
|
||
("is_flax_available", lambda: False),
|
||
("is_safetensors_available", lambda: True),
|
||
]:
|
||
if not hasattr(tf_mod, name):
|
||
tf_mod.__dict__[name] = val
|
||
if hasattr(tf_mod, "_objects"):
|
||
tf_mod._objects[name] = val
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
TASK_MODEL_MAP = {
|
||
"causal_lm": AutoModelForCausalLM,
|
||
"classification": AutoModelForSequenceClassification,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ModelHandle:
|
||
"""Wrapper around a HF model + tokenizer with metadata useful for ablation."""
|
||
|
||
model: PreTrainedModel
|
||
tokenizer: PreTrainedTokenizerBase
|
||
config: AutoConfig
|
||
model_name: str
|
||
task: str
|
||
architecture: str = ""
|
||
num_layers: int = 0
|
||
num_heads: int = 0
|
||
hidden_size: int = 0
|
||
intermediate_size: int = 0
|
||
_original_state: Optional[dict] = field(default=None, repr=False)
|
||
_offload_dir: Optional[str] = field(default=None, repr=False)
|
||
|
||
def __post_init__(self):
|
||
cfg = self.config
|
||
self.architecture = cfg.model_type
|
||
# For composite configs (e.g. VL models like Qwen3.5), the text model
|
||
# attributes live under a nested text_config. Fall through to it when
|
||
# the top-level config doesn't have the standard attributes.
|
||
text_cfg = getattr(cfg, "text_config", None)
|
||
self.num_layers = getattr(cfg, "num_hidden_layers", 0) or (
|
||
getattr(text_cfg, "num_hidden_layers", 0) if text_cfg else 0
|
||
)
|
||
self.num_heads = getattr(cfg, "num_attention_heads", 0) or (
|
||
getattr(text_cfg, "num_attention_heads", 0) if text_cfg else 0
|
||
)
|
||
self.hidden_size = getattr(cfg, "hidden_size", 0) or (
|
||
getattr(text_cfg, "hidden_size", 0) if text_cfg else 0
|
||
)
|
||
self.intermediate_size = getattr(cfg, "intermediate_size", 0) or (
|
||
getattr(text_cfg, "intermediate_size", 0) if text_cfg else 0
|
||
)
|
||
|
||
def snapshot(self):
|
||
"""Save a copy of the model state dict so we can restore after ablation.
|
||
|
||
Tensors are moved to CPU to avoid doubling GPU memory usage on
|
||
multi-GPU (device_map) setups.
|
||
"""
|
||
self._original_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
|
||
|
||
def restore(self):
|
||
"""Restore the model to the snapshot state.
|
||
|
||
Moves CPU-saved tensors back to each parameter's current device.
|
||
"""
|
||
if self._original_state is None:
|
||
raise RuntimeError("No snapshot to restore — call .snapshot() first.")
|
||
# Map each key to the device where the model currently holds it
|
||
current_state = self.model.state_dict()
|
||
restored = {}
|
||
for k, v in self._original_state.items():
|
||
target = current_state[k].device if k in current_state else None
|
||
restored[k] = v.to(target) if target is not None else v
|
||
self.model.load_state_dict(restored)
|
||
|
||
def cleanup(self):
|
||
"""Remove temporary offload directory if one was auto-created."""
|
||
if self._offload_dir is not None:
|
||
import shutil
|
||
try:
|
||
shutil.rmtree(self._offload_dir, ignore_errors=True)
|
||
except Exception:
|
||
pass
|
||
self._offload_dir = None
|
||
|
||
def __del__(self):
|
||
self.cleanup()
|
||
|
||
def summary(self) -> dict:
|
||
return {
|
||
"model_name": self.model_name,
|
||
"architecture": self.architecture,
|
||
"task": self.task,
|
||
"num_layers": self.num_layers,
|
||
"num_heads": self.num_heads,
|
||
"hidden_size": self.hidden_size,
|
||
"intermediate_size": self.intermediate_size,
|
||
"total_params": sum(p.numel() for p in self.model.parameters()),
|
||
}
|
||
|
||
|
||
def _estimate_model_memory_gb(config: AutoConfig, dtype: torch.dtype) -> float:
|
||
"""Rough estimate of model weight memory in GB."""
|
||
# Estimate total params from config. For composite configs (VL models),
|
||
# fall through to text_config when top-level attributes are missing.
|
||
text_cfg = getattr(config, "text_config", None)
|
||
hidden = getattr(config, "hidden_size", 0) or (
|
||
getattr(text_cfg, "hidden_size", 0) if text_cfg else 0
|
||
)
|
||
n_layers = getattr(config, "num_hidden_layers", 0) or (
|
||
getattr(text_cfg, "num_hidden_layers", 0) if text_cfg else 0
|
||
)
|
||
intermediate = getattr(config, "intermediate_size", 0) or (
|
||
getattr(text_cfg, "intermediate_size", hidden * 4) if text_cfg else hidden * 4
|
||
)
|
||
vocab = getattr(config, "vocab_size", 0) or (
|
||
getattr(text_cfg, "vocab_size", 0) if text_cfg else 0
|
||
)
|
||
|
||
if hidden == 0 or n_layers == 0:
|
||
return 0.0
|
||
|
||
# For MoE models, the FFN is replicated per expert
|
||
num_experts = getattr(config, "num_local_experts", None) or getattr(config, "num_experts", 1)
|
||
|
||
# Per layer: attn (4 * hidden^2) + ffn (3 * hidden * intermediate * num_experts) + norms
|
||
per_layer = 4 * hidden * hidden + num_experts * 3 * hidden * intermediate
|
||
# Embedding + LM head
|
||
embedding = 2 * vocab * hidden
|
||
total_params = per_layer * n_layers + embedding
|
||
|
||
bytes_per_param = {torch.float32: 4, torch.float16: 2, torch.bfloat16: 2}.get(dtype, 2)
|
||
return total_params * bytes_per_param / (1024 ** 3)
|
||
|
||
|
||
def _available_gpu_memory_gb() -> float:
|
||
"""Return free accelerator memory in GB (CUDA, MPS, or 0 for CPU)."""
|
||
return dev.get_total_free_gb()
|
||
|
||
|
||
def _hf_token() -> str | None:
|
||
"""Return the HF_TOKEN from environment, or None."""
|
||
return os.environ.get("HF_TOKEN") or None
|
||
|
||
|
||
def load_model(
|
||
model_name: str,
|
||
task: str = "causal_lm",
|
||
device: str = "auto",
|
||
dtype: str = "float32",
|
||
trust_remote_code: bool = False,
|
||
num_labels: int = 2,
|
||
quantization: str | None = None,
|
||
offload_folder: str | None = None,
|
||
skip_snapshot: bool | None = None,
|
||
) -> ModelHandle:
|
||
"""Load a HuggingFace model and tokenizer, returning a ModelHandle.
|
||
|
||
Args:
|
||
model_name: HuggingFace model identifier (e.g. "gpt2", "meta-llama/Llama-2-7b-hf").
|
||
task: One of "causal_lm", "classification".
|
||
device: Torch device string. "auto" uses accelerate's device_map.
|
||
dtype: Weight dtype — "float32", "float16", "bfloat16".
|
||
trust_remote_code: Whether to trust remote code from the Hub.
|
||
num_labels: Number of labels for classification tasks.
|
||
quantization: None, "4bit", or "8bit". Requires bitsandbytes.
|
||
offload_folder: Directory for disk offloading when model exceeds GPU memory.
|
||
If None and offloading is needed, a temp directory is created automatically.
|
||
skip_snapshot: Controls initial state dict snapshot.
|
||
None (default): auto-decide based on GPU memory headroom.
|
||
True: always skip (saves memory).
|
||
False: always snapshot (force even for large models).
|
||
"""
|
||
_apply_deferred_shims()
|
||
|
||
if task not in TASK_MODEL_MAP:
|
||
raise ValueError(f"Unknown task {task!r}. Choose from {list(TASK_MODEL_MAP)}")
|
||
|
||
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
|
||
if dtype not in dtype_map:
|
||
raise ValueError(f"Unknown dtype {dtype!r}. Choose from {list(dtype_map)}")
|
||
torch_dtype = dtype_map[dtype]
|
||
|
||
token = _hf_token()
|
||
|
||
try:
|
||
config = AutoConfig.from_pretrained(
|
||
model_name, trust_remote_code=trust_remote_code, token=token,
|
||
)
|
||
except PermissionError:
|
||
fallback_cache = os.path.join(tempfile.gettempdir(), "hf_home", "hub")
|
||
os.makedirs(fallback_cache, exist_ok=True)
|
||
config = AutoConfig.from_pretrained(
|
||
model_name, trust_remote_code=trust_remote_code, cache_dir=fallback_cache,
|
||
token=token,
|
||
)
|
||
except OSError as e:
|
||
# Gated repo access denied — provide a clear, actionable error.
|
||
err_msg = str(e)
|
||
if "gated repo" in err_msg.lower() or "access to model" in err_msg.lower():
|
||
raise RuntimeError(
|
||
f"Access denied for gated model '{model_name}'.\n\n"
|
||
f"This model requires you to:\n"
|
||
f" 1. Accept the license at https://huggingface.co/{model_name}\n"
|
||
f" 2. Set your HF_TOKEN: export HF_TOKEN=hf_...\n"
|
||
f" (or add it to your HF Space secrets)\n\n"
|
||
f"Token {'is' if token else 'is NOT'} currently set."
|
||
) from e
|
||
raise
|
||
except (ValueError, KeyError) as e:
|
||
# Unrecognized model_type — don't silently escalate trust_remote_code.
|
||
# Provide a clear error with guidance instead.
|
||
raise RuntimeError(
|
||
f"Architecture '{model_name}' is not recognized by transformers "
|
||
f"{__import__('transformers').__version__}. "
|
||
f"Try: pip install --upgrade transformers\n"
|
||
f"If this model requires custom code, pass trust_remote_code=True explicitly."
|
||
) from e
|
||
|
||
# Memory estimation and warnings (skip for natively quantized models — estimate is wrong)
|
||
native_quant = getattr(config, "quantization_config", None)
|
||
est_gb = _estimate_model_memory_gb(config, torch_dtype) if native_quant is None else 0.0
|
||
gpu_gb = _available_gpu_memory_gb()
|
||
if est_gb > 0 and gpu_gb > 0:
|
||
logger.info(f"Estimated model size: {est_gb:.1f} GB | Available GPU: {gpu_gb:.1f} GB")
|
||
if est_gb > gpu_gb * 0.9 and quantization is None:
|
||
logger.warning(
|
||
f"Model (~{est_gb:.0f} GB) may exceed GPU memory ({gpu_gb:.0f} GB). "
|
||
f"Consider using quantization='4bit' or quantization='8bit'."
|
||
)
|
||
|
||
model_cls = TASK_MODEL_MAP[task]
|
||
load_kwargs: dict = {
|
||
"pretrained_model_name_or_path": model_name,
|
||
"config": config,
|
||
"torch_dtype": torch_dtype,
|
||
"trust_remote_code": trust_remote_code,
|
||
"token": token,
|
||
}
|
||
if task == "classification":
|
||
config.num_labels = num_labels
|
||
load_kwargs["config"] = config
|
||
|
||
# Quantization support (requires bitsandbytes)
|
||
if native_quant is not None:
|
||
# Model ships with native quantization (e.g. Mxfp4Config) — don't layer BitsAndBytes
|
||
# on top, and don't override its compute dtype with our torch_dtype
|
||
logger.info(
|
||
f"Model has native quantization ({type(native_quant).__name__}), "
|
||
f"skipping BitsAndBytes and using model's native dtype"
|
||
)
|
||
load_kwargs.pop("torch_dtype", None)
|
||
load_kwargs["device_map"] = "auto"
|
||
elif quantization in ("4bit", "8bit"):
|
||
# BitsAndBytes only works on NVIDIA CUDA GPUs.
|
||
resolved_device = dev.get_device(device)
|
||
if not dev.supports_bitsandbytes(resolved_device):
|
||
logger.warning(
|
||
"BitsAndBytes quantization is not supported on %s. "
|
||
"Loading in %s instead.",
|
||
resolved_device, dtype,
|
||
)
|
||
# On MPS, load normally to the device; on CPU, fall through.
|
||
if resolved_device == "mps":
|
||
device = "mps"
|
||
# Don't set quantization_config — fall through to normal loading.
|
||
else:
|
||
try:
|
||
import bitsandbytes # noqa: F401
|
||
except ImportError:
|
||
raise RuntimeError(
|
||
f"Quantization '{quantization}' requires bitsandbytes: "
|
||
f"pip install -U bitsandbytes>=0.46.1"
|
||
)
|
||
from transformers import BitsAndBytesConfig
|
||
|
||
# Enable fp32 CPU offload so that models too large to fit entirely on
|
||
# GPU (even quantized) can spill to CPU without crashing bitsandbytes.
|
||
# This is critical for frontier MoE models (GLM-5 744B, DeepSeek-V3 685B,
|
||
# Mistral Large 3 675B, etc.) on single-GPU setups.
|
||
if quantization == "4bit":
|
||
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||
load_in_4bit=True,
|
||
bnb_4bit_compute_dtype=torch_dtype,
|
||
bnb_4bit_quant_type="nf4",
|
||
llm_int8_enable_fp32_cpu_offload=True,
|
||
)
|
||
else:
|
||
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||
load_in_8bit=True,
|
||
llm_int8_enable_fp32_cpu_offload=True,
|
||
)
|
||
load_kwargs["device_map"] = "auto"
|
||
|
||
# device_map="auto" is only reliable on CUDA (accelerate doesn't support MPS).
|
||
if "device_map" not in load_kwargs and device == "auto":
|
||
resolved_device = dev.get_device(device)
|
||
if dev.supports_device_map_auto(resolved_device):
|
||
load_kwargs["device_map"] = "auto"
|
||
else:
|
||
# MPS / CPU: load to CPU first, then .to(device) after loading.
|
||
pass
|
||
|
||
# Offload support: provide a folder for disk offloading when GPU memory is insufficient
|
||
_offload_dir = None
|
||
if load_kwargs.get("device_map") == "auto":
|
||
if offload_folder:
|
||
_offload_dir = offload_folder
|
||
load_kwargs["offload_folder"] = offload_folder
|
||
else:
|
||
# Auto-create a temp offload dir so from_pretrained never crashes
|
||
# when Accelerate needs disk offloading
|
||
_offload_dir = tempfile.mkdtemp(prefix="obliteratus_offload_")
|
||
load_kwargs["offload_folder"] = _offload_dir
|
||
logger.info(f"Auto-created offload folder: {_offload_dir}")
|
||
|
||
# Reserve GPU headroom for inference (KV cache, activations, generate()).
|
||
# Without this, device_map="auto" packs 100% of layers onto GPU, leaving
|
||
# no room for forward passes or generation on tight-memory setups.
|
||
if dev.is_cuda():
|
||
max_memory = {}
|
||
for i in range(dev.device_count()):
|
||
total = torch.cuda.get_device_properties(i).total_memory
|
||
# Reserve 15% or 2 GiB (whichever is larger) for inference headroom
|
||
reserve = max(int(total * 0.15), 2 * 1024 ** 3)
|
||
usable = total - reserve
|
||
max_memory[i] = f"{usable // (1024 ** 2)}MiB"
|
||
# Allow overflow to CPU RAM, capped at 85% of physical memory
|
||
# to leave room for the OS, Python runtime, and serialization buffers.
|
||
total_ram, _ = dev._system_memory_gb()
|
||
cpu_budget_gb = int(total_ram * 0.85)
|
||
max_memory["cpu"] = f"{max(cpu_budget_gb, 4)}GiB"
|
||
load_kwargs["max_memory"] = max_memory
|
||
logger.info(
|
||
f"GPU memory budget: {', '.join(f'GPU{k}={v}' for k, v in max_memory.items() if k != 'cpu')}"
|
||
)
|
||
|
||
try:
|
||
model = model_cls.from_pretrained(**load_kwargs)
|
||
except OSError as e:
|
||
err_msg = str(e)
|
||
if "gated repo" in err_msg.lower() or "access to model" in err_msg.lower():
|
||
raise RuntimeError(
|
||
f"Access denied for gated model '{model_name}'.\n\n"
|
||
f"This model requires you to:\n"
|
||
f" 1. Accept the license at https://huggingface.co/{model_name}\n"
|
||
f" 2. Set your HF_TOKEN: export HF_TOKEN=hf_...\n"
|
||
f" (or add it to your HF Space secrets)\n\n"
|
||
f"Token {'is' if token else 'is NOT'} currently set."
|
||
) from e
|
||
raise
|
||
except PermissionError as e:
|
||
# Cache dir (typically ~/.cache/huggingface) is not writable — common in
|
||
# containers running as UID with no home dir. Retry with /tmp cache.
|
||
logger.warning(
|
||
"PermissionError loading model (%s). Retrying with cache_dir=/tmp/hf_home/hub", e
|
||
)
|
||
fallback_cache = os.path.join(tempfile.gettempdir(), "hf_home", "hub")
|
||
os.makedirs(fallback_cache, exist_ok=True)
|
||
load_kwargs["cache_dir"] = fallback_cache
|
||
model = model_cls.from_pretrained(**load_kwargs)
|
||
except (ValueError, KeyError) as e:
|
||
err_msg = str(e)
|
||
if "does not recognize this architecture" in err_msg or "model type" in err_msg:
|
||
model_type = getattr(config, "model_type", "unknown")
|
||
raise RuntimeError(
|
||
f"Model architecture '{model_type}' is not supported by transformers "
|
||
f"{__import__('transformers').__version__}. "
|
||
f"Run: pip install --upgrade transformers\n"
|
||
f"If this model was released very recently, it may require "
|
||
f"pip install git+https://github.com/huggingface/transformers.git"
|
||
) from e
|
||
raise
|
||
|
||
if device not in ("auto",) and quantization is None and native_quant is None:
|
||
model = model.to(device)
|
||
elif device == "auto" and not dev.supports_device_map_auto():
|
||
# MPS / CPU: device_map wasn't used, move model to best device.
|
||
resolved = dev.get_device()
|
||
model = model.to(resolved)
|
||
|
||
model.eval()
|
||
|
||
# Free accelerator cache after loading
|
||
dev.empty_cache()
|
||
|
||
try:
|
||
tokenizer = AutoTokenizer.from_pretrained(
|
||
model_name, trust_remote_code=trust_remote_code, token=token,
|
||
)
|
||
except PermissionError:
|
||
fallback_cache = os.path.join(tempfile.gettempdir(), "hf_home", "hub")
|
||
tokenizer = AutoTokenizer.from_pretrained(
|
||
model_name, trust_remote_code=trust_remote_code, cache_dir=fallback_cache,
|
||
token=token,
|
||
)
|
||
if tokenizer.pad_token is None:
|
||
tokenizer.pad_token = tokenizer.eos_token
|
||
|
||
handle = ModelHandle(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
config=config,
|
||
model_name=model_name,
|
||
task=task,
|
||
_offload_dir=_offload_dir,
|
||
)
|
||
|
||
# Skip snapshot for large models to avoid doubling memory usage
|
||
if skip_snapshot is True:
|
||
pass # user explicitly opted out
|
||
elif skip_snapshot is False:
|
||
handle.snapshot() # user explicitly forced snapshot
|
||
else:
|
||
# Auto-decide: skip when GPU is present and memory is tight.
|
||
# For natively quantized models (est_gb==0), check actual GPU usage instead.
|
||
if gpu_gb > 0 and native_quant is not None:
|
||
# Model is pre-quantized but we can't estimate its true size.
|
||
# Check actual free memory after loading — if less than 40% free, skip snapshot.
|
||
free_gb = dev.get_total_free_gb()
|
||
if free_gb < gpu_gb * 0.4:
|
||
logger.warning(
|
||
f"Auto-skipping state dict snapshot for natively quantized model "
|
||
f"(free GPU: {free_gb:.1f} GB / {gpu_gb:.1f} GB). "
|
||
f"Use skip_snapshot=False to force."
|
||
)
|
||
else:
|
||
handle.snapshot()
|
||
elif gpu_gb > 0 and est_gb > 0 and est_gb > gpu_gb * 0.5:
|
||
logger.warning(
|
||
f"Auto-skipping state dict snapshot to save memory "
|
||
f"(model ~{est_gb:.0f} GB vs GPU {gpu_gb:.0f} GB). "
|
||
f"Use skip_snapshot=False to force."
|
||
)
|
||
else:
|
||
handle.snapshot()
|
||
|
||
return handle
|