Add multi-GPU support with --gpus flag

Adds --gpus flag to obliterate, run, and tourney commands for controlling
which GPUs to use (sets CUDA_VISIBLE_DEVICES). Works both locally and with
--remote. Models are automatically split across selected GPUs via
accelerate's device_map="auto". Also adds gpus field to remote YAML config.
This commit is contained in:
Stella Biderman
2026-03-12 12:35:29 -04:00
parent 34032b7821
commit cbdb772eb9
4 changed files with 74 additions and 5 deletions
+6
View File
@@ -9,6 +9,11 @@
#
# You can also use --remote on any command instead of a YAML section:
# obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct --remote root@gpu-node --ssh-key ~/.ssh/id_rsa
#
# Multi-GPU: Models are automatically split across all available GPUs via
# accelerate's device_map="auto". Use --gpus or the gpus: field to select
# specific GPUs:
# obliteratus obliterate model --remote root@gpu-node --gpus 0,1,2,3
model:
name: meta-llama/Llama-3.1-8B-Instruct
@@ -39,3 +44,4 @@ remote:
remote_dir: /tmp/obliteratus_run
python: python3
sync_results: true
# gpus: "0,1,2,3" # uncomment to select specific GPUs (default: all)
+44
View File
@@ -22,6 +22,19 @@ _BANNER = r"""
"""
def _add_gpu_args(parser):
"""Add --gpus flag for multi-GPU control."""
gpu_group = parser.add_argument_group("GPU selection")
gpu_group.add_argument(
"--gpus", type=str, default=None, metavar="IDS",
help=(
"Comma-separated GPU IDs to use (e.g. '0,1,2,3' or 'all'). "
"Sets CUDA_VISIBLE_DEVICES. By default uses all available GPUs. "
"Models are automatically split across selected GPUs via accelerate."
),
)
def _add_remote_args(parser):
"""Add --remote execution flags to a subcommand parser."""
remote_group = parser.add_argument_group("remote execution")
@@ -51,6 +64,28 @@ def _add_remote_args(parser):
)
def _apply_gpu_selection(args):
"""Set CUDA_VISIBLE_DEVICES based on --gpus flag (for local runs only)."""
import os
gpus = getattr(args, "gpus", None)
if gpus is None or getattr(args, "remote", None):
return # skip for remote runs (handled by remote runner)
if gpus.lower() == "all":
return # use all GPUs (default behavior)
# Validate: should be comma-separated integers
try:
gpu_ids = [int(g.strip()) for g in gpus.split(",")]
except ValueError:
console.print(f"[red]Invalid --gpus value: {gpus!r}. Expected comma-separated integers or 'all'.[/]")
raise SystemExit(1)
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpu_ids)
console.print(f"[dim]Using GPUs: {gpu_ids} (CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']})[/dim]")
def main(argv: list[str] | None = None):
console.print(_BANNER)
parser = argparse.ArgumentParser(
@@ -69,6 +104,7 @@ def main(argv: list[str] | None = None):
default=None,
help="Apply a preset (e.g. quick, full, attention, jailbreak, guardrail)",
)
_add_gpu_args(run_parser)
_add_remote_args(run_parser)
# --- info ---
@@ -174,10 +210,12 @@ def main(argv: list[str] | None = None):
help="One-click: remove refusal directions from a model (SOTA multi-technique)",
)
_add_obliterate_args(abl_parser)
_add_gpu_args(abl_parser)
_add_remote_args(abl_parser)
# Backward-compat alias (hidden from help)
abl_alias = subparsers.add_parser("abliterate", help=argparse.SUPPRESS)
_add_obliterate_args(abl_alias)
_add_gpu_args(abl_alias)
_add_remote_args(abl_alias)
# --- report ---
@@ -212,6 +250,7 @@ def main(argv: list[str] | None = None):
"--methods", type=str, nargs="+", default=None,
help="Override: only run these methods (space-separated)",
)
_add_gpu_args(tourney_parser)
_add_remote_args(tourney_parser)
# --- recommend ---
@@ -229,6 +268,9 @@ def main(argv: list[str] | None = None):
args = parser.parse_args(argv)
# Apply GPU selection early (before any CUDA init)
_apply_gpu_selection(args)
if args.command == "run":
if getattr(args, "remote", None):
_cmd_remote_run(args)
@@ -369,6 +411,7 @@ def _cmd_run(args):
remote_dir=config.remote.remote_dir,
python=config.remote.python,
sync_results=config.remote.sync_results,
gpus=config.remote.gpus,
)
runner = RemoteRunner(rc)
result_path = runner.run_config(
@@ -733,6 +776,7 @@ def _make_remote_runner(args):
remote_dir=args.remote_dir,
python=args.remote_python,
sync_results=not args.no_sync,
gpus=getattr(args, "gpus", None),
)
return RemoteRunner(rc)
+1
View File
@@ -46,6 +46,7 @@ class RemoteConfig:
remote_dir: str = "/tmp/obliteratus_run"
python: str = "python3"
sync_results: bool = True
gpus: str | None = None # comma-separated GPU IDs or "all"
@dataclass
+23 -5
View File
@@ -45,6 +45,7 @@ class RemoteConfig:
install_timeout: int = 600 # seconds
python: str = "python3" # remote python binary
sync_results: bool = True
gpus: str | None = None # comma-separated GPU IDs or "all"
@property
def ssh_target(self) -> str:
@@ -151,14 +152,31 @@ class RemoteRunner:
def check_gpu(self) -> str | None:
"""Check for CUDA GPUs on remote. Returns nvidia-smi output or None."""
result = self.run_ssh("nvidia-smi --query-gpu=name,memory.total --format=csv,noheader", timeout=30)
result = self.run_ssh(
"nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv,noheader",
timeout=30,
)
if isinstance(result, subprocess.CompletedProcess) and result.returncode == 0:
gpu_info = result.stdout.strip()
self.on_log(f"Remote GPUs: {gpu_info}")
lines = gpu_info.split("\n")
self.on_log(f"Remote GPUs ({len(lines)} detected):")
for line in lines:
self.on_log(f" {line.strip()}")
if self.config.gpus and self.config.gpus.lower() != "all":
self.on_log(f" Selected GPUs: {self.config.gpus}")
else:
self.on_log(f" Using: all {len(lines)} GPUs")
return gpu_info
self.on_log("[yellow]No GPUs detected on remote (nvidia-smi failed)[/]")
return None
def _env_prefix(self) -> str:
"""Build environment variable prefix for remote commands (e.g. CUDA_VISIBLE_DEVICES)."""
parts = []
if self.config.gpus and self.config.gpus.lower() != "all":
parts.append(f"CUDA_VISIBLE_DEVICES={self.config.gpus}")
return " ".join(parts) + " " if parts else ""
def ensure_obliteratus(self) -> bool:
"""Install or update obliteratus on the remote if needed."""
# Check if already installed
@@ -224,7 +242,7 @@ class RemoteRunner:
remote_output = output_dir or f"{self.config.remote_dir}/output/{model.replace('/', '_')}"
parts = [
self.config.python, "-m", "obliteratus",
self._env_prefix() + self.config.python, "-m", "obliteratus",
"obliterate", shlex.quote(model),
"--output-dir", shlex.quote(remote_output),
"--method", method,
@@ -251,7 +269,7 @@ class RemoteRunner:
def build_run_command(self, remote_config_path: str, output_dir: str | None = None, preset: str | None = None) -> str:
"""Build remote 'obliteratus run' command."""
parts = [
self.config.python, "-m", "obliteratus",
self._env_prefix() + self.config.python, "-m", "obliteratus",
"run", shlex.quote(remote_config_path),
]
if output_dir:
@@ -276,7 +294,7 @@ class RemoteRunner:
remote_output = output_dir or f"{self.config.remote_dir}/tourney/{model.replace('/', '_')}"
parts = [
self.config.python, "-m", "obliteratus",
self._env_prefix() + self.config.python, "-m", "obliteratus",
"tourney", shlex.quote(model),
"--output-dir", shlex.quote(remote_output),
"--device", device,