From 34032b78214929f0ecd2ea5ae34914cdeb061e0b Mon Sep 17 00:00:00 2001 From: Stella Biderman Date: Wed, 11 Mar 2026 15:51:16 -0400 Subject: [PATCH] Add remote SSH execution support for GPU nodes Adds --remote [user@]host flag to obliterate, run, and tourney commands, enabling execution on remote GPU nodes via SSH. Also supports a remote: section in YAML configs. The remote runner handles SSH connectivity checks, GPU detection, auto-installation of obliteratus, log streaming, and result syncing back to the local machine via scp. --- examples/remote_gpu_node.yaml | 41 ++++ obliteratus/__init__.py | 8 + obliteratus/cli.py | 181 ++++++++++++++- obliteratus/config.py | 19 ++ obliteratus/remote.py | 417 ++++++++++++++++++++++++++++++++++ 5 files changed, 663 insertions(+), 3 deletions(-) create mode 100644 examples/remote_gpu_node.yaml create mode 100644 obliteratus/remote.py diff --git a/examples/remote_gpu_node.yaml b/examples/remote_gpu_node.yaml new file mode 100644 index 0000000..b6b533c --- /dev/null +++ b/examples/remote_gpu_node.yaml @@ -0,0 +1,41 @@ +# Example: Run an ablation study on a remote GPU node via SSH. +# +# Usage: +# obliteratus run examples/remote_gpu_node.yaml +# +# The 'remote' section tells Obliteratus to SSH into the specified host, +# install obliteratus if needed, run the pipeline there, and copy results +# back to the local machine. +# +# You can also use --remote on any command instead of a YAML section: +# obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct --remote root@gpu-node --ssh-key ~/.ssh/id_rsa + +model: + name: meta-llama/Llama-3.1-8B-Instruct + task: causal_lm + dtype: float16 + device: auto + +dataset: + name: wikitext + split: test + max_samples: 500 + +strategies: + - name: layer_removal + params: + layer_indices: [10, 11, 12] + +metrics: [perplexity] +batch_size: 8 +max_length: 512 +output_dir: results/remote_example + +remote: + host: gpu-node.example.com + user: root + port: 22 + ssh_key: ~/.ssh/id_rsa + remote_dir: /tmp/obliteratus_run + python: python3 + sync_results: true diff --git a/obliteratus/__init__.py b/obliteratus/__init__.py index 7f70058..2039325 100644 --- a/obliteratus/__init__.py +++ b/obliteratus/__init__.py @@ -17,6 +17,8 @@ __all__ = [ "TourneyResult", "get_adaptive_recommendation", "AdaptiveRecommendation", + "RemoteRunner", + "RemoteConfig", ] @@ -60,4 +62,10 @@ def __getattr__(name): if name == "AdaptiveRecommendation": from obliteratus.adaptive_defaults import AdaptiveRecommendation return AdaptiveRecommendation + if name == "RemoteRunner": + from obliteratus.remote import RemoteRunner + return RemoteRunner + if name == "RemoteConfig": + from obliteratus.remote import RemoteConfig + return RemoteConfig raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/obliteratus/cli.py b/obliteratus/cli.py index e32fb25..b9595ed 100644 --- a/obliteratus/cli.py +++ b/obliteratus/cli.py @@ -22,6 +22,35 @@ _BANNER = r""" """ +def _add_remote_args(parser): + """Add --remote execution flags to a subcommand parser.""" + remote_group = parser.add_argument_group("remote execution") + remote_group.add_argument( + "--remote", type=str, default=None, metavar="[USER@]HOST", + help="Run on a remote GPU node via SSH (e.g. root@gpu-node or just gpu-node)", + ) + remote_group.add_argument( + "--ssh-key", type=str, default=None, + help="Path to SSH private key (default: use SSH agent or ~/.ssh/id_rsa)", + ) + remote_group.add_argument( + "--ssh-port", type=int, default=22, + help="SSH port on remote host (default: 22)", + ) + remote_group.add_argument( + "--remote-dir", type=str, default="/tmp/obliteratus_run", + help="Working directory on the remote machine (default: /tmp/obliteratus_run)", + ) + remote_group.add_argument( + "--remote-python", type=str, default="python3", + help="Python binary on the remote machine (default: python3)", + ) + remote_group.add_argument( + "--no-sync", action="store_true", default=False, + help="Don't copy results back to local machine after remote run", + ) + + def main(argv: list[str] | None = None): console.print(_BANNER) parser = argparse.ArgumentParser( @@ -40,6 +69,7 @@ def main(argv: list[str] | None = None): default=None, help="Apply a preset (e.g. quick, full, attention, jailbreak, guardrail)", ) + _add_remote_args(run_parser) # --- info --- info_parser = subparsers.add_parser("info", help="Print model architecture info") @@ -144,9 +174,11 @@ def main(argv: list[str] | None = None): help="One-click: remove refusal directions from a model (SOTA multi-technique)", ) _add_obliterate_args(abl_parser) + _add_remote_args(abl_parser) # Backward-compat alias (hidden from help) abl_alias = subparsers.add_parser("abliterate", help=argparse.SUPPRESS) _add_obliterate_args(abl_alias) + _add_remote_args(abl_alias) # --- report --- report_parser = subparsers.add_parser("report", help="Regenerate report from saved results") @@ -180,6 +212,7 @@ def main(argv: list[str] | None = None): "--methods", type=str, nargs="+", default=None, help="Override: only run these methods (space-separated)", ) + _add_remote_args(tourney_parser) # --- recommend --- recommend_parser = subparsers.add_parser( @@ -197,7 +230,10 @@ def main(argv: list[str] | None = None): args = parser.parse_args(argv) if args.command == "run": - _cmd_run(args) + if getattr(args, "remote", None): + _cmd_remote_run(args) + else: + _cmd_run(args) elif args.command == "interactive": _cmd_interactive() elif args.command == "models": @@ -217,9 +253,15 @@ def main(argv: list[str] | None = None): elif args.command == "recommend": _cmd_recommend(args) elif args.command == "tourney": - _cmd_tourney(args) + if getattr(args, "remote", None): + _cmd_remote_tourney(args) + else: + _cmd_tourney(args) elif args.command in ("obliterate", "abliterate"): - _cmd_abliterate(args) + if getattr(args, "remote", None): + _cmd_remote_abliterate(args) + else: + _cmd_abliterate(args) def _cmd_ui(args): @@ -314,6 +356,33 @@ def _cmd_run(args): config = StudyConfig.from_dict(raw) if args.output_dir: config.output_dir = args.output_dir + + # If YAML has a remote: section, dispatch to remote runner + if config.remote is not None: + from obliteratus.remote import RemoteConfig as _RC, RemoteRunner + + rc = _RC( + host=config.remote.host, + user=config.remote.user, + port=config.remote.port, + ssh_key=config.remote.ssh_key, + remote_dir=config.remote.remote_dir, + python=config.remote.python, + sync_results=config.remote.sync_results, + ) + runner = RemoteRunner(rc) + result_path = runner.run_config( + local_config_path=args.config, + local_output_dir=config.output_dir, + preset=args.preset, + ) + if result_path: + console.print(f"\n[bold green]Remote run complete.[/] Results at: [cyan]{result_path}[/]") + else: + console.print("[red]Remote run failed. Check logs above.[/]") + raise SystemExit(1) + return + run_study(config) @@ -653,5 +722,111 @@ def _cmd_abliterate(args): ) +def _make_remote_runner(args): + """Create a RemoteRunner from CLI --remote flags.""" + from obliteratus.remote import RemoteConfig, RemoteRunner + + rc = RemoteConfig.from_cli_args( + args.remote, + port=args.ssh_port, + ssh_key=args.ssh_key, + remote_dir=args.remote_dir, + python=args.remote_python, + sync_results=not args.no_sync, + ) + return RemoteRunner(rc) + + +def _cmd_remote_abliterate(args): + from rich.panel import Panel + + runner = _make_remote_runner(args) + + kwargs = {} + if args.method: + kwargs["method"] = args.method + if args.device: + kwargs["device"] = args.device + if args.dtype: + kwargs["dtype"] = args.dtype + if args.quantization: + kwargs["quantization"] = args.quantization + if args.n_directions is not None: + kwargs["n_directions"] = args.n_directions + if getattr(args, "direction_method", None): + kwargs["direction_method"] = args.direction_method + if args.regularization is not None: + kwargs["regularization"] = args.regularization + if args.refinement_passes is not None: + kwargs["refinement_passes"] = args.refinement_passes + if getattr(args, "large_model", False): + kwargs["large_model"] = True + if getattr(args, "verify_sample_size", None) is not None: + kwargs["verify_sample_size"] = args.verify_sample_size + + result_path = runner.run_obliterate( + model=args.model, + local_output_dir=args.output_dir, + **kwargs, + ) + + if result_path: + console.print( + Panel( + f"[bold green]Remote abliteration complete![/]\n\n" + f" Results at: [cyan]{result_path}[/]\n\n" + f" [dim]Load with:[/] AutoModelForCausalLM.from_pretrained('{result_path}')", + border_style="green", + title="[bold green]REBIRTH COMPLETE (remote)[/]", + ) + ) + else: + console.print("[red]Remote abliteration failed. Check logs above.[/]") + raise SystemExit(1) + + +def _cmd_remote_run(args): + runner = _make_remote_runner(args) + result_path = runner.run_config( + local_config_path=args.config, + local_output_dir=args.output_dir, + preset=args.preset, + ) + if result_path: + console.print(f"\n[bold green]Remote run complete.[/] Results at: [cyan]{result_path}[/]") + else: + console.print("[red]Remote run failed. Check logs above.[/]") + raise SystemExit(1) + + +def _cmd_remote_tourney(args): + from rich.panel import Panel + + runner = _make_remote_runner(args) + result_path = runner.run_tourney( + model=args.model, + local_output_dir=args.output_dir, + device=args.device, + dtype=args.dtype, + quantization=args.quantization, + methods=args.methods, + hub_org=args.hub_org, + hub_repo=args.hub_repo, + dataset=args.dataset, + ) + if result_path: + console.print( + Panel( + f"[bold green]Remote tournament complete![/]\n\n" + f" Results at: [cyan]{result_path}[/]", + border_style="green", + title="[bold green]TOURNAMENT COMPLETE (remote)[/]", + ) + ) + else: + console.print("[red]Remote tournament failed. Check logs above.[/]") + raise SystemExit(1) + + if __name__ == "__main__": main() diff --git a/obliteratus/config.py b/obliteratus/config.py index 9947803..1a31e31 100644 --- a/obliteratus/config.py +++ b/obliteratus/config.py @@ -35,6 +35,19 @@ class StrategyConfig: params: dict[str, Any] = field(default_factory=dict) +@dataclass +class RemoteConfig: + """Optional remote execution settings for running on a GPU node via SSH.""" + + host: str + user: str = "root" + port: int = 22 + ssh_key: str | None = None + remote_dir: str = "/tmp/obliteratus_run" + python: str = "python3" + sync_results: bool = True + + @dataclass class StudyConfig: """Top-level configuration for an ablation run.""" @@ -46,6 +59,7 @@ class StudyConfig: batch_size: int = 8 max_length: int = 512 output_dir: str = "results" + remote: RemoteConfig | None = None @classmethod def from_yaml(cls, path: str | Path) -> StudyConfig: @@ -82,6 +96,10 @@ class StudyConfig: model = ModelConfig(**d["model"]) dataset = DatasetConfig(**d["dataset"]) strategies = [StrategyConfig(**s) for s in d["strategies"]] + remote = None + if "remote" in d and d["remote"]: + remote = RemoteConfig(**d["remote"]) + return cls( model=model, dataset=dataset, @@ -90,6 +108,7 @@ class StudyConfig: batch_size=d.get("batch_size", 8), max_length=d.get("max_length", 512), output_dir=d.get("output_dir", "results"), + remote=remote, ) def to_dict(self) -> dict: diff --git a/obliteratus/remote.py b/obliteratus/remote.py new file mode 100644 index 0000000..9659e04 --- /dev/null +++ b/obliteratus/remote.py @@ -0,0 +1,417 @@ +"""Remote execution support for Obliteratus. + +Run abliteration pipelines on remote GPU nodes via SSH. The remote machine +must have CUDA-capable GPUs and a Python environment. Obliteratus will be +auto-installed if not present. + +Usage (CLI): + obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct \ + --remote user@gpu-node \ + --ssh-key ~/.ssh/id_rsa + +Usage (YAML config): + remote: + host: gpu-node + user: root + ssh_key: ~/.ssh/id_rsa + remote_dir: /tmp/obliteratus_run +""" + +from __future__ import annotations + +import os +import shlex +import subprocess +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable + +from rich.console import Console + +console = Console() + + +@dataclass +class RemoteConfig: + """SSH connection and remote execution settings.""" + + host: str + user: str = "root" + port: int = 22 + ssh_key: str | None = None + remote_dir: str = "/tmp/obliteratus_run" + install_timeout: int = 600 # seconds + python: str = "python3" # remote python binary + sync_results: bool = True + + @property + def ssh_target(self) -> str: + return f"{self.user}@{self.host}" + + @classmethod + def from_cli_args(cls, remote_str: str, **kwargs) -> RemoteConfig: + """Parse 'user@host' or just 'host' from CLI --remote flag.""" + if "@" in remote_str: + user, host = remote_str.rsplit("@", 1) + else: + user = "root" + host = remote_str + return cls(host=host, user=user, **kwargs) + + @classmethod + def from_dict(cls, d: dict) -> RemoteConfig: + return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__}) + + +class RemoteRunner: + """Execute Obliteratus commands on a remote machine via SSH.""" + + def __init__( + self, + config: RemoteConfig, + on_log: Callable[[str], None] | None = None, + ): + self.config = config + self.on_log = on_log or (lambda msg: console.print(f"[dim][remote][/] {msg}")) + + def _ssh_base_cmd(self) -> list[str]: + """Build base SSH command with common options.""" + cmd = [ + "ssh", + "-o", "StrictHostKeyChecking=no", + "-o", "BatchMode=yes", + "-o", "ConnectTimeout=30", + "-p", str(self.config.port), + ] + if self.config.ssh_key: + key_path = os.path.expanduser(self.config.ssh_key) + cmd.extend(["-i", key_path]) + cmd.append(self.config.ssh_target) + return cmd + + def _scp_base_cmd(self) -> list[str]: + """Build base SCP command.""" + cmd = [ + "scp", + "-o", "StrictHostKeyChecking=no", + "-o", "BatchMode=yes", + "-P", str(self.config.port), + "-r", + ] + if self.config.ssh_key: + key_path = os.path.expanduser(self.config.ssh_key) + cmd.extend(["-i", key_path]) + return cmd + + def run_ssh(self, remote_cmd: str, stream: bool = False, timeout: int | None = None) -> subprocess.CompletedProcess | int: + """Run a command on the remote host. + + If stream=True, streams stdout/stderr in real-time and returns the + exit code. Otherwise returns CompletedProcess. + """ + cmd = self._ssh_base_cmd() + [remote_cmd] + + if stream: + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + try: + for line in proc.stdout: + line = line.rstrip("\n") + self.on_log(line) + proc.wait(timeout=timeout) + except subprocess.TimeoutExpired: + proc.kill() + self.on_log("[red]Remote command timed out[/]") + return 1 + return proc.returncode + else: + return subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + ) + + def check_connection(self) -> bool: + """Verify SSH connectivity.""" + self.on_log(f"Testing SSH connection to {self.config.ssh_target}...") + result = self.run_ssh("echo ok", timeout=30) + if isinstance(result, subprocess.CompletedProcess) and result.returncode == 0: + self.on_log("SSH connection OK") + return True + self.on_log("[red]SSH connection failed[/]") + return False + + def check_gpu(self) -> str | None: + """Check for CUDA GPUs on remote. Returns nvidia-smi output or None.""" + result = self.run_ssh("nvidia-smi --query-gpu=name,memory.total --format=csv,noheader", timeout=30) + if isinstance(result, subprocess.CompletedProcess) and result.returncode == 0: + gpu_info = result.stdout.strip() + self.on_log(f"Remote GPUs: {gpu_info}") + return gpu_info + self.on_log("[yellow]No GPUs detected on remote (nvidia-smi failed)[/]") + return None + + def ensure_obliteratus(self) -> bool: + """Install or update obliteratus on the remote if needed.""" + # Check if already installed + check = self.run_ssh( + f"{self.config.python} -c \"import obliteratus; print(obliteratus.__version__)\"", + timeout=30, + ) + if isinstance(check, subprocess.CompletedProcess) and check.returncode == 0: + version = check.stdout.strip() + self.on_log(f"Obliteratus {version} already installed on remote") + return True + + # Install from PyPI or git + self.on_log("Installing obliteratus on remote...") + install_cmd = ( + f"{self.config.python} -m pip install --quiet " + f"'obliteratus @ git+https://github.com/EleutherAI/OBLITERATUS.git'" + ) + rc = self.run_ssh(install_cmd, stream=True, timeout=self.config.install_timeout) + if rc != 0: + self.on_log("[red]Failed to install obliteratus on remote[/]") + return False + + self.on_log("Obliteratus installed successfully") + return True + + def sync_results_back(self, remote_output_dir: str, local_output_dir: str) -> bool: + """Copy results from remote back to local machine via scp.""" + local_path = Path(local_output_dir) + local_path.mkdir(parents=True, exist_ok=True) + + self.on_log(f"Syncing results: {self.config.ssh_target}:{remote_output_dir} -> {local_output_dir}") + + cmd = self._scp_base_cmd() + [ + f"{self.config.ssh_target}:{remote_output_dir}/", + str(local_path), + ] + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600) + if result.returncode == 0: + self.on_log(f"Results synced to {local_output_dir}") + return True + else: + self.on_log(f"[red]SCP failed: {result.stderr}[/]") + return False + + def build_obliterate_command( + self, + model: str, + output_dir: str | None = None, + method: str = "advanced", + device: str = "auto", + dtype: str = "float16", + quantization: str | None = None, + n_directions: int | None = None, + direction_method: str | None = None, + regularization: float | None = None, + refinement_passes: int | None = None, + large_model: bool = False, + verify_sample_size: int | None = None, + ) -> str: + """Build the remote obliteratus CLI command.""" + remote_output = output_dir or f"{self.config.remote_dir}/output/{model.replace('/', '_')}" + + parts = [ + self.config.python, "-m", "obliteratus", + "obliterate", shlex.quote(model), + "--output-dir", shlex.quote(remote_output), + "--method", method, + "--device", device, + "--dtype", dtype, + ] + if quantization: + parts.extend(["--quantization", quantization]) + if n_directions is not None: + parts.extend(["--n-directions", str(n_directions)]) + if direction_method: + parts.extend(["--direction-method", direction_method]) + if regularization is not None: + parts.extend(["--regularization", str(regularization)]) + if refinement_passes is not None: + parts.extend(["--refinement-passes", str(refinement_passes)]) + if large_model: + parts.append("--large-model") + if verify_sample_size is not None: + parts.extend(["--verify-sample-size", str(verify_sample_size)]) + + return " ".join(parts) + + def build_run_command(self, remote_config_path: str, output_dir: str | None = None, preset: str | None = None) -> str: + """Build remote 'obliteratus run' command.""" + parts = [ + self.config.python, "-m", "obliteratus", + "run", shlex.quote(remote_config_path), + ] + if output_dir: + parts.extend(["--output-dir", shlex.quote(output_dir)]) + if preset: + parts.extend(["--preset", preset]) + return " ".join(parts) + + def build_tourney_command( + self, + model: str, + output_dir: str | None = None, + device: str = "auto", + dtype: str = "float16", + quantization: str | None = None, + methods: list[str] | None = None, + hub_org: str | None = None, + hub_repo: str | None = None, + dataset: str = "builtin", + ) -> str: + """Build remote 'obliteratus tourney' command.""" + remote_output = output_dir or f"{self.config.remote_dir}/tourney/{model.replace('/', '_')}" + + parts = [ + self.config.python, "-m", "obliteratus", + "tourney", shlex.quote(model), + "--output-dir", shlex.quote(remote_output), + "--device", device, + "--dtype", dtype, + "--dataset", dataset, + ] + if quantization: + parts.extend(["--quantization", quantization]) + if hub_org: + parts.extend(["--hub-org", hub_org]) + if hub_repo: + parts.extend(["--hub-repo", hub_repo]) + if methods: + parts.extend(["--methods"] + methods) + return " ".join(parts) + + def upload_config(self, local_config_path: str) -> str: + """Upload a YAML config file to the remote.""" + remote_path = f"{self.config.remote_dir}/config.yaml" + self.run_ssh(f"mkdir -p {shlex.quote(self.config.remote_dir)}") + + cmd = self._scp_base_cmd() + # scp uses -P not -p, already handled in _scp_base_cmd + cmd += [local_config_path, f"{self.config.ssh_target}:{remote_path}"] + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + if result.returncode != 0: + raise RuntimeError(f"Failed to upload config: {result.stderr}") + self.on_log(f"Config uploaded to {remote_path}") + return remote_path + + def run_obliterate( + self, + model: str, + local_output_dir: str | None = None, + **kwargs, + ) -> str | None: + """Full remote obliteration: setup, run, sync results. + + Returns local path to results, or None on failure. + """ + # 1. Verify connection + if not self.check_connection(): + return None + + # 2. Check GPUs + self.check_gpu() + + # 3. Ensure obliteratus is installed + if not self.ensure_obliteratus(): + return None + + # 4. Create remote working directory + self.run_ssh(f"mkdir -p {shlex.quote(self.config.remote_dir)}") + + # 5. Build and run the command + remote_output = f"{self.config.remote_dir}/output/{model.replace('/', '_')}" + cmd = self.build_obliterate_command(model, output_dir=remote_output, **kwargs) + self.on_log(f"Running: {cmd}") + + rc = self.run_ssh(cmd, stream=True) + if rc != 0: + self.on_log(f"[red]Remote obliteration failed (exit code {rc})[/]") + return None + + # 6. Sync results back + if self.config.sync_results: + local_output = local_output_dir or f"abliterated/{model.replace('/', '_')}" + if self.sync_results_back(remote_output, local_output): + return local_output + return None + + self.on_log(f"Results on remote: {remote_output}") + return remote_output + + def run_config( + self, + local_config_path: str, + local_output_dir: str | None = None, + preset: str | None = None, + ) -> str | None: + """Upload config, run study remotely, sync results.""" + if not self.check_connection(): + return None + self.check_gpu() + if not self.ensure_obliteratus(): + return None + + # Upload config + remote_config = self.upload_config(local_config_path) + + # Determine remote output dir + remote_output = f"{self.config.remote_dir}/results" + cmd = self.build_run_command(remote_config, output_dir=remote_output, preset=preset) + self.on_log(f"Running: {cmd}") + + rc = self.run_ssh(cmd, stream=True) + if rc != 0: + self.on_log(f"[red]Remote run failed (exit code {rc})[/]") + return None + + if self.config.sync_results: + local_output = local_output_dir or "results" + if self.sync_results_back(remote_output, local_output): + return local_output + return None + + return remote_output + + def run_tourney( + self, + model: str, + local_output_dir: str | None = None, + **kwargs, + ) -> str | None: + """Run tournament remotely, sync results.""" + if not self.check_connection(): + return None + self.check_gpu() + if not self.ensure_obliteratus(): + return None + + remote_output = f"{self.config.remote_dir}/tourney/{model.replace('/', '_')}" + cmd = self.build_tourney_command(model, output_dir=remote_output, **kwargs) + self.on_log(f"Running: {cmd}") + + rc = self.run_ssh(cmd, stream=True) + if rc != 0: + self.on_log(f"[red]Remote tourney failed (exit code {rc})[/]") + return None + + if self.config.sync_results: + local_output = local_output_dir or f"/tmp/obliteratus_tourney/{model.replace('/', '_')}" + if self.sync_results_back(remote_output, local_output): + return local_output + return None + + return remote_output