feat(add probe image endpoint):

This commit is contained in:
Alexander Myasoedov
2024-12-17 15:01:51 +02:00
parent 9e8b9ec33e
commit a38bcec50f
2 changed files with 54 additions and 5 deletions
+14 -5
View File
@@ -9,11 +9,7 @@ from ..probe_data import REGISTRY
router = APIRouter()
@router.post("/v1/self-probe")
def self_probe(probe: Probe):
refuse = random.random() < 0.2
message = random.choice(REFUSAL_MARKS) if refuse else "This is a test!"
message = probe.prompt + " " + message
def make_mock_response(message: str) -> dict:
return {
"id": "chatcmpl-abc123",
"object": "chat.completion",
@@ -31,6 +27,14 @@ def self_probe(probe: Probe):
}
@router.post("/v1/self-probe")
def self_probe(probe: Probe):
refuse = random.random() < 0.2
message = random.choice(REFUSAL_MARKS) if refuse else "This is a test!"
message = probe.prompt + " " + message
return make_mock_response(message)
@router.post("/v1/self-probe-file", response_model=FileProbeResponse)
async def self_probe_file(
file: UploadFile = File(...),
@@ -58,6 +62,11 @@ async def self_probe_file(
return FileProbeResponse(text=mock_text, model=model)
@router.post("/v1/self-probe-image")
async def self_probe_image():
return make_mock_response(message="This is a mock response for the image.")
@router.get("/v1/data-config")
async def data_config():
return [m for m in REGISTRY]
+40
View File
@@ -1,4 +1,5 @@
import io
import pytest
from fastapi.testclient import TestClient
@@ -167,3 +168,42 @@ def test_self_probe_file_missing_file():
data={"model": "whisper-large-v3"},
)
assert response.status_code == 422
def test_self_probe_image_endpoint():
"""Test /v1/self-probe-image endpoint with valid input"""
headers = {"Authorization": "Bearer test_api_key"}
# Test with different valid payloads
payloads = [
# OpenAI-style multi-modal payload
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,mockbase64data"},
},
],
}
],
# Simple text payload
{"message": "Test message"},
# Nested payload
{"level1": {"level2": "test"}},
# Empty object
{},
# Empty array
[],
]
for payload in payloads:
response = client.post("/v1/self-probe-image", json=payload, headers=headers)
assert response.status_code == 200, (payload, response.json())
data = response.json()
assert "choices" in data
assert len(data["choices"]) == 1
assert "message" in data["choices"][0]