Add remote SSH execution support for GPU nodes

Adds --remote [user@]host flag to obliterate, run, and tourney commands,
enabling execution on remote GPU nodes via SSH. Also supports a remote:
section in YAML configs. The remote runner handles SSH connectivity checks,
GPU detection, auto-installation of obliteratus, log streaming, and result
syncing back to the local machine via scp.
This commit is contained in:
Stella Biderman
2026-03-11 15:51:16 -04:00
parent 84bdf5d978
commit 34032b7821
5 changed files with 663 additions and 3 deletions
+41
View File
@@ -0,0 +1,41 @@
# Example: Run an ablation study on a remote GPU node via SSH.
#
# Usage:
# obliteratus run examples/remote_gpu_node.yaml
#
# The 'remote' section tells Obliteratus to SSH into the specified host,
# install obliteratus if needed, run the pipeline there, and copy results
# back to the local machine.
#
# You can also use --remote on any command instead of a YAML section:
# obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct --remote root@gpu-node --ssh-key ~/.ssh/id_rsa
model:
name: meta-llama/Llama-3.1-8B-Instruct
task: causal_lm
dtype: float16
device: auto
dataset:
name: wikitext
split: test
max_samples: 500
strategies:
- name: layer_removal
params:
layer_indices: [10, 11, 12]
metrics: [perplexity]
batch_size: 8
max_length: 512
output_dir: results/remote_example
remote:
host: gpu-node.example.com
user: root
port: 22
ssh_key: ~/.ssh/id_rsa
remote_dir: /tmp/obliteratus_run
python: python3
sync_results: true
+8
View File
@@ -17,6 +17,8 @@ __all__ = [
"TourneyResult",
"get_adaptive_recommendation",
"AdaptiveRecommendation",
"RemoteRunner",
"RemoteConfig",
]
@@ -60,4 +62,10 @@ def __getattr__(name):
if name == "AdaptiveRecommendation":
from obliteratus.adaptive_defaults import AdaptiveRecommendation
return AdaptiveRecommendation
if name == "RemoteRunner":
from obliteratus.remote import RemoteRunner
return RemoteRunner
if name == "RemoteConfig":
from obliteratus.remote import RemoteConfig
return RemoteConfig
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
+178 -3
View File
@@ -22,6 +22,35 @@ _BANNER = r"""
"""
def _add_remote_args(parser):
"""Add --remote execution flags to a subcommand parser."""
remote_group = parser.add_argument_group("remote execution")
remote_group.add_argument(
"--remote", type=str, default=None, metavar="[USER@]HOST",
help="Run on a remote GPU node via SSH (e.g. root@gpu-node or just gpu-node)",
)
remote_group.add_argument(
"--ssh-key", type=str, default=None,
help="Path to SSH private key (default: use SSH agent or ~/.ssh/id_rsa)",
)
remote_group.add_argument(
"--ssh-port", type=int, default=22,
help="SSH port on remote host (default: 22)",
)
remote_group.add_argument(
"--remote-dir", type=str, default="/tmp/obliteratus_run",
help="Working directory on the remote machine (default: /tmp/obliteratus_run)",
)
remote_group.add_argument(
"--remote-python", type=str, default="python3",
help="Python binary on the remote machine (default: python3)",
)
remote_group.add_argument(
"--no-sync", action="store_true", default=False,
help="Don't copy results back to local machine after remote run",
)
def main(argv: list[str] | None = None):
console.print(_BANNER)
parser = argparse.ArgumentParser(
@@ -40,6 +69,7 @@ def main(argv: list[str] | None = None):
default=None,
help="Apply a preset (e.g. quick, full, attention, jailbreak, guardrail)",
)
_add_remote_args(run_parser)
# --- info ---
info_parser = subparsers.add_parser("info", help="Print model architecture info")
@@ -144,9 +174,11 @@ def main(argv: list[str] | None = None):
help="One-click: remove refusal directions from a model (SOTA multi-technique)",
)
_add_obliterate_args(abl_parser)
_add_remote_args(abl_parser)
# Backward-compat alias (hidden from help)
abl_alias = subparsers.add_parser("abliterate", help=argparse.SUPPRESS)
_add_obliterate_args(abl_alias)
_add_remote_args(abl_alias)
# --- report ---
report_parser = subparsers.add_parser("report", help="Regenerate report from saved results")
@@ -180,6 +212,7 @@ def main(argv: list[str] | None = None):
"--methods", type=str, nargs="+", default=None,
help="Override: only run these methods (space-separated)",
)
_add_remote_args(tourney_parser)
# --- recommend ---
recommend_parser = subparsers.add_parser(
@@ -197,7 +230,10 @@ def main(argv: list[str] | None = None):
args = parser.parse_args(argv)
if args.command == "run":
_cmd_run(args)
if getattr(args, "remote", None):
_cmd_remote_run(args)
else:
_cmd_run(args)
elif args.command == "interactive":
_cmd_interactive()
elif args.command == "models":
@@ -217,9 +253,15 @@ def main(argv: list[str] | None = None):
elif args.command == "recommend":
_cmd_recommend(args)
elif args.command == "tourney":
_cmd_tourney(args)
if getattr(args, "remote", None):
_cmd_remote_tourney(args)
else:
_cmd_tourney(args)
elif args.command in ("obliterate", "abliterate"):
_cmd_abliterate(args)
if getattr(args, "remote", None):
_cmd_remote_abliterate(args)
else:
_cmd_abliterate(args)
def _cmd_ui(args):
@@ -314,6 +356,33 @@ def _cmd_run(args):
config = StudyConfig.from_dict(raw)
if args.output_dir:
config.output_dir = args.output_dir
# If YAML has a remote: section, dispatch to remote runner
if config.remote is not None:
from obliteratus.remote import RemoteConfig as _RC, RemoteRunner
rc = _RC(
host=config.remote.host,
user=config.remote.user,
port=config.remote.port,
ssh_key=config.remote.ssh_key,
remote_dir=config.remote.remote_dir,
python=config.remote.python,
sync_results=config.remote.sync_results,
)
runner = RemoteRunner(rc)
result_path = runner.run_config(
local_config_path=args.config,
local_output_dir=config.output_dir,
preset=args.preset,
)
if result_path:
console.print(f"\n[bold green]Remote run complete.[/] Results at: [cyan]{result_path}[/]")
else:
console.print("[red]Remote run failed. Check logs above.[/]")
raise SystemExit(1)
return
run_study(config)
@@ -653,5 +722,111 @@ def _cmd_abliterate(args):
)
def _make_remote_runner(args):
"""Create a RemoteRunner from CLI --remote flags."""
from obliteratus.remote import RemoteConfig, RemoteRunner
rc = RemoteConfig.from_cli_args(
args.remote,
port=args.ssh_port,
ssh_key=args.ssh_key,
remote_dir=args.remote_dir,
python=args.remote_python,
sync_results=not args.no_sync,
)
return RemoteRunner(rc)
def _cmd_remote_abliterate(args):
from rich.panel import Panel
runner = _make_remote_runner(args)
kwargs = {}
if args.method:
kwargs["method"] = args.method
if args.device:
kwargs["device"] = args.device
if args.dtype:
kwargs["dtype"] = args.dtype
if args.quantization:
kwargs["quantization"] = args.quantization
if args.n_directions is not None:
kwargs["n_directions"] = args.n_directions
if getattr(args, "direction_method", None):
kwargs["direction_method"] = args.direction_method
if args.regularization is not None:
kwargs["regularization"] = args.regularization
if args.refinement_passes is not None:
kwargs["refinement_passes"] = args.refinement_passes
if getattr(args, "large_model", False):
kwargs["large_model"] = True
if getattr(args, "verify_sample_size", None) is not None:
kwargs["verify_sample_size"] = args.verify_sample_size
result_path = runner.run_obliterate(
model=args.model,
local_output_dir=args.output_dir,
**kwargs,
)
if result_path:
console.print(
Panel(
f"[bold green]Remote abliteration complete![/]\n\n"
f" Results at: [cyan]{result_path}[/]\n\n"
f" [dim]Load with:[/] AutoModelForCausalLM.from_pretrained('{result_path}')",
border_style="green",
title="[bold green]REBIRTH COMPLETE (remote)[/]",
)
)
else:
console.print("[red]Remote abliteration failed. Check logs above.[/]")
raise SystemExit(1)
def _cmd_remote_run(args):
runner = _make_remote_runner(args)
result_path = runner.run_config(
local_config_path=args.config,
local_output_dir=args.output_dir,
preset=args.preset,
)
if result_path:
console.print(f"\n[bold green]Remote run complete.[/] Results at: [cyan]{result_path}[/]")
else:
console.print("[red]Remote run failed. Check logs above.[/]")
raise SystemExit(1)
def _cmd_remote_tourney(args):
from rich.panel import Panel
runner = _make_remote_runner(args)
result_path = runner.run_tourney(
model=args.model,
local_output_dir=args.output_dir,
device=args.device,
dtype=args.dtype,
quantization=args.quantization,
methods=args.methods,
hub_org=args.hub_org,
hub_repo=args.hub_repo,
dataset=args.dataset,
)
if result_path:
console.print(
Panel(
f"[bold green]Remote tournament complete![/]\n\n"
f" Results at: [cyan]{result_path}[/]",
border_style="green",
title="[bold green]TOURNAMENT COMPLETE (remote)[/]",
)
)
else:
console.print("[red]Remote tournament failed. Check logs above.[/]")
raise SystemExit(1)
if __name__ == "__main__":
main()
+19
View File
@@ -35,6 +35,19 @@ class StrategyConfig:
params: dict[str, Any] = field(default_factory=dict)
@dataclass
class RemoteConfig:
"""Optional remote execution settings for running on a GPU node via SSH."""
host: str
user: str = "root"
port: int = 22
ssh_key: str | None = None
remote_dir: str = "/tmp/obliteratus_run"
python: str = "python3"
sync_results: bool = True
@dataclass
class StudyConfig:
"""Top-level configuration for an ablation run."""
@@ -46,6 +59,7 @@ class StudyConfig:
batch_size: int = 8
max_length: int = 512
output_dir: str = "results"
remote: RemoteConfig | None = None
@classmethod
def from_yaml(cls, path: str | Path) -> StudyConfig:
@@ -82,6 +96,10 @@ class StudyConfig:
model = ModelConfig(**d["model"])
dataset = DatasetConfig(**d["dataset"])
strategies = [StrategyConfig(**s) for s in d["strategies"]]
remote = None
if "remote" in d and d["remote"]:
remote = RemoteConfig(**d["remote"])
return cls(
model=model,
dataset=dataset,
@@ -90,6 +108,7 @@ class StudyConfig:
batch_size=d.get("batch_size", 8),
max_length=d.get("max_length", 512),
output_dir=d.get("output_dir", "results"),
remote=remote,
)
def to_dict(self) -> dict:
+417
View File
@@ -0,0 +1,417 @@
"""Remote execution support for Obliteratus.
Run abliteration pipelines on remote GPU nodes via SSH. The remote machine
must have CUDA-capable GPUs and a Python environment. Obliteratus will be
auto-installed if not present.
Usage (CLI):
obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct \
--remote user@gpu-node \
--ssh-key ~/.ssh/id_rsa
Usage (YAML config):
remote:
host: gpu-node
user: root
ssh_key: ~/.ssh/id_rsa
remote_dir: /tmp/obliteratus_run
"""
from __future__ import annotations
import os
import shlex
import subprocess
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable
from rich.console import Console
console = Console()
@dataclass
class RemoteConfig:
"""SSH connection and remote execution settings."""
host: str
user: str = "root"
port: int = 22
ssh_key: str | None = None
remote_dir: str = "/tmp/obliteratus_run"
install_timeout: int = 600 # seconds
python: str = "python3" # remote python binary
sync_results: bool = True
@property
def ssh_target(self) -> str:
return f"{self.user}@{self.host}"
@classmethod
def from_cli_args(cls, remote_str: str, **kwargs) -> RemoteConfig:
"""Parse 'user@host' or just 'host' from CLI --remote flag."""
if "@" in remote_str:
user, host = remote_str.rsplit("@", 1)
else:
user = "root"
host = remote_str
return cls(host=host, user=user, **kwargs)
@classmethod
def from_dict(cls, d: dict) -> RemoteConfig:
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
class RemoteRunner:
"""Execute Obliteratus commands on a remote machine via SSH."""
def __init__(
self,
config: RemoteConfig,
on_log: Callable[[str], None] | None = None,
):
self.config = config
self.on_log = on_log or (lambda msg: console.print(f"[dim][remote][/] {msg}"))
def _ssh_base_cmd(self) -> list[str]:
"""Build base SSH command with common options."""
cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-o", "BatchMode=yes",
"-o", "ConnectTimeout=30",
"-p", str(self.config.port),
]
if self.config.ssh_key:
key_path = os.path.expanduser(self.config.ssh_key)
cmd.extend(["-i", key_path])
cmd.append(self.config.ssh_target)
return cmd
def _scp_base_cmd(self) -> list[str]:
"""Build base SCP command."""
cmd = [
"scp",
"-o", "StrictHostKeyChecking=no",
"-o", "BatchMode=yes",
"-P", str(self.config.port),
"-r",
]
if self.config.ssh_key:
key_path = os.path.expanduser(self.config.ssh_key)
cmd.extend(["-i", key_path])
return cmd
def run_ssh(self, remote_cmd: str, stream: bool = False, timeout: int | None = None) -> subprocess.CompletedProcess | int:
"""Run a command on the remote host.
If stream=True, streams stdout/stderr in real-time and returns the
exit code. Otherwise returns CompletedProcess.
"""
cmd = self._ssh_base_cmd() + [remote_cmd]
if stream:
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
try:
for line in proc.stdout:
line = line.rstrip("\n")
self.on_log(line)
proc.wait(timeout=timeout)
except subprocess.TimeoutExpired:
proc.kill()
self.on_log("[red]Remote command timed out[/]")
return 1
return proc.returncode
else:
return subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout,
)
def check_connection(self) -> bool:
"""Verify SSH connectivity."""
self.on_log(f"Testing SSH connection to {self.config.ssh_target}...")
result = self.run_ssh("echo ok", timeout=30)
if isinstance(result, subprocess.CompletedProcess) and result.returncode == 0:
self.on_log("SSH connection OK")
return True
self.on_log("[red]SSH connection failed[/]")
return False
def check_gpu(self) -> str | None:
"""Check for CUDA GPUs on remote. Returns nvidia-smi output or None."""
result = self.run_ssh("nvidia-smi --query-gpu=name,memory.total --format=csv,noheader", timeout=30)
if isinstance(result, subprocess.CompletedProcess) and result.returncode == 0:
gpu_info = result.stdout.strip()
self.on_log(f"Remote GPUs: {gpu_info}")
return gpu_info
self.on_log("[yellow]No GPUs detected on remote (nvidia-smi failed)[/]")
return None
def ensure_obliteratus(self) -> bool:
"""Install or update obliteratus on the remote if needed."""
# Check if already installed
check = self.run_ssh(
f"{self.config.python} -c \"import obliteratus; print(obliteratus.__version__)\"",
timeout=30,
)
if isinstance(check, subprocess.CompletedProcess) and check.returncode == 0:
version = check.stdout.strip()
self.on_log(f"Obliteratus {version} already installed on remote")
return True
# Install from PyPI or git
self.on_log("Installing obliteratus on remote...")
install_cmd = (
f"{self.config.python} -m pip install --quiet "
f"'obliteratus @ git+https://github.com/EleutherAI/OBLITERATUS.git'"
)
rc = self.run_ssh(install_cmd, stream=True, timeout=self.config.install_timeout)
if rc != 0:
self.on_log("[red]Failed to install obliteratus on remote[/]")
return False
self.on_log("Obliteratus installed successfully")
return True
def sync_results_back(self, remote_output_dir: str, local_output_dir: str) -> bool:
"""Copy results from remote back to local machine via scp."""
local_path = Path(local_output_dir)
local_path.mkdir(parents=True, exist_ok=True)
self.on_log(f"Syncing results: {self.config.ssh_target}:{remote_output_dir} -> {local_output_dir}")
cmd = self._scp_base_cmd() + [
f"{self.config.ssh_target}:{remote_output_dir}/",
str(local_path),
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600)
if result.returncode == 0:
self.on_log(f"Results synced to {local_output_dir}")
return True
else:
self.on_log(f"[red]SCP failed: {result.stderr}[/]")
return False
def build_obliterate_command(
self,
model: str,
output_dir: str | None = None,
method: str = "advanced",
device: str = "auto",
dtype: str = "float16",
quantization: str | None = None,
n_directions: int | None = None,
direction_method: str | None = None,
regularization: float | None = None,
refinement_passes: int | None = None,
large_model: bool = False,
verify_sample_size: int | None = None,
) -> str:
"""Build the remote obliteratus CLI command."""
remote_output = output_dir or f"{self.config.remote_dir}/output/{model.replace('/', '_')}"
parts = [
self.config.python, "-m", "obliteratus",
"obliterate", shlex.quote(model),
"--output-dir", shlex.quote(remote_output),
"--method", method,
"--device", device,
"--dtype", dtype,
]
if quantization:
parts.extend(["--quantization", quantization])
if n_directions is not None:
parts.extend(["--n-directions", str(n_directions)])
if direction_method:
parts.extend(["--direction-method", direction_method])
if regularization is not None:
parts.extend(["--regularization", str(regularization)])
if refinement_passes is not None:
parts.extend(["--refinement-passes", str(refinement_passes)])
if large_model:
parts.append("--large-model")
if verify_sample_size is not None:
parts.extend(["--verify-sample-size", str(verify_sample_size)])
return " ".join(parts)
def build_run_command(self, remote_config_path: str, output_dir: str | None = None, preset: str | None = None) -> str:
"""Build remote 'obliteratus run' command."""
parts = [
self.config.python, "-m", "obliteratus",
"run", shlex.quote(remote_config_path),
]
if output_dir:
parts.extend(["--output-dir", shlex.quote(output_dir)])
if preset:
parts.extend(["--preset", preset])
return " ".join(parts)
def build_tourney_command(
self,
model: str,
output_dir: str | None = None,
device: str = "auto",
dtype: str = "float16",
quantization: str | None = None,
methods: list[str] | None = None,
hub_org: str | None = None,
hub_repo: str | None = None,
dataset: str = "builtin",
) -> str:
"""Build remote 'obliteratus tourney' command."""
remote_output = output_dir or f"{self.config.remote_dir}/tourney/{model.replace('/', '_')}"
parts = [
self.config.python, "-m", "obliteratus",
"tourney", shlex.quote(model),
"--output-dir", shlex.quote(remote_output),
"--device", device,
"--dtype", dtype,
"--dataset", dataset,
]
if quantization:
parts.extend(["--quantization", quantization])
if hub_org:
parts.extend(["--hub-org", hub_org])
if hub_repo:
parts.extend(["--hub-repo", hub_repo])
if methods:
parts.extend(["--methods"] + methods)
return " ".join(parts)
def upload_config(self, local_config_path: str) -> str:
"""Upload a YAML config file to the remote."""
remote_path = f"{self.config.remote_dir}/config.yaml"
self.run_ssh(f"mkdir -p {shlex.quote(self.config.remote_dir)}")
cmd = self._scp_base_cmd()
# scp uses -P not -p, already handled in _scp_base_cmd
cmd += [local_config_path, f"{self.config.ssh_target}:{remote_path}"]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
if result.returncode != 0:
raise RuntimeError(f"Failed to upload config: {result.stderr}")
self.on_log(f"Config uploaded to {remote_path}")
return remote_path
def run_obliterate(
self,
model: str,
local_output_dir: str | None = None,
**kwargs,
) -> str | None:
"""Full remote obliteration: setup, run, sync results.
Returns local path to results, or None on failure.
"""
# 1. Verify connection
if not self.check_connection():
return None
# 2. Check GPUs
self.check_gpu()
# 3. Ensure obliteratus is installed
if not self.ensure_obliteratus():
return None
# 4. Create remote working directory
self.run_ssh(f"mkdir -p {shlex.quote(self.config.remote_dir)}")
# 5. Build and run the command
remote_output = f"{self.config.remote_dir}/output/{model.replace('/', '_')}"
cmd = self.build_obliterate_command(model, output_dir=remote_output, **kwargs)
self.on_log(f"Running: {cmd}")
rc = self.run_ssh(cmd, stream=True)
if rc != 0:
self.on_log(f"[red]Remote obliteration failed (exit code {rc})[/]")
return None
# 6. Sync results back
if self.config.sync_results:
local_output = local_output_dir or f"abliterated/{model.replace('/', '_')}"
if self.sync_results_back(remote_output, local_output):
return local_output
return None
self.on_log(f"Results on remote: {remote_output}")
return remote_output
def run_config(
self,
local_config_path: str,
local_output_dir: str | None = None,
preset: str | None = None,
) -> str | None:
"""Upload config, run study remotely, sync results."""
if not self.check_connection():
return None
self.check_gpu()
if not self.ensure_obliteratus():
return None
# Upload config
remote_config = self.upload_config(local_config_path)
# Determine remote output dir
remote_output = f"{self.config.remote_dir}/results"
cmd = self.build_run_command(remote_config, output_dir=remote_output, preset=preset)
self.on_log(f"Running: {cmd}")
rc = self.run_ssh(cmd, stream=True)
if rc != 0:
self.on_log(f"[red]Remote run failed (exit code {rc})[/]")
return None
if self.config.sync_results:
local_output = local_output_dir or "results"
if self.sync_results_back(remote_output, local_output):
return local_output
return None
return remote_output
def run_tourney(
self,
model: str,
local_output_dir: str | None = None,
**kwargs,
) -> str | None:
"""Run tournament remotely, sync results."""
if not self.check_connection():
return None
self.check_gpu()
if not self.ensure_obliteratus():
return None
remote_output = f"{self.config.remote_dir}/tourney/{model.replace('/', '_')}"
cmd = self.build_tourney_command(model, output_dir=remote_output, **kwargs)
self.on_log(f"Running: {cmd}")
rc = self.run_ssh(cmd, stream=True)
if rc != 0:
self.on_log(f"[red]Remote tourney failed (exit code {rc})[/]")
return None
if self.config.sync_results:
local_output = local_output_dir or f"/tmp/obliteratus_tourney/{model.replace('/', '_')}"
if self.sync_results_back(remote_output, local_output):
return local_output
return None
return remote_output