Use pyproject.toml instead of requirements.txt and fix some broken tests.

This commit is contained in:
Hemang
2025-04-10 14:01:11 +02:00
committed by Hemang Sarkar
parent 6b6f33bde6
commit 5bf121bbda
23 changed files with 99 additions and 81 deletions

View File

@@ -48,7 +48,7 @@ jobs:
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: ./gateway context: ./gateway
file: ./gateway/Dockerfile.gateway file: Dockerfile.gateway
platforms: linux/amd64 platforms: linux/amd64
push: true push: true
tags: ${{ env.tags }} tags: ${{ env.tags }}

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12

16
Dockerfile.gateway Normal file
View File

@@ -0,0 +1,16 @@
FROM python:3.12
WORKDIR /srv/gateway
COPY pyproject.toml ./
COPY gateway/ ./
COPY README.md /srv/gateway/README.md
RUN pip install --no-cache-dir .
RUN chmod +x /srv/gateway/run.sh
# Run the application
CMD ["./run.sh"]

View File

@@ -2,8 +2,8 @@ services:
invariant-gateway: invariant-gateway:
container_name: invariant-gateway container_name: invariant-gateway
build: build:
context: ./gateway context: .
dockerfile: ../gateway/Dockerfile.gateway dockerfile: Dockerfile.gateway
working_dir: /srv/gateway working_dir: /srv/gateway
env_file: env_file:
- .env - .env

View File

@@ -1,14 +0,0 @@
FROM python:3.10
COPY ./requirements.txt /srv/gateway/requirements.txt
WORKDIR /srv/gateway
RUN pip install --no-cache-dir -r requirements.txt
COPY . /srv/gateway
# Ensure run.sh is executable
RUN chmod +x /srv/gateway/run.sh
WORKDIR /srv/gateway
# Run the application
CMD ["./run.sh"]

View File

@@ -8,8 +8,7 @@ from typing import Optional
import fastapi import fastapi
from httpx import HTTPStatusError from httpx import HTTPStatusError
from common.guardrails import Guardrail, GuardrailAction, GuardrailRuleSet from gateway.common.guardrails import Guardrail, GuardrailAction, GuardrailRuleSet
from common.authorization import extract_authorization_from_headers
def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[str]: def extract_policy_from_headers(request: Optional[fastapi.Request]) -> Optional[str]:
@@ -40,7 +39,7 @@ class GatewayConfig:
Loads the guardrails from the file specified in GUARDRAILS_FILE_PATH. Loads the guardrails from the file specified in GUARDRAILS_FILE_PATH.
Returns the guardrails file content as a string. Returns the guardrails file content as a string.
""" """
from integrations.guardrails import _preload from gateway.integrations.guardrails import _preload
guardrails_file = os.getenv("GUARDRAILS_FILE_PATH", "") guardrails_file = os.getenv("GUARDRAILS_FILE_PATH", "")
if not guardrails_file: if not guardrails_file:

View File

@@ -5,9 +5,9 @@ from typing import Any, Dict, Optional
import fastapi import fastapi
from common.config_manager import GatewayConfig from gateway.common.config_manager import GatewayConfig
from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
from common.authorization import ( from gateway.common.authorization import (
extract_guardrail_service_authorization_from_headers, extract_guardrail_service_authorization_from_headers,
) )

View File

@@ -5,7 +5,7 @@ from typing import Any, Dict, List
from fastapi import HTTPException from fastapi import HTTPException
from common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction from gateway.common.guardrails import GuardrailRuleSet, Guardrail, GuardrailAction
from invariant_sdk.async_client import AsyncClient from invariant_sdk.async_client import AsyncClient
from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse from invariant_sdk.types.push_traces import PushTracesRequest, PushTracesResponse
from invariant_sdk.types.annotations import AnnotationCreate from invariant_sdk.types.annotations import AnnotationCreate

View File

@@ -8,9 +8,10 @@ from functools import wraps
from fastapi import HTTPException from fastapi import HTTPException
import httpx import httpx
from common.guardrails import Guardrail
from common.request_context import RequestContext from gateway.common.guardrails import Guardrail
from common.authorization import ( from gateway.common.request_context import RequestContext
from gateway.common.authorization import (
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER, INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER,
) )

View File

