diff --git a/gateway/common/authorization.py b/gateway/common/authorization.py new file mode 100644 index 0000000..6a4e844 --- /dev/null +++ b/gateway/common/authorization.py @@ -0,0 +1,52 @@ +"""Common Authorization functions used in the gateway.""" + +from typing import Tuple, Optional +from fastapi import HTTPException, Request + +INVARIANT_AUTHORIZATION_HEADER = "invariant-authorization" +API_KEYS_SEPARATOR = ";invariant-auth=" + + +def extract_authorization_from_headers( + request: Request, dataset_name: Optional[str], llm_provider_api_key_header: str +) -> Tuple[str, str]: + """ + Extracts the Invariant authorization and LLM Provider API key from the request headers. + + In case the user wants to push to Explorer (when dataset_name is not None), + the request headers must contain the Invariant API Key. + The invariant-authorization header contains the Invariant API Key as + "invariant-authorization": "Bearer " + {llm_provider_api_key_header} contains the LLM Provider API Key as + {llm_provider_api_key_header}: "" + + For some clients, it is not possible to pass a custom header + In such cases, the Invariant API Key is passed as part of the + {llm_provider_api_key_header} with the LLM Provider API Key + The header in that case becomes: + {llm_provider_api_key_header}: ";invariant-auth=" + """ + invariant_authorization = request.headers.get(INVARIANT_AUTHORIZATION_HEADER) + llm_provider_api_key = request.headers.get(llm_provider_api_key_header) + if dataset_name: + if invariant_authorization is None: + if llm_provider_api_key is None: + raise HTTPException( + status_code=400, detail="Missing LLM Provider API Key" + ) + + if API_KEYS_SEPARATOR not in llm_provider_api_key: + raise HTTPException(status_code=400, detail="Missing invariant api key") + + # Both the API keys are passed in the llm_provider_api_key_header + api_keys = request.headers.get(llm_provider_api_key_header).split( + API_KEYS_SEPARATOR + ) + if len(api_keys) != 2 or not api_keys[1].strip(): + raise HTTPException( + status_code=400, detail="Invalid API Key format" + ) + + invariant_authorization = f"Bearer {api_keys[1].strip()}" + llm_provider_api_key = f"{api_keys[0].strip()}" + return invariant_authorization, llm_provider_api_key diff --git a/gateway/utils/constants.py b/gateway/common/constants.py similarity index 83% rename from gateway/utils/constants.py rename to gateway/common/constants.py index 07434ff..02a8209 100644 --- a/gateway/utils/constants.py +++ b/gateway/common/constants.py @@ -13,4 +13,3 @@ IGNORED_HEADERS = [ ] CLIENT_TIMEOUT = 60.0 -INVARIANT_AUTHORIZATION_HEADER = "invariant-authorization" diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 45db6fb..fc8ab02 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -7,12 +7,12 @@ import httpx from common.config_manager import GatewayConfig, GatewayConfigManager from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from starlette.responses import StreamingResponse -from utils.constants import ( +from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, - INVARIANT_AUTHORIZATION_HEADER, ) from utils.explorer import push_trace +from common.authorization import extract_authorization_from_headers gateway = APIRouter() @@ -56,42 +56,14 @@ async def anthropic_v1_messages_gateway( } headers["accept-encoding"] = "identity" - # In case the user wants to push to Explorer, the request must contain the Invariant API Key - # The invariant-authorization header contains the Invariant API Key - # "invariant-authorization": "Bearer " - # The x-api-key header contains the Anthropic API Key - # "x-api-key": "" - # - # For some clients, it is not possible to pass a custom header - # In such cases, the Invariant API Key is passed as part of the - # x-api-key header with the Anthropic API key. - # The header in that case becomes: - # "x-api-key": ";invariant-auth=" - invariant_authorization = None - if dataset_name: - if request.headers.get( - INVARIANT_AUTHORIZATION_HEADER - ) is None and ";invariant-auth=" not in request.headers.get( - ANTHROPIC_AUTHORIZATION_HEADER - ): - raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_API_KEY) - if request.headers.get(INVARIANT_AUTHORIZATION_HEADER): - invariant_authorization = request.headers.get( - INVARIANT_AUTHORIZATION_HEADER - ) - else: - header_value = request.headers.get(ANTHROPIC_AUTHORIZATION_HEADER) - api_keys = header_value.split(";invariant-auth=") - invariant_authorization = f"Bearer {api_keys[1].strip()}" - # Update the authorization header to pass the Anthropic API Key - headers[ANTHROPIC_AUTHORIZATION_HEADER] = f"{api_keys[0].strip()}" + invariant_authorization, anthopic_api_key = extract_authorization_from_headers( + request, dataset_name, ANTHROPIC_AUTHORIZATION_HEADER + ) + headers[ANTHROPIC_AUTHORIZATION_HEADER] = anthopic_api_key request_body = await request.body() - request_body_json = json.loads(request_body) - client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) - anthropic_request = client.build_request( "POST", "https://api.anthropic.com/v1/messages", diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index c60cb35..6dee3ec 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -7,10 +7,17 @@ import httpx from common.config_manager import GatewayConfig, GatewayConfigManager from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response from fastapi.responses import StreamingResponse -from utils.constants import CLIENT_TIMEOUT, IGNORED_HEADERS +from common.constants import ( + CLIENT_TIMEOUT, + IGNORED_HEADERS, +) +from common.authorization import extract_authorization_from_headers +from utils.explorer import push_trace gateway = APIRouter() +GEMINI_AUTHORIZATION_HEADER = "x-goog-api-key" + @gateway.post("/gemini/{api_version}/models/{model}:{endpoint}") async def gemini_generate_content_gateway( @@ -37,7 +44,15 @@ async def gemini_generate_content_gateway( } headers["accept-encoding"] = "identity" + invariant_authorization, gemini_api_key = extract_authorization_from_headers( + request, dataset_name, GEMINI_AUTHORIZATION_HEADER + ) + headers[GEMINI_AUTHORIZATION_HEADER] = gemini_api_key + request_body_bytes = await request.body() + request_body_json = json.loads(request_body_bytes) + print("Here is the request: ", request_body_json) + client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) gemini_api_url = f"https://generativelanguage.googleapis.com/{api_version}/models/{model}:{endpoint}" if alt == "sse": @@ -56,7 +71,9 @@ async def gemini_generate_content_gateway( dataset_name, ) response = await client.send(gemini_request) - return await handle_non_streaming_response(response, dataset_name) + return await handle_non_streaming_response( + response, dataset_name, request_body_json, invariant_authorization + ) async def stream_response( @@ -83,6 +100,7 @@ async def stream_response( continue # Yield chunk immediately to the client + print("Here is the response chunk: ", chunk) yield chunk # Send full merged response to the explorer @@ -93,13 +111,33 @@ async def stream_response( return StreamingResponse(event_generator(), media_type="text/event-stream") +async def push_to_explorer( + dataset_name: str, + merged_response: dict[str, Any], + request_body: dict[str, Any], + invariant_authorization: str, +) -> None: + """Pushes the full trace to the Invariant Explorer""" + # Combine the messages from the request body and the choices from the Gemini response + messages = request_body.get("messages", []) + messages += [choice["message"] for choice in merged_response.get("choices", [])] + _ = await push_trace( + dataset_name=dataset_name, + messages=[messages], + invariant_authorization=invariant_authorization, + ) + + async def handle_non_streaming_response( response: httpx.Response, dataset_name: Optional[str], + request_body_json: dict[str, Any], + invariant_authorization: Optional[str], ) -> Response: """Handles non-streaming Gemini responses""" try: json_response = response.json() + print("Here is the response: ", json_response) except json.JSONDecodeError as e: raise HTTPException( status_code=response.status_code, @@ -111,8 +149,9 @@ async def handle_non_streaming_response( detail=json_response.get("error", "Unknown error from Gemini API"), ) if dataset_name: - # Push to Explorer - pass + await push_to_explorer( + dataset_name, json_response, request_body_json, invariant_authorization + ) return Response( content=json.dumps(json_response), diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 0fd9d17..6688e84 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -7,12 +7,12 @@ import httpx from common.config_manager import GatewayConfig, GatewayConfigManager from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi.responses import StreamingResponse -from utils.constants import ( +from common.constants import ( CLIENT_TIMEOUT, IGNORED_HEADERS, - INVARIANT_AUTHORIZATION_HEADER, ) from utils.explorer import push_trace +from common.authorization import extract_authorization_from_headers gateway = APIRouter() @@ -47,43 +47,16 @@ async def openai_chat_completions_gateway( } headers["accept-encoding"] = "identity" + invariant_authorization, openai_api_key = extract_authorization_from_headers( + request, dataset_name, OPENAI_AUTHORIZATION_HEADER + ) + headers[OPENAI_AUTHORIZATION_HEADER] = openai_api_key + request_body_bytes = await request.body() request_body_json = json.loads(request_body_bytes) - # Check if the request is for streaming is_streaming = request_body_json.get("stream", False) - # In case the user wants to push to Explorer, the request must contain the Invariant API Key - # The invariant-authorization header contains the Invariant API Key - # "invariant-authorization": "Bearer " - # The authorization header contains the OpenAI API Key - # "authorization": "" - # - # For some clients, it is not possible to pass a custom header - # In such cases, the Invariant API Key is passed as part of the - # authorization header with the OpenAI API key. - # The header in that case becomes: - # "authorization": ";invariant-auth=" - invariant_authorization = None - if dataset_name: - if request.headers.get( - INVARIANT_AUTHORIZATION_HEADER - ) is None and ";invariant-auth=" not in request.headers.get( - OPENAI_AUTHORIZATION_HEADER - ): - raise HTTPException(status_code=400, detail=MISSING_INVARIANT_AUTH_API_KEY) - - if request.headers.get(INVARIANT_AUTHORIZATION_HEADER): - invariant_authorization = request.headers.get( - INVARIANT_AUTHORIZATION_HEADER - ) - else: - header_value = request.headers.get(OPENAI_AUTHORIZATION_HEADER) - api_keys = header_value.split(";invariant-auth=") - invariant_authorization = f"Bearer {api_keys[1].strip()}" - # Update the authorization header to pass the OpenAI API Key to the OpenAI API - headers[OPENAI_AUTHORIZATION_HEADER] = f"{api_keys[0].strip()}" - client = httpx.AsyncClient(timeout=httpx.Timeout(CLIENT_TIMEOUT)) open_ai_request = client.build_request( "POST",