Files
fuzzforge_ai/ai/agents/task_agent/litellm_agent/state.py

255 lines
9.1 KiB
Python

"""Session state utilities for the LiteLLM hot-swap agent."""
from __future__ import annotations
from dataclasses import dataclass
import os
from typing import Any, Mapping, MutableMapping, Optional
import httpx
from .config import (
DEFAULT_MODEL,
DEFAULT_PROVIDER,
PROXY_BASE_URL,
STATE_MODEL_KEY,
STATE_PROMPT_KEY,
STATE_PROVIDER_KEY,
)
@dataclass(slots=True)
class HotSwapState:
"""Lightweight view of the hot-swap session state."""
model: str = DEFAULT_MODEL
provider: Optional[str] = None
prompt: Optional[str] = None
@classmethod
def from_mapping(cls, mapping: Optional[Mapping[str, Any]]) -> "HotSwapState":
if not mapping:
return cls()
raw_model = mapping.get(STATE_MODEL_KEY, DEFAULT_MODEL)
raw_provider = mapping.get(STATE_PROVIDER_KEY)
raw_prompt = mapping.get(STATE_PROMPT_KEY)
model = raw_model.strip() if isinstance(raw_model, str) else DEFAULT_MODEL
provider = raw_provider.strip() if isinstance(raw_provider, str) else None
if not provider and DEFAULT_PROVIDER:
provider = DEFAULT_PROVIDER.strip() or None
prompt = raw_prompt.strip() if isinstance(raw_prompt, str) else None
return cls(
model=model or DEFAULT_MODEL,
provider=provider or None,
prompt=prompt or None,
)
def persist(self, store: MutableMapping[str, object]) -> None:
store[STATE_MODEL_KEY] = self.model
if self.provider:
store[STATE_PROVIDER_KEY] = self.provider
else:
store[STATE_PROVIDER_KEY] = None
store[STATE_PROMPT_KEY] = self.prompt
def describe(self) -> str:
prompt_value = self.prompt if self.prompt else "(default prompt)"
provider_value = self.provider if self.provider else "(default provider)"
return (
"📊 Current Configuration\n"
"━━━━━━━━━━━━━━━━━━━━━━\n"
f"Model: {self.model}\n"
f"Provider: {provider_value}\n"
f"System Prompt: {prompt_value}\n"
"━━━━━━━━━━━━━━━━━━━━━━"
)
def instantiate_llm(self):
"""Create a LiteLlm instance for the current state."""
from google.adk.models.lite_llm import LiteLlm # Lazy import to avoid cycle
from google.adk.models.lite_llm import LiteLLMClient
from litellm.types.utils import Choices, Message, ModelResponse, Usage
kwargs = {"model": self.model}
if self.provider:
kwargs["custom_llm_provider"] = self.provider
if PROXY_BASE_URL:
provider = (self.provider or DEFAULT_PROVIDER or "").lower()
if provider and provider != "openai":
kwargs["api_base"] = f"{PROXY_BASE_URL.rstrip('/')}/{provider}"
else:
kwargs["api_base"] = PROXY_BASE_URL
kwargs.setdefault("api_key", os.environ.get("OPENAI_API_KEY"))
provider = (self.provider or DEFAULT_PROVIDER or "").lower()
model_suffix = self.model.split("/", 1)[-1]
use_responses = provider == "openai" and (
model_suffix.startswith("gpt-5") or model_suffix.startswith("o1")
)
if use_responses:
kwargs.setdefault("use_responses_api", True)
llm = LiteLlm(**kwargs)
if use_responses and PROXY_BASE_URL:
class _ResponsesAwareClient(LiteLLMClient):
def __init__(self, base_client: LiteLLMClient, api_base: str, api_key: str):
self._base_client = base_client
self._api_base = api_base.rstrip("/")
self._api_key = api_key
async def acompletion(self, model, messages, tools, **kwargs): # type: ignore[override]
use_responses_api = kwargs.pop("use_responses_api", False)
if not use_responses_api:
return await self._base_client.acompletion(
model=model,
messages=messages,
tools=tools,
**kwargs,
)
resolved_model = model
if "/" not in resolved_model:
resolved_model = f"openai/{resolved_model}"
payload = {
"model": resolved_model,
"input": _messages_to_responses_input(messages),
}
timeout = kwargs.get("timeout", 60)
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{self._api_base}/v1/responses",
json=payload,
headers=headers,
)
try:
response.raise_for_status()
except httpx.HTTPStatusError as exc:
text = exc.response.text
raise RuntimeError(
f"Bifrost responses request failed: {text}"
) from exc
data = response.json()
text_output = _extract_output_text(data)
usage = data.get("usage", {})
return ModelResponse(
id=data.get("id"),
model=model,
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(role="assistant", content=text_output),
provider_specific_fields={"bifrost_response": data},
)
],
usage=Usage(
prompt_tokens=usage.get("input_tokens"),
completion_tokens=usage.get("output_tokens"),
reasoning_tokens=usage.get("output_tokens_details", {}).get(
"reasoning_tokens"
),
total_tokens=usage.get("total_tokens"),
),
)
llm.llm_client = _ResponsesAwareClient(
llm.llm_client,
PROXY_BASE_URL,
os.environ.get("OPENAI_API_KEY", ""),
)
return llm
@property
def display_model(self) -> str:
if self.provider:
return f"{self.provider}/{self.model}"
return self.model
def apply_state_to_agent(invocation_context, state: HotSwapState) -> None:
"""Update the provided agent with a LiteLLM instance matching state."""
agent = invocation_context.agent
agent.model = state.instantiate_llm()
def _messages_to_responses_input(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
inputs: list[dict[str, Any]] = []
for message in messages:
role = message.get("role", "user")
content = message.get("content", "")
text_segments: list[str] = []
if isinstance(content, list):
for item in content:
if isinstance(item, dict):
text = item.get("text") or item.get("content")
if text:
text_segments.append(str(text))
elif isinstance(item, str):
text_segments.append(item)
elif isinstance(content, str):
text_segments.append(content)
text = "\n".join(segment.strip() for segment in text_segments if segment)
if not text:
continue
entry_type = "input_text"
if role == "assistant":
entry_type = "output_text"
inputs.append(
{
"role": role,
"content": [
{
"type": entry_type,
"text": text,
}
],
}
)
if not inputs:
inputs.append(
{
"role": "user",
"content": [
{
"type": "input_text",
"text": "",
}
],
}
)
return inputs
def _extract_output_text(response_json: dict[str, Any]) -> str:
outputs = response_json.get("output", [])
collected: list[str] = []
for item in outputs:
if isinstance(item, dict) and item.get("type") == "message":
for part in item.get("content", []):
if isinstance(part, dict) and part.get("type") == "output_text":
text = part.get("text", "")
if text:
collected.append(str(text))
return "\n\n".join(collected).strip()