mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-30 15:07:51 +02:00
689 lines
28 KiB
Python
689 lines
28 KiB
Python
"""Load HuggingFace models and wrap them for ablation."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import copy
|
||
import logging
|
||
import os
|
||
import tempfile
|
||
from dataclasses import dataclass, field
|
||
from typing import Optional
|
||
|
||
import sys as _sys
|
||
|
||
import torch
|
||
from transformers import (
|
||
AutoConfig,
|
||
AutoModelForCausalLM,
|
||
AutoModelForSequenceClassification,
|
||
AutoTokenizer,
|
||
PreTrainedModel,
|
||
PreTrainedTokenizerBase,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Compat shims for transformers ≥5.0 breaking changes.
|
||
#
|
||
# Many HuggingFace model repos ship custom modeling code (loaded via
|
||
# trust_remote_code=True) that imports symbols from their pre-5.x locations.
|
||
# We monkey-patch the old module paths so loading works without downgrading.
|
||
#
|
||
# Every section is wrapped in try/except so a failure in one shim never
|
||
# breaks unrelated functionality. Patches are purely additive — we never
|
||
# remove attributes that already exist.
|
||
# ---------------------------------------------------------------------------
|
||
|
||
# ── 1. utils.generic → utils.output_capturing ──────────────────────
|
||
# OutputRecorder, check_model_inputs, _CAN_RECORD_REGISTRY moved.
|
||
# Affected: MiniMax-M2.x, DeepSeek-V3
|
||
try:
|
||
import transformers.utils.generic as _tfu_generic
|
||
try:
|
||
from transformers.utils import output_capturing as _oc
|
||
for _old, _new in [
|
||
("OutputRecorder", "OutputRecorder"),
|
||
("check_model_inputs", "capture_outputs"),
|
||
("_CAN_RECORD_REGISTRY", "_CAN_RECORD_REGISTRY"),
|
||
]:
|
||
if not hasattr(_tfu_generic, _old) and hasattr(_oc, _new):
|
||
setattr(_tfu_generic, _old, getattr(_oc, _new))
|
||
except ImportError:
|
||
pass
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 2. utils.generic.working_or_temp_dir ───────────────────────────
|
||
# Removed in 5.x. Trivial contextmanager replacement.
|
||
# Affected: GLM-4 / ChatGLM custom code
|
||
try:
|
||
import transformers.utils.generic as _tfu_generic # noqa: F811 – may already be imported
|
||
if not hasattr(_tfu_generic, "working_or_temp_dir"):
|
||
import contextlib as _ctxlib
|
||
import tempfile as _tmpmod
|
||
|
||
@_ctxlib.contextmanager
|
||
def _working_or_temp_dir(working_dir=None):
|
||
if working_dir is not None:
|
||
yield working_dir
|
||
else:
|
||
with _tmpmod.TemporaryDirectory() as tmp:
|
||
yield tmp
|
||
|
||
_tfu_generic.working_or_temp_dir = _working_or_temp_dir
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 3. utils.import_utils: removed availability checks ─────────────
|
||
# is_torch_fx_available → removed (torch.fx always present in torch≥2.0)
|
||
# is_tf_available → removed (TF backend dropped in v5)
|
||
# is_flax_available → removed (Flax backend dropped in v5)
|
||
# is_safetensors_available→ removed (safetensors is now mandatory)
|
||
# Affected: various model repos that defensively check backends
|
||
try:
|
||
import transformers.utils.import_utils as _tfu_imports
|
||
_import_shims = {
|
||
"is_torch_fx_available": lambda: True,
|
||
"is_tf_available": lambda: False,
|
||
"is_flax_available": lambda: False,
|
||
"is_safetensors_available": lambda: True,
|
||
}
|
||
for _name, _fn in _import_shims.items():
|
||
if not hasattr(_tfu_imports, _name):
|
||
setattr(_tfu_imports, _name, _fn)
|
||
# Also patch the top-level transformers.utils re-export so both
|
||
# ``from transformers.utils import is_tf_available`` and
|
||
# ``from transformers.utils.import_utils import is_tf_available`` work.
|
||
try:
|
||
import transformers.utils as _tu
|
||
for _name, _fn in _import_shims.items():
|
||
if not hasattr(_tu, _name):
|
||
setattr(_tu, _name, _fn)
|
||
except Exception:
|
||
pass
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 4. pytorch_utils: removed version-check constants ──────────────
|
||
# ``is_torch_greater_or_equal_than_X_Y`` constants removed in v4.48+.
|
||
# Affected: DeepSeek-V3/R1/V2-Lite, MiniCPM3, older custom code
|
||
try:
|
||
import transformers.pytorch_utils as _pt_utils
|
||
# transformers ≥5.0 requires torch ≥2.0, so every historical gate is True.
|
||
for _ver in [
|
||
"is_torch_greater_or_equal_than_2_4",
|
||
"is_torch_greater_or_equal_than_2_3",
|
||
"is_torch_greater_or_equal_than_2_2",
|
||
"is_torch_greater_or_equal_than_2_1",
|
||
"is_torch_greater_or_equal_than_2_0",
|
||
"is_torch_greater_or_equal_than_1_13",
|
||
"is_torch_greater_or_equal_than_1_12",
|
||
"is_torch_greater_or_equal_than_1_11",
|
||
"is_torch_greater_or_equal_than_1_10",
|
||
"is_torch_greater_or_equal_than_1_9",
|
||
"is_torch_greater_or_equal_than_1_8",
|
||
"is_torch_greater_or_equal_than_1_6",
|
||
]:
|
||
if not hasattr(_pt_utils, _ver):
|
||
setattr(_pt_utils, _ver, True)
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 5. generation_utils module → transformers.generation ────────────
|
||
# Entire module removed; old custom code does
|
||
# ``from transformers.generation_utils import GenerationMixin``
|
||
# Affected: older generation-customising model repos
|
||
try:
|
||
import transformers.generation_utils # noqa: F401 – already exists
|
||
except ModuleNotFoundError:
|
||
try:
|
||
import transformers.generation as _gen
|
||
_sys.modules["transformers.generation_utils"] = _gen
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 6. deepspeed module → transformers.integrations.deepspeed ───────
|
||
# Affected: model repos with DeepSpeed training code
|
||
try:
|
||
import transformers.deepspeed # noqa: F401 – already exists
|
||
except ModuleNotFoundError:
|
||
try:
|
||
import transformers.integrations.deepspeed as _ds
|
||
_sys.modules["transformers.deepspeed"] = _ds
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 7. DynamicCache.get_max_length → get_max_cache_shape ───────────
|
||
# Removed in v4.49+. DeepSeek-V3/R1 custom code calls .get_max_length().
|
||
try:
|
||
from transformers.cache_utils import DynamicCache as _DC
|
||
if not hasattr(_DC, "get_max_length") and hasattr(_DC, "get_max_cache_shape"):
|
||
_DC.get_max_length = _DC.get_max_cache_shape
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 8. LogitsWarper → LogitsProcessor ──────────────────────────────
|
||
# LogitsWarper removed in v5.0 (deprecated v4.46). Drop-in alias.
|
||
# Affected: MiniCPM-o custom code
|
||
# NOTE: submodule patch runs here; top-level ``transformers.LogitsWarper``
|
||
# is deferred to _apply_deferred_shims() because the _LazyModule may reset
|
||
# its __dict__ during initial import.
|
||
try:
|
||
import transformers.generation.logits_process as _lp_mod
|
||
if not hasattr(_lp_mod, "LogitsWarper"):
|
||
from transformers.generation.logits_process import LogitsProcessor as _LP
|
||
_lp_mod.LogitsWarper = _LP
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 9. processing_utils._validate_images_text_input_order ──────────
|
||
# Removed in v5.0rc3. Kimi-VL custom code imports it.
|
||
try:
|
||
import transformers.processing_utils as _proc
|
||
if not hasattr(_proc, "_validate_images_text_input_order"):
|
||
def _validate_images_text_input_order(images=None, text=None, **kw):
|
||
return images, text
|
||
_proc._validate_images_text_input_order = _validate_images_text_input_order
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 10. TF/Flax weight constants (removed with TF backend) ─────────
|
||
try:
|
||
import transformers.utils as _tu # noqa: F811
|
||
for _cname, _cval in [
|
||
("TF_WEIGHTS_NAME", "tf_model.h5"),
|
||
("TF2_WEIGHTS_NAME", "tf_model.h5"),
|
||
]:
|
||
if not hasattr(_tu, _cname):
|
||
setattr(_tu, _cname, _cval)
|
||
except Exception:
|
||
pass
|
||
|
||
# ── 11. file_utils.cached_path → huggingface_hub fallback ──────────
|
||
# Removed in v4.22. Very old model repos use it for file download.
|
||
try:
|
||
import transformers.file_utils as _fu
|
||
if not hasattr(_fu, "cached_path"):
|
||
def _cached_path_shim(url_or_filename, cache_dir=None, **kwargs):
|
||
"""Minimal shim: local paths pass through, HF paths download."""
|
||
if os.path.exists(str(url_or_filename)):
|
||
return str(url_or_filename)
|
||
try:
|
||
from huggingface_hub import hf_hub_download
|
||
parts = str(url_or_filename).rsplit("/", 1)
|
||
if len(parts) == 2:
|
||
return hf_hub_download(repo_id=parts[0], filename=parts[1],
|
||
cache_dir=cache_dir)
|
||
except Exception:
|
||
pass
|
||
return str(url_or_filename)
|
||
_fu.cached_path = _cached_path_shim
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
# ── Deferred shims ──────────────────────────────────────────────────
|
||
# Some patches must wait until the _LazyModule has fully initialized
|
||
# (it replaces its __dict__ during bootstrap). We apply these once,
|
||
# lazily, the first time load_model() is called.
|
||
_DEFERRED_SHIMS_APPLIED = False
|
||
|
||
|
||
def _apply_deferred_shims():
|
||
global _DEFERRED_SHIMS_APPLIED
|
||
if _DEFERRED_SHIMS_APPLIED:
|
||
return
|
||
_DEFERRED_SHIMS_APPLIED = True
|
||
|
||
tf_mod = _sys.modules.get("transformers")
|
||
if tf_mod is None:
|
||
return
|
||
|
||
# LogitsWarper → LogitsProcessor on the top-level transformers namespace
|
||
try:
|
||
if not hasattr(tf_mod, "LogitsWarper"):
|
||
from transformers.generation.logits_process import LogitsProcessor
|
||
tf_mod.__dict__["LogitsWarper"] = LogitsProcessor
|
||
if hasattr(tf_mod, "_objects"):
|
||
tf_mod._objects["LogitsWarper"] = LogitsProcessor
|
||
except Exception:
|
||
pass
|
||
|
||
# is_tf_available / is_flax_available / is_safetensors_available
|
||
# on the top-level namespace (complements shim 3 which patches submodules)
|
||
try:
|
||
for name, val in [
|
||
("is_tf_available", lambda: False),
|
||
("is_flax_available", lambda: False),
|
||
("is_safetensors_available", lambda: True),
|
||
]:
|
||
if not hasattr(tf_mod, name):
|
||
tf_mod.__dict__[name] = val
|
||
if hasattr(tf_mod, "_objects"):
|
||
tf_mod._objects[name] = val
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
TASK_MODEL_MAP = {
|
||
"causal_lm": AutoModelForCausalLM,
|
||
"classification": AutoModelForSequenceClassification,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ModelHandle:
|
||
"""Wrapper around a HF model + tokenizer with metadata useful for ablation."""
|
||
|
||
model: PreTrainedModel
|
||
tokenizer: PreTrainedTokenizerBase
|
||
config: AutoConfig
|
||
model_name: str
|
||
task: str
|
||
architecture: str = ""
|
||
num_layers: int = 0
|
||
num_heads: int = 0
|
||
hidden_size: int = 0
|
||
intermediate_size: int = 0
|
||
_original_state: Optional[dict] = field(default=None, repr=False)
|
||
_offload_dir: Optional[str] = field(default=None, repr=False)
|
||
|
||
def __post_init__(self):
|
||
cfg = self.config
|
||
self.architecture = cfg.model_type
|
||
# For composite configs (e.g. VL models like Qwen3.5), the text model
|
||
# attributes live under a nested text_config. Fall through to it when
|
||
# the top-level config doesn't have the standard attributes.
|
||
text_cfg = getattr(cfg, "text_config", None)
|
||
self.num_layers = getattr(cfg, "num_hidden_layers", 0) or (
|
||
getattr(text_cfg, "num_hidden_layers", 0) if text_cfg else 0
|
||
)
|
||
self.num_heads = getattr(cfg, "num_attention_heads", 0) or (
|
||
getattr(text_cfg, "num_attention_heads", 0) if text_cfg else 0
|
||
)
|
||
self.hidden_size = getattr(cfg, "hidden_size", 0) or (
|
||
getattr(text_cfg, "hidden_size", 0) if text_cfg else 0
|
||
)
|
||
self.intermediate_size = getattr(cfg, "intermediate_size", 0) or (
|
||
getattr(text_cfg, "intermediate_size", 0) if text_cfg else 0
|
||
)
|
||
|
||
def snapshot(self):
|
||
"""Save a deep copy of the model state dict so we can restore after ablation."""
|
||
self._original_state = copy.deepcopy(self.model.state_dict())
|
||
|
||
def restore(self):
|
||
"""Restore the model to the snapshot state."""
|
||
if self._original_state is None:
|
||
raise RuntimeError("No snapshot to restore — call .snapshot() first.")
|
||
self.model.load_state_dict(self._original_state)
|
||
|
||
def cleanup(self):
|
||
"""Remove temporary offload directory if one was auto-created."""
|
||
if self._offload_dir is not None:
|
||
import shutil
|
||
try:
|
||
shutil.rmtree(self._offload_dir, ignore_errors=True)
|
||
except Exception:
|
||
pass
|
||
self._offload_dir = None
|
||
|
||
def __del__(self):
|
||
self.cleanup()
|
||
|
||
def summary(self) -> dict:
|
||
return {
|
||
"model_name": self.model_name,
|
||
"architecture": self.architecture,
|
||
"task": self.task,
|
||
"num_layers": self.num_layers,
|
||
"num_heads": self.num_heads,
|
||
"hidden_size": self.hidden_size,
|
||
"intermediate_size": self.intermediate_size,
|
||
"total_params": sum(p.numel() for p in self.model.parameters()),
|
||
}
|
||
|
||
|
||
def _estimate_model_memory_gb(config: AutoConfig, dtype: torch.dtype) -> float:
|
||
"""Rough estimate of model weight memory in GB."""
|
||
# Estimate total params from config. For composite configs (VL models),
|
||
# fall through to text_config when top-level attributes are missing.
|
||
text_cfg = getattr(config, "text_config", None)
|
||
hidden = getattr(config, "hidden_size", 0) or (
|
||
getattr(text_cfg, "hidden_size", 0) if text_cfg else 0
|
||
)
|
||
n_layers = getattr(config, "num_hidden_layers", 0) or (
|
||
getattr(text_cfg, "num_hidden_layers", 0) if text_cfg else 0
|
||
)
|
||
intermediate = getattr(config, "intermediate_size", 0) or (
|
||
getattr(text_cfg, "intermediate_size", hidden * 4) if text_cfg else hidden * 4
|
||
)
|
||
vocab = getattr(config, "vocab_size", 0) or (
|
||
getattr(text_cfg, "vocab_size", 0) if text_cfg else 0
|
||
)
|
||
|
||
if hidden == 0 or n_layers == 0:
|
||
return 0.0
|
||
|
||
# For MoE models, the FFN is replicated per expert
|
||
num_experts = getattr(config, "num_local_experts", None) or getattr(config, "num_experts", 1)
|
||
|
||
# Per layer: attn (4 * hidden^2) + ffn (3 * hidden * intermediate * num_experts) + norms
|
||
per_layer = 4 * hidden * hidden + num_experts * 3 * hidden * intermediate
|
||
# Embedding + LM head
|
||
embedding = 2 * vocab * hidden
|
||
total_params = per_layer * n_layers + embedding
|
||
|
||
bytes_per_param = {torch.float32: 4, torch.float16: 2, torch.bfloat16: 2}.get(dtype, 2)
|
||
return total_params * bytes_per_param / (1024 ** 3)
|
||
|
||
|
||
def _available_gpu_memory_gb() -> float:
|
||
"""Return free GPU memory across all CUDA devices, in GB.
|
||
|
||
Uses torch.cuda.mem_get_info which reports actual free memory,
|
||
not total capacity. Falls back to total_memory if mem_get_info
|
||
is unavailable (PyTorch < 1.10).
|
||
"""
|
||
if not torch.cuda.is_available():
|
||
return 0.0
|
||
total_free = 0.0
|
||
for i in range(torch.cuda.device_count()):
|
||
try:
|
||
free, _ = torch.cuda.mem_get_info(i)
|
||
total_free += free / (1024 ** 3)
|
||
except AttributeError:
|
||
# Fallback for old PyTorch without mem_get_info
|
||
props = torch.cuda.get_device_properties(i)
|
||
total_free += props.total_memory / (1024 ** 3)
|
||
return total_free
|
||
|
||
|
||
def _hf_token() -> str | None:
|
||
"""Return the HF_TOKEN from environment, or None."""
|
||
return os.environ.get("HF_TOKEN") or None
|
||
|
||
|
||
def load_model(
|
||
model_name: str,
|
||
task: str = "causal_lm",
|
||
device: str = "auto",
|
||
dtype: str = "float32",
|
||
trust_remote_code: bool = False,
|
||
num_labels: int = 2,
|
||
quantization: str | None = None,
|
||
offload_folder: str | None = None,
|
||
skip_snapshot: bool | None = None,
|
||
) -> ModelHandle:
|
||
"""Load a HuggingFace model and tokenizer, returning a ModelHandle.
|
||
|
||
Args:
|
||
model_name: HuggingFace model identifier (e.g. "gpt2", "meta-llama/Llama-2-7b-hf").
|
||
task: One of "causal_lm", "classification".
|
||
device: Torch device string. "auto" uses accelerate's device_map.
|
||
dtype: Weight dtype — "float32", "float16", "bfloat16".
|
||
trust_remote_code: Whether to trust remote code from the Hub.
|
||
num_labels: Number of labels for classification tasks.
|
||
quantization: None, "4bit", or "8bit". Requires bitsandbytes.
|
||
offload_folder: Directory for disk offloading when model exceeds GPU memory.
|
||
If None and offloading is needed, a temp directory is created automatically.
|
||
skip_snapshot: Controls initial state dict snapshot.
|
||
None (default): auto-decide based on GPU memory headroom.
|
||
True: always skip (saves memory).
|
||
False: always snapshot (force even for large models).
|
||
"""
|
||
_apply_deferred_shims()
|
||
|
||
if task not in TASK_MODEL_MAP:
|
||
raise ValueError(f"Unknown task {task!r}. Choose from {list(TASK_MODEL_MAP)}")
|
||
|
||
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
|
||
if dtype not in dtype_map:
|
||
raise ValueError(f"Unknown dtype {dtype!r}. Choose from {list(dtype_map)}")
|
||
torch_dtype = dtype_map[dtype]
|
||
|
||
token = _hf_token()
|
||
|
||
try:
|
||
config = AutoConfig.from_pretrained(
|
||
model_name, trust_remote_code=trust_remote_code, token=token,
|
||
)
|
||
except PermissionError:
|
||
fallback_cache = os.path.join(tempfile.gettempdir(), "hf_home", "hub")
|
||
os.makedirs(fallback_cache, exist_ok=True)
|
||
config = AutoConfig.from_pretrained(
|
||
model_name, trust_remote_code=trust_remote_code, cache_dir=fallback_cache,
|
||
token=token,
|
||
)
|
||
except OSError as e:
|
||
# Gated repo access denied — provide a clear, actionable error.
|
||
err_msg = str(e)
|
||
if "gated repo" in err_msg.lower() or "access to model" in err_msg.lower():
|
||
raise RuntimeError(
|
||
f"Access denied for gated model '{model_name}'.\n\n"
|
||
f"This model requires you to:\n"
|
||
f" 1. Accept the license at https://huggingface.co/{model_name}\n"
|
||
f" 2. Set your HF_TOKEN: export HF_TOKEN=hf_...\n"
|
||
f" (or add it to your HF Space secrets)\n\n"
|
||
f"Token {'is' if token else 'is NOT'} currently set."
|
||
) from e
|
||
raise
|
||
except (ValueError, KeyError) as e:
|
||
# Unrecognized model_type — don't silently escalate trust_remote_code.
|
||
# Provide a clear error with guidance instead.
|
||
raise RuntimeError(
|
||
f"Architecture '{model_name}' is not recognized by transformers "
|
||
f"{__import__('transformers').__version__}. "
|
||
f"Try: pip install --upgrade transformers\n"
|
||
f"If this model requires custom code, pass trust_remote_code=True explicitly."
|
||
) from e
|
||
|
||
# Memory estimation and warnings (skip for natively quantized models — estimate is wrong)
|
||
native_quant = getattr(config, "quantization_config", None)
|
||
est_gb = _estimate_model_memory_gb(config, torch_dtype) if native_quant is None else 0.0
|
||
gpu_gb = _available_gpu_memory_gb()
|
||
if est_gb > 0 and gpu_gb > 0:
|
||
logger.info(f"Estimated model size: {est_gb:.1f} GB | Available GPU: {gpu_gb:.1f} GB")
|
||
if est_gb > gpu_gb * 0.9 and quantization is None:
|
||
logger.warning(
|
||
f"Model (~{est_gb:.0f} GB) may exceed GPU memory ({gpu_gb:.0f} GB). "
|
||
f"Consider using quantization='4bit' or quantization='8bit'."
|
||
)
|
||
|
||
model_cls = TASK_MODEL_MAP[task]
|
||
load_kwargs: dict = {
|
||
"pretrained_model_name_or_path": model_name,
|
||
"config": config,
|
||
"torch_dtype": torch_dtype,
|
||
"trust_remote_code": trust_remote_code,
|
||
"token": token,
|
||
}
|
||
if task == "classification":
|
||
config.num_labels = num_labels
|
||
load_kwargs["config"] = config
|
||
|
||
# Quantization support (requires bitsandbytes)
|
||
if native_quant is not None:
|
||
# Model ships with native quantization (e.g. Mxfp4Config) — don't layer BitsAndBytes
|
||
# on top, and don't override its compute dtype with our torch_dtype
|
||
logger.info(
|
||
f"Model has native quantization ({type(native_quant).__name__}), "
|
||
f"skipping BitsAndBytes and using model's native dtype"
|
||
)
|
||
load_kwargs.pop("torch_dtype", None)
|
||
load_kwargs["device_map"] = "auto"
|
||
elif quantization in ("4bit", "8bit"):
|
||
try:
|
||
import bitsandbytes # noqa: F401
|
||
except ImportError:
|
||
raise RuntimeError(
|
||
f"Quantization '{quantization}' requires bitsandbytes: "
|
||
f"pip install -U bitsandbytes>=0.46.1"
|
||
)
|
||
from transformers import BitsAndBytesConfig
|
||
|
||
# Enable fp32 CPU offload so that models too large to fit entirely on
|
||
# GPU (even quantized) can spill to CPU without crashing bitsandbytes.
|
||
# This is critical for frontier MoE models (GLM-5 744B, DeepSeek-V3 685B,
|
||
# Mistral Large 3 675B, etc.) on single-GPU setups.
|
||
if quantization == "4bit":
|
||
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||
load_in_4bit=True,
|
||
bnb_4bit_compute_dtype=torch_dtype,
|
||
bnb_4bit_quant_type="nf4",
|
||
llm_int8_enable_fp32_cpu_offload=True,
|
||
)
|
||
else:
|
||
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||
load_in_8bit=True,
|
||
llm_int8_enable_fp32_cpu_offload=True,
|
||
)
|
||
load_kwargs["device_map"] = "auto"
|
||
elif device == "auto":
|
||
load_kwargs["device_map"] = "auto"
|
||
|
||
# Offload support: provide a folder for disk offloading when GPU memory is insufficient
|
||
_offload_dir = None
|
||
if load_kwargs.get("device_map") == "auto":
|
||
if offload_folder:
|
||
_offload_dir = offload_folder
|
||
load_kwargs["offload_folder"] = offload_folder
|
||
else:
|
||
# Auto-create a temp offload dir so from_pretrained never crashes
|
||
# when Accelerate needs disk offloading
|
||
_offload_dir = tempfile.mkdtemp(prefix="obliteratus_offload_")
|
||
load_kwargs["offload_folder"] = _offload_dir
|
||
logger.info(f"Auto-created offload folder: {_offload_dir}")
|
||
|
||
# Reserve GPU headroom for inference (KV cache, activations, generate()).
|
||
# Without this, device_map="auto" packs 100% of layers onto GPU, leaving
|
||
# no room for forward passes or generation on tight-memory setups.
|
||
if torch.cuda.is_available():
|
||
max_memory = {}
|
||
for i in range(torch.cuda.device_count()):
|
||
total = torch.cuda.get_device_properties(i).total_memory
|
||
# Reserve 15% or 2 GiB (whichever is larger) for inference headroom
|
||
reserve = max(int(total * 0.15), 2 * 1024 ** 3)
|
||
usable = total - reserve
|
||
max_memory[i] = f"{usable // (1024 ** 2)}MiB"
|
||
# Allow overflow to CPU RAM, capped at 85% of physical memory
|
||
# to leave room for the OS, Python runtime, and serialization buffers.
|
||
try:
|
||
import psutil
|
||
cpu_ram_gb = psutil.virtual_memory().total / (1024 ** 3)
|
||
except ImportError:
|
||
try:
|
||
cpu_ram_gb = os.sysconf("SC_PHYS_PAGES") * os.sysconf("SC_PAGE_SIZE") / (1024 ** 3)
|
||
except (AttributeError, ValueError):
|
||
# os.sysconf is unavailable on non-POSIX platforms (Windows)
|
||
cpu_ram_gb = 16.0 # conservative fallback
|
||
cpu_budget_gb = int(cpu_ram_gb * 0.85)
|
||
max_memory["cpu"] = f"{max(cpu_budget_gb, 4)}GiB"
|
||
load_kwargs["max_memory"] = max_memory
|
||
logger.info(
|
||
f"GPU memory budget: {', '.join(f'GPU{k}={v}' for k, v in max_memory.items() if k != 'cpu')}"
|
||
)
|
||
|
||
try:
|
||
model = model_cls.from_pretrained(**load_kwargs)
|
||
except OSError as e:
|
||
err_msg = str(e)
|
||
if "gated repo" in err_msg.lower() or "access to model" in err_msg.lower():
|
||
raise RuntimeError(
|
||
f"Access denied for gated model '{model_name}'.\n\n"
|
||
f"This model requires you to:\n"
|
||
f" 1. Accept the license at https://huggingface.co/{model_name}\n"
|
||
f" 2. Set your HF_TOKEN: export HF_TOKEN=hf_...\n"
|
||
f" (or add it to your HF Space secrets)\n\n"
|
||
f"Token {'is' if token else 'is NOT'} currently set."
|
||
) from e
|
||
raise
|
||
except PermissionError as e:
|
||
# Cache dir (typically ~/.cache/huggingface) is not writable — common in
|
||
# containers running as UID with no home dir. Retry with /tmp cache.
|
||
logger.warning(
|
||
"PermissionError loading model (%s). Retrying with cache_dir=/tmp/hf_home/hub", e
|
||
)
|
||
fallback_cache = os.path.join(tempfile.gettempdir(), "hf_home", "hub")
|
||
os.makedirs(fallback_cache, exist_ok=True)
|
||
load_kwargs["cache_dir"] = fallback_cache
|
||
model = model_cls.from_pretrained(**load_kwargs)
|
||
except (ValueError, KeyError) as e:
|
||
err_msg = str(e)
|
||
if "does not recognize this architecture" in err_msg or "model type" in err_msg:
|
||
model_type = getattr(config, "model_type", "unknown")
|
||
raise RuntimeError(
|
||
f"Model architecture '{model_type}' is not supported by transformers "
|
||
f"{__import__('transformers').__version__}. "
|
||
f"Run: pip install --upgrade transformers\n"
|
||
f"If this model was released very recently, it may require "
|
||
f"pip install git+https://github.com/huggingface/transformers.git"
|
||
) from e
|
||
raise
|
||
|
||
if device not in ("auto",) and quantization is None and native_quant is None:
|
||
model = model.to(device)
|
||
|
||
model.eval()
|
||
|
||
# Free CUDA cache after loading
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
try:
|
||
tokenizer = AutoTokenizer.from_pretrained(
|
||
model_name, trust_remote_code=trust_remote_code, token=token,
|
||
)
|
||
except PermissionError:
|
||
fallback_cache = os.path.join(tempfile.gettempdir(), "hf_home", "hub")
|
||
tokenizer = AutoTokenizer.from_pretrained(
|
||
model_name, trust_remote_code=trust_remote_code, cache_dir=fallback_cache,
|
||
token=token,
|
||
)
|
||
if tokenizer.pad_token is None:
|
||
tokenizer.pad_token = tokenizer.eos_token
|
||
|
||
handle = ModelHandle(
|
||
model=model,
|
||
tokenizer=tokenizer,
|
||
config=config,
|
||
model_name=model_name,
|
||
task=task,
|
||
_offload_dir=_offload_dir,
|
||
)
|
||
|
||
# Skip snapshot for large models to avoid doubling memory usage
|
||
if skip_snapshot is True:
|
||
pass # user explicitly opted out
|
||
elif skip_snapshot is False:
|
||
handle.snapshot() # user explicitly forced snapshot
|
||
else:
|
||
# Auto-decide: skip when GPU is present and memory is tight.
|
||
# For natively quantized models (est_gb==0), check actual GPU usage instead.
|
||
if gpu_gb > 0 and native_quant is not None:
|
||
# Model is pre-quantized but we can't estimate its true size.
|
||
# Check actual free memory after loading — if less than 40% free, skip snapshot.
|
||
free_gb = 0.0
|
||
for i in range(torch.cuda.device_count()):
|
||
free_gb += torch.cuda.mem_get_info(i)[0] / (1024 ** 3)
|
||
if free_gb < gpu_gb * 0.4:
|
||
logger.warning(
|
||
f"Auto-skipping state dict snapshot for natively quantized model "
|
||
f"(free GPU: {free_gb:.1f} GB / {gpu_gb:.1f} GB). "
|
||
f"Use skip_snapshot=False to force."
|
||
)
|
||
else:
|
||
handle.snapshot()
|
||
elif gpu_gb > 0 and est_gb > 0 and est_gb > gpu_gb * 0.5:
|
||
logger.warning(
|
||
f"Auto-skipping state dict snapshot to save memory "
|
||
f"(model ~{est_gb:.0f} GB vs GPU {gpu_gb:.0f} GB). "
|
||
f"Use skip_snapshot=False to force."
|
||
)
|
||
else:
|
||
handle.snapshot()
|
||
|
||
return handle
|