Update comments and function names.

This commit is contained in:
Hemang
2025-06-11 22:17:03 +02:00
committed by Hemang Sarkar
parent b6b738a9aa
commit 491a279f6e
4 changed files with 23 additions and 48 deletions

View File

@@ -32,7 +32,6 @@ from gateway.routes.instrumentation import (
)
gateway = APIRouter()
MISSING_ANTHROPIC_AUTH_HEADER = "Missing Anthropic authorization header"
FAILED_TO_PUSH_TRACE = "Failed to push trace to the dataset: "
END_REASONS = ["end_turn", "max_tokens", "stop_sequence"]
@@ -66,13 +65,7 @@ async def anthropic_v1_messages_gateway(
config: GatewayConfig = Depends(GatewayConfigManager.get_config),
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
):
"""
Proxy calls to the Anthropic APIs
All Anthropic-specific cases (SSE, message conversion) handled by provider
"""
# Standard Anthropic request setup
"""Proxy calls to the Anthropic APIs"""
headers = {
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
}
@@ -111,12 +104,10 @@ async def anthropic_v1_messages_gateway(
request=request,
)
# Create Anthropic provider
provider = AnthropicProvider()
# Handle streaming vs non-streaming
# Handle streaming and non-streaming
if request_json.get("stream"):
# Use the base class directly - it handles SSE processing via the provider
response = InstrumentedStreamingResponse(
context=context,
client=client,
@@ -176,7 +167,7 @@ def update_merged_response(
class AnthropicProvider(BaseProvider):
"""Complete Anthropic provider covering all cases"""
"""Concrete implementation of BaseProvider for Anthropic"""
def get_provider_name(self) -> str:
return "anthropic"
@@ -339,7 +330,7 @@ class AnthropicProvider(BaseProvider):
return events, remaining
def is_streaming_complete(self, _: dict[str, Any], chunk_text: str = "") -> bool:
"""Anthropic completion detection"""
"""Anthropic streaming completion detection"""
return "message_stop" in chunk_text
def initialize_streaming_response(self) -> dict[str, Any]:

View File

@@ -52,11 +52,7 @@ async def gemini_generate_content_gateway(
config: GatewayConfig = Depends(GatewayConfigManager.get_config),
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
) -> Response:
"""
Proxy calls to the Gemini APIs
All Gemini-specific cases (message conversion, end_of_stream behavior) handled by provider
"""
"""Proxy calls to the Gemini APIs"""
# Gemini endpoint validation
if endpoint not in ["generateContent", "streamGenerateContent"]:
@@ -117,12 +113,10 @@ async def gemini_generate_content_gateway(
request=request,
)
# Create Gemini provider
provider = GeminiProvider()
# Handle streaming and non-streaming
if alt == "sse" or endpoint == "streamGenerateContent":
# Use the base class directly - it handles Gemini streaming via the provider
response = InstrumentedStreamingResponse(
context=context,
client=client,
@@ -214,7 +208,7 @@ def make_refusal(
class GeminiProvider(BaseProvider):
"""Complete Gemini provider covering all cases"""
"""Concrete implementation of BaseProvider for Gemini"""
def get_provider_name(self) -> str:
return "gemini"
@@ -222,7 +216,7 @@ class GeminiProvider(BaseProvider):
def combine_messages(
self, request_json: dict[str, Any], response_json: dict[str, Any]
) -> list[dict[str, Any]]:
"""Gemini message combination with format conversion"""
"""Gemini messages combination with format conversion"""
converted_requests = convert_request(request_json)
converted_responses = convert_response(response_json) if response_json else []
@@ -298,7 +292,7 @@ class GeminiProvider(BaseProvider):
def should_push_trace(
self, merged_response: dict[str, Any], has_errors: bool
) -> bool:
"""Gemini push criteria"""
"""Gemini push trace criteria"""
return has_errors or (
merged_response.get("candidates", [])
and merged_response["candidates"][0].get("finishReason") is not None
@@ -307,7 +301,7 @@ class GeminiProvider(BaseProvider):
def process_streaming_chunk(
self, chunk: bytes, merged_response: dict[str, Any], _: dict[str, Any]
) -> None:
"""Gemini streaming hunk processing"""
"""Gemini streaming chunk processing"""
chunk_text = chunk.decode().strip()
if not chunk_text:
return

View File

@@ -71,13 +71,6 @@ class BaseInstrumentedResponse(ABC):
This can be used for input guardrails or other pre-processing tasks.
"""
@abstractmethod
async def on_end(self):
"""
Post-processing hook.
This can be used for output guardrails or other post-processing tasks.
"""
@abstractmethod
async def on_chunk(self, chunk: Any):
"""
@@ -85,6 +78,13 @@ class BaseInstrumentedResponse(ABC):
This can be used for streaming responses to handle each chunk as it arrives.
"""
@abstractmethod
async def on_end(self):
"""
Post-processing hook.
This can be used for output guardrails or other post-processing tasks.
"""
async def check_guardrails_common(
self, messages: list[dict[str, Any]], action: GuardrailAction
) -> dict[str, Any]:
@@ -208,8 +208,8 @@ class BaseInstrumentedResponse(ABC):
location="response",
)
async def push_successful_trace(self, response_data: dict[str, Any]) -> None:
"""Push successful trace"""
async def push_trace_to_explorer(self, response_data: dict[str, Any]) -> None:
"""Push trace to explorer if dataset is configured"""
if self.context.dataset_name:
should_push = self.provider.should_push_trace(
response_data,
@@ -382,7 +382,7 @@ class InstrumentedStreamingResponse(BaseInstrumentedResponse):
async def on_end(self) -> ExtraItem | None:
"""Run post-processing after the streaming response ends."""
await self.push_successful_trace(self.merged_response)
await self.push_trace_to_explorer(self.merged_response)
async def event_generator(self):
"""Generic event generator using provider protocol"""
@@ -428,8 +428,7 @@ class InstrumentedResponse(BaseInstrumentedResponse):
if result: # If guardrails failed
return result
# Push successful trace
await self.push_successful_trace(self.response_json)
await self.push_trace_to_explorer(self.response_json)
async def event_generator(self):
"""

View File

@@ -111,13 +111,8 @@ async def openai_chat_completions_gateway(
config: GatewayConfig = Depends(GatewayConfigManager.get_config),
header_guardrails: GuardrailRuleSet = Depends(extract_guardrails_from_header),
) -> Response:
"""
Proxy calls to the OpenAI APIs
"""Proxy calls to the OpenAI chat completions endpoint"""
All OpenAI-specific cases handled by the provider and base classes
"""
# Standard OpenAI request setup
headers = {
k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS
}
@@ -156,12 +151,10 @@ async def openai_chat_completions_gateway(
request=request,
)
# Create OpenAI provider
provider = OpenAIProvider()
# Handle streaming vs non-streaming
# Handle streaming and non-streaming
if request_json.get("stream", False):
# Use the base class directly - it handles everything via the provider
response = InstrumentedStreamingResponse(
context=context,
client=client,
@@ -182,7 +175,7 @@ async def openai_chat_completions_gateway(
class OpenAIProvider(BaseProvider):
"""Complete OpenAI provider covering all cases"""
"""Concrete implementation of BaseProvider for OpenAI"""
def get_provider_name(self) -> str:
return "openai"
@@ -251,8 +244,6 @@ class OpenAIProvider(BaseProvider):
}
}
)
# return an extra error chunk (without preventing the original
# chunk to go through after)
return ExtraItem(f"data: {error_chunk}\n\n".encode(), end_of_stream=True)
def should_push_trace(