mirror of
https://github.com/FuzzingLabs/fuzzforge_ai.git
synced 2026-02-12 19:12:49 +00:00
- Disable FuzzForge MCP connection (no Prefect backend) - Add a2a_wrapper module for programmatic A2A agent tasks - Add task_agent (LiteLLM A2A agent) on port 10900 - Create volumes/env/ for centralized Docker config - Update docker-compose.yml with task-agent service - Remove workflow_automation_skill from agent card
138 lines
4.6 KiB
Python
138 lines
4.6 KiB
Python
"""Callbacks and instruction providers for the LiteLLM hot-swap agent."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Optional
|
|
|
|
from google.adk.agents.callback_context import CallbackContext
|
|
from google.adk.agents.readonly_context import ReadonlyContext
|
|
from google.adk.models.llm_request import LlmRequest
|
|
from google.genai import types
|
|
|
|
from .config import CONTROL_PREFIX, DEFAULT_MODEL
|
|
from .control import HotSwapCommand, parse_control_message, parse_model_spec
|
|
from .prompts import BASE_INSTRUCTION
|
|
from .state import HotSwapState, apply_state_to_agent
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
def provide_instruction(ctx: ReadonlyContext | None = None) -> str:
|
|
"""Compose the system instruction using the stored state."""
|
|
|
|
state_mapping = getattr(ctx, "state", None)
|
|
state = HotSwapState.from_mapping(state_mapping)
|
|
prompt = state.prompt or BASE_INSTRUCTION
|
|
return f"{prompt}\n\nActive model: {state.display_model}"
|
|
|
|
|
|
def _ensure_state(callback_context: CallbackContext) -> HotSwapState:
|
|
state = HotSwapState.from_mapping(callback_context.state)
|
|
state.persist(callback_context.state)
|
|
return state
|
|
|
|
|
|
def _session_id(callback_context: CallbackContext) -> str:
|
|
session = getattr(callback_context, "session", None)
|
|
if session is None:
|
|
session = getattr(callback_context._invocation_context, "session", None)
|
|
return getattr(session, "id", "unknown-session")
|
|
|
|
|
|
async def before_model_callback(
|
|
callback_context: CallbackContext,
|
|
llm_request: LlmRequest,
|
|
) -> Optional[types.Content]:
|
|
"""Ensure outgoing requests use the active model from session state."""
|
|
|
|
state = _ensure_state(callback_context)
|
|
try:
|
|
apply_state_to_agent(callback_context._invocation_context, state)
|
|
except Exception: # pragma: no cover - defensive logging
|
|
_LOGGER.exception(
|
|
"Failed to apply LiteLLM model '%s' (provider=%s) for session %s",
|
|
state.model,
|
|
state.provider,
|
|
callback_context.session.id,
|
|
)
|
|
llm_request.model = state.model or DEFAULT_MODEL
|
|
return None
|
|
|
|
|
|
async def before_agent_callback(
|
|
callback_context: CallbackContext,
|
|
) -> Optional[types.Content]:
|
|
"""Intercept hot-swap control messages and update session state."""
|
|
|
|
user_content = callback_context.user_content
|
|
if not user_content or not user_content.parts:
|
|
return None
|
|
|
|
first_part = user_content.parts[0]
|
|
message_text = (first_part.text or "").strip()
|
|
if not message_text.startswith(CONTROL_PREFIX):
|
|
return None
|
|
|
|
parsed = parse_control_message(message_text)
|
|
if not parsed:
|
|
return None
|
|
|
|
command, payload = parsed
|
|
state = _ensure_state(callback_context)
|
|
|
|
if command is HotSwapCommand.MODEL:
|
|
if not payload:
|
|
return _render("❌ Missing model specification for hot-swap.")
|
|
try:
|
|
spec = parse_model_spec(payload)
|
|
except ValueError as exc:
|
|
return _render(f"❌ Invalid model specification: {exc}")
|
|
|
|
state.model = spec.model
|
|
state.provider = spec.provider
|
|
state.persist(callback_context.state)
|
|
try:
|
|
apply_state_to_agent(callback_context._invocation_context, state)
|
|
except Exception: # pragma: no cover - defensive logging
|
|
_LOGGER.exception(
|
|
"Failed to apply LiteLLM model '%s' (provider=%s) for session %s",
|
|
state.model,
|
|
state.provider,
|
|
_session_id(callback_context),
|
|
)
|
|
_LOGGER.info(
|
|
"Hot-swapped model to %s (provider=%s, session=%s)",
|
|
state.model,
|
|
state.provider,
|
|
_session_id(callback_context),
|
|
)
|
|
label = state.display_model
|
|
return _render(f"✅ Model switched to: {label}")
|
|
|
|
if command is HotSwapCommand.PROMPT:
|
|
prompt_value = (payload or "").strip()
|
|
state.prompt = prompt_value or None
|
|
state.persist(callback_context.state)
|
|
if state.prompt:
|
|
_LOGGER.info(
|
|
"Updated prompt for session %s", _session_id(callback_context)
|
|
)
|
|
return _render(
|
|
"✅ System prompt updated. This change takes effect immediately."
|
|
)
|
|
return _render("✅ System prompt cleared. Reverting to default instruction.")
|
|
|
|
if command is HotSwapCommand.GET_CONFIG:
|
|
return _render(state.describe())
|
|
|
|
expected = ", ".join(HotSwapCommand.choices())
|
|
return _render(
|
|
"⚠️ Unsupported hot-swap command. Available verbs: "
|
|
f"{expected}."
|
|
)
|
|
|
|
|
|
def _render(message: str) -> types.ModelContent:
|
|
return types.ModelContent(parts=[types.Part(text=message)])
|