mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-05-28 18:11:29 +02:00
a2bb748f1b
This reverts commit 1a6e2577bb.
436 lines
15 KiB
Python
436 lines
15 KiB
Python
"""Remote execution support for Obliteratus.
|
|
|
|
Run abliteration pipelines on remote GPU nodes via SSH. The remote machine
|
|
must have CUDA-capable GPUs and a Python environment. Obliteratus will be
|
|
auto-installed if not present.
|
|
|
|
Usage (CLI):
|
|
obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct \
|
|
--remote user@gpu-node \
|
|
--ssh-key ~/.ssh/id_rsa
|
|
|
|
Usage (YAML config):
|
|
remote:
|
|
host: gpu-node
|
|
user: root
|
|
ssh_key: ~/.ssh/id_rsa
|
|
remote_dir: /tmp/obliteratus_run
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import shlex
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
|
|
from rich.console import Console
|
|
|
|
console = Console()
|
|
|
|
|
|
@dataclass
|
|
class RemoteConfig:
|
|
"""SSH connection and remote execution settings."""
|
|
|
|
host: str
|
|
user: str = "root"
|
|
port: int = 22
|
|
ssh_key: str | None = None
|
|
remote_dir: str = "/tmp/obliteratus_run"
|
|
install_timeout: int = 600 # seconds
|
|
python: str = "python3" # remote python binary
|
|
sync_results: bool = True
|
|
gpus: str | None = None # comma-separated GPU IDs or "all"
|
|
|
|
@property
|
|
def ssh_target(self) -> str:
|
|
return f"{self.user}@{self.host}"
|
|
|
|
@classmethod
|
|
def from_cli_args(cls, remote_str: str, **kwargs) -> RemoteConfig:
|
|
"""Parse 'user@host' or just 'host' from CLI --remote flag."""
|
|
if "@" in remote_str:
|
|
user, host = remote_str.rsplit("@", 1)
|
|
else:
|
|
user = "root"
|
|
host = remote_str
|
|
return cls(host=host, user=user, **kwargs)
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict) -> RemoteConfig:
|
|
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
|
|
|
|
|
|
class RemoteRunner:
|
|
"""Execute Obliteratus commands on a remote machine via SSH."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: RemoteConfig,
|
|
on_log: Callable[[str], None] | None = None,
|
|
):
|
|
self.config = config
|
|
self.on_log = on_log or (lambda msg: console.print(f"[dim][remote][/] {msg}"))
|
|
|
|
def _ssh_base_cmd(self) -> list[str]:
|
|
"""Build base SSH command with common options."""
|
|
cmd = [
|
|
"ssh",
|
|
"-o", "StrictHostKeyChecking=no",
|
|
"-o", "BatchMode=yes",
|
|
"-o", "ConnectTimeout=30",
|
|
"-p", str(self.config.port),
|
|
]
|
|
if self.config.ssh_key:
|
|
key_path = os.path.expanduser(self.config.ssh_key)
|
|
cmd.extend(["-i", key_path])
|
|
cmd.append(self.config.ssh_target)
|
|
return cmd
|
|
|
|
def _scp_base_cmd(self) -> list[str]:
|
|
"""Build base SCP command."""
|
|
cmd = [
|
|
"scp",
|
|
"-o", "StrictHostKeyChecking=no",
|
|
"-o", "BatchMode=yes",
|
|
"-P", str(self.config.port),
|
|
"-r",
|
|
]
|
|
if self.config.ssh_key:
|
|
key_path = os.path.expanduser(self.config.ssh_key)
|
|
cmd.extend(["-i", key_path])
|
|
return cmd
|
|
|
|
def run_ssh(self, remote_cmd: str, stream: bool = False, timeout: int | None = None) -> subprocess.CompletedProcess | int:
|
|
"""Run a command on the remote host.
|
|
|
|
If stream=True, streams stdout/stderr in real-time and returns the
|
|
exit code. Otherwise returns CompletedProcess.
|
|
"""
|
|
cmd = self._ssh_base_cmd() + [remote_cmd]
|
|
|
|
if stream:
|
|
proc = subprocess.Popen(
|
|
cmd,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.STDOUT,
|
|
text=True,
|
|
bufsize=1,
|
|
)
|
|
try:
|
|
for line in proc.stdout:
|
|
line = line.rstrip("\n")
|
|
self.on_log(line)
|
|
proc.wait(timeout=timeout)
|
|
except subprocess.TimeoutExpired:
|
|
proc.kill()
|
|
self.on_log("[red]Remote command timed out[/]")
|
|
return 1
|
|
return proc.returncode
|
|
else:
|
|
return subprocess.run(
|
|
cmd,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=timeout,
|
|
)
|
|
|
|
def check_connection(self) -> bool:
|
|
"""Verify SSH connectivity."""
|
|
self.on_log(f"Testing SSH connection to {self.config.ssh_target}...")
|
|
result = self.run_ssh("echo ok", timeout=30)
|
|
if isinstance(result, subprocess.CompletedProcess) and result.returncode == 0:
|
|
self.on_log("SSH connection OK")
|
|
return True
|
|
self.on_log("[red]SSH connection failed[/]")
|
|
return False
|
|
|
|
def check_gpu(self) -> str | None:
|
|
"""Check for CUDA GPUs on remote. Returns nvidia-smi output or None."""
|
|
result = self.run_ssh(
|
|
"nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv,noheader",
|
|
timeout=30,
|
|
)
|
|
if isinstance(result, subprocess.CompletedProcess) and result.returncode == 0:
|
|
gpu_info = result.stdout.strip()
|
|
lines = gpu_info.split("\n")
|
|
self.on_log(f"Remote GPUs ({len(lines)} detected):")
|
|
for line in lines:
|
|
self.on_log(f" {line.strip()}")
|
|
if self.config.gpus and self.config.gpus.lower() != "all":
|
|
self.on_log(f" Selected GPUs: {self.config.gpus}")
|
|
else:
|
|
self.on_log(f" Using: all {len(lines)} GPUs")
|
|
return gpu_info
|
|
self.on_log("[yellow]No GPUs detected on remote (nvidia-smi failed)[/]")
|
|
return None
|
|
|
|
def _env_prefix(self) -> str:
|
|
"""Build environment variable prefix for remote commands (e.g. CUDA_VISIBLE_DEVICES)."""
|
|
parts = []
|
|
if self.config.gpus and self.config.gpus.lower() != "all":
|
|
parts.append(f"CUDA_VISIBLE_DEVICES={self.config.gpus}")
|
|
return " ".join(parts) + " " if parts else ""
|
|
|
|
def ensure_obliteratus(self) -> bool:
|
|
"""Install or update obliteratus on the remote if needed."""
|
|
# Check if already installed
|
|
check = self.run_ssh(
|
|
f"{self.config.python} -c \"import obliteratus; print(obliteratus.__version__)\"",
|
|
timeout=30,
|
|
)
|
|
if isinstance(check, subprocess.CompletedProcess) and check.returncode == 0:
|
|
version = check.stdout.strip()
|
|
self.on_log(f"Obliteratus {version} already installed on remote")
|
|
return True
|
|
|
|
# Install from PyPI or git
|
|
self.on_log("Installing obliteratus on remote...")
|
|
install_cmd = (
|
|
f"{self.config.python} -m pip install --quiet "
|
|
f"git+https://github.com/StellaAthena/OBLITERATUS.git"
|
|
)
|
|
rc = self.run_ssh(install_cmd, stream=True, timeout=self.config.install_timeout)
|
|
if rc != 0:
|
|
self.on_log("[red]Failed to install obliteratus on remote[/]")
|
|
return False
|
|
|
|
self.on_log("Obliteratus installed successfully")
|
|
return True
|
|
|
|
def sync_results_back(self, remote_output_dir: str, local_output_dir: str) -> bool:
|
|
"""Copy results from remote back to local machine via scp."""
|
|
local_path = Path(local_output_dir)
|
|
local_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
self.on_log(f"Syncing results: {self.config.ssh_target}:{remote_output_dir} -> {local_output_dir}")
|
|
|
|
cmd = self._scp_base_cmd() + [
|
|
f"{self.config.ssh_target}:{remote_output_dir}/",
|
|
str(local_path),
|
|
]
|
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600)
|
|
if result.returncode == 0:
|
|
self.on_log(f"Results synced to {local_output_dir}")
|
|
return True
|
|
else:
|
|
self.on_log(f"[red]SCP failed: {result.stderr}[/]")
|
|
return False
|
|
|
|
def build_obliterate_command(
|
|
self,
|
|
model: str,
|
|
output_dir: str | None = None,
|
|
method: str = "advanced",
|
|
device: str = "auto",
|
|
dtype: str = "float16",
|
|
quantization: str | None = None,
|
|
n_directions: int | None = None,
|
|
direction_method: str | None = None,
|
|
regularization: float | None = None,
|
|
refinement_passes: int | None = None,
|
|
large_model: bool = False,
|
|
verify_sample_size: int | None = None,
|
|
) -> str:
|
|
"""Build the remote obliteratus CLI command."""
|
|
remote_output = output_dir or f"{self.config.remote_dir}/output/{model.replace('/', '_')}"
|
|
|
|
parts = [
|
|
self._env_prefix() + self.config.python, "-m", "obliteratus",
|
|
"obliterate", shlex.quote(model),
|
|
"--output-dir", shlex.quote(remote_output),
|
|
"--method", method,
|
|
"--device", device,
|
|
"--dtype", dtype,
|
|
]
|
|
if quantization:
|
|
parts.extend(["--quantization", quantization])
|
|
if n_directions is not None:
|
|
parts.extend(["--n-directions", str(n_directions)])
|
|
if direction_method:
|
|
parts.extend(["--direction-method", direction_method])
|
|
if regularization is not None:
|
|
parts.extend(["--regularization", str(regularization)])
|
|
if refinement_passes is not None:
|
|
parts.extend(["--refinement-passes", str(refinement_passes)])
|
|
if large_model:
|
|
parts.append("--large-model")
|
|
if verify_sample_size is not None:
|
|
parts.extend(["--verify-sample-size", str(verify_sample_size)])
|
|
|
|
return " ".join(parts)
|
|
|
|
def build_run_command(self, remote_config_path: str, output_dir: str | None = None, preset: str | None = None) -> str:
|
|
"""Build remote 'obliteratus run' command."""
|
|
parts = [
|
|
self._env_prefix() + self.config.python, "-m", "obliteratus",
|
|
"run", shlex.quote(remote_config_path),
|
|
]
|
|
if output_dir:
|
|
parts.extend(["--output-dir", shlex.quote(output_dir)])
|
|
if preset:
|
|
parts.extend(["--preset", preset])
|
|
return " ".join(parts)
|
|
|
|
def build_tourney_command(
|
|
self,
|
|
model: str,
|
|
output_dir: str | None = None,
|
|
device: str = "auto",
|
|
dtype: str = "float16",
|
|
quantization: str | None = None,
|
|
methods: list[str] | None = None,
|
|
hub_org: str | None = None,
|
|
hub_repo: str | None = None,
|
|
dataset: str = "builtin",
|
|
) -> str:
|
|
"""Build remote 'obliteratus tourney' command."""
|
|
remote_output = output_dir or f"{self.config.remote_dir}/tourney/{model.replace('/', '_')}"
|
|
|
|
parts = [
|
|
self._env_prefix() + self.config.python, "-m", "obliteratus",
|
|
"tourney", shlex.quote(model),
|
|
"--output-dir", shlex.quote(remote_output),
|
|
"--device", device,
|
|
"--dtype", dtype,
|
|
"--dataset", dataset,
|
|
]
|
|
if quantization:
|
|
parts.extend(["--quantization", quantization])
|
|
if hub_org:
|
|
parts.extend(["--hub-org", hub_org])
|
|
if hub_repo:
|
|
parts.extend(["--hub-repo", hub_repo])
|
|
if methods:
|
|
parts.extend(["--methods"] + methods)
|
|
return " ".join(parts)
|
|
|
|
def upload_config(self, local_config_path: str) -> str:
|
|
"""Upload a YAML config file to the remote."""
|
|
remote_path = f"{self.config.remote_dir}/config.yaml"
|
|
self.run_ssh(f"mkdir -p {shlex.quote(self.config.remote_dir)}")
|
|
|
|
cmd = self._scp_base_cmd()
|
|
# scp uses -P not -p, already handled in _scp_base_cmd
|
|
cmd += [local_config_path, f"{self.config.ssh_target}:{remote_path}"]
|
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Failed to upload config: {result.stderr}")
|
|
self.on_log(f"Config uploaded to {remote_path}")
|
|
return remote_path
|
|
|
|
def run_obliterate(
|
|
self,
|
|
model: str,
|
|
local_output_dir: str | None = None,
|
|
**kwargs,
|
|
) -> str | None:
|
|
"""Full remote obliteration: setup, run, sync results.
|
|
|
|
Returns local path to results, or None on failure.
|
|
"""
|
|
# 1. Verify connection
|
|
if not self.check_connection():
|
|
return None
|
|
|
|
# 2. Check GPUs
|
|
self.check_gpu()
|
|
|
|
# 3. Ensure obliteratus is installed
|
|
if not self.ensure_obliteratus():
|
|
return None
|
|
|
|
# 4. Create remote working directory
|
|
self.run_ssh(f"mkdir -p {shlex.quote(self.config.remote_dir)}")
|
|
|
|
# 5. Build and run the command
|
|
remote_output = f"{self.config.remote_dir}/output/{model.replace('/', '_')}"
|
|
cmd = self.build_obliterate_command(model, output_dir=remote_output, **kwargs)
|
|
self.on_log(f"Running: {cmd}")
|
|
|
|
rc = self.run_ssh(cmd, stream=True)
|
|
if rc != 0:
|
|
self.on_log(f"[red]Remote obliteration failed (exit code {rc})[/]")
|
|
return None
|
|
|
|
# 6. Sync results back
|
|
if self.config.sync_results:
|
|
local_output = local_output_dir or f"abliterated/{model.replace('/', '_')}"
|
|
if self.sync_results_back(remote_output, local_output):
|
|
return local_output
|
|
return None
|
|
|
|
self.on_log(f"Results on remote: {remote_output}")
|
|
return remote_output
|
|
|
|
def run_config(
|
|
self,
|
|
local_config_path: str,
|
|
local_output_dir: str | None = None,
|
|
preset: str | None = None,
|
|
) -> str | None:
|
|
"""Upload config, run study remotely, sync results."""
|
|
if not self.check_connection():
|
|
return None
|
|
self.check_gpu()
|
|
if not self.ensure_obliteratus():
|
|
return None
|
|
|
|
# Upload config
|
|
remote_config = self.upload_config(local_config_path)
|
|
|
|
# Determine remote output dir
|
|
remote_output = f"{self.config.remote_dir}/results"
|
|
cmd = self.build_run_command(remote_config, output_dir=remote_output, preset=preset)
|
|
self.on_log(f"Running: {cmd}")
|
|
|
|
rc = self.run_ssh(cmd, stream=True)
|
|
if rc != 0:
|
|
self.on_log(f"[red]Remote run failed (exit code {rc})[/]")
|
|
return None
|
|
|
|
if self.config.sync_results:
|
|
local_output = local_output_dir or "results"
|
|
if self.sync_results_back(remote_output, local_output):
|
|
return local_output
|
|
return None
|
|
|
|
return remote_output
|
|
|
|
def run_tourney(
|
|
self,
|
|
model: str,
|
|
local_output_dir: str | None = None,
|
|
**kwargs,
|
|
) -> str | None:
|
|
"""Run tournament remotely, sync results."""
|
|
if not self.check_connection():
|
|
return None
|
|
self.check_gpu()
|
|
if not self.ensure_obliteratus():
|
|
return None
|
|
|
|
remote_output = f"{self.config.remote_dir}/tourney/{model.replace('/', '_')}"
|
|
cmd = self.build_tourney_command(model, output_dir=remote_output, **kwargs)
|
|
self.on_log(f"Running: {cmd}")
|
|
|
|
rc = self.run_ssh(cmd, stream=True)
|
|
if rc != 0:
|
|
self.on_log(f"[red]Remote tourney failed (exit code {rc})[/]")
|
|
return None
|
|
|
|
if self.config.sync_results:
|
|
local_output = local_output_dir or f"/tmp/obliteratus_tourney/{model.replace('/', '_')}"
|
|
if self.sync_results_back(remote_output, local_output):
|
|
return local_output
|
|
return None
|
|
|
|
return remote_output
|