From 32fa1c080d174834b4ecab47af036013fd2d6ecf Mon Sep 17 00:00:00 2001 From: Hemang Date: Mon, 3 Mar 2025 11:44:53 +0100 Subject: [PATCH] Add a Gemini API proxy route. --- .env | 1 + proxy/routes/gemini.py | 61 ++++++++++++++++++++++++++++++++++++++++++ proxy/serve.py | 3 +++ run.sh | 8 +++--- 4 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 proxy/routes/gemini.py diff --git a/.env b/.env index ab9c670..96e30a7 100644 --- a/.env +++ b/.env @@ -3,3 +3,4 @@ # If you want to push to a local instance of explorer, then specify the app-api docker container name like: # http://:8000 to push to the local explorer instance. INVARIANT_API_URL=https://explorer.invariantlabs.ai +GUADRAILS_API_URL=https://guardrail.invariantnet.com diff --git a/proxy/routes/gemini.py b/proxy/routes/gemini.py new file mode 100644 index 0000000..4475a06 --- /dev/null +++ b/proxy/routes/gemini.py @@ -0,0 +1,61 @@ +"""Proxy service to forward requests to the Gemini APIs""" + +import json + +from common.config_manager import ProxyConfig, ProxyConfigManager +from fastapi import APIRouter, Depends, Request, Response +from utils.constants import IGNORED_HEADERS + +proxy = APIRouter() + + +def _extract_dataset_name_and_endpoint(endpoint: str): + """Extracts the dataset name and endpoint from the given endpoint.""" + endpoint_parts = endpoint.split("/") + dataset_name = None + if endpoint_parts[1] == "models": + # Case 1: Without dataset_name + # `endpoint = /models/:generateContent` + reconstructed_endpoint = "/".join(endpoint_parts) + elif endpoint_parts[2] == "models": + # Case 2: With dataset_name + # `endpoint = //models/:generateContent` + dataset_name = endpoint_parts[0] + reconstructed_endpoint = "/".join(endpoint_parts[1:]) + else: + # Case 3: Invalid endpoint + return Response( + content=f"Invalid endpoint: {endpoint} - the endpoint should be in the format: \ + /api/v1/proxy/gemini//models/:generateContent or \ + /api/v1/proxy/gemini//models/:generateContent", + status_code=400, + ) + return dataset_name, reconstructed_endpoint + + +@proxy.post( + "/gemini/{endpoint:path}", +) +async def gemini_generate_content_proxy( + request: Request, + endpoint: str, + config: ProxyConfig = Depends(ProxyConfigManager.get_config), # pylint: disable=unused-argument +) -> Response: + """Proxy calls to the OpenAI APIs""" + headers = { + k: v for k, v in request.headers.items() if k.lower() not in IGNORED_HEADERS + } + headers["accept-encoding"] = "identity" + + request_body_bytes = await request.body() + request_body_json = json.loads(request_body_bytes) + api_key = headers.get("x-goog-api-key") + print(f"API Key: {api_key}") + print("request body json: ", request_body_json) + dataset_name, reconstructed_endpoint = _extract_dataset_name_and_endpoint(endpoint) + + print(f"API Key: {api_key}") + print("Processed Endpoint: ", reconstructed_endpoint) + print("Dataset Name: ", dataset_name) + + return {} diff --git a/proxy/serve.py b/proxy/serve.py index 29ae04c..d7503b4 100644 --- a/proxy/serve.py +++ b/proxy/serve.py @@ -3,6 +3,7 @@ import fastapi import uvicorn from routes.anthropic import proxy as anthropic_proxy +from routes.gemini import proxy as gemini_proxy from routes.open_ai import proxy as open_ai_proxy from starlette_compress import CompressMiddleware @@ -26,6 +27,8 @@ router.include_router(open_ai_proxy, prefix="/proxy", tags=["open_ai_proxy"]) router.include_router(anthropic_proxy, prefix="/proxy", tags=["anthropic_proxy"]) +router.include_router(gemini_proxy, prefix="/proxy", tags=["gemini_proxy"]) + app.include_router(router) diff --git a/run.sh b/run.sh index 4f29c59..b50c94f 100755 --- a/run.sh +++ b/run.sh @@ -20,11 +20,13 @@ up() { shift done - if [[ -n "$POLICIES_FILE_PATH" && -f "$POLICIES_FILE_PATH" ]]; then - POLICIES_FILE_PATH=$(realpath "$POLICIES_FILE_PATH") - else + if [[ -n "$POLICIES_FILE_PATH" ]]; then + if [[ -f "$POLICIES_FILE_PATH" ]]; then + POLICIES_FILE_PATH=$(realpath "$POLICIES_FILE_PATH") + else echo "Error: Specified policies file does not exist: $POLICIES_FILE_PATH" exit 1 + fi fi # Start Docker Compose with the correct environment variable