mirror of
https://github.com/msoedov/agentic_security.git
synced 2026-06-24 06:09:55 +02:00
Merge pull request #113 from Praveenk8051/feat/extension-with-sample-tests
feat(operator): add agent testing functionality with endpoint
This commit is contained in:
@@ -1,9 +1,16 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from httpx import LLMSpec
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_ai import Agent, RunContext
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentSpecification(BaseModel):
|
||||
name: str | None = Field(None, description="Name of the LLM/agent")
|
||||
@@ -13,9 +20,9 @@ class AgentSpecification(BaseModel):
|
||||
configuration: dict[str, Any] | None = Field(
|
||||
None, description="Configuration settings"
|
||||
)
|
||||
endpoint: str | None = Field(None, description="Endpoint URL of the deployed agent")
|
||||
|
||||
|
||||
# Define the OperatorToolBox class
|
||||
class OperatorToolBox:
|
||||
def __init__(self, spec: AgentSpecification, datasets: list[dict[str, Any]]):
|
||||
self.spec = spec
|
||||
@@ -29,7 +36,6 @@ class OperatorToolBox:
|
||||
return self.datasets
|
||||
|
||||
def validate(self) -> bool:
|
||||
# Validate the tool box based on the specification
|
||||
if not self.spec.name or not self.spec.version:
|
||||
self.failures.append("Invalid specification: Name or version is missing.")
|
||||
return False
|
||||
@@ -39,28 +45,70 @@ class OperatorToolBox:
|
||||
return True
|
||||
|
||||
def stop(self) -> None:
|
||||
# Stop the tool box
|
||||
print("Stopping the toolbox...")
|
||||
logger.info("Stopping the toolbox...")
|
||||
|
||||
def run(self) -> None:
|
||||
# Run the tool box
|
||||
print("Running the toolbox...")
|
||||
logger.info("Running the toolbox...")
|
||||
|
||||
def get_results(self) -> list[dict[str, Any]]:
|
||||
# Get the results
|
||||
return self.datasets
|
||||
|
||||
def get_failures(self) -> list[str]:
|
||||
# Handle failure
|
||||
return self.failures
|
||||
|
||||
def run_operation(self, operation: str) -> str:
|
||||
# Run an operation based on the specification
|
||||
if operation not in ["dataset1", "dataset2", "dataset3"]:
|
||||
self.failures.append(f"Operation '{operation}' failed: Dataset not found.")
|
||||
return f"Operation '{operation}' failed: Dataset not found."
|
||||
return f"Operation '{operation}' executed successfully."
|
||||
|
||||
async def test(self, description: str, sample_test: dict[str, Any]) -> str:
|
||||
agent = Agent(
|
||||
"openai:gpt-4o",
|
||||
result_type=LLMSpec,
|
||||
system_prompt="Extract the LLM specification from the input",
|
||||
)
|
||||
|
||||
async with agent.run_stream(description) as result:
|
||||
async for spec in result.stream():
|
||||
self.spec.endpoint = spec.url
|
||||
|
||||
# Verify access to the endpoint
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
access_response = await client.get(spec.url)
|
||||
access_response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
self.failures.append(f"HTTP error occurred: {e}")
|
||||
logger.error(f"Access verification failed: {e}")
|
||||
return f"Access verification failed: {e}"
|
||||
except Exception as e:
|
||||
self.failures.append(f"An error occurred: {e}")
|
||||
logger.error(f"Access verification failed: {e}")
|
||||
return f"Access verification failed: {e}"
|
||||
|
||||
# Run the sample test
|
||||
try:
|
||||
test_response = await client.post(
|
||||
f"{spec.url}/test", json=sample_test
|
||||
)
|
||||
test_response.raise_for_status()
|
||||
response_data = test_response.json()
|
||||
if "choices" in response_data and len(response_data["choices"]) > 0:
|
||||
return f"Testing agent at {spec.url} succeeded: {response_data}"
|
||||
else:
|
||||
self.failures.append("Invalid response format")
|
||||
logger.error("Sample test failed: Invalid response format")
|
||||
return "Sample test failed: Invalid response format"
|
||||
except httpx.HTTPStatusError as e:
|
||||
self.failures.append(f"HTTP error occurred: {e}")
|
||||
logger.error(f"Sample test failed: {e}")
|
||||
return f"Sample test failed: {e}"
|
||||
except Exception as e:
|
||||
self.failures.append(f"An error occurred: {e}")
|
||||
logger.error(f"Sample test failed: {e}")
|
||||
return f"Sample test failed: {e}"
|
||||
|
||||
|
||||
# Initialize OperatorToolBox with AgentSpecification
|
||||
spec = AgentSpecification(
|
||||
@@ -71,24 +119,19 @@ spec = AgentSpecification(
|
||||
configuration={"max_tokens": 100},
|
||||
)
|
||||
|
||||
# dataset_manager_agent.py
|
||||
|
||||
|
||||
# Initialize OperatorToolBox
|
||||
toolbox = OperatorToolBox(spec=spec, datasets=["dataset1", "dataset2", "dataset3"])
|
||||
|
||||
# Define the agent with OperatorToolBox as its dependency
|
||||
dataset_manager_agent = Agent(
|
||||
model="gpt-4",
|
||||
deps_type=OperatorToolBox,
|
||||
result_type=str, # The agent will return string results
|
||||
result_type=str,
|
||||
system_prompt="You can validate the toolbox, run operations, and retrieve results or failures.",
|
||||
)
|
||||
|
||||
|
||||
@dataset_manager_agent.tool
|
||||
async def validate_toolbox(ctx: RunContext[OperatorToolBox]) -> str:
|
||||
"""Validate the OperatorToolBox."""
|
||||
is_valid = ctx.deps.validate()
|
||||
if is_valid:
|
||||
return "ToolBox validation successful."
|
||||
@@ -98,14 +141,12 @@ async def validate_toolbox(ctx: RunContext[OperatorToolBox]) -> str:
|
||||
|
||||
@dataset_manager_agent.tool
|
||||
async def execute_operation(ctx: RunContext[OperatorToolBox], operation: str) -> str:
|
||||
"""Execute an operation on a dataset."""
|
||||
result = ctx.deps.run_operation(operation)
|
||||
return result
|
||||
|
||||
|
||||
@dataset_manager_agent.tool
|
||||
async def retrieve_results(ctx: RunContext[OperatorToolBox]) -> str:
|
||||
"""Retrieve the results of operations."""
|
||||
results = ctx.deps.get_results()
|
||||
if results:
|
||||
formatted_results = "\n".join([f"{op}: {res}" for op, res in results.items()])
|
||||
@@ -116,7 +157,6 @@ async def retrieve_results(ctx: RunContext[OperatorToolBox]) -> str:
|
||||
|
||||
@dataset_manager_agent.tool
|
||||
async def retrieve_failures(ctx: RunContext[OperatorToolBox]) -> str:
|
||||
"""Retrieve the list of failures."""
|
||||
failures = ctx.deps.get_failures()
|
||||
if failures:
|
||||
formatted_failures = "\n".join(failures)
|
||||
@@ -125,6 +165,14 @@ async def retrieve_failures(ctx: RunContext[OperatorToolBox]) -> str:
|
||||
return "No failures recorded."
|
||||
|
||||
|
||||
@dataset_manager_agent.tool
|
||||
async def test_agent(
|
||||
ctx: RunContext[OperatorToolBox], description: str, sample_test: dict[str, Any]
|
||||
) -> str:
|
||||
result = await ctx.deps.test(description, sample_test)
|
||||
return result
|
||||
|
||||
|
||||
# Synchronous run example
|
||||
def run_dataset_manager_agent_sync():
|
||||
prompts = [
|
||||
@@ -133,10 +181,18 @@ def run_dataset_manager_agent_sync():
|
||||
"Execute operation on 'dataset4'.", # This should fail
|
||||
"Retrieve the results.",
|
||||
"Retrieve any failures.",
|
||||
"Test my openAI compatible agent deployed at localhost:3000",
|
||||
]
|
||||
|
||||
sample_test = {"prompt": "Hello, how are you?", "max_tokens": 5}
|
||||
|
||||
for prompt in prompts:
|
||||
result = dataset_manager_agent.run_sync(prompt, deps=toolbox)
|
||||
if "Test my" in prompt:
|
||||
result = dataset_manager_agent.run_sync(
|
||||
prompt, deps=toolbox, sample_test=sample_test
|
||||
)
|
||||
else:
|
||||
result = dataset_manager_agent.run_sync(prompt, deps=toolbox)
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"Response: {result.data}\n")
|
||||
|
||||
@@ -149,10 +205,18 @@ async def run_dataset_manager_agent_async():
|
||||
"Execute operation on 'dataset4'.", # This should fail
|
||||
"Retrieve the results.",
|
||||
"Retrieve any failures.",
|
||||
"Test my openAI compatible agent deployed at localhost:3000",
|
||||
]
|
||||
|
||||
sample_test = {"prompt": "Hello, how are you?", "max_tokens": 5}
|
||||
|
||||
for prompt in prompts:
|
||||
result = await dataset_manager_agent.run(prompt, deps=toolbox)
|
||||
if "Test my" in prompt:
|
||||
result = await dataset_manager_agent.run(
|
||||
prompt, deps=toolbox, sample_test=sample_test
|
||||
)
|
||||
else:
|
||||
result = await dataset_manager_agent.run(prompt, deps=toolbox)
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"Response: {result.data}\n")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user