feat(minor api improvement):

This commit is contained in:
Alexander Myasoedov
2024-11-29 16:19:27 +02:00
parent e7cf291433
commit 65edfe8930
+24 -14
View File
@@ -9,7 +9,7 @@ from fastapi import BackgroundTasks, FastAPI, HTTPException, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from loguru import logger
from pydantic import BaseModel
from pydantic import BaseModel, Field
from starlette.middleware.base import BaseHTTPMiddleware
from .http_spec import LLMSpec
@@ -24,6 +24,18 @@ origins = [
"*",
]
# Configuration
class Settings:
MAX_BUDGET = 1000
MAX_DATASETS = 10
RATE_LIMIT = "100/minute"
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False)
FEATURE_PROXY = False
settings = Settings()
# Middleware setup
app.add_middleware(
CORSMiddleware,
@@ -37,10 +49,6 @@ tools_inbox = Queue()
# Global stop event for cancelling scans
stop_event = Event() # Added stop_event to cancel the scan
FEATURE_PROXY = False
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False)
@app.get("/")
async def root():
@@ -57,7 +65,7 @@ async def main_js():
@app.get("/telemetry.js")
async def telemetry_js():
agentic_security_path = Path(__file__).parent
if DISABLE_TELEMETRY:
if settings.DISABLE_TELEMETRY:
return FileResponse(f"{agentic_security_path}/static/telemetry_disabled.js")
return FileResponse(f"{agentic_security_path}/static/telemetry.js")
@@ -183,15 +191,17 @@ class Message(BaseModel):
class CompletionRequest(BaseModel):
"""Model for completion requests."""
model: str
messages: list[Message]
temperature: float = 0.7 # Default value for temperature
top_p: float = 1.0 # Default value for top_p
n: int = 1 # Default value for n
stop: list[str] = None # Optional; specify as None if not provided
max_tokens: int = 100 # Default value for max_tokens
presence_penalty: float = 0.0 # Default value for presence_penalty
frequency_penalty: float = 0.0 # Default value for frequency_penalty
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
n: int = Field(default=1, ge=1, le=10)
stop: list[str] | None = None
max_tokens: int = Field(default=100, ge=1, le=4096)
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
# OpenAI proxy endpoint
@@ -206,7 +216,7 @@ async def proxy_completions(request: CompletionRequest):
ready = Event()
ref = dict(message=message, reply="", ready=ready)
tools_inbox.put_nowait(ref)
if FEATURE_PROXY:
if settings.FEATURE_PROXY:
# Proxy to agent
await ready.wait()
reply = ref["reply"]