diff --git a/agentic_security/app.py b/agentic_security/app.py index 59fbe58..28fd40a 100644 --- a/agentic_security/app.py +++ b/agentic_security/app.py @@ -1,287 +1,28 @@ -import os -import random -from asyncio import Event, Queue -from datetime import datetime -from logging import config -from pathlib import Path - -from fastapi import BackgroundTasks, FastAPI, HTTPException, Request, Response -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse, StreamingResponse -from loguru import logger -from pydantic import BaseModel, Field -from starlette.middleware.base import BaseHTTPMiddleware - -from .http_spec import LLMSpec -from .probe_actor import fuzzer -from .probe_actor.refusal import REFUSAL_MARKS -from .probe_data import REGISTRY -from .report_chart import plot_security_report - -# Create the FastAPI app instance -app = FastAPI() -origins = [ - "*", -] - - -# Configuration -class Settings: - MAX_BUDGET = 1000 - MAX_DATASETS = 10 - RATE_LIMIT = "100/minute" - DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) - FEATURE_PROXY = False - - -settings = Settings() - -# Middleware setup -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], # Allows all methods - allow_headers=["*"], # Allows all headers +from .core.app import create_app +from .core.logging import setup_logging +from .middleware.cors import setup_cors +from .middleware.logging import LogNon200ResponsesMiddleware +from .routes import ( + static_router, + scan_router, + probe_router, + proxy_router, + report_router, ) -tools_inbox = Queue() -# Global stop event for cancelling scans -stop_event = Event() # Added stop_event to cancel the scan +# Create the FastAPI app +app = create_app() - -@app.get("/") -async def root(): - agentic_security_path = Path(__file__).parent - return FileResponse(f"{agentic_security_path}/static/index.html") - - -@app.get("/main.js") -async def main_js(): - agentic_security_path = Path(__file__).parent - return FileResponse(f"{agentic_security_path}/static/main.js") - - -@app.get("/telemetry.js") -async def telemetry_js(): - agentic_security_path = Path(__file__).parent - if settings.DISABLE_TELEMETRY: - return FileResponse(f"{agentic_security_path}/static/telemetry_disabled.js") - return FileResponse(f"{agentic_security_path}/static/telemetry.js") - - -@app.get("/favicon.ico") -async def favicon(): - agentic_security_path = Path(__file__).parent - return FileResponse(f"{agentic_security_path}/static/favicon.ico") - - -class LLMInfo(BaseModel): - spec: str - - -@app.post("/verify") -async def verify(info: LLMInfo): - - spec = LLMSpec.from_string(info.spec) - r = await spec.probe("test") - if r.status_code >= 400: - raise HTTPException(status_code=r.status_code, detail=r.text) - return dict( - status_code=r.status_code, - body=r.text, - elapsed=r.elapsed.total_seconds(), - timestamp=datetime.now().isoformat(), - ) - - -class Scan(BaseModel): - llmSpec: str - maxBudget: int - datasets: list[dict] = [] - optimize: bool = False - - -class ScanResult(BaseModel): - module: str - tokens: int - cost: float - progress: float - failureRate: float = 0.0 - - -def streaming_response_generator(scan_parameters: Scan): - # The generator function for StreamingResponse - request_factory = LLMSpec.from_string(scan_parameters.llmSpec) - - async def _gen(): - async for scan_result in fuzzer.perform_scan( - request_factory=request_factory, - max_budget=scan_parameters.maxBudget, - datasets=scan_parameters.datasets, - tools_inbox=tools_inbox, - optimize=scan_parameters.optimize, - stop_event=stop_event, # Pass the stop_event to the generator - ): - yield scan_result + "\n" # Adding a newline for separation - - return _gen() - - -@app.post("/scan") -async def scan(scan_parameters: Scan, background_tasks: BackgroundTasks): - - # Initiates streaming of scan results - return StreamingResponse( - streaming_response_generator(scan_parameters), media_type="application/json" - ) - - -class Probe(BaseModel): - prompt: str - - -@app.post("/v1/self-probe") -def self_probe(probe: Probe): - refuse = random.random() < 0.2 - message = random.choice(REFUSAL_MARKS) if refuse else "This is a test!" - message = probe.prompt + " " + message - return { - "id": "chatcmpl-abc123", - "object": "chat.completion", - "created": 1677858242, - "model": "gpt-3.5-turbo-0613", - "usage": {"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20}, - "choices": [ - { - "message": {"role": "assistant", "content": message}, - "logprobs": None, - "finish_reason": "stop", - "index": 0, - } - ], - } - - -@app.get("/v1/data-config") -async def data_config(): - return [m for m in REGISTRY] - - -@app.get("/failures") -async def failures_csv(): - if not Path("failures.csv").exists(): - return {"error": "No failures found"} - return FileResponse("failures.csv") - - -class Table(BaseModel): - table: list[dict] - - -@app.post("/plot.jpeg", response_class=Response) -async def get_plot(table: Table): - buf = plot_security_report(table.table) - return StreamingResponse(buf, media_type="image/jpeg") - - -class Message(BaseModel): - role: str - content: str - - -class CompletionRequest(BaseModel): - """Model for completion requests.""" - - model: str - messages: list[Message] - temperature: float = Field(default=0.7, ge=0.0, le=2.0) - top_p: float = Field(default=1.0, ge=0.0, le=1.0) - n: int = Field(default=1, ge=1, le=10) - stop: list[str] | None = None - max_tokens: int = Field(default=100, ge=1, le=4096) - presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) - frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) - - -# OpenAI proxy endpoint -@app.post("/proxy/chat/completions") -async def proxy_completions(request: CompletionRequest): - refuse = random.random() < 0.2 - message = random.choice(REFUSAL_MARKS) if refuse else "This is a test!" - prompt_content = " ".join( - [msg.content for msg in request.messages if msg.role == "user"] - ) - message = prompt_content + " " + message - ready = Event() - ref = dict(message=message, reply="", ready=ready) - tools_inbox.put_nowait(ref) - if settings.FEATURE_PROXY: - # Proxy to agent - await ready.wait() - reply = ref["reply"] - return reply - # Simulate a completion response - return { - "id": "chatcmpl-abc123", - "object": "chat.completion", - "created": 1677858242, - "model": "gpt-3.5-turbo-0613", - "usage": {"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20}, - "choices": [ - { - "message": {"role": "assistant", "content": message}, - "logprobs": None, - "finish_reason": "stop", - "index": 0, - } - ], - } - - -config.dictConfig( - { - "version": 1, - "disable_existing_loggers": True, - "handlers": { - "console": { - "class": "logging.StreamHandler", - }, - }, - "root": { - "handlers": ["console"], - "level": "INFO", - }, - "loggers": { - "uvicorn.access": { - "level": "ERROR", # Set higher log level to suppress info logs globally - "handlers": ["console"], - "propagate": False, - } - }, - } -) - - -@app.post("/stop") -async def stop_scan(): - stop_event.set() # Set the stop event to cancel the scan - return {"status": "Scan stopped"} - - -class LogNon200ResponsesMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - try: - response = await call_next(request) - except Exception as e: - logger.exception("Yikes") - raise e - if response.status_code != 200: - logger.error( - f"{request.method} {request.url} - Status code: {response.status_code}" - ) - return response - - -# Add middleware to the application +# Setup middleware +setup_cors(app) app.add_middleware(LogNon200ResponsesMiddleware) + +# Setup logging +setup_logging() + +# Register routers +app.include_router(static_router) +app.include_router(scan_router) +app.include_router(probe_router) +app.include_router(proxy_router) +app.include_router(report_router) diff --git a/agentic_security/core/app.py b/agentic_security/core/app.py new file mode 100644 index 0000000..7a8d55d --- /dev/null +++ b/agentic_security/core/app.py @@ -0,0 +1,21 @@ +from asyncio import Event, Queue +from fastapi import FastAPI + +tools_inbox: Queue = Queue() +stop_event: Event = Event() + + +def create_app() -> FastAPI: + """Create and configure the FastAPI application.""" + app = FastAPI() + return app + + +def get_tools_inbox() -> Queue: + """Get the global tools inbox queue.""" + return tools_inbox + + +def get_stop_event() -> Event: + """Get the global stop event.""" + return stop_event diff --git a/agentic_security/core/logging.py b/agentic_security/core/logging.py new file mode 100644 index 0000000..c602f7a --- /dev/null +++ b/agentic_security/core/logging.py @@ -0,0 +1,26 @@ +from logging import config + + +def setup_logging(): + config.dictConfig( + { + "version": 1, + "disable_existing_loggers": True, + "handlers": { + "console": { + "class": "logging.StreamHandler", + }, + }, + "root": { + "handlers": ["console"], + "level": "INFO", + }, + "loggers": { + "uvicorn.access": { + "level": "ERROR", # Set higher log level to suppress info logs globally + "handlers": ["console"], + "propagate": False, + } + }, + } + ) diff --git a/agentic_security/lib.py b/agentic_security/lib.py index 51a2563..751fe36 100644 --- a/agentic_security/lib.py +++ b/agentic_security/lib.py @@ -5,7 +5,8 @@ import colorama import tqdm.asyncio from tabulate import tabulate -from agentic_security.app import Scan, streaming_response_generator +from agentic_security.models.schemas import Scan +from agentic_security.routes.scan import streaming_response_generator from agentic_security.probe_data import REGISTRY RESET = colorama.Style.RESET_ALL diff --git a/agentic_security/middleware/cors.py b/agentic_security/middleware/cors.py new file mode 100644 index 0000000..5390fe7 --- /dev/null +++ b/agentic_security/middleware/cors.py @@ -0,0 +1,14 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + + +def setup_cors(app: FastAPI): + origins = ["*"] + + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers + ) diff --git a/agentic_security/middleware/logging.py b/agentic_security/middleware/logging.py new file mode 100644 index 0000000..9171a4f --- /dev/null +++ b/agentic_security/middleware/logging.py @@ -0,0 +1,17 @@ +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware +from loguru import logger + + +class LogNon200ResponsesMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + try: + response = await call_next(request) + except Exception as e: + logger.exception("Yikes") + raise e + if response.status_code != 200: + logger.error( + f"{request.method} {request.url} - Status code: {response.status_code}" + ) + return response diff --git a/agentic_security/models/schemas.py b/agentic_security/models/schemas.py new file mode 100644 index 0000000..9c1569a --- /dev/null +++ b/agentic_security/models/schemas.py @@ -0,0 +1,56 @@ +import os +from pydantic import BaseModel, Field + + +class Settings: + MAX_BUDGET = 1000 + MAX_DATASETS = 10 + RATE_LIMIT = "100/minute" + DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) + FEATURE_PROXY = False + + +class LLMInfo(BaseModel): + spec: str + + +class Scan(BaseModel): + llmSpec: str + maxBudget: int + datasets: list[dict] = [] + optimize: bool = False + + +class ScanResult(BaseModel): + module: str + tokens: int + cost: float + progress: float + failureRate: float = 0.0 + + +class Probe(BaseModel): + prompt: str + + +class Message(BaseModel): + role: str + content: str + + +class CompletionRequest(BaseModel): + """Model for completion requests.""" + + model: str + messages: list[Message] + temperature: float = Field(default=0.7, ge=0.0, le=2.0) + top_p: float = Field(default=1.0, ge=0.0, le=1.0) + n: int = Field(default=1, ge=1, le=10) + stop: list[str] | None = None + max_tokens: int = Field(default=100, ge=1, le=4096) + presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) + frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) + + +class Table(BaseModel): + table: list[dict] diff --git a/agentic_security/routes/__init__.py b/agentic_security/routes/__init__.py new file mode 100644 index 0000000..d334d65 --- /dev/null +++ b/agentic_security/routes/__init__.py @@ -0,0 +1,13 @@ +from .static import router as static_router +from .scan import router as scan_router +from .probe import router as probe_router +from .proxy import router as proxy_router +from .report import router as report_router + +__all__ = [ + "static_router", + "scan_router", + "probe_router", + "proxy_router", + "report_router", +] diff --git a/agentic_security/routes/probe.py b/agentic_security/routes/probe.py new file mode 100644 index 0000000..b2bbddf --- /dev/null +++ b/agentic_security/routes/probe.py @@ -0,0 +1,34 @@ +import random +from fastapi import APIRouter +from ..models.schemas import Probe +from ..probe_actor.refusal import REFUSAL_MARKS +from ..probe_data import REGISTRY + +router = APIRouter() + + +@router.post("/v1/self-probe") +def self_probe(probe: Probe): + refuse = random.random() < 0.2 + message = random.choice(REFUSAL_MARKS) if refuse else "This is a test!" + message = probe.prompt + " " + message + return { + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-3.5-turbo-0613", + "usage": {"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20}, + "choices": [ + { + "message": {"role": "assistant", "content": message}, + "logprobs": None, + "finish_reason": "stop", + "index": 0, + } + ], + } + + +@router.get("/v1/data-config") +async def data_config(): + return [m for m in REGISTRY] diff --git a/agentic_security/routes/proxy.py b/agentic_security/routes/proxy.py new file mode 100644 index 0000000..9732ed7 --- /dev/null +++ b/agentic_security/routes/proxy.py @@ -0,0 +1,45 @@ +import random +from asyncio import Event +from fastapi import APIRouter +from ..models.schemas import CompletionRequest, Settings +from ..probe_actor.refusal import REFUSAL_MARKS +from ..core.app import get_tools_inbox + +router = APIRouter() + + +@router.post("/proxy/chat/completions") +async def proxy_completions(request: CompletionRequest): + refuse = random.random() < 0.2 + message = random.choice(REFUSAL_MARKS) if refuse else "This is a test!" + prompt_content = " ".join( + [msg.content for msg in request.messages if msg.role == "user"] + ) + message = prompt_content + " " + message + ready = Event() + ref = dict(message=message, reply="", ready=ready) + tools_inbox = get_tools_inbox() + await tools_inbox.put(ref) + + if Settings.FEATURE_PROXY: + # Proxy to agent + await ready.wait() + reply = ref["reply"] + return reply + + # Simulate a completion response + return { + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1677858242, + "model": "gpt-3.5-turbo-0613", + "usage": {"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20}, + "choices": [ + { + "message": {"role": "assistant", "content": message}, + "logprobs": None, + "finish_reason": "stop", + "index": 0, + } + ], + } diff --git a/agentic_security/routes/report.py b/agentic_security/routes/report.py new file mode 100644 index 0000000..25988aa --- /dev/null +++ b/agentic_security/routes/report.py @@ -0,0 +1,20 @@ +from pathlib import Path +from fastapi import APIRouter, Response +from fastapi.responses import FileResponse, StreamingResponse +from ..models.schemas import Table +from ..report_chart import plot_security_report + +router = APIRouter() + + +@router.get("/failures") +async def failures_csv(): + if not Path("failures.csv").exists(): + return {"error": "No failures found"} + return FileResponse("failures.csv") + + +@router.post("/plot.jpeg", response_class=Response) +async def get_plot(table: Table): + buf = plot_security_report(table.table) + return StreamingResponse(buf, media_type="image/jpeg") diff --git a/agentic_security/routes/scan.py b/agentic_security/routes/scan.py new file mode 100644 index 0000000..cd7062d --- /dev/null +++ b/agentic_security/routes/scan.py @@ -0,0 +1,53 @@ +from datetime import datetime +from fastapi import APIRouter, BackgroundTasks, HTTPException +from fastapi.responses import StreamingResponse +from ..models.schemas import LLMInfo, Scan +from ..http_spec import LLMSpec +from ..probe_actor import fuzzer +from ..core.app import get_tools_inbox, get_stop_event + +router = APIRouter() + + +@router.post("/verify") +async def verify(info: LLMInfo): + spec = LLMSpec.from_string(info.spec) + r = await spec.probe("test") + if r.status_code >= 400: + raise HTTPException(status_code=r.status_code, detail=r.text) + return dict( + status_code=r.status_code, + body=r.text, + elapsed=r.elapsed.total_seconds(), + timestamp=datetime.now().isoformat(), + ) + + +def streaming_response_generator(scan_parameters: Scan): + request_factory = LLMSpec.from_string(scan_parameters.llmSpec) + + async def _gen(): + async for scan_result in fuzzer.perform_scan( + request_factory=request_factory, + max_budget=scan_parameters.maxBudget, + datasets=scan_parameters.datasets, + tools_inbox=get_tools_inbox(), + optimize=scan_parameters.optimize, + stop_event=get_stop_event(), + ): + yield scan_result + "\n" + + return _gen() + + +@router.post("/scan") +async def scan(scan_parameters: Scan, background_tasks: BackgroundTasks): + return StreamingResponse( + streaming_response_generator(scan_parameters), media_type="application/json" + ) + + +@router.post("/stop") +async def stop_scan(): + get_stop_event().set() + return {"status": "Scan stopped"} diff --git a/agentic_security/routes/static.py b/agentic_security/routes/static.py new file mode 100644 index 0000000..a97bdf8 --- /dev/null +++ b/agentic_security/routes/static.py @@ -0,0 +1,32 @@ +from pathlib import Path +from fastapi import APIRouter +from fastapi.responses import FileResponse +from ..models.schemas import Settings + +router = APIRouter() + + +@router.get("/") +async def root(): + agentic_security_path = Path(__file__).parent.parent + return FileResponse(f"{agentic_security_path}/static/index.html") + + +@router.get("/main.js") +async def main_js(): + agentic_security_path = Path(__file__).parent.parent + return FileResponse(f"{agentic_security_path}/static/main.js") + + +@router.get("/telemetry.js") +async def telemetry_js(): + agentic_security_path = Path(__file__).parent.parent + if Settings.DISABLE_TELEMETRY: + return FileResponse(f"{agentic_security_path}/static/telemetry_disabled.js") + return FileResponse(f"{agentic_security_path}/static/telemetry.js") + + +@router.get("/favicon.ico") +async def favicon(): + agentic_security_path = Path(__file__).parent.parent + return FileResponse(f"{agentic_security_path}/static/favicon.ico") diff --git a/pyproject.toml b/pyproject.toml index 14532ff..9bbc252 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "agentic_security" -version = "0.2.6" +version = "0.3.0" description = "Agentic LLM vulnerability scanner" authors = ["Alexander Miasoiedov "] maintainers = ["Alexander Miasoiedov "]