mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-06-06 22:34:02 +02:00
Add multi-GPU support with --gpus flag
Adds --gpus flag to obliterate, run, and tourney commands for controlling which GPUs to use (sets CUDA_VISIBLE_DEVICES). Works both locally and with --remote. Models are automatically split across selected GPUs via accelerate's device_map="auto". Also adds gpus field to remote YAML config.
This commit is contained in:
@@ -9,6 +9,11 @@
|
||||
#
|
||||
# You can also use --remote on any command instead of a YAML section:
|
||||
# obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct --remote root@gpu-node --ssh-key ~/.ssh/id_rsa
|
||||
#
|
||||
# Multi-GPU: Models are automatically split across all available GPUs via
|
||||
# accelerate's device_map="auto". Use --gpus or the gpus: field to select
|
||||
# specific GPUs:
|
||||
# obliteratus obliterate model --remote root@gpu-node --gpus 0,1,2,3
|
||||
|
||||
model:
|
||||
name: meta-llama/Llama-3.1-8B-Instruct
|
||||
@@ -39,3 +44,4 @@ remote:
|
||||
remote_dir: /tmp/obliteratus_run
|
||||
python: python3
|
||||
sync_results: true
|
||||
# gpus: "0,1,2,3" # uncomment to select specific GPUs (default: all)
|
||||
|
||||
@@ -22,6 +22,19 @@ _BANNER = r"""
|
||||
"""
|
||||
|
||||
|
||||
def _add_gpu_args(parser):
|
||||
"""Add --gpus flag for multi-GPU control."""
|
||||
gpu_group = parser.add_argument_group("GPU selection")
|
||||
gpu_group.add_argument(
|
||||
"--gpus", type=str, default=None, metavar="IDS",
|
||||
help=(
|
||||
"Comma-separated GPU IDs to use (e.g. '0,1,2,3' or 'all'). "
|
||||
"Sets CUDA_VISIBLE_DEVICES. By default uses all available GPUs. "
|
||||
"Models are automatically split across selected GPUs via accelerate."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _add_remote_args(parser):
|
||||
"""Add --remote execution flags to a subcommand parser."""
|
||||
remote_group = parser.add_argument_group("remote execution")
|
||||
@@ -51,6 +64,28 @@ def _add_remote_args(parser):
|
||||
)
|
||||
|
||||
|
||||
def _apply_gpu_selection(args):
|
||||
"""Set CUDA_VISIBLE_DEVICES based on --gpus flag (for local runs only)."""
|
||||
import os
|
||||
|
||||
gpus = getattr(args, "gpus", None)
|
||||
if gpus is None or getattr(args, "remote", None):
|
||||
return # skip for remote runs (handled by remote runner)
|
||||
|
||||
if gpus.lower() == "all":
|
||||
return # use all GPUs (default behavior)
|
||||
|
||||
# Validate: should be comma-separated integers
|
||||
try:
|
||||
gpu_ids = [int(g.strip()) for g in gpus.split(",")]
|
||||
except ValueError:
|
||||
console.print(f"[red]Invalid --gpus value: {gpus!r}. Expected comma-separated integers or 'all'.[/]")
|
||||
raise SystemExit(1)
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpu_ids)
|
||||
console.print(f"[dim]Using GPUs: {gpu_ids} (CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']})[/dim]")
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None):
|
||||
console.print(_BANNER)
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -69,6 +104,7 @@ def main(argv: list[str] | None = None):
|
||||
default=None,
|
||||
help="Apply a preset (e.g. quick, full, attention, jailbreak, guardrail)",
|
||||
)
|
||||
_add_gpu_args(run_parser)
|
||||
_add_remote_args(run_parser)
|
||||
|
||||
# --- info ---
|
||||
@@ -174,10 +210,12 @@ def main(argv: list[str] | None = None):
|
||||
help="One-click: remove refusal directions from a model (SOTA multi-technique)",
|
||||
)
|
||||
_add_obliterate_args(abl_parser)
|
||||
_add_gpu_args(abl_parser)
|
||||
_add_remote_args(abl_parser)
|
||||
# Backward-compat alias (hidden from help)
|
||||
abl_alias = subparsers.add_parser("abliterate", help=argparse.SUPPRESS)
|
||||
_add_obliterate_args(abl_alias)
|
||||
_add_gpu_args(abl_alias)
|
||||
_add_remote_args(abl_alias)
|
||||
|
||||
# --- report ---
|
||||
@@ -212,6 +250,7 @@ def main(argv: list[str] | None = None):
|
||||
"--methods", type=str, nargs="+", default=None,
|
||||
help="Override: only run these methods (space-separated)",
|
||||
)
|
||||
_add_gpu_args(tourney_parser)
|
||||
_add_remote_args(tourney_parser)
|
||||
|
||||
# --- recommend ---
|
||||
@@ -229,6 +268,9 @@ def main(argv: list[str] | None = None):
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
# Apply GPU selection early (before any CUDA init)
|
||||
_apply_gpu_selection(args)
|
||||
|
||||
if args.command == "run":
|
||||
if getattr(args, "remote", None):
|
||||
_cmd_remote_run(args)
|
||||
@@ -369,6 +411,7 @@ def _cmd_run(args):
|
||||
remote_dir=config.remote.remote_dir,
|
||||
python=config.remote.python,
|
||||
sync_results=config.remote.sync_results,
|
||||
gpus=config.remote.gpus,
|
||||
)
|
||||
runner = RemoteRunner(rc)
|
||||
result_path = runner.run_config(
|
||||
@@ -733,6 +776,7 @@ def _make_remote_runner(args):
|
||||
remote_dir=args.remote_dir,
|
||||
python=args.remote_python,
|
||||
sync_results=not args.no_sync,
|
||||
gpus=getattr(args, "gpus", None),
|
||||
)
|
||||
return RemoteRunner(rc)
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ class RemoteConfig:
|
||||
remote_dir: str = "/tmp/obliteratus_run"
|
||||
python: str = "python3"
|
||||
sync_results: bool = True
|
||||
gpus: str | None = None # comma-separated GPU IDs or "all"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
+23
-5
@@ -45,6 +45,7 @@ class RemoteConfig:
|
||||
install_timeout: int = 600 # seconds
|
||||
python: str = "python3" # remote python binary
|
||||
sync_results: bool = True
|
||||
gpus: str | None = None # comma-separated GPU IDs or "all"
|
||||
|
||||
@property
|
||||
def ssh_target(self) -> str:
|
||||
@@ -151,14 +152,31 @@ class RemoteRunner:
|
||||
|
||||
def check_gpu(self) -> str | None:
|
||||
"""Check for CUDA GPUs on remote. Returns nvidia-smi output or None."""
|
||||
result = self.run_ssh("nvidia-smi --query-gpu=name,memory.total --format=csv,noheader", timeout=30)
|
||||
result = self.run_ssh(
|
||||
"nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv,noheader",
|
||||
timeout=30,
|
||||
)
|
||||
if isinstance(result, subprocess.CompletedProcess) and result.returncode == 0:
|
||||
gpu_info = result.stdout.strip()
|
||||
self.on_log(f"Remote GPUs: {gpu_info}")
|
||||
lines = gpu_info.split("\n")
|
||||
self.on_log(f"Remote GPUs ({len(lines)} detected):")
|
||||
for line in lines:
|
||||
self.on_log(f" {line.strip()}")
|
||||
if self.config.gpus and self.config.gpus.lower() != "all":
|
||||
self.on_log(f" Selected GPUs: {self.config.gpus}")
|
||||
else:
|
||||
self.on_log(f" Using: all {len(lines)} GPUs")
|
||||
return gpu_info
|
||||
self.on_log("[yellow]No GPUs detected on remote (nvidia-smi failed)[/]")
|
||||
return None
|
||||
|
||||
def _env_prefix(self) -> str:
|
||||
"""Build environment variable prefix for remote commands (e.g. CUDA_VISIBLE_DEVICES)."""
|
||||
parts = []
|
||||
if self.config.gpus and self.config.gpus.lower() != "all":
|
||||
parts.append(f"CUDA_VISIBLE_DEVICES={self.config.gpus}")
|
||||
return " ".join(parts) + " " if parts else ""
|
||||
|
||||
def ensure_obliteratus(self) -> bool:
|
||||
"""Install or update obliteratus on the remote if needed."""
|
||||
# Check if already installed
|
||||
@@ -224,7 +242,7 @@ class RemoteRunner:
|
||||
remote_output = output_dir or f"{self.config.remote_dir}/output/{model.replace('/', '_')}"
|
||||
|
||||
parts = [
|
||||
self.config.python, "-m", "obliteratus",
|
||||
self._env_prefix() + self.config.python, "-m", "obliteratus",
|
||||
"obliterate", shlex.quote(model),
|
||||
"--output-dir", shlex.quote(remote_output),
|
||||
"--method", method,
|
||||
@@ -251,7 +269,7 @@ class RemoteRunner:
|
||||
def build_run_command(self, remote_config_path: str, output_dir: str | None = None, preset: str | None = None) -> str:
|
||||
"""Build remote 'obliteratus run' command."""
|
||||
parts = [
|
||||
self.config.python, "-m", "obliteratus",
|
||||
self._env_prefix() + self.config.python, "-m", "obliteratus",
|
||||
"run", shlex.quote(remote_config_path),
|
||||
]
|
||||
if output_dir:
|
||||
@@ -276,7 +294,7 @@ class RemoteRunner:
|
||||
remote_output = output_dir or f"{self.config.remote_dir}/tourney/{model.replace('/', '_')}"
|
||||
|
||||
parts = [
|
||||
self.config.python, "-m", "obliteratus",
|
||||
self._env_prefix() + self.config.python, "-m", "obliteratus",
|
||||
"tourney", shlex.quote(model),
|
||||
"--output-dir", shlex.quote(remote_output),
|
||||
"--device", device,
|
||||
|
||||
Reference in New Issue
Block a user