Files
Stella Biderman 51f621d0a2 Save model snapshot to CPU to avoid OOM on multi-GPU setups
The snapshot() deepcopy was cloning tensors on their original GPU
devices, doubling VRAM usage. For a 234GB model sharded across 6
A100-80GB GPUs (~39GB each), this left no room for the copy.

Now snapshot stores tensors on CPU and restore() moves them back
to each parameter's current device.
2026-03-13 17:13:50 -04:00

700 lines
29 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Load HuggingFace models and wrap them for ablation."""
from __future__ import annotations
import copy
import logging
import os
import tempfile
from dataclasses import dataclass, field
from typing import Optional
import sys as _sys
import torch
from obliteratus import device as dev
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizerBase,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Compat shims for transformers ≥5.0 breaking changes.
#
# Many HuggingFace model repos ship custom modeling code (loaded via
# trust_remote_code=True) that imports symbols from their pre-5.x locations.
# We monkey-patch the old module paths so loading works without downgrading.
#
# Every section is wrapped in try/except so a failure in one shim never
# breaks unrelated functionality. Patches are purely additive — we never
# remove attributes that already exist.
# ---------------------------------------------------------------------------
# ── 1. utils.generic → utils.output_capturing ──────────────────────
# OutputRecorder, check_model_inputs, _CAN_RECORD_REGISTRY moved.
# Affected: MiniMax-M2.x, DeepSeek-V3
try:
import transformers.utils.generic as _tfu_generic
try:
from transformers.utils import output_capturing as _oc
for _old, _new in [
("OutputRecorder", "OutputRecorder"),
("check_model_inputs", "capture_outputs"),
("_CAN_RECORD_REGISTRY", "_CAN_RECORD_REGISTRY"),
]:
if not hasattr(_tfu_generic, _old) and hasattr(_oc, _new):
setattr(_tfu_generic, _old, getattr(_oc, _new))
except ImportError:
pass
except Exception:
pass
# ── 2. utils.generic.working_or_temp_dir ───────────────────────────
# Removed in 5.x. Trivial contextmanager replacement.
# Affected: GLM-4 / ChatGLM custom code
try:
import transformers.utils.generic as _tfu_generic # noqa: F811 may already be imported
if not hasattr(_tfu_generic, "working_or_temp_dir"):
import contextlib as _ctxlib
import tempfile as _tmpmod
@_ctxlib.contextmanager
def _working_or_temp_dir(working_dir=None):
if working_dir is not None:
yield working_dir
else:
with _tmpmod.TemporaryDirectory() as tmp:
yield tmp
_tfu_generic.working_or_temp_dir = _working_or_temp_dir
except Exception:
pass
# ── 3. utils.import_utils: removed availability checks ─────────────
# is_torch_fx_available → removed (torch.fx always present in torch≥2.0)
# is_tf_available → removed (TF backend dropped in v5)
# is_flax_available → removed (Flax backend dropped in v5)
# is_safetensors_available→ removed (safetensors is now mandatory)
# Affected: various model repos that defensively check backends
try:
import transformers.utils.import_utils as _tfu_imports
_import_shims = {
"is_torch_fx_available": lambda: True,
"is_tf_available": lambda: False,
"is_flax_available": lambda: False,
"is_safetensors_available": lambda: True,
}
for _name, _fn in _import_shims.items():
if not hasattr(_tfu_imports, _name):
setattr(_tfu_imports, _name, _fn)
# Also patch the top-level transformers.utils re-export so both
# ``from transformers.utils import is_tf_available`` and
# ``from transformers.utils.import_utils import is_tf_available`` work.
try:
import transformers.utils as _tu
for _name, _fn in _import_shims.items():
if not hasattr(_tu, _name):
setattr(_tu, _name, _fn)
except Exception:
pass
except Exception:
pass
# ── 4. pytorch_utils: removed version-check constants ──────────────
# ``is_torch_greater_or_equal_than_X_Y`` constants removed in v4.48+.
# Affected: DeepSeek-V3/R1/V2-Lite, MiniCPM3, older custom code
try:
import transformers.pytorch_utils as _pt_utils
# transformers ≥5.0 requires torch ≥2.0, so every historical gate is True.
for _ver in [
"is_torch_greater_or_equal_than_2_4",
"is_torch_greater_or_equal_than_2_3",
"is_torch_greater_or_equal_than_2_2",
"is_torch_greater_or_equal_than_2_1",
"is_torch_greater_or_equal_than_2_0",
"is_torch_greater_or_equal_than_1_13",
"is_torch_greater_or_equal_than_1_12",
"is_torch_greater_or_equal_than_1_11",
"is_torch_greater_or_equal_than_1_10",
"is_torch_greater_or_equal_than_1_9",
"is_torch_greater_or_equal_than_1_8",
"is_torch_greater_or_equal_than_1_6",
]:
if not hasattr(_pt_utils, _ver):
setattr(_pt_utils, _ver, True)
except Exception:
pass
# ── 5. generation_utils module → transformers.generation ────────────
# Entire module removed; old custom code does
# ``from transformers.generation_utils import GenerationMixin``
# Affected: older generation-customising model repos
try:
import transformers.generation_utils # noqa: F401 already exists
except ModuleNotFoundError:
try:
import transformers.generation as _gen
_sys.modules["transformers.generation_utils"] = _gen
except Exception:
pass
# ── 6. deepspeed module → transformers.integrations.deepspeed ───────
# Affected: model repos with DeepSpeed training code
try:
import transformers.deepspeed # noqa: F401 already exists
except ModuleNotFoundError:
try:
import transformers.integrations.deepspeed as _ds
_sys.modules["transformers.deepspeed"] = _ds
except Exception:
pass
# ── 7. DynamicCache.get_max_length → get_max_cache_shape ───────────
# Removed in v4.49+. DeepSeek-V3/R1 custom code calls .get_max_length().
try:
from transformers.cache_utils import DynamicCache as _DC
if not hasattr(_DC, "get_max_length") and hasattr(_DC, "get_max_cache_shape"):
_DC.get_max_length = _DC.get_max_cache_shape
except Exception:
pass
# ── 8. LogitsWarper → LogitsProcessor ──────────────────────────────
# LogitsWarper removed in v5.0 (deprecated v4.46). Drop-in alias.
# Affected: MiniCPM-o custom code
# NOTE: submodule patch runs here; top-level ``transformers.LogitsWarper``
# is deferred to _apply_deferred_shims() because the _LazyModule may reset
# its __dict__ during initial import.
try:
import transformers.generation.logits_process as _lp_mod
if not hasattr(_lp_mod, "LogitsWarper"):
from transformers.generation.logits_process import LogitsProcessor as _LP
_lp_mod.LogitsWarper = _LP
except Exception:
pass
# ── 9. processing_utils._validate_images_text_input_order ──────────
# Removed in v5.0rc3. Kimi-VL custom code imports it.
try:
import transformers.processing_utils as _proc
if not hasattr(_proc, "_validate_images_text_input_order"):
def _validate_images_text_input_order(images=None, text=None, **kw):
return images, text
_proc._validate_images_text_input_order = _validate_images_text_input_order
except Exception:
pass
# ── 10. TF/Flax weight constants (removed with TF backend) ─────────
try:
import transformers.utils as _tu # noqa: F811
for _cname, _cval in [
("TF_WEIGHTS_NAME", "tf_model.h5"),
("TF2_WEIGHTS_NAME", "tf_model.h5"),
]:
if not hasattr(_tu, _cname):
setattr(_tu, _cname, _cval)
except Exception:
pass
# ── 11. file_utils.cached_path → huggingface_hub fallback ──────────
# Removed in v4.22. Very old model repos use it for file download.
try:
import transformers.file_utils as _fu
if not hasattr(_fu, "cached_path"):
def _cached_path_shim(url_or_filename, cache_dir=None, **kwargs):
"""Minimal shim: local paths pass through, HF paths download."""
if os.path.exists(str(url_or_filename)):
return str(url_or_filename)
try:
from huggingface_hub import hf_hub_download
parts = str(url_or_filename).rsplit("/", 1)
if len(parts) == 2:
return hf_hub_download(repo_id=parts[0], filename=parts[1],
cache_dir=cache_dir)
except Exception:
pass
return str(url_or_filename)
_fu.cached_path = _cached_path_shim
except Exception:
pass
# ── Deferred shims ──────────────────────────────────────────────────
# Some patches must wait until the _LazyModule has fully initialized
# (it replaces its __dict__ during bootstrap). We apply these once,
# lazily, the first time load_model() is called.
_DEFERRED_SHIMS_APPLIED = False
def _apply_deferred_shims():
global _DEFERRED_SHIMS_APPLIED
if _DEFERRED_SHIMS_APPLIED:
return
_DEFERRED_SHIMS_APPLIED = True
tf_mod = _sys.modules.get("transformers")
if tf_mod is None:
return
# LogitsWarper → LogitsProcessor on the top-level transformers namespace
try:
if not hasattr(tf_mod, "LogitsWarper"):
from transformers.generation.logits_process import LogitsProcessor
tf_mod.__dict__["LogitsWarper"] = LogitsProcessor
if hasattr(tf_mod, "_objects"):
tf_mod._objects["LogitsWarper"] = LogitsProcessor
except Exception:
pass
# is_tf_available / is_flax_available / is_safetensors_available
# on the top-level namespace (complements shim 3 which patches submodules)
try:
for name, val in [
("is_tf_available", lambda: False),
("is_flax_available", lambda: False),
("is_safetensors_available", lambda: True),
]:
if not hasattr(tf_mod, name):
tf_mod.__dict__[name] = val
if hasattr(tf_mod, "_objects"):
tf_mod._objects[name] = val
except Exception:
pass
TASK_MODEL_MAP = {
"causal_lm": AutoModelForCausalLM,
"classification": AutoModelForSequenceClassification,
}
@dataclass
class ModelHandle:
"""Wrapper around a HF model + tokenizer with metadata useful for ablation."""
model: PreTrainedModel
tokenizer: PreTrainedTokenizerBase
config: AutoConfig
model_name: str
task: str
architecture: str = ""
num_layers: int = 0
num_heads: int = 0
hidden_size: int = 0
intermediate_size: int = 0
_original_state: Optional[dict] = field(default=None, repr=False)
_offload_dir: Optional[str] = field(default=None, repr=False)
def __post_init__(self):
cfg = self.config
self.architecture = cfg.model_type
# For composite configs (e.g. VL models like Qwen3.5), the text model
# attributes live under a nested text_config. Fall through to it when
# the top-level config doesn't have the standard attributes.
text_cfg = getattr(cfg, "text_config", None)
self.num_layers = getattr(cfg, "num_hidden_layers", 0) or (
getattr(text_cfg, "num_hidden_layers", 0) if text_cfg else 0
)
self.num_heads = getattr(cfg, "num_attention_heads", 0) or (
getattr(text_cfg, "num_attention_heads", 0) if text_cfg else 0
)
self.hidden_size = getattr(cfg, "hidden_size", 0) or (
getattr(text_cfg, "hidden_size", 0) if text_cfg else 0
)
self.intermediate_size = getattr(cfg, "intermediate_size", 0) or (
getattr(text_cfg, "intermediate_size", 0) if text_cfg else 0
)
def snapshot(self):
"""Save a copy of the model state dict so we can restore after ablation.
Tensors are moved to CPU to avoid doubling GPU memory usage on
multi-GPU (device_map) setups.
"""
self._original_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
def restore(self):
"""Restore the model to the snapshot state.
Moves CPU-saved tensors back to each parameter's current device.
"""
if self._original_state is None:
raise RuntimeError("No snapshot to restore — call .snapshot() first.")
# Map each key to the device where the model currently holds it
current_state = self.model.state_dict()
restored = {}
for k, v in self._original_state.items():
target = current_state[k].device if k in current_state else None
restored[k] = v.to(target) if target is not None else v
self.model.load_state_dict(restored)
def cleanup(self):
"""Remove temporary offload directory if one was auto-created."""
if self._offload_dir is not None:
import shutil
try:
shutil.rmtree(self._offload_dir, ignore_errors=True)
except Exception:
pass
self._offload_dir = None
def __del__(self):
self.cleanup()
def summary(self) -> dict:
return {
"model_name": self.model_name,
"architecture": self.architecture,
"task": self.task,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_size": self.hidden_size,
"intermediate_size": self.intermediate_size,
"total_params": sum(p.numel() for p in self.model.parameters()),
}
def _estimate_model_memory_gb(config: AutoConfig, dtype: torch.dtype) -> float:
"""Rough estimate of model weight memory in GB."""
# Estimate total params from config. For composite configs (VL models),
# fall through to text_config when top-level attributes are missing.
text_cfg = getattr(config, "text_config", None)
hidden = getattr(config, "hidden_size", 0) or (
getattr(text_cfg, "hidden_size", 0) if text_cfg else 0
)
n_layers = getattr(config, "num_hidden_layers", 0) or (
getattr(text_cfg, "num_hidden_layers", 0) if text_cfg else 0
)
intermediate = getattr(config, "intermediate_size", 0) or (
getattr(text_cfg, "intermediate_size", hidden * 4) if text_cfg else hidden * 4
)
vocab = getattr(config, "vocab_size", 0) or (
getattr(text_cfg, "vocab_size", 0) if text_cfg else 0
)
if hidden == 0 or n_layers == 0:
return 0.0
# For MoE models, the FFN is replicated per expert
num_experts = getattr(config, "num_local_experts", None) or getattr(config, "num_experts", 1)
# Per layer: attn (4 * hidden^2) + ffn (3 * hidden * intermediate * num_experts) + norms
per_layer = 4 * hidden * hidden + num_experts * 3 * hidden * intermediate
# Embedding + LM head
embedding = 2 * vocab * hidden
total_params = per_layer * n_layers + embedding
bytes_per_param = {torch.float32: 4, torch.float16: 2, torch.bfloat16: 2}.get(dtype, 2)
return total_params * bytes_per_param / (1024 ** 3)
def _available_gpu_memory_gb() -> float:
"""Return free accelerator memory in GB (CUDA, MPS, or 0 for CPU)."""
return dev.get_total_free_gb()
def _hf_token() -> str | None:
"""Return the HF_TOKEN from environment, or None."""
return os.environ.get("HF_TOKEN") or None
def load_model(
model_name: str,
task: str = "causal_lm",
device: str = "auto",
dtype: str = "float32",
trust_remote_code: bool = False,
num_labels: int = 2,
quantization: str | None = None,
offload_folder: str | None = None,
skip_snapshot: bool | None = None,
) -> ModelHandle:
"""Load a HuggingFace model and tokenizer, returning a ModelHandle.
Args:
model_name: HuggingFace model identifier (e.g. "gpt2", "meta-llama/Llama-2-7b-hf").
task: One of "causal_lm", "classification".
device: Torch device string. "auto" uses accelerate's device_map.
dtype: Weight dtype — "float32", "float16", "bfloat16".
trust_remote_code: Whether to trust remote code from the Hub.
num_labels: Number of labels for classification tasks.
quantization: None, "4bit", or "8bit". Requires bitsandbytes.
offload_folder: Directory for disk offloading when model exceeds GPU memory.
If None and offloading is needed, a temp directory is created automatically.
skip_snapshot: Controls initial state dict snapshot.
None (default): auto-decide based on GPU memory headroom.
True: always skip (saves memory).
False: always snapshot (force even for large models).
"""
_apply_deferred_shims()
if task not in TASK_MODEL_MAP:
raise ValueError(f"Unknown task {task!r}. Choose from {list(TASK_MODEL_MAP)}")
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
if dtype not in dtype_map:
raise ValueError(f"Unknown dtype {dtype!r}. Choose from {list(dtype_map)}")
torch_dtype = dtype_map[dtype]
token = _hf_token()
try:
config = AutoConfig.from_pretrained(
model_name, trust_remote_code=trust_remote_code, token=token,
)
except PermissionError:
fallback_cache = os.path.join(tempfile.gettempdir(), "hf_home", "hub")
os.makedirs(fallback_cache, exist_ok=True)
config = AutoConfig.from_pretrained(
model_name, trust_remote_code=trust_remote_code, cache_dir=fallback_cache,
token=token,
)
except OSError as e:
# Gated repo access denied — provide a clear, actionable error.
err_msg = str(e)
if "gated repo" in err_msg.lower() or "access to model" in err_msg.lower():
raise RuntimeError(
f"Access denied for gated model '{model_name}'.\n\n"
f"This model requires you to:\n"
f" 1. Accept the license at https://huggingface.co/{model_name}\n"
f" 2. Set your HF_TOKEN: export HF_TOKEN=hf_...\n"
f" (or add it to your HF Space secrets)\n\n"
f"Token {'is' if token else 'is NOT'} currently set."
) from e
raise
except (ValueError, KeyError) as e:
# Unrecognized model_type — don't silently escalate trust_remote_code.
# Provide a clear error with guidance instead.
raise RuntimeError(
f"Architecture '{model_name}' is not recognized by transformers "
f"{__import__('transformers').__version__}. "
f"Try: pip install --upgrade transformers\n"
f"If this model requires custom code, pass trust_remote_code=True explicitly."
) from e
# Memory estimation and warnings (skip for natively quantized models — estimate is wrong)
native_quant = getattr(config, "quantization_config", None)
est_gb = _estimate_model_memory_gb(config, torch_dtype) if native_quant is None else 0.0
gpu_gb = _available_gpu_memory_gb()
if est_gb > 0 and gpu_gb > 0:
logger.info(f"Estimated model size: {est_gb:.1f} GB | Available GPU: {gpu_gb:.1f} GB")
if est_gb > gpu_gb * 0.9 and quantization is None:
logger.warning(
f"Model (~{est_gb:.0f} GB) may exceed GPU memory ({gpu_gb:.0f} GB). "
f"Consider using quantization='4bit' or quantization='8bit'."
)
model_cls = TASK_MODEL_MAP[task]
load_kwargs: dict = {
"pretrained_model_name_or_path": model_name,
"config": config,
"torch_dtype": torch_dtype,
"trust_remote_code": trust_remote_code,
"token": token,
}
if task == "classification":
config.num_labels = num_labels
load_kwargs["config"] = config
# Quantization support (requires bitsandbytes)
if native_quant is not None:
# Model ships with native quantization (e.g. Mxfp4Config) — don't layer BitsAndBytes
# on top, and don't override its compute dtype with our torch_dtype
logger.info(
f"Model has native quantization ({type(native_quant).__name__}), "
f"skipping BitsAndBytes and using model's native dtype"
)
load_kwargs.pop("torch_dtype", None)
load_kwargs["device_map"] = "auto"
elif quantization in ("4bit", "8bit"):
# BitsAndBytes only works on NVIDIA CUDA GPUs.
resolved_device = dev.get_device(device)
if not dev.supports_bitsandbytes(resolved_device):
logger.warning(
"BitsAndBytes quantization is not supported on %s. "
"Loading in %s instead.",
resolved_device, dtype,
)
# On MPS, load normally to the device; on CPU, fall through.
if resolved_device == "mps":
device = "mps"
# Don't set quantization_config — fall through to normal loading.
else:
try:
import bitsandbytes # noqa: F401
except ImportError:
raise RuntimeError(
f"Quantization '{quantization}' requires bitsandbytes: "
f"pip install -U bitsandbytes>=0.46.1"
)
from transformers import BitsAndBytesConfig
# Enable fp32 CPU offload so that models too large to fit entirely on
# GPU (even quantized) can spill to CPU without crashing bitsandbytes.
# This is critical for frontier MoE models (GLM-5 744B, DeepSeek-V3 685B,
# Mistral Large 3 675B, etc.) on single-GPU setups.
if quantization == "4bit":
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_quant_type="nf4",
llm_int8_enable_fp32_cpu_offload=True,
)
else:
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True,
)
load_kwargs["device_map"] = "auto"
# device_map="auto" is only reliable on CUDA (accelerate doesn't support MPS).
if "device_map" not in load_kwargs and device == "auto":
resolved_device = dev.get_device(device)
if dev.supports_device_map_auto(resolved_device):
load_kwargs["device_map"] = "auto"
else:
# MPS / CPU: load to CPU first, then .to(device) after loading.
pass
# Offload support: provide a folder for disk offloading when GPU memory is insufficient
_offload_dir = None
if load_kwargs.get("device_map") == "auto":
if offload_folder:
_offload_dir = offload_folder
load_kwargs["offload_folder"] = offload_folder
else:
# Auto-create a temp offload dir so from_pretrained never crashes
# when Accelerate needs disk offloading
_offload_dir = tempfile.mkdtemp(prefix="obliteratus_offload_")
load_kwargs["offload_folder"] = _offload_dir
logger.info(f"Auto-created offload folder: {_offload_dir}")
# Reserve GPU headroom for inference (KV cache, activations, generate()).
# Without this, device_map="auto" packs 100% of layers onto GPU, leaving
# no room for forward passes or generation on tight-memory setups.
if dev.is_cuda():
max_memory = {}
for i in range(dev.device_count()):
total = torch.cuda.get_device_properties(i).total_memory
# Reserve 15% or 2 GiB (whichever is larger) for inference headroom
reserve = max(int(total * 0.15), 2 * 1024 ** 3)
usable = total - reserve
max_memory[i] = f"{usable // (1024 ** 2)}MiB"
# Allow overflow to CPU RAM, capped at 85% of physical memory
# to leave room for the OS, Python runtime, and serialization buffers.
total_ram, _ = dev._system_memory_gb()
cpu_budget_gb = int(total_ram * 0.85)
max_memory["cpu"] = f"{max(cpu_budget_gb, 4)}GiB"
load_kwargs["max_memory"] = max_memory
logger.info(
f"GPU memory budget: {', '.join(f'GPU{k}={v}' for k, v in max_memory.items() if k != 'cpu')}"
)
try:
model = model_cls.from_pretrained(**load_kwargs)
except OSError as e:
err_msg = str(e)
if "gated repo" in err_msg.lower() or "access to model" in err_msg.lower():
raise RuntimeError(
f"Access denied for gated model '{model_name}'.\n\n"
f"This model requires you to:\n"
f" 1. Accept the license at https://huggingface.co/{model_name}\n"
f" 2. Set your HF_TOKEN: export HF_TOKEN=hf_...\n"
f" (or add it to your HF Space secrets)\n\n"
f"Token {'is' if token else 'is NOT'} currently set."
) from e
raise
except PermissionError as e:
# Cache dir (typically ~/.cache/huggingface) is not writable — common in
# containers running as UID with no home dir. Retry with /tmp cache.
logger.warning(
"PermissionError loading model (%s). Retrying with cache_dir=/tmp/hf_home/hub", e
)
fallback_cache = os.path.join(tempfile.gettempdir(), "hf_home", "hub")
os.makedirs(fallback_cache, exist_ok=True)
load_kwargs["cache_dir"] = fallback_cache
model = model_cls.from_pretrained(**load_kwargs)
except (ValueError, KeyError) as e:
err_msg = str(e)
if "does not recognize this architecture" in err_msg or "model type" in err_msg:
model_type = getattr(config, "model_type", "unknown")
raise RuntimeError(
f"Model architecture '{model_type}' is not supported by transformers "
f"{__import__('transformers').__version__}. "
f"Run: pip install --upgrade transformers\n"
f"If this model was released very recently, it may require "
f"pip install git+https://github.com/huggingface/transformers.git"
) from e
raise
if device not in ("auto",) and quantization is None and native_quant is None:
model = model.to(device)
elif device == "auto" and not dev.supports_device_map_auto():
# MPS / CPU: device_map wasn't used, move model to best device.
resolved = dev.get_device()
model = model.to(resolved)
model.eval()
# Free accelerator cache after loading
dev.empty_cache()
try:
tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=trust_remote_code, token=token,
)
except PermissionError:
fallback_cache = os.path.join(tempfile.gettempdir(), "hf_home", "hub")
tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=trust_remote_code, cache_dir=fallback_cache,
token=token,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
handle = ModelHandle(
model=model,
tokenizer=tokenizer,
config=config,
model_name=model_name,
task=task,
_offload_dir=_offload_dir,
)
# Skip snapshot for large models to avoid doubling memory usage
if skip_snapshot is True:
pass # user explicitly opted out
elif skip_snapshot is False:
handle.snapshot() # user explicitly forced snapshot
else:
# Auto-decide: skip when GPU is present and memory is tight.
# For natively quantized models (est_gb==0), check actual GPU usage instead.
if gpu_gb > 0 and native_quant is not None:
# Model is pre-quantized but we can't estimate its true size.
# Check actual free memory after loading — if less than 40% free, skip snapshot.
free_gb = dev.get_total_free_gb()
if free_gb < gpu_gb * 0.4:
logger.warning(
f"Auto-skipping state dict snapshot for natively quantized model "
f"(free GPU: {free_gb:.1f} GB / {gpu_gb:.1f} GB). "
f"Use skip_snapshot=False to force."
)
else:
handle.snapshot()
elif gpu_gb > 0 and est_gb > 0 and est_gb > gpu_gb * 0.5:
logger.warning(
f"Auto-skipping state dict snapshot to save memory "
f"(model ~{est_gb:.0f} GB vs GPU {gpu_gb:.0f} GB). "
f"Use skip_snapshot=False to force."
)
else:
handle.snapshot()
return handle