mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-02-12 14:32:45 +00:00
Update comments and function names.
This commit is contained in:
@@ -32,7 +32,6 @@ from gateway.routes.instrumentation import (
|
||||
)
|
||||
|
||||
gateway = APIRouter()
|
||||
|
||||
MISSING_ANTHROPIC_AUTH_HEADER = "Missing Anthropic authorization header"
|
||||
FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: "
|
||||
END_REASONS = ["end_turn", "max_tokens", "stop_sequence"]
|
||||
@@ -66,13 +65,7 @@ async def anthropic_v1_messages_gateway(
|
||||
config: GatewayConfig = Depends(GatewayConfigManager.get_config),
|
||||
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
|
||||
):
|
||||
"""
|
||||
Proxy calls to the Anthropic APIs
|
||||
|
||||
All Anthropic-specific cases (SSE, message conversion) handled by provider
|
||||
"""
|
||||
|
||||
# Standard Anthropic request setup
|
||||
"""Proxy calls to the Anthropic APIs"""
|
||||
headers = {
|
||||
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
|
||||
}
|
||||
@@ -111,12 +104,10 @@ async def anthropic_v1_messages_gateway(
|
||||
request=request,
|
||||
)
|
||||
|
||||
# Create Anthropic provider
|
||||
provider = AnthropicProvider()
|
||||
|
||||
# Handle streaming vs non-streaming
|
||||
# Handle streaming and non-streaming
|
||||
if request_json.get("stream"):
|
||||
# Use the base class directly - it handles SSE processing via the provider
|
||||
response = InstrumentedStreamingResponse(
|
||||
context=context,
|
||||
client=client,
|
||||
@@ -176,7 +167,7 @@ def update_merged_response(
|
||||
|
||||
|
||||
class AnthropicProvider(BaseProvider):
|
||||
"""Complete Anthropic provider covering all cases"""
|
||||
"""Concrete implementation of BaseProvider for Anthropic"""
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
return "anthropic"
|
||||
@@ -339,7 +330,7 @@ class AnthropicProvider(BaseProvider):
|
||||
return events, remaining
|
||||
|
||||
def is_streaming_complete(self, _: dict[str, Any], chunk_text: str = "") -> bool:
|
||||
"""Anthropic completion detection"""
|
||||
"""Anthropic streaming completion detection"""
|
||||
return "message_stop" in chunk_text
|
||||
|
||||
def initialize_streaming_response(self) -> dict[str, Any]:
|
||||
|
||||
@@ -52,11 +52,7 @@ async def gemini_generate_content_gateway(
|
||||
config: GatewayConfig = Depends(GatewayConfigManager.get_config),
|
||||
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
|
||||
) -> Response:
|
||||
"""
|
||||
Proxy calls to the Gemini APIs
|
||||
|
||||
All Gemini-specific cases (message conversion, end_of_stream behavior) handled by provider
|
||||
"""
|
||||
"""Proxy calls to the Gemini APIs"""
|
||||
|
||||
# Gemini endpoint validation
|
||||
if endpoint not in ["generateContent", "streamGenerateContent"]:
|
||||
@@ -117,12 +113,10 @@ async def gemini_generate_content_gateway(
|
||||
request=request,
|
||||
)
|
||||
|
||||
# Create Gemini provider
|
||||
provider = GeminiProvider()
|
||||
|
||||
# Handle streaming and non-streaming
|
||||
if alt == "sse" or endpoint == "streamGenerateContent":
|
||||
# Use the base class directly - it handles Gemini streaming via the provider
|
||||
response = InstrumentedStreamingResponse(
|
||||
context=context,
|
||||
client=client,
|
||||
@@ -214,7 +208,7 @@ def make_refusal(
|
||||
|
||||
|
||||
class GeminiProvider(BaseProvider):
|
||||
"""Complete Gemini provider covering all cases"""
|
||||
"""Concrete implementation of BaseProvider for Gemini"""
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
return "gemini"
|
||||
@@ -222,7 +216,7 @@ class GeminiProvider(BaseProvider):
|
||||
def combine_messages(
|
||||
self, request_json: dict[str, Any], response_json: dict[str, Any]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Gemini message combination with format conversion"""
|
||||
"""Gemini messages combination with format conversion"""
|
||||
converted_requests = convert_request(request_json)
|
||||
converted_responses = convert_response(response_json) if response_json else []
|
||||
|
||||
@@ -298,7 +292,7 @@ class GeminiProvider(BaseProvider):
|
||||
def should_push_trace(
|
||||
self, merged_response: dict[str, Any], has_errors: bool
|
||||
) -> bool:
|
||||
"""Gemini push criteria"""
|
||||
"""Gemini push trace criteria"""
|
||||
return has_errors or (
|
||||
merged_response.get("candidates", [])
|
||||
and merged_response["candidates"][0].get("finishReason") is not None
|
||||
@@ -307,7 +301,7 @@ class GeminiProvider(BaseProvider):
|
||||
def process_streaming_chunk(
|
||||
self, chunk: bytes, merged_response: dict[str, Any], _: dict[str, Any]
|
||||
) -> None:
|
||||
"""Gemini streaming hunk processing"""
|
||||
"""Gemini streaming chunk processing"""
|
||||
chunk_text = chunk.decode().strip()
|
||||
if not chunk_text:
|
||||
return
|
||||
|
||||
@@ -71,13 +71,6 @@ class BaseInstrumentedResponse(ABC):
|
||||
This can be used for input guardrails or other pre-processing tasks.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def on_end(self):
|
||||
"""
|
||||
Post-processing hook.
|
||||
This can be used for output guardrails or other post-processing tasks.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def on_chunk(self, chunk: Any):
|
||||
"""
|
||||
@@ -85,6 +78,13 @@ class BaseInstrumentedResponse(ABC):
|
||||
This can be used for streaming responses to handle each chunk as it arrives.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def on_end(self):
|
||||
"""
|
||||
Post-processing hook.
|
||||
This can be used for output guardrails or other post-processing tasks.
|
||||
"""
|
||||
|
||||
async def check_guardrails_common(
|
||||
self, messages: list[dict[str, Any]], action: GuardrailAction
|
||||
) -> dict[str, Any]:
|
||||
@@ -208,8 +208,8 @@ class BaseInstrumentedResponse(ABC):
|
||||
location="response",
|
||||
)
|
||||
|
||||
async def push_successful_trace(self, response_data: dict[str, Any]) -> None:
|
||||
"""Push successful trace"""
|
||||
async def push_trace_to_explorer(self, response_data: dict[str, Any]) -> None:
|
||||
"""Push trace to explorer if dataset is configured"""
|
||||
if self.context.dataset_name:
|
||||
should_push = self.provider.should_push_trace(
|
||||
response_data,
|
||||
@@ -382,7 +382,7 @@ class InstrumentedStreamingResponse(BaseInstrumentedResponse):
|
||||
|
||||
async def on_end(self) -> ExtraItem | None:
|
||||
"""Run post-processing after the streaming response ends."""
|
||||
await self.push_successful_trace(self.merged_response)
|
||||
await self.push_trace_to_explorer(self.merged_response)
|
||||
|
||||
async def event_generator(self):
|
||||
"""Generic event generator using provider protocol"""
|
||||
@@ -428,8 +428,7 @@ class InstrumentedResponse(BaseInstrumentedResponse):
|
||||
if result: # If guardrails failed
|
||||
return result
|
||||
|
||||
# Push successful trace
|
||||
await self.push_successful_trace(self.response_json)
|
||||
await self.push_trace_to_explorer(self.response_json)
|
||||
|
||||
async def event_generator(self):
|
||||
"""
|
||||
|
||||
@@ -111,13 +111,8 @@ async def openai_chat_completions_gateway(
|
||||
config: GatewayConfig = Depends(GatewayConfigManager.get_config),
|
||||
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
|
||||
) -> Response:
|
||||
"""
|
||||
Proxy calls to the OpenAI APIs
|
||||
"""Proxy calls to the OpenAI chat completions endpoint"""
|
||||
|
||||
All OpenAI-specific cases handled by the provider and base classes
|
||||
"""
|
||||
|
||||
# Standard OpenAI request setup
|
||||
headers = {
|
||||
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
|
||||
}
|
||||
@@ -156,12 +151,10 @@ async def openai_chat_completions_gateway(
|
||||
request=request,
|
||||
)
|
||||
|
||||
# Create OpenAI provider
|
||||
provider = OpenAIProvider()
|
||||
|
||||
# Handle streaming vs non-streaming
|
||||
# Handle streaming and non-streaming
|
||||
if request_json.get("stream", False):
|
||||
# Use the base class directly - it handles everything via the provider
|
||||
response = InstrumentedStreamingResponse(
|
||||
context=context,
|
||||
client=client,
|
||||
@@ -182,7 +175,7 @@ async def openai_chat_completions_gateway(
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
"""Complete OpenAI provider covering all cases"""
|
||||
"""Concrete implementation of BaseProvider for OpenAI"""
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
return "openai"
|
||||
@@ -251,8 +244,6 @@ class OpenAIProvider(BaseProvider):
|
||||
}
|
||||
}
|
||||
)
|
||||
# return an extra error chunk (without preventing the original
|
||||
# chunk to go through after)
|
||||
return ExtraItem(f"data: {error_chunk}\n\n".encode(), end_of_stream=True)
|
||||
|
||||
def should_push_trace(
|
||||
|
||||
Reference in New Issue
Block a user