From 37b292a48a9b8702ac58ae79b5c188fb96c27164 Mon Sep 17 00:00:00 2001 From: Alexander Myasoedov Date: Tue, 17 Dec 2024 14:16:49 +0200 Subject: [PATCH] fix(add file self probe endpoint): --- agentic_security/models/schemas.py | 7 +++ agentic_security/routes/probe.py | 31 +++++++++++- agentic_security/routes/test_probe.py | 72 +++++++++++++++++++++++++++ agentic_security/test_spec_assets.py | 2 +- 4 files changed, 109 insertions(+), 3 deletions(-) diff --git a/agentic_security/models/schemas.py b/agentic_security/models/schemas.py index 3c9bab6..742a5ff 100644 --- a/agentic_security/models/schemas.py +++ b/agentic_security/models/schemas.py @@ -68,5 +68,12 @@ class CompletionRequest(BaseModel): frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) +class FileProbeResponse(BaseModel): + """Response model for file probe endpoint.""" + + text: str + model: str + + class Table(BaseModel): table: list[dict] diff --git a/agentic_security/routes/probe.py b/agentic_security/routes/probe.py index 612f03d..5cce7a2 100644 --- a/agentic_security/routes/probe.py +++ b/agentic_security/routes/probe.py @@ -1,8 +1,8 @@ import random -from fastapi import APIRouter +from fastapi import APIRouter, File, Header, HTTPException, UploadFile -from ..models.schemas import Probe +from ..models.schemas import FileProbeResponse, Probe from ..probe_actor.refusal import REFUSAL_MARKS from ..probe_data import REGISTRY @@ -31,6 +31,33 @@ def self_probe(probe: Probe): } +@router.post("/v1/self-probe-file", response_model=FileProbeResponse) +async def self_probe_file( + file: UploadFile = File(...), + model: str = "whisper-large-v3", + authorization: str = Header(...), +): + if not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization header") + + api_key = authorization.replace("Bearer ", "") + if not api_key: + raise HTTPException(status_code=401, detail="Missing API key") + + if not file.filename or not file.filename.lower().endswith( + (".m4a", ".mp3", ".wav") + ): + raise HTTPException( + status_code=400, + detail="Invalid file format. Supported formats: m4a, mp3, wav", + ) + + # For testing purposes, return mock transcription + mock_text = "This is a mock transcription of the audio file." + + return FileProbeResponse(text=mock_text, model=model) + + @router.get("/v1/data-config") async def data_config(): return [m for m in REGISTRY] diff --git a/agentic_security/routes/test_probe.py b/agentic_security/routes/test_probe.py index 476b01c..2b8831f 100644 --- a/agentic_security/routes/test_probe.py +++ b/agentic_security/routes/test_probe.py @@ -1,3 +1,4 @@ +import io import pytest from fastapi.testclient import TestClient @@ -95,3 +96,74 @@ def test_refusal_rate(): assert ( 0.15 <= refusal_rate <= 0.25 ), f"Refusal rate {refusal_rate} is outside expected range" + + +def test_self_probe_file_endpoint(): + """Test /v1/self-probe-file endpoint with valid input""" + # Create a mock audio file + file_content = b"mock audio content" + file = io.BytesIO(file_content) + files = {"file": ("test.m4a", file, "audio/m4a")} + headers = {"Authorization": "Bearer test_api_key"} + + response = client.post( + "/v1/self-probe-file", + files=files, + headers=headers, + data={"model": "whisper-large-v3"}, + ) + assert response.status_code == 200 + + data = response.json() + assert "text" in data + assert "model" in data + assert data["model"] == "whisper-large-v3" + + +def test_self_probe_file_invalid_auth(): + """Test /v1/self-probe-file endpoint with invalid authorization""" + file_content = b"mock audio content" + file = io.BytesIO(file_content) + files = {"file": ("test.m4a", file, "audio/m4a")} + + # Test missing auth header + response = client.post("/v1/self-probe-file", files=files) + assert response.status_code == 422 + + # Test invalid auth format + headers = {"Authorization": "InvalidFormat test_api_key"} + response = client.post("/v1/self-probe-file", files=files, headers=headers) + assert response.status_code == 401 + + # Test empty token + headers = {"Authorization": "Bearer "} + response = client.post("/v1/self-probe-file", files=files, headers=headers) + assert response.status_code == 401 + + +def test_self_probe_file_invalid_format(): + """Test /v1/self-probe-file endpoint with invalid file format""" + file_content = b"mock content" + file = io.BytesIO(file_content) + files = {"file": ("test.txt", file, "text/plain")} + headers = {"Authorization": "Bearer test_api_key"} + + response = client.post( + "/v1/self-probe-file", + files=files, + headers=headers, + data={"model": "whisper-large-v3"}, + ) + assert response.status_code == 400 + assert "Invalid file format" in response.json()["detail"] + + +def test_self_probe_file_missing_file(): + """Test /v1/self-probe-file endpoint with missing file""" + headers = {"Authorization": "Bearer test_api_key"} + response = client.post( + "/v1/self-probe-file", + headers=headers, + data={"model": "whisper-large-v3"}, + ) + assert response.status_code == 422 diff --git a/agentic_security/test_spec_assets.py b/agentic_security/test_spec_assets.py index 92d3432..28972a4 100644 --- a/agentic_security/test_spec_assets.py +++ b/agentic_security/test_spec_assets.py @@ -67,7 +67,7 @@ Content-Type: application/json FILE_SPEC = """ -POST http://0.0.0.0:9094/v1/self-probe +POST http://0.0.0.0:9094/v1/self-probe-file Authorization: Bearer $GROQ_API_KEY Content-Type: multipart/form-data