mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-05-16 04:49:04 +02:00
95 lines
3.2 KiB
Python
95 lines
3.2 KiB
Python
import pytest
|
|
import uuid
|
|
from litellm import completion
|
|
import litellm
|
|
import time
|
|
import requests
|
|
import os
|
|
|
|
MODEL_API_KEYS = {
|
|
"openai/gpt-4o": "OPENAI_API_KEY",
|
|
"gemini/gemini-2.0-flash": "GEMINI_API_KEY",
|
|
"anthropic/claude-3-5-haiku-20241022": "ANTHROPIC_API_KEY",
|
|
}
|
|
|
|
@pytest.mark.parametrize("litellm_model", MODEL_API_KEYS.keys(),)
|
|
@pytest.mark.parametrize(
|
|
"do_stream, push_to_explorer",
|
|
[(False, False)],
|
|
)
|
|
async def test_chat_completion(
|
|
explorer_api_url: str, litellm_model: str, gateway_url: str, do_stream: bool, push_to_explorer: bool
|
|
):
|
|
"""Test the chat completions gateway calls with tool calling through litellm."""
|
|
# Check if the API key is set in the environment variables
|
|
api_key_env_var = MODEL_API_KEYS[litellm_model]
|
|
api_key = os.getenv(api_key_env_var)
|
|
|
|
if not api_key:
|
|
pytest.skip(f"Skipping {litellm_model} because {api_key_env_var} is not set")
|
|
|
|
|
|
dataset_name = f"test-dataset-litellm-{litellm_model}-{uuid.uuid4()}"
|
|
base_url = (f"{gateway_url}/api/v1/gateway/{dataset_name}"
|
|
if push_to_explorer
|
|
else f"{gateway_url}/api/v1/gateway")
|
|
|
|
base_url += "/" + litellm_model.split("/")[0] #add provider name
|
|
if litellm_model.split("/")[0] == "gemini":
|
|
base_url += f"/v1beta/models/{litellm_model.split('/')[1]}" #gemini expects the model name in the url
|
|
|
|
print(f"base_url: {base_url}")
|
|
chat_response = completion(
|
|
model=litellm_model,
|
|
messages=[{"role": "user", "content": "What is the capital of France?"}],
|
|
extra_headers= {"Invariant-Authorization": "Bearer <some-key>"},
|
|
stream=do_stream,
|
|
base_url=base_url,
|
|
)
|
|
|
|
# Verify the chat response
|
|
if not do_stream:
|
|
assert "PARIS" in chat_response.choices[0].message.content.upper()
|
|
expected_assistant_message = chat_response.choices[0].message.content
|
|
else:
|
|
full_response = ""
|
|
for chunk in chat_response:
|
|
if chunk.choices and chunk.choices[0].delta.content:
|
|
full_response += chunk.choices[0].delta.content
|
|
assert "PARIS" in full_response.upper()
|
|
expected_assistant_message = full_response
|
|
|
|
if push_to_explorer:
|
|
# Wait for the trace to be saved
|
|
# This is needed because the trace is saved asynchronously
|
|
time.sleep(2)
|
|
# Fetch the trace ids for the dataset
|
|
traces_response = requests.get(
|
|
f"{explorer_api_url}/api/v1/dataset/byuser/developer/{dataset_name}/traces",
|
|
timeout=5,
|
|
)
|
|
traces = traces_response.json()
|
|
assert len(traces) == 1
|
|
trace_id = traces[0]["id"]
|
|
|
|
# Fetch the trace
|
|
trace_response = requests.get(
|
|
f"{explorer_api_url}/api/v1/trace/{trace_id}",
|
|
timeout=5,
|
|
)
|
|
trace = trace_response.json()
|
|
|
|
for message in trace["messages"]:
|
|
message.pop("annotations", None)
|
|
|
|
# Verify the trace messages
|
|
assert trace["messages"] == [
|
|
{
|
|
"role": "user",
|
|
"content": "What is the capital of France?",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": expected_assistant_message,
|
|
},
|
|
] |