@@ -1,6 +0,0 @@
fastapi==0.115.7
httpx==0.28.1
invariant-ai>=0.2.1
invariant-sdk>=0.0.10
starlette-compress==1.4.0
uvicorn==0.34.0

View File

@@ -8,27 +8,27 @@ import httpx
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
from common.authorization import extract_authorization_from_headers from gateway.common.authorization import extract_authorization_from_headers
from common.config_manager import ( from gateway.common.config_manager import (
GatewayConfig, GatewayConfig,
GatewayConfigManager, GatewayConfigManager,
GuardrailsInHeader, GuardrailsInHeader,
) )
from common.constants import ( from gateway.common.constants import (
CLIENT_TIMEOUT, CLIENT_TIMEOUT,
IGNORED_HEADERS, IGNORED_HEADERS,
) )
from common.guardrails import GuardrailAction, GuardrailRuleSet from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
from common.request_context import RequestContext from gateway.common.request_context import RequestContext
from converters.anthropic_to_invariant import ( from gateway.converters.anthropic_to_invariant import (
convert_anthropic_to_invariant_message_format, convert_anthropic_to_invariant_message_format,
) )
from integrations.explorer import ( from gateway.integrations.explorer import (
create_annotations_from_guardrails_errors, create_annotations_from_guardrails_errors,
fetch_guardrails_from_explorer, fetch_guardrails_from_explorer,
push_trace, push_trace,
) )
from integrations.guardrails import ( from gateway.integrations.guardrails import (
ExtraItem, ExtraItem,
InstrumentedResponse, InstrumentedResponse,
InstrumentedStreamingResponse, InstrumentedStreamingResponse,

View File

@@ -8,25 +8,25 @@ import httpx
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from common.authorization import extract_authorization_from_headers from gateway.common.authorization import extract_authorization_from_headers
from common.config_manager import ( from gateway.common.config_manager import (
GatewayConfig, GatewayConfig,
GatewayConfigManager, GatewayConfigManager,
GuardrailsInHeader, GuardrailsInHeader,
) )
from common.constants import ( from gateway.common.constants import (
CLIENT_TIMEOUT, CLIENT_TIMEOUT,
IGNORED_HEADERS, IGNORED_HEADERS,
) )
from common.guardrails import GuardrailAction, GuardrailRuleSet from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
from common.request_context import RequestContext from gateway.common.request_context import RequestContext
from converters.gemini_to_invariant import convert_request, convert_response from gateway.converters.gemini_to_invariant import convert_request, convert_response
from integrations.explorer import ( from gateway.integrations.explorer import (
create_annotations_from_guardrails_errors, create_annotations_from_guardrails_errors,
fetch_guardrails_from_explorer, fetch_guardrails_from_explorer,
push_trace, push_trace,
) )
from integrations.guardrails import ( from gateway.integrations.guardrails import (
ExtraItem, ExtraItem,
InstrumentedResponse, InstrumentedResponse,
InstrumentedStreamingResponse, InstrumentedStreamingResponse,

View File

@@ -8,24 +8,24 @@ import httpx
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from common.authorization import extract_authorization_from_headers from gateway.common.authorization import extract_authorization_from_headers
from common.config_manager import ( from gateway.common.config_manager import (
GatewayConfig, GatewayConfig,
GatewayConfigManager, GatewayConfigManager,
GuardrailsInHeader, GuardrailsInHeader,
) )
from common.constants import ( from gateway.common.constants import (
CLIENT_TIMEOUT, CLIENT_TIMEOUT,
IGNORED_HEADERS, IGNORED_HEADERS,
) )
from common.guardrails import GuardrailAction, GuardrailRuleSet from gateway.common.guardrails import GuardrailAction, GuardrailRuleSet
from common.request_context import RequestContext from gateway.common.request_context import RequestContext
from integrations.explorer import ( from gateway.integrations.explorer import (
create_annotations_from_guardrails_errors, create_annotations_from_guardrails_errors,
fetch_guardrails_from_explorer, fetch_guardrails_from_explorer,
push_trace, push_trace,
) )
from integrations.guardrails import ( from gateway.integrations.guardrails import (
ExtraItem, ExtraItem,
InstrumentedResponse, InstrumentedResponse,
InstrumentedStreamingResponse, InstrumentedStreamingResponse,

View File

@@ -1,5 +1,7 @@
#!/bin/bash #!/bin/bash
export PYTHONPATH=/srv
# Validate configuration # Validate configuration
python validate_config.py python validate_config.py

View File

@@ -2,11 +2,12 @@
import fastapi import fastapi
import uvicorn import uvicorn
from routes.anthropic import gateway as anthropic_gateway
from routes.gemini import gateway as gemini_gateway
from routes.open_ai import gateway as open_ai_gateway
from starlette_compress import CompressMiddleware from starlette_compress import CompressMiddleware
from gateway.routes.anthropic import gateway as anthropic_gateway
from gateway.routes.gemini import gateway as gemini_gateway
from gateway.routes.open_ai import gateway as open_ai_gateway
app = fastapi.app = fastapi.FastAPI( app = fastapi.app = fastapi.FastAPI(
docs_url="/api/v1/gateway/docs", docs_url="/api/v1/gateway/docs",
redoc_url="/api/v1/gateway/redoc", redoc_url="/api/v1/gateway/redoc",

View File

@@ -2,7 +2,7 @@
import sys import sys
from common.config_manager import GatewayConfigManager from gateway.common.config_manager import GatewayConfigManager
try: try:
_ = GatewayConfigManager.get_config() _ = GatewayConfigManager.get_config()

21
pyproject.toml Normal file
View File

@@ -0,0 +1,21 @@
[project]
name = "invariant-gateway"
version = "0.1.0"
description = "LLM proxy to observe and debug what your AI agents are doing"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"fastapi==0.115.7",
"httpx==0.28.1",
"invariant-ai>=0.2.1",
"invariant-sdk>=0.0.10",
"starlette-compress==1.4.0",
"uvicorn==0.34.0"
]
[tool.setuptools.packages.find]
where = ["."]
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

9
run.sh
View File

@@ -64,13 +64,12 @@ build() {
down() { down() {
# Bring down local services # Bring down local services
docker compose -f docker-compose.local.yml down docker compose -f docker-compose.local.yml down
GATEWAY_PATH=$(pwd)/gateway docker compose -f tests/integration/docker-compose.test.yml down GATEWAY_ROOT_PATH=$(pwd) docker compose -f tests/integration/docker-compose.test.yml down
} }
unit_tests() { unit_tests() {
echo "Running unit tests..." echo "Running unit tests..."
PYTHONPATH=. pytest tests/unit_tests $@
pytest tests/unit_tests $@
} }
integration_tests() { integration_tests() {
@@ -108,7 +107,7 @@ integration_tests() {
fi fi
fi fi
export GATEWAY_PATH=$(pwd)/gateway export GATEWAY_ROOT_PATH=$(pwd)
export GUARDRAILS_FILE_PATH="$TEST_GUARDRAILS_FILE_PATH" export GUARDRAILS_FILE_PATH="$TEST_GUARDRAILS_FILE_PATH"
# Start containers # Start containers
@@ -151,7 +150,7 @@ integration_tests() {
--env-file ./tests/integration/.env.test \ --env-file ./tests/integration/.env.test \
invariant-gateway-tests $@ invariant-gateway-tests $@
unset GATEWAY_PATH unset GATEWAY_ROOT_PATH
unset GUARDRAILS_FILE_PATH unset GUARDRAILS_FILE_PATH
} }

View File

@@ -24,8 +24,8 @@ services:
invariant-gateway: invariant-gateway:
container_name: invariant-gateway-test container_name: invariant-gateway-test
build: build:
context: ${GATEWAY_PATH} context: ${GATEWAY_ROOT_PATH}
dockerfile: ${GATEWAY_PATH}/Dockerfile.gateway dockerfile: ${GATEWAY_ROOT_PATH}/Dockerfile.gateway
depends_on: depends_on:
app-api: app-api:
condition: service_healthy condition: service_healthy
@@ -38,7 +38,7 @@ services:
- ${INVARIANT_API_KEY:+INVARIANT_API_KEY=${INVARIANT_API_KEY}} - ${INVARIANT_API_KEY:+INVARIANT_API_KEY=${INVARIANT_API_KEY}}
volumes: volumes:
- type: bind - type: bind
source: ${GATEWAY_PATH} source: ${GATEWAY_ROOT_PATH}/gateway
target: /srv/gateway target: /srv/gateway
- type: bind - type: bind
source: ${GUARDRAILS_FILE_PATH:-/dev/null} source: ${GUARDRAILS_FILE_PATH:-/dev/null}

View File

@@ -457,12 +457,12 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
assert ( assert (
annotations[0]["content"] == "ogre detected in response" annotations[0]["content"] == "ogre detected in response"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error" and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block" and annotations[0]["extra_metadata"]["guardrail"]["action"] == "block"
) )
assert ( assert (
annotations[1]["content"] == "Fiona detected in response" annotations[1]["content"] == "Fiona detected in response"
and annotations[1]["extra_metadata"]["source"] == "guardrails-error" and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
and annotations[1]["extra_metadata"]["guardrail-action"] == "log" and annotations[1]["extra_metadata"]["guardrail"]["action"] == "log"
) )
@@ -584,7 +584,7 @@ async def test_preguardrailing_with_guardrails_from_explorer(
assert ( assert (
annotations[0]["content"] == "pun detected in user message" annotations[0]["content"] == "pun detected in user message"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error" and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block" and annotations[0]["extra_metadata"]["guardrail"]["action"] == "block"
if is_block_action if is_block_action
else "log" else "log"
) )

View File

@@ -435,12 +435,12 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
assert ( assert (
annotations[0]["content"] == "ogre detected in response" annotations[0]["content"] == "ogre detected in response"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error" and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block" and annotations[0]["extra_metadata"]["guardrail"]["action"] == "block"
) )
assert ( assert (
annotations[1]["content"] == "Fiona detected in response" annotations[1]["content"] == "Fiona detected in response"
and annotations[1]["extra_metadata"]["source"] == "guardrails-error" and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
and annotations[1]["extra_metadata"]["guardrail-action"] == "log" and annotations[1]["extra_metadata"]["guardrail"]["action"] == "log"
) )
@@ -550,7 +550,7 @@ async def test_preguardrailing_with_guardrails_from_explorer(
assert ( assert (
annotations[0]["content"] == "pun detected in user message" annotations[0]["content"] == "pun detected in user message"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error" and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block" and annotations[0]["extra_metadata"]["guardrail"]["action"] == "block"
if is_block_action if is_block_action
else "log" else "log"
) )

View File

@@ -457,12 +457,12 @@ async def test_with_guardrails_from_explorer(explorer_api_url, gateway_url, do_s
assert ( assert (
annotations[0]["content"] == "ogre detected in response" annotations[0]["content"] == "ogre detected in response"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error" and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block" and annotations[0]["extra_metadata"]["guardrail"]["action"] == "block"
) )
assert ( assert (
annotations[1]["content"] == "Fiona detected in response" annotations[1]["content"] == "Fiona detected in response"
and annotations[1]["extra_metadata"]["source"] == "guardrails-error" and annotations[1]["extra_metadata"]["source"] == "guardrails-error"
and annotations[1]["extra_metadata"]["guardrail-action"] == "log" and annotations[1]["extra_metadata"]["guardrail"]["action"] == "log"
) )
@@ -581,7 +581,7 @@ async def test_preguardrailing_with_guardrails_from_explorer(
assert ( assert (
annotations[0]["content"] == "pun detected in user message" annotations[0]["content"] == "pun detected in user message"
and annotations[0]["extra_metadata"]["source"] == "guardrails-error" and annotations[0]["extra_metadata"]["source"] == "guardrails-error"
and annotations[0]["extra_metadata"]["guardrail-action"] == "block" and annotations[0]["extra_metadata"]["guardrail"]["action"] == "block"
if is_block_action if is_block_action
else "log" else "log"
) )

View File

@@ -7,10 +7,6 @@ import random
import string import string
import pytest import pytest
from gateway.common.config_manager import GatewayConfig
from gateway.common.request_context import RequestContext
# Add root folder (parent) to sys.path # Add root folder (parent) to sys.path
sys.path.append( sys.path.append(
os.path.dirname( os.path.dirname(
@@ -18,6 +14,8 @@ sys.path.append(
) )
) )
from gateway.common.config_manager import GatewayConfig
from gateway.common.request_context import RequestContext
from gateway.common.authorization import ( from gateway.common.authorization import (
INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER, INVARIANT_GUARDRAIL_SERVICE_AUTHORIZATION_HEADER,
extract_authorization_from_headers, extract_authorization_from_headers,