mirror of
https://github.com/FuzzingLabs/fuzzforge_ai.git
synced 2026-05-21 04:16:48 +02:00
feat: seed governance config and responses routing
This commit is contained in:
@@ -0,0 +1,557 @@
|
||||
"""Bootstrap the Bifrost proxy with providers and default virtual keys.
|
||||
|
||||
This script runs inside a one-shot container during docker-compose startup.
|
||||
It will:
|
||||
1. Wait for the proxy health endpoint to respond.
|
||||
2. Configure any upstream providers for which an env key is present.
|
||||
3. Create (or reuse) the default virtual key for the task agent.
|
||||
4. Persist the generated key back into volumes/env/.env so the agent uses it.
|
||||
|
||||
The script is idempotent: rerunning it leaves existing configs in place and skips
|
||||
key generation if OPENAI_API_KEY already contains a proxy-issued key.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterable, Mapping, Sequence
|
||||
|
||||
PROXY_BASE_URL = os.getenv("PROXY_BASE_URL", "http://llm-proxy:8080").rstrip("/")
|
||||
ENV_FILE_PATH = Path(os.getenv("ENV_FILE_PATH", "/bootstrap/env/.env"))
|
||||
BIFROST_ENV_FILE_PATH = Path(
|
||||
os.getenv("BIFROST_ENV_FILE_PATH", "/bootstrap/env/.env.bifrost")
|
||||
)
|
||||
CONFIG_FILE_PATH = Path(os.getenv("CONFIG_FILE_PATH", "/bootstrap/data/config.json"))
|
||||
DEFAULT_VIRTUAL_KEY_NAME = "task-agent default"
|
||||
DEFAULT_VIRTUAL_KEY_USER = "fuzzforge-task-agent"
|
||||
PLACEHOLDER_KEY = "sk-proxy-default"
|
||||
MAX_WAIT_SECONDS = 120
|
||||
|
||||
DEFAULT_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"openai": ["gpt-5"],
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderSpec:
|
||||
name: str
|
||||
env_var: str
|
||||
fallback_env_vars: tuple[str, ...] = ()
|
||||
|
||||
@property
|
||||
def env_reference(self) -> str:
|
||||
return f"env.{self.env_var}"
|
||||
|
||||
|
||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
ProviderSpec("openai", "BIFROST_OPENAI_KEY", ("OPENAI_API_KEY",)),
|
||||
ProviderSpec("anthropic", "BIFROST_ANTHROPIC_KEY", ("ANTHROPIC_API_KEY",)),
|
||||
ProviderSpec("gemini", "BIFROST_GEMINI_KEY", ("GEMINI_API_KEY",)),
|
||||
ProviderSpec("mistral", "BIFROST_MISTRAL_KEY", ("MISTRAL_API_KEY",)),
|
||||
ProviderSpec("openrouter", "BIFROST_OPENROUTER_KEY", ("OPENROUTER_API_KEY",)),
|
||||
)
|
||||
|
||||
|
||||
UUID_PATTERN = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
|
||||
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def looks_like_virtual_key(candidate: str | None) -> bool:
|
||||
if not candidate:
|
||||
return False
|
||||
value = candidate.strip()
|
||||
if not value or value == PLACEHOLDER_KEY:
|
||||
return False
|
||||
if UUID_PATTERN.match(value):
|
||||
return True
|
||||
if value.startswith("sk-proxy-"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def set_env_value(lines: list[str], key: str, value: str) -> tuple[list[str], bool]:
|
||||
prefix = f"{key}="
|
||||
new_line = f"{prefix}{value}"
|
||||
for idx, line in enumerate(lines):
|
||||
stripped = line.lstrip()
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
if stripped.startswith(prefix):
|
||||
if lines[idx].lstrip() == new_line:
|
||||
return lines, False
|
||||
indent = line[: len(line) - len(stripped)]
|
||||
lines[idx] = f"{indent}{new_line}"
|
||||
return lines, True
|
||||
lines.append(new_line)
|
||||
return lines, True
|
||||
|
||||
|
||||
def parse_env_lines(lines: list[str]) -> dict[str, str]:
|
||||
mapping: dict[str, str] = {}
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
if "=" not in stripped:
|
||||
continue
|
||||
key, value = stripped.split("=", 1)
|
||||
mapping[key] = value
|
||||
return mapping
|
||||
|
||||
|
||||
def resolve_provider_key(
|
||||
provider: ProviderSpec,
|
||||
env_map: dict[str, str],
|
||||
bifrost_map: dict[str, str],
|
||||
) -> tuple[str | None, str | None, str | None]:
|
||||
candidate = bifrost_map.get(provider.env_var)
|
||||
if candidate:
|
||||
value = candidate.strip()
|
||||
if value:
|
||||
return value, provider.env_var, "bifrost"
|
||||
|
||||
candidate = env_map.get(provider.env_var)
|
||||
if candidate:
|
||||
value = candidate.strip()
|
||||
if value and value != PLACEHOLDER_KEY:
|
||||
return value, provider.env_var, "env"
|
||||
|
||||
candidate = os.getenv(provider.env_var)
|
||||
if candidate:
|
||||
value = candidate.strip()
|
||||
if value and value != PLACEHOLDER_KEY:
|
||||
return value, provider.env_var, "env"
|
||||
|
||||
for var in provider.fallback_env_vars:
|
||||
raw_value = env_map.get(var) or os.getenv(var)
|
||||
if not raw_value:
|
||||
continue
|
||||
value = raw_value.strip()
|
||||
if not value or value == PLACEHOLDER_KEY:
|
||||
continue
|
||||
if var == "OPENAI_API_KEY" and looks_like_virtual_key(value):
|
||||
continue
|
||||
return value, var, "fallback"
|
||||
|
||||
return None, None, None
|
||||
|
||||
|
||||
def ensure_provider_env_export(
|
||||
lines: list[str], provider: ProviderSpec, key_value: str
|
||||
) -> tuple[list[str], bool]:
|
||||
# Store provider secrets under their dedicated BIFROST_* variables so future
|
||||
# restarts inject them into the proxy container environment automatically.
|
||||
updated_lines, changed = set_env_value(lines, provider.env_var, key_value)
|
||||
if changed:
|
||||
os.environ[provider.env_var] = key_value
|
||||
return updated_lines, changed
|
||||
|
||||
|
||||
def get_models_for_provider(
|
||||
provider: ProviderSpec,
|
||||
env_map: dict[str, str],
|
||||
bifrost_map: dict[str, str],
|
||||
) -> list[str]:
|
||||
env_var = f"BIFROST_{provider.name.upper()}_MODELS"
|
||||
raw_value = (
|
||||
os.getenv(env_var)
|
||||
or env_map.get(env_var)
|
||||
or bifrost_map.get(env_var)
|
||||
)
|
||||
if raw_value:
|
||||
models = [item.strip() for item in raw_value.split(",") if item.strip()]
|
||||
if models:
|
||||
return models
|
||||
return DEFAULT_PROVIDER_MODELS.get(provider.name, [])
|
||||
|
||||
|
||||
def _should_use_responses_api(
|
||||
provider: ProviderSpec,
|
||||
models: list[str],
|
||||
env_map: dict[str, str],
|
||||
bifrost_map: dict[str, str],
|
||||
) -> bool:
|
||||
if provider.name != "openai":
|
||||
return False
|
||||
|
||||
env_var = "BIFROST_OPENAI_USE_RESPONSES_API"
|
||||
raw_value = (
|
||||
os.getenv(env_var)
|
||||
or env_map.get(env_var)
|
||||
or bifrost_map.get(env_var)
|
||||
)
|
||||
if raw_value and raw_value.strip().lower() in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
|
||||
for model in models:
|
||||
suffix = model.split("/", 1)[-1]
|
||||
if suffix.startswith("gpt-5") or suffix.startswith("o1"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def build_provider_config_entry(
|
||||
provider: ProviderSpec,
|
||||
env_map: dict[str, str],
|
||||
bifrost_map: dict[str, str],
|
||||
) -> dict[str, object]:
|
||||
models = get_models_for_provider(provider, env_map, bifrost_map)
|
||||
key_entry: dict[str, object] = {
|
||||
"value": provider.env_reference,
|
||||
"models": models,
|
||||
"weight": 1.0,
|
||||
}
|
||||
if _should_use_responses_api(provider, models, env_map, bifrost_map):
|
||||
key_entry["openai_key_config"] = {"use_responses_api": True}
|
||||
|
||||
entry: dict[str, object] = {"keys": [key_entry]}
|
||||
return entry
|
||||
|
||||
|
||||
def _default_client_config() -> dict[str, object]:
|
||||
return {"drop_excess_requests": False}
|
||||
|
||||
|
||||
def _default_config_store_config() -> dict[str, object]:
|
||||
return {
|
||||
"enabled": True,
|
||||
"type": "sqlite",
|
||||
"config": {"path": "./config.db"},
|
||||
}
|
||||
|
||||
|
||||
def update_config_file(
|
||||
providers_config: dict[str, dict[str, object]],
|
||||
virtual_key_value: str | None = None,
|
||||
) -> None:
|
||||
if not providers_config:
|
||||
return
|
||||
|
||||
config_data: dict[str, object]
|
||||
if CONFIG_FILE_PATH.exists():
|
||||
try:
|
||||
config_data = json.loads(CONFIG_FILE_PATH.read_text() or "{}")
|
||||
except json.JSONDecodeError:
|
||||
log(
|
||||
"Existing config.json is invalid JSON; regenerating from provider metadata"
|
||||
)
|
||||
config_data = {}
|
||||
else:
|
||||
config_data = {}
|
||||
|
||||
providers_section = config_data.setdefault("providers", {})
|
||||
config_data.setdefault("client", _default_client_config())
|
||||
config_data.setdefault("config_store", _default_config_store_config())
|
||||
|
||||
changed = False
|
||||
for name, entry in providers_config.items():
|
||||
if providers_section.get(name) != entry:
|
||||
providers_section[name] = entry
|
||||
changed = True
|
||||
|
||||
if virtual_key_value:
|
||||
governance_section = config_data.setdefault("governance", {})
|
||||
vk_list: list[dict[str, object]] = governance_section.setdefault(
|
||||
"virtual_keys", []
|
||||
)
|
||||
|
||||
provider_configs = []
|
||||
for provider_name, entry in providers_config.items():
|
||||
allowed_models: list[str] = []
|
||||
for key_entry in entry.get("keys", []):
|
||||
models = key_entry.get("models", [])
|
||||
if models:
|
||||
allowed_models.extend(models)
|
||||
provider_configs.append(
|
||||
{
|
||||
"provider": provider_name,
|
||||
"weight": 1.0,
|
||||
"allowed_models": allowed_models,
|
||||
}
|
||||
)
|
||||
|
||||
now_iso = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
virtual_key_entry = {
|
||||
"id": f"{DEFAULT_VIRTUAL_KEY_USER}-vk",
|
||||
"name": DEFAULT_VIRTUAL_KEY_NAME,
|
||||
"description": "Default virtual key issued during bootstrap",
|
||||
"value": virtual_key_value,
|
||||
"is_active": True,
|
||||
"provider_configs": provider_configs,
|
||||
"created_at": now_iso,
|
||||
"updated_at": now_iso,
|
||||
}
|
||||
|
||||
matched = False
|
||||
for existing in vk_list:
|
||||
if existing.get("name") == DEFAULT_VIRTUAL_KEY_NAME or existing.get(
|
||||
"id"
|
||||
) == virtual_key_entry["id"]:
|
||||
existing.update(virtual_key_entry)
|
||||
matched = True
|
||||
changed = True
|
||||
break
|
||||
|
||||
if not matched:
|
||||
vk_list.append(virtual_key_entry)
|
||||
changed = True
|
||||
|
||||
if not changed:
|
||||
return
|
||||
|
||||
CONFIG_FILE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
CONFIG_FILE_PATH.write_text(json.dumps(config_data, indent=2, sort_keys=True) + "\n")
|
||||
log(f"Wrote provider config to {CONFIG_FILE_PATH}")
|
||||
|
||||
|
||||
def log(message: str) -> None:
|
||||
print(f"[llm-proxy-bootstrap] {message}", flush=True)
|
||||
|
||||
|
||||
def wait_for_proxy() -> None:
|
||||
url = f"{PROXY_BASE_URL}/health"
|
||||
deadline = time.time() + MAX_WAIT_SECONDS
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
with urllib.request.urlopen(url) as response: # noqa: S310
|
||||
if response.status == 200:
|
||||
log("Proxy health endpoint is reachable")
|
||||
return
|
||||
except urllib.error.URLError as exc: # pragma: no cover - best effort logging
|
||||
log(f"Proxy not ready yet: {exc}")
|
||||
time.sleep(3)
|
||||
raise TimeoutError(f"Timed out waiting for {url}")
|
||||
|
||||
|
||||
def request_json(path: str, *, method: str = "GET", payload: dict | None = None) -> tuple[int, str]:
|
||||
url = f"{PROXY_BASE_URL}{path}"
|
||||
data = None
|
||||
headers = {"Accept": "application/json"}
|
||||
if payload is not None:
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
headers["Content-Type"] = "application/json"
|
||||
request = urllib.request.Request(url, data=data, headers=headers, method=method)
|
||||
try:
|
||||
with urllib.request.urlopen(request) as response: # noqa: S310
|
||||
body = response.read().decode("utf-8")
|
||||
return response.status, body
|
||||
except urllib.error.HTTPError as exc:
|
||||
body = exc.read().decode("utf-8")
|
||||
return exc.code, body
|
||||
|
||||
|
||||
def post_json(path: str, payload: dict) -> tuple[int, str]:
|
||||
return request_json(path, method="POST", payload=payload)
|
||||
|
||||
|
||||
def get_json(path: str) -> tuple[int, str]:
|
||||
return request_json(path, method="GET")
|
||||
|
||||
|
||||
def configure_providers() -> dict[str, dict[str, object]]:
|
||||
env_map = parse_env_lines(read_env_file())
|
||||
bifrost_lines = read_bifrost_env_file()
|
||||
bifrost_map = parse_env_lines(bifrost_lines)
|
||||
bifrost_lines_changed = False
|
||||
config_updates: dict[str, dict[str, object]] = {}
|
||||
|
||||
for provider in PROVIDERS:
|
||||
key_value, _source_var, _ = resolve_provider_key(provider, env_map, bifrost_map)
|
||||
if not key_value:
|
||||
continue
|
||||
|
||||
payload = {
|
||||
"provider": provider.name,
|
||||
"keys": [
|
||||
{
|
||||
"value": key_value,
|
||||
"models": [],
|
||||
"weight": 1.0,
|
||||
}
|
||||
],
|
||||
}
|
||||
status, body = post_json("/api/providers", payload)
|
||||
if status in {200, 201}:
|
||||
log(f"Configured provider '{provider.name}'")
|
||||
elif status == 409:
|
||||
log(f"Provider '{provider.name}' already exists (409)")
|
||||
else:
|
||||
log(
|
||||
"Failed to configure provider '%s' (%s): %s"
|
||||
% (provider.name, status, body)
|
||||
)
|
||||
continue
|
||||
|
||||
os.environ[provider.env_var] = key_value
|
||||
if bifrost_map.get(provider.env_var, "") != key_value:
|
||||
bifrost_lines, changed = ensure_provider_env_export(
|
||||
bifrost_lines, provider, key_value
|
||||
)
|
||||
if changed:
|
||||
bifrost_lines_changed = True
|
||||
bifrost_map[provider.env_var] = key_value
|
||||
|
||||
config_updates[provider.name] = build_provider_config_entry(
|
||||
provider, env_map, bifrost_map
|
||||
)
|
||||
|
||||
if bifrost_lines_changed:
|
||||
write_bifrost_env_file(bifrost_lines)
|
||||
return config_updates
|
||||
|
||||
|
||||
def read_env_file() -> list[str]:
|
||||
if not ENV_FILE_PATH.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Expected env file at {ENV_FILE_PATH}. Copy volumes/env/.env.example first."
|
||||
)
|
||||
return ENV_FILE_PATH.read_text().splitlines()
|
||||
|
||||
|
||||
def write_env_file(lines: Iterable[str]) -> None:
|
||||
ENV_FILE_PATH.write_text("\n".join(lines) + "\n")
|
||||
|
||||
|
||||
def read_bifrost_env_file() -> list[str]:
|
||||
if not BIFROST_ENV_FILE_PATH.exists():
|
||||
return []
|
||||
return BIFROST_ENV_FILE_PATH.read_text().splitlines()
|
||||
|
||||
|
||||
def write_bifrost_env_file(lines: Iterable[str]) -> None:
|
||||
BIFROST_ENV_FILE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
BIFROST_ENV_FILE_PATH.write_text("\n".join(lines) + "\n")
|
||||
|
||||
|
||||
def current_env_key() -> str | None:
|
||||
existing = os.getenv("OPENAI_API_KEY")
|
||||
if existing:
|
||||
return existing.strip()
|
||||
# Fall back to reading file if not present in the container environment
|
||||
for line in read_env_file():
|
||||
if line.startswith("OPENAI_API_KEY="):
|
||||
return line.split("=", 1)[1].strip()
|
||||
return None
|
||||
|
||||
|
||||
def _extract_key_value(record: Mapping[str, object]) -> str | None:
|
||||
value = record.get("value") or record.get("key")
|
||||
if value:
|
||||
return str(value)
|
||||
budget = record.get("virtual_key") if isinstance(record.get("virtual_key"), Mapping) else None
|
||||
if isinstance(budget, Mapping):
|
||||
inner_value = budget.get("value") or budget.get("key")
|
||||
if inner_value:
|
||||
return str(inner_value)
|
||||
return None
|
||||
|
||||
|
||||
def find_existing_virtual_key() -> Mapping[str, object] | None:
|
||||
status, body = get_json("/api/governance/virtual-keys")
|
||||
if status != 200:
|
||||
log(f"Could not list virtual keys ({status}): {body}")
|
||||
return None
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except json.JSONDecodeError as exc: # pragma: no cover - defensive
|
||||
log(f"Failed to parse virtual key list: {exc}")
|
||||
return None
|
||||
|
||||
candidates: Sequence[Mapping[str, object]]
|
||||
if isinstance(data, dict) and "virtual_keys" in data and isinstance(data["virtual_keys"], list):
|
||||
candidates = [item for item in data["virtual_keys"] if isinstance(item, Mapping)]
|
||||
elif isinstance(data, list):
|
||||
candidates = [item for item in data if isinstance(item, Mapping)]
|
||||
else:
|
||||
log("Virtual key list response in unexpected format; skipping lookup")
|
||||
return None
|
||||
|
||||
for item in candidates:
|
||||
if str(item.get("name", "")).strip() == DEFAULT_VIRTUAL_KEY_NAME:
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
def upsert_virtual_key() -> str | None:
|
||||
existing_env = current_env_key()
|
||||
|
||||
record = find_existing_virtual_key()
|
||||
if record:
|
||||
key = _extract_key_value(record)
|
||||
if key:
|
||||
log("Reusing existing virtual key from proxy store")
|
||||
return key
|
||||
|
||||
if existing_env and looks_like_virtual_key(existing_env):
|
||||
log(
|
||||
"Virtual key present in env but not found in proxy store; issuing a new key"
|
||||
)
|
||||
|
||||
payload = {
|
||||
"name": DEFAULT_VIRTUAL_KEY_NAME,
|
||||
"user_id": DEFAULT_VIRTUAL_KEY_USER,
|
||||
"budget": {"max_limit": 25.0, "reset_duration": "7d"},
|
||||
}
|
||||
status, body = post_json("/api/governance/virtual-keys", payload)
|
||||
if status not in {200, 201}:
|
||||
log(f"Failed to create virtual key ({status}): {body}")
|
||||
return None
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except json.JSONDecodeError as exc: # pragma: no cover - defensive
|
||||
log(f"Could not parse virtual key response: {exc}")
|
||||
return None
|
||||
key = _extract_key_value(data)
|
||||
if not key:
|
||||
log(f"Virtual key response missing key field: {body}")
|
||||
return None
|
||||
log("Generated new virtual key for task agent")
|
||||
return key
|
||||
|
||||
|
||||
def persist_key_to_env_file(new_key: str) -> None:
|
||||
lines = read_env_file()
|
||||
updated = False
|
||||
for idx, line in enumerate(lines):
|
||||
if line.startswith("OPENAI_API_KEY="):
|
||||
lines[idx] = f"OPENAI_API_KEY={new_key}"
|
||||
updated = True
|
||||
break
|
||||
if not updated:
|
||||
lines.append(f"OPENAI_API_KEY={new_key}")
|
||||
write_env_file(lines)
|
||||
log(f"Wrote virtual key to {ENV_FILE_PATH}")
|
||||
os.environ["OPENAI_API_KEY"] = new_key
|
||||
|
||||
|
||||
def main() -> int:
|
||||
log("Bootstrapping Bifrost proxy")
|
||||
try:
|
||||
wait_for_proxy()
|
||||
providers_config = configure_providers()
|
||||
existing_key = current_env_key()
|
||||
new_key = upsert_virtual_key()
|
||||
virtual_key_value = new_key or existing_key
|
||||
if new_key and new_key != existing_key:
|
||||
persist_key_to_env_file(new_key)
|
||||
update_config_file(providers_config, virtual_key_value)
|
||||
log("Bootstrap complete")
|
||||
return 0
|
||||
except Exception as exc: # pragma: no cover - startup failure reported to logs
|
||||
log(f"Bootstrap failed: {exc}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user