mirror of
https://github.com/elder-plinius/OBLITERATUS.git
synced 2026-04-29 22:47:50 +02:00
Add remote SSH execution support for GPU nodes
Adds --remote [user@]host flag to obliterate, run, and tourney commands, enabling execution on remote GPU nodes via SSH. Also supports a remote: section in YAML configs. The remote runner handles SSH connectivity checks, GPU detection, auto-installation of obliteratus, log streaming, and result syncing back to the local machine via scp.
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
# Example: Run an ablation study on a remote GPU node via SSH.
|
||||
#
|
||||
# Usage:
|
||||
# obliteratus run examples/remote_gpu_node.yaml
|
||||
#
|
||||
# The 'remote' section tells Obliteratus to SSH into the specified host,
|
||||
# install obliteratus if needed, run the pipeline there, and copy results
|
||||
# back to the local machine.
|
||||
#
|
||||
# You can also use --remote on any command instead of a YAML section:
|
||||
# obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct --remote root@gpu-node --ssh-key ~/.ssh/id_rsa
|
||||
|
||||
model:
|
||||
name: meta-llama/Llama-3.1-8B-Instruct
|
||||
task: causal_lm
|
||||
dtype: float16
|
||||
device: auto
|
||||
|
||||
dataset:
|
||||
name: wikitext
|
||||
split: test
|
||||
max_samples: 500
|
||||
|
||||
strategies:
|
||||
- name: layer_removal
|
||||
params:
|
||||
layer_indices: [10, 11, 12]
|
||||
|
||||
metrics: [perplexity]
|
||||
batch_size: 8
|
||||
max_length: 512
|
||||
output_dir: results/remote_example
|
||||
|
||||
remote:
|
||||
host: gpu-node.example.com
|
||||
user: root
|
||||
port: 22
|
||||
ssh_key: ~/.ssh/id_rsa
|
||||
remote_dir: /tmp/obliteratus_run
|
||||
python: python3
|
||||
sync_results: true
|
||||
@@ -17,6 +17,8 @@ __all__ = [
|
||||
"TourneyResult",
|
||||
"get_adaptive_recommendation",
|
||||
"AdaptiveRecommendation",
|
||||
"RemoteRunner",
|
||||
"RemoteConfig",
|
||||
]
|
||||
|
||||
|
||||
@@ -60,4 +62,10 @@ def __getattr__(name):
|
||||
if name == "AdaptiveRecommendation":
|
||||
from obliteratus.adaptive_defaults import AdaptiveRecommendation
|
||||
return AdaptiveRecommendation
|
||||
if name == "RemoteRunner":
|
||||
from obliteratus.remote import RemoteRunner
|
||||
return RemoteRunner
|
||||
if name == "RemoteConfig":
|
||||
from obliteratus.remote import RemoteConfig
|
||||
return RemoteConfig
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
+178
-3
@@ -22,6 +22,35 @@ _BANNER = r"""
|
||||
"""
|
||||
|
||||
|
||||
def _add_remote_args(parser):
|
||||
"""Add --remote execution flags to a subcommand parser."""
|
||||
remote_group = parser.add_argument_group("remote execution")
|
||||
remote_group.add_argument(
|
||||
"--remote", type=str, default=None, metavar="[USER@]HOST",
|
||||
help="Run on a remote GPU node via SSH (e.g. root@gpu-node or just gpu-node)",
|
||||
)
|
||||
remote_group.add_argument(
|
||||
"--ssh-key", type=str, default=None,
|
||||
help="Path to SSH private key (default: use SSH agent or ~/.ssh/id_rsa)",
|
||||
)
|
||||
remote_group.add_argument(
|
||||
"--ssh-port", type=int, default=22,
|
||||
help="SSH port on remote host (default: 22)",
|
||||
)
|
||||
remote_group.add_argument(
|
||||
"--remote-dir", type=str, default="/tmp/obliteratus_run",
|
||||
help="Working directory on the remote machine (default: /tmp/obliteratus_run)",
|
||||
)
|
||||
remote_group.add_argument(
|
||||
"--remote-python", type=str, default="python3",
|
||||
help="Python binary on the remote machine (default: python3)",
|
||||
)
|
||||
remote_group.add_argument(
|
||||
"--no-sync", action="store_true", default=False,
|
||||
help="Don't copy results back to local machine after remote run",
|
||||
)
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None):
|
||||
console.print(_BANNER)
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -40,6 +69,7 @@ def main(argv: list[str] | None = None):
|
||||
default=None,
|
||||
help="Apply a preset (e.g. quick, full, attention, jailbreak, guardrail)",
|
||||
)
|
||||
_add_remote_args(run_parser)
|
||||
|
||||
# --- info ---
|
||||
info_parser = subparsers.add_parser("info", help="Print model architecture info")
|
||||
@@ -144,9 +174,11 @@ def main(argv: list[str] | None = None):
|
||||
help="One-click: remove refusal directions from a model (SOTA multi-technique)",
|
||||
)
|
||||
_add_obliterate_args(abl_parser)
|
||||
_add_remote_args(abl_parser)
|
||||
# Backward-compat alias (hidden from help)
|
||||
abl_alias = subparsers.add_parser("abliterate", help=argparse.SUPPRESS)
|
||||
_add_obliterate_args(abl_alias)
|
||||
_add_remote_args(abl_alias)
|
||||
|
||||
# --- report ---
|
||||
report_parser = subparsers.add_parser("report", help="Regenerate report from saved results")
|
||||
@@ -180,6 +212,7 @@ def main(argv: list[str] | None = None):
|
||||
"--methods", type=str, nargs="+", default=None,
|
||||
help="Override: only run these methods (space-separated)",
|
||||
)
|
||||
_add_remote_args(tourney_parser)
|
||||
|
||||
# --- recommend ---
|
||||
recommend_parser = subparsers.add_parser(
|
||||
@@ -197,7 +230,10 @@ def main(argv: list[str] | None = None):
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
if args.command == "run":
|
||||
_cmd_run(args)
|
||||
if getattr(args, "remote", None):
|
||||
_cmd_remote_run(args)
|
||||
else:
|
||||
_cmd_run(args)
|
||||
elif args.command == "interactive":
|
||||
_cmd_interactive()
|
||||
elif args.command == "models":
|
||||
@@ -217,9 +253,15 @@ def main(argv: list[str] | None = None):
|
||||
elif args.command == "recommend":
|
||||
_cmd_recommend(args)
|
||||
elif args.command == "tourney":
|
||||
_cmd_tourney(args)
|
||||
if getattr(args, "remote", None):
|
||||
_cmd_remote_tourney(args)
|
||||
else:
|
||||
_cmd_tourney(args)
|
||||
elif args.command in ("obliterate", "abliterate"):
|
||||
_cmd_abliterate(args)
|
||||
if getattr(args, "remote", None):
|
||||
_cmd_remote_abliterate(args)
|
||||
else:
|
||||
_cmd_abliterate(args)
|
||||
|
||||
|
||||
def _cmd_ui(args):
|
||||
@@ -314,6 +356,33 @@ def _cmd_run(args):
|
||||
config = StudyConfig.from_dict(raw)
|
||||
if args.output_dir:
|
||||
config.output_dir = args.output_dir
|
||||
|
||||
# If YAML has a remote: section, dispatch to remote runner
|
||||
if config.remote is not None:
|
||||
from obliteratus.remote import RemoteConfig as _RC, RemoteRunner
|
||||
|
||||
rc = _RC(
|
||||
host=config.remote.host,
|
||||
user=config.remote.user,
|
||||
port=config.remote.port,
|
||||
ssh_key=config.remote.ssh_key,
|
||||
remote_dir=config.remote.remote_dir,
|
||||
python=config.remote.python,
|
||||
sync_results=config.remote.sync_results,
|
||||
)
|
||||
runner = RemoteRunner(rc)
|
||||
result_path = runner.run_config(
|
||||
local_config_path=args.config,
|
||||
local_output_dir=config.output_dir,
|
||||
preset=args.preset,
|
||||
)
|
||||
if result_path:
|
||||
console.print(f"\n[bold green]Remote run complete.[/] Results at: [cyan]{result_path}[/]")
|
||||
else:
|
||||
console.print("[red]Remote run failed. Check logs above.[/]")
|
||||
raise SystemExit(1)
|
||||
return
|
||||
|
||||
run_study(config)
|
||||
|
||||
|
||||
@@ -653,5 +722,111 @@ def _cmd_abliterate(args):
|
||||
)
|
||||
|
||||
|
||||
def _make_remote_runner(args):
|
||||
"""Create a RemoteRunner from CLI --remote flags."""
|
||||
from obliteratus.remote import RemoteConfig, RemoteRunner
|
||||
|
||||
rc = RemoteConfig.from_cli_args(
|
||||
args.remote,
|
||||
port=args.ssh_port,
|
||||
ssh_key=args.ssh_key,
|
||||
remote_dir=args.remote_dir,
|
||||
python=args.remote_python,
|
||||
sync_results=not args.no_sync,
|
||||
)
|
||||
return RemoteRunner(rc)
|
||||
|
||||
|
||||
def _cmd_remote_abliterate(args):
|
||||
from rich.panel import Panel
|
||||
|
||||
runner = _make_remote_runner(args)
|
||||
|
||||
kwargs = {}
|
||||
if args.method:
|
||||
kwargs["method"] = args.method
|
||||
if args.device:
|
||||
kwargs["device"] = args.device
|
||||
if args.dtype:
|
||||
kwargs["dtype"] = args.dtype
|
||||
if args.quantization:
|
||||
kwargs["quantization"] = args.quantization
|
||||
if args.n_directions is not None:
|
||||
kwargs["n_directions"] = args.n_directions
|
||||
if getattr(args, "direction_method", None):
|
||||
kwargs["direction_method"] = args.direction_method
|
||||
if args.regularization is not None:
|
||||
kwargs["regularization"] = args.regularization
|
||||
if args.refinement_passes is not None:
|
||||
kwargs["refinement_passes"] = args.refinement_passes
|
||||
if getattr(args, "large_model", False):
|
||||
kwargs["large_model"] = True
|
||||
if getattr(args, "verify_sample_size", None) is not None:
|
||||
kwargs["verify_sample_size"] = args.verify_sample_size
|
||||
|
||||
result_path = runner.run_obliterate(
|
||||
model=args.model,
|
||||
local_output_dir=args.output_dir,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if result_path:
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold green]Remote abliteration complete![/]\n\n"
|
||||
f" Results at: [cyan]{result_path}[/]\n\n"
|
||||
f" [dim]Load with:[/] AutoModelForCausalLM.from_pretrained('{result_path}')",
|
||||
border_style="green",
|
||||
title="[bold green]REBIRTH COMPLETE (remote)[/]",
|
||||
)
|
||||
)
|
||||
else:
|
||||
console.print("[red]Remote abliteration failed. Check logs above.[/]")
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
def _cmd_remote_run(args):
|
||||
runner = _make_remote_runner(args)
|
||||
result_path = runner.run_config(
|
||||
local_config_path=args.config,
|
||||
local_output_dir=args.output_dir,
|
||||
preset=args.preset,
|
||||
)
|
||||
if result_path:
|
||||
console.print(f"\n[bold green]Remote run complete.[/] Results at: [cyan]{result_path}[/]")
|
||||
else:
|
||||
console.print("[red]Remote run failed. Check logs above.[/]")
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
def _cmd_remote_tourney(args):
|
||||
from rich.panel import Panel
|
||||
|
||||
runner = _make_remote_runner(args)
|
||||
result_path = runner.run_tourney(
|
||||
model=args.model,
|
||||
local_output_dir=args.output_dir,
|
||||
device=args.device,
|
||||
dtype=args.dtype,
|
||||
quantization=args.quantization,
|
||||
methods=args.methods,
|
||||
hub_org=args.hub_org,
|
||||
hub_repo=args.hub_repo,
|
||||
dataset=args.dataset,
|
||||
)
|
||||
if result_path:
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold green]Remote tournament complete![/]\n\n"
|
||||
f" Results at: [cyan]{result_path}[/]",
|
||||
border_style="green",
|
||||
title="[bold green]TOURNAMENT COMPLETE (remote)[/]",
|
||||
)
|
||||
)
|
||||
else:
|
||||
console.print("[red]Remote tournament failed. Check logs above.[/]")
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -35,6 +35,19 @@ class StrategyConfig:
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteConfig:
|
||||
"""Optional remote execution settings for running on a GPU node via SSH."""
|
||||
|
||||
host: str
|
||||
user: str = "root"
|
||||
port: int = 22
|
||||
ssh_key: str | None = None
|
||||
remote_dir: str = "/tmp/obliteratus_run"
|
||||
python: str = "python3"
|
||||
sync_results: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class StudyConfig:
|
||||
"""Top-level configuration for an ablation run."""
|
||||
@@ -46,6 +59,7 @@ class StudyConfig:
|
||||
batch_size: int = 8
|
||||
max_length: int = 512
|
||||
output_dir: str = "results"
|
||||
remote: RemoteConfig | None = None
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> StudyConfig:
|
||||
@@ -82,6 +96,10 @@ class StudyConfig:
|
||||
model = ModelConfig(**d["model"])
|
||||
dataset = DatasetConfig(**d["dataset"])
|
||||
strategies = [StrategyConfig(**s) for s in d["strategies"]]
|
||||
remote = None
|
||||
if "remote" in d and d["remote"]:
|
||||
remote = RemoteConfig(**d["remote"])
|
||||
|
||||
return cls(
|
||||
model=model,
|
||||
dataset=dataset,
|
||||
@@ -90,6 +108,7 @@ class StudyConfig:
|
||||
batch_size=d.get("batch_size", 8),
|
||||
max_length=d.get("max_length", 512),
|
||||
output_dir=d.get("output_dir", "results"),
|
||||
remote=remote,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
|
||||
@@ -0,0 +1,417 @@
|
||||
"""Remote execution support for Obliteratus.
|
||||
|
||||
Run abliteration pipelines on remote GPU nodes via SSH. The remote machine
|
||||
must have CUDA-capable GPUs and a Python environment. Obliteratus will be
|
||||
auto-installed if not present.
|
||||
|
||||
Usage (CLI):
|
||||
obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct \
|
||||
--remote user@gpu-node \
|
||||
--ssh-key ~/.ssh/id_rsa
|
||||
|
||||
Usage (YAML config):
|
||||
remote:
|
||||
host: gpu-node
|
||||
user: root
|
||||
ssh_key: ~/.ssh/id_rsa
|
||||
remote_dir: /tmp/obliteratus_run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteConfig:
|
||||
"""SSH connection and remote execution settings."""
|
||||
|
||||
host: str
|
||||
user: str = "root"
|
||||
port: int = 22
|
||||
ssh_key: str | None = None
|
||||
remote_dir: str = "/tmp/obliteratus_run"
|
||||
install_timeout: int = 600 # seconds
|
||||
python: str = "python3" # remote python binary
|
||||
sync_results: bool = True
|
||||
|
||||
@property
|
||||
def ssh_target(self) -> str:
|
||||
return f"{self.user}@{self.host}"
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, remote_str: str, **kwargs) -> RemoteConfig:
|
||||
"""Parse 'user@host' or just 'host' from CLI --remote flag."""
|
||||
if "@" in remote_str:
|
||||
user, host = remote_str.rsplit("@", 1)
|
||||
else:
|
||||
user = "root"
|
||||
host = remote_str
|
||||
return cls(host=host, user=user, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> RemoteConfig:
|
||||
return cls(**{k: v for k, v in d.items() if k in cls.__dataclass_fields__})
|
||||
|
||||
|
||||
class RemoteRunner:
|
||||
"""Execute Obliteratus commands on a remote machine via SSH."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RemoteConfig,
|
||||
on_log: Callable[[str], None] | None = None,
|
||||
):
|
||||
self.config = config
|
||||
self.on_log = on_log or (lambda msg: console.print(f"[dim][remote][/] {msg}"))
|
||||
|
||||
def _ssh_base_cmd(self) -> list[str]:
|
||||
"""Build base SSH command with common options."""
|
||||
cmd = [
|
||||
"ssh",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "BatchMode=yes",
|
||||
"-o", "ConnectTimeout=30",
|
||||
"-p", str(self.config.port),
|
||||
]
|
||||
if self.config.ssh_key:
|
||||
key_path = os.path.expanduser(self.config.ssh_key)
|
||||
cmd.extend(["-i", key_path])
|
||||
cmd.append(self.config.ssh_target)
|
||||
return cmd
|
||||
|
||||
def _scp_base_cmd(self) -> list[str]:
|
||||
"""Build base SCP command."""
|
||||
cmd = [
|
||||
"scp",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "BatchMode=yes",
|
||||
"-P", str(self.config.port),
|
||||
"-r",
|
||||
]
|
||||
if self.config.ssh_key:
|
||||
key_path = os.path.expanduser(self.config.ssh_key)
|
||||
cmd.extend(["-i", key_path])
|
||||
return cmd
|
||||
|
||||
def run_ssh(self, remote_cmd: str, stream: bool = False, timeout: int | None = None) -> subprocess.CompletedProcess | int:
|
||||
"""Run a command on the remote host.
|
||||
|
||||
If stream=True, streams stdout/stderr in real-time and returns the
|
||||
exit code. Otherwise returns CompletedProcess.
|
||||
"""
|
||||
cmd = self._ssh_base_cmd() + [remote_cmd]
|
||||
|
||||
if stream:
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip("\n")
|
||||
self.on_log(line)
|
||||
proc.wait(timeout=timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
self.on_log("[red]Remote command timed out[/]")
|
||||
return 1
|
||||
return proc.returncode
|
||||
else:
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def check_connection(self) -> bool:
|
||||
"""Verify SSH connectivity."""
|
||||
self.on_log(f"Testing SSH connection to {self.config.ssh_target}...")
|
||||
result = self.run_ssh("echo ok", timeout=30)
|
||||
if isinstance(result, subprocess.CompletedProcess) and result.returncode == 0:
|
||||
self.on_log("SSH connection OK")
|
||||
return True
|
||||
self.on_log("[red]SSH connection failed[/]")
|
||||
return False
|
||||
|
||||
def check_gpu(self) -> str | None:
|
||||
"""Check for CUDA GPUs on remote. Returns nvidia-smi output or None."""
|
||||
result = self.run_ssh("nvidia-smi --query-gpu=name,memory.total --format=csv,noheader", timeout=30)
|
||||
if isinstance(result, subprocess.CompletedProcess) and result.returncode == 0:
|
||||
gpu_info = result.stdout.strip()
|
||||
self.on_log(f"Remote GPUs: {gpu_info}")
|
||||
return gpu_info
|
||||
self.on_log("[yellow]No GPUs detected on remote (nvidia-smi failed)[/]")
|
||||
return None
|
||||
|
||||
def ensure_obliteratus(self) -> bool:
|
||||
"""Install or update obliteratus on the remote if needed."""
|
||||
# Check if already installed
|
||||
check = self.run_ssh(
|
||||
f"{self.config.python} -c \"import obliteratus; print(obliteratus.__version__)\"",
|
||||
timeout=30,
|
||||
)
|
||||
if isinstance(check, subprocess.CompletedProcess) and check.returncode == 0:
|
||||
version = check.stdout.strip()
|
||||
self.on_log(f"Obliteratus {version} already installed on remote")
|
||||
return True
|
||||
|
||||
# Install from PyPI or git
|
||||
self.on_log("Installing obliteratus on remote...")
|
||||
install_cmd = (
|
||||
f"{self.config.python} -m pip install --quiet "
|
||||
f"'obliteratus @ git+https://github.com/EleutherAI/OBLITERATUS.git'"
|
||||
)
|
||||
rc = self.run_ssh(install_cmd, stream=True, timeout=self.config.install_timeout)
|
||||
if rc != 0:
|
||||
self.on_log("[red]Failed to install obliteratus on remote[/]")
|
||||
return False
|
||||
|
||||
self.on_log("Obliteratus installed successfully")
|
||||
return True
|
||||
|
||||
def sync_results_back(self, remote_output_dir: str, local_output_dir: str) -> bool:
|
||||
"""Copy results from remote back to local machine via scp."""
|
||||
local_path = Path(local_output_dir)
|
||||
local_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.on_log(f"Syncing results: {self.config.ssh_target}:{remote_output_dir} -> {local_output_dir}")
|
||||
|
||||
cmd = self._scp_base_cmd() + [
|
||||
f"{self.config.ssh_target}:{remote_output_dir}/",
|
||||
str(local_path),
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600)
|
||||
if result.returncode == 0:
|
||||
self.on_log(f"Results synced to {local_output_dir}")
|
||||
return True
|
||||
else:
|
||||
self.on_log(f"[red]SCP failed: {result.stderr}[/]")
|
||||
return False
|
||||
|
||||
def build_obliterate_command(
|
||||
self,
|
||||
model: str,
|
||||
output_dir: str | None = None,
|
||||
method: str = "advanced",
|
||||
device: str = "auto",
|
||||
dtype: str = "float16",
|
||||
quantization: str | None = None,
|
||||
n_directions: int | None = None,
|
||||
direction_method: str | None = None,
|
||||
regularization: float | None = None,
|
||||
refinement_passes: int | None = None,
|
||||
large_model: bool = False,
|
||||
verify_sample_size: int | None = None,
|
||||
) -> str:
|
||||
"""Build the remote obliteratus CLI command."""
|
||||
remote_output = output_dir or f"{self.config.remote_dir}/output/{model.replace('/', '_')}"
|
||||
|
||||
parts = [
|
||||
self.config.python, "-m", "obliteratus",
|
||||
"obliterate", shlex.quote(model),
|
||||
"--output-dir", shlex.quote(remote_output),
|
||||
"--method", method,
|
||||
"--device", device,
|
||||
"--dtype", dtype,
|
||||
]
|
||||
if quantization:
|
||||
parts.extend(["--quantization", quantization])
|
||||
if n_directions is not None:
|
||||
parts.extend(["--n-directions", str(n_directions)])
|
||||
if direction_method:
|
||||
parts.extend(["--direction-method", direction_method])
|
||||
if regularization is not None:
|
||||
parts.extend(["--regularization", str(regularization)])
|
||||
if refinement_passes is not None:
|
||||
parts.extend(["--refinement-passes", str(refinement_passes)])
|
||||
if large_model:
|
||||
parts.append("--large-model")
|
||||
if verify_sample_size is not None:
|
||||
parts.extend(["--verify-sample-size", str(verify_sample_size)])
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
def build_run_command(self, remote_config_path: str, output_dir: str | None = None, preset: str | None = None) -> str:
|
||||
"""Build remote 'obliteratus run' command."""
|
||||
parts = [
|
||||
self.config.python, "-m", "obliteratus",
|
||||
"run", shlex.quote(remote_config_path),
|
||||
]
|
||||
if output_dir:
|
||||
parts.extend(["--output-dir", shlex.quote(output_dir)])
|
||||
if preset:
|
||||
parts.extend(["--preset", preset])
|
||||
return " ".join(parts)
|
||||
|
||||
def build_tourney_command(
|
||||
self,
|
||||
model: str,
|
||||
output_dir: str | None = None,
|
||||
device: str = "auto",
|
||||
dtype: str = "float16",
|
||||
quantization: str | None = None,
|
||||
methods: list[str] | None = None,
|
||||
hub_org: str | None = None,
|
||||
hub_repo: str | None = None,
|
||||
dataset: str = "builtin",
|
||||
) -> str:
|
||||
"""Build remote 'obliteratus tourney' command."""
|
||||
remote_output = output_dir or f"{self.config.remote_dir}/tourney/{model.replace('/', '_')}"
|
||||
|
||||
parts = [
|
||||
self.config.python, "-m", "obliteratus",
|
||||
"tourney", shlex.quote(model),
|
||||
"--output-dir", shlex.quote(remote_output),
|
||||
"--device", device,
|
||||
"--dtype", dtype,
|
||||
"--dataset", dataset,
|
||||
]
|
||||
if quantization:
|
||||
parts.extend(["--quantization", quantization])
|
||||
if hub_org:
|
||||
parts.extend(["--hub-org", hub_org])
|
||||
if hub_repo:
|
||||
parts.extend(["--hub-repo", hub_repo])
|
||||
if methods:
|
||||
parts.extend(["--methods"] + methods)
|
||||
return " ".join(parts)
|
||||
|
||||
def upload_config(self, local_config_path: str) -> str:
|
||||
"""Upload a YAML config file to the remote."""
|
||||
remote_path = f"{self.config.remote_dir}/config.yaml"
|
||||
self.run_ssh(f"mkdir -p {shlex.quote(self.config.remote_dir)}")
|
||||
|
||||
cmd = self._scp_base_cmd()
|
||||
# scp uses -P not -p, already handled in _scp_base_cmd
|
||||
cmd += [local_config_path, f"{self.config.ssh_target}:{remote_path}"]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Failed to upload config: {result.stderr}")
|
||||
self.on_log(f"Config uploaded to {remote_path}")
|
||||
return remote_path
|
||||
|
||||
def run_obliterate(
|
||||
self,
|
||||
model: str,
|
||||
local_output_dir: str | None = None,
|
||||
**kwargs,
|
||||
) -> str | None:
|
||||
"""Full remote obliteration: setup, run, sync results.
|
||||
|
||||
Returns local path to results, or None on failure.
|
||||
"""
|
||||
# 1. Verify connection
|
||||
if not self.check_connection():
|
||||
return None
|
||||
|
||||
# 2. Check GPUs
|
||||
self.check_gpu()
|
||||
|
||||
# 3. Ensure obliteratus is installed
|
||||
if not self.ensure_obliteratus():
|
||||
return None
|
||||
|
||||
# 4. Create remote working directory
|
||||
self.run_ssh(f"mkdir -p {shlex.quote(self.config.remote_dir)}")
|
||||
|
||||
# 5. Build and run the command
|
||||
remote_output = f"{self.config.remote_dir}/output/{model.replace('/', '_')}"
|
||||
cmd = self.build_obliterate_command(model, output_dir=remote_output, **kwargs)
|
||||
self.on_log(f"Running: {cmd}")
|
||||
|
||||
rc = self.run_ssh(cmd, stream=True)
|
||||
if rc != 0:
|
||||
self.on_log(f"[red]Remote obliteration failed (exit code {rc})[/]")
|
||||
return None
|
||||
|
||||
# 6. Sync results back
|
||||
if self.config.sync_results:
|
||||
local_output = local_output_dir or f"abliterated/{model.replace('/', '_')}"
|
||||
if self.sync_results_back(remote_output, local_output):
|
||||
return local_output
|
||||
return None
|
||||
|
||||
self.on_log(f"Results on remote: {remote_output}")
|
||||
return remote_output
|
||||
|
||||
def run_config(
|
||||
self,
|
||||
local_config_path: str,
|
||||
local_output_dir: str | None = None,
|
||||
preset: str | None = None,
|
||||
) -> str | None:
|
||||
"""Upload config, run study remotely, sync results."""
|
||||
if not self.check_connection():
|
||||
return None
|
||||
self.check_gpu()
|
||||
if not self.ensure_obliteratus():
|
||||
return None
|
||||
|
||||
# Upload config
|
||||
remote_config = self.upload_config(local_config_path)
|
||||
|
||||
# Determine remote output dir
|
||||
remote_output = f"{self.config.remote_dir}/results"
|
||||
cmd = self.build_run_command(remote_config, output_dir=remote_output, preset=preset)
|
||||
self.on_log(f"Running: {cmd}")
|
||||
|
||||
rc = self.run_ssh(cmd, stream=True)
|
||||
if rc != 0:
|
||||
self.on_log(f"[red]Remote run failed (exit code {rc})[/]")
|
||||
return None
|
||||
|
||||
if self.config.sync_results:
|
||||
local_output = local_output_dir or "results"
|
||||
if self.sync_results_back(remote_output, local_output):
|
||||
return local_output
|
||||
return None
|
||||
|
||||
return remote_output
|
||||
|
||||
def run_tourney(
|
||||
self,
|
||||
model: str,
|
||||
local_output_dir: str | None = None,
|
||||
**kwargs,
|
||||
) -> str | None:
|
||||
"""Run tournament remotely, sync results."""
|
||||
if not self.check_connection():
|
||||
return None
|
||||
self.check_gpu()
|
||||
if not self.ensure_obliteratus():
|
||||
return None
|
||||
|
||||
remote_output = f"{self.config.remote_dir}/tourney/{model.replace('/', '_')}"
|
||||
cmd = self.build_tourney_command(model, output_dir=remote_output, **kwargs)
|
||||
self.on_log(f"Running: {cmd}")
|
||||
|
||||
rc = self.run_ssh(cmd, stream=True)
|
||||
if rc != 0:
|
||||
self.on_log(f"[red]Remote tourney failed (exit code {rc})[/]")
|
||||
return None
|
||||
|
||||
if self.config.sync_results:
|
||||
local_output = local_output_dir or f"/tmp/obliteratus_tourney/{model.replace('/', '_')}"
|
||||
if self.sync_results_back(remote_output, local_output):
|
||||
return local_output
|
||||
return None
|
||||
|
||||
return remote_output
|
||||
Reference in New Issue
Block a user