From cbdb772eb9c0270d434967ff69f3e9f43c97e906 Mon Sep 17 00:00:00 2001 From: Stella Biderman Date: Thu, 12 Mar 2026 12:35:29 -0400 Subject: [PATCH] Add multi-GPU support with --gpus flag Adds --gpus flag to obliterate, run, and tourney commands for controlling which GPUs to use (sets CUDA_VISIBLE_DEVICES). Works both locally and with --remote. Models are automatically split across selected GPUs via accelerate's device_map="auto". Also adds gpus field to remote YAML config. --- examples/remote_gpu_node.yaml | 6 +++++ obliteratus/cli.py | 44 +++++++++++++++++++++++++++++++++++ obliteratus/config.py | 1 + obliteratus/remote.py | 28 ++++++++++++++++++---- 4 files changed, 74 insertions(+), 5 deletions(-) diff --git a/examples/remote_gpu_node.yaml b/examples/remote_gpu_node.yaml index b6b533c..6d6951f 100644 --- a/examples/remote_gpu_node.yaml +++ b/examples/remote_gpu_node.yaml @@ -9,6 +9,11 @@ # # You can also use --remote on any command instead of a YAML section: # obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct --remote root@gpu-node --ssh-key ~/.ssh/id_rsa +# +# Multi-GPU: Models are automatically split across all available GPUs via +# accelerate's device_map="auto". Use --gpus or the gpus: field to select +# specific GPUs: +# obliteratus obliterate model --remote root@gpu-node --gpus 0,1,2,3 model: name: meta-llama/Llama-3.1-8B-Instruct @@ -39,3 +44,4 @@ remote: remote_dir: /tmp/obliteratus_run python: python3 sync_results: true + # gpus: "0,1,2,3" # uncomment to select specific GPUs (default: all) diff --git a/obliteratus/cli.py b/obliteratus/cli.py index b9595ed..aa37c94 100644 --- a/obliteratus/cli.py +++ b/obliteratus/cli.py @@ -22,6 +22,19 @@ _BANNER = r""" """ +def _add_gpu_args(parser): + """Add --gpus flag for multi-GPU control.""" + gpu_group = parser.add_argument_group("GPU selection") + gpu_group.add_argument( + "--gpus", type=str, default=None, metavar="IDS", + help=( + "Comma-separated GPU IDs to use (e.g. '0,1,2,3' or 'all'). " + "Sets CUDA_VISIBLE_DEVICES. By default uses all available GPUs. " + "Models are automatically split across selected GPUs via accelerate." + ), + ) + + def _add_remote_args(parser): """Add --remote execution flags to a subcommand parser.""" remote_group = parser.add_argument_group("remote execution") @@ -51,6 +64,28 @@ def _add_remote_args(parser): ) +def _apply_gpu_selection(args): + """Set CUDA_VISIBLE_DEVICES based on --gpus flag (for local runs only).""" + import os + + gpus = getattr(args, "gpus", None) + if gpus is None or getattr(args, "remote", None): + return # skip for remote runs (handled by remote runner) + + if gpus.lower() == "all": + return # use all GPUs (default behavior) + + # Validate: should be comma-separated integers + try: + gpu_ids = [int(g.strip()) for g in gpus.split(",")] + except ValueError: + console.print(f"[red]Invalid --gpus value: {gpus!r}. Expected comma-separated integers or 'all'.[/]") + raise SystemExit(1) + + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpu_ids) + console.print(f"[dim]Using GPUs: {gpu_ids} (CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']})[/dim]") + + def main(argv: list[str] | None = None): console.print(_BANNER) parser = argparse.ArgumentParser( @@ -69,6 +104,7 @@ def main(argv: list[str] | None = None): default=None, help="Apply a preset (e.g. quick, full, attention, jailbreak, guardrail)", ) + _add_gpu_args(run_parser) _add_remote_args(run_parser) # --- info --- @@ -174,10 +210,12 @@ def main(argv: list[str] | None = None): help="One-click: remove refusal directions from a model (SOTA multi-technique)", ) _add_obliterate_args(abl_parser) + _add_gpu_args(abl_parser) _add_remote_args(abl_parser) # Backward-compat alias (hidden from help) abl_alias = subparsers.add_parser("abliterate", help=argparse.SUPPRESS) _add_obliterate_args(abl_alias) + _add_gpu_args(abl_alias) _add_remote_args(abl_alias) # --- report --- @@ -212,6 +250,7 @@ def main(argv: list[str] | None = None): "--methods", type=str, nargs="+", default=None, help="Override: only run these methods (space-separated)", ) + _add_gpu_args(tourney_parser) _add_remote_args(tourney_parser) # --- recommend --- @@ -229,6 +268,9 @@ def main(argv: list[str] | None = None): args = parser.parse_args(argv) + # Apply GPU selection early (before any CUDA init) + _apply_gpu_selection(args) + if args.command == "run": if getattr(args, "remote", None): _cmd_remote_run(args) @@ -369,6 +411,7 @@ def _cmd_run(args): remote_dir=config.remote.remote_dir, python=config.remote.python, sync_results=config.remote.sync_results, + gpus=config.remote.gpus, ) runner = RemoteRunner(rc) result_path = runner.run_config( @@ -733,6 +776,7 @@ def _make_remote_runner(args): remote_dir=args.remote_dir, python=args.remote_python, sync_results=not args.no_sync, + gpus=getattr(args, "gpus", None), ) return RemoteRunner(rc) diff --git a/obliteratus/config.py b/obliteratus/config.py index 1a31e31..e12d837 100644 --- a/obliteratus/config.py +++ b/obliteratus/config.py @@ -46,6 +46,7 @@ class RemoteConfig: remote_dir: str = "/tmp/obliteratus_run" python: str = "python3" sync_results: bool = True + gpus: str | None = None # comma-separated GPU IDs or "all" @dataclass diff --git a/obliteratus/remote.py b/obliteratus/remote.py index 9659e04..57c7be7 100644 --- a/obliteratus/remote.py +++ b/obliteratus/remote.py @@ -45,6 +45,7 @@ class RemoteConfig: install_timeout: int = 600 # seconds python: str = "python3" # remote python binary sync_results: bool = True + gpus: str | None = None # comma-separated GPU IDs or "all" @property def ssh_target(self) -> str: @@ -151,14 +152,31 @@ class RemoteRunner: def check_gpu(self) -> str | None: """Check for CUDA GPUs on remote. Returns nvidia-smi output or None.""" - result = self.run_ssh("nvidia-smi --query-gpu=name,memory.total --format=csv,noheader", timeout=30) + result = self.run_ssh( + "nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv,noheader", + timeout=30, + ) if isinstance(result, subprocess.CompletedProcess) and result.returncode == 0: gpu_info = result.stdout.strip() - self.on_log(f"Remote GPUs: {gpu_info}") + lines = gpu_info.split("\n") + self.on_log(f"Remote GPUs ({len(lines)} detected):") + for line in lines: + self.on_log(f" {line.strip()}") + if self.config.gpus and self.config.gpus.lower() != "all": + self.on_log(f" Selected GPUs: {self.config.gpus}") + else: + self.on_log(f" Using: all {len(lines)} GPUs") return gpu_info self.on_log("[yellow]No GPUs detected on remote (nvidia-smi failed)[/]") return None + def _env_prefix(self) -> str: + """Build environment variable prefix for remote commands (e.g. CUDA_VISIBLE_DEVICES).""" + parts = [] + if self.config.gpus and self.config.gpus.lower() != "all": + parts.append(f"CUDA_VISIBLE_DEVICES={self.config.gpus}") + return " ".join(parts) + " " if parts else "" + def ensure_obliteratus(self) -> bool: """Install or update obliteratus on the remote if needed.""" # Check if already installed @@ -224,7 +242,7 @@ class RemoteRunner: remote_output = output_dir or f"{self.config.remote_dir}/output/{model.replace('/', '_')}" parts = [ - self.config.python, "-m", "obliteratus", + self._env_prefix() + self.config.python, "-m", "obliteratus", "obliterate", shlex.quote(model), "--output-dir", shlex.quote(remote_output), "--method", method, @@ -251,7 +269,7 @@ class RemoteRunner: def build_run_command(self, remote_config_path: str, output_dir: str | None = None, preset: str | None = None) -> str: """Build remote 'obliteratus run' command.""" parts = [ - self.config.python, "-m", "obliteratus", + self._env_prefix() + self.config.python, "-m", "obliteratus", "run", shlex.quote(remote_config_path), ] if output_dir: @@ -276,7 +294,7 @@ class RemoteRunner: remote_output = output_dir or f"{self.config.remote_dir}/tourney/{model.replace('/', '_')}" parts = [ - self.config.python, "-m", "obliteratus", + self._env_prefix() + self.config.python, "-m", "obliteratus", "tourney", shlex.quote(model), "--output-dir", shlex.quote(remote_output), "--device", device,