mirror of
https://github.com/invariantlabs-ai/invariant-gateway.git
synced 2026-07-01 16:55:31 +02:00
add litellm
This commit is contained in:
@@ -2,4 +2,5 @@ fastapi==0.115.7
|
||||
httpx==0.28.1
|
||||
invariant-sdk>=0.0.10
|
||||
starlette-compress==1.4.0
|
||||
uvicorn==0.34.0
|
||||
uvicorn==0.34.0
|
||||
litellm
|
||||
@@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from starlette.responses import StreamingResponse
|
||||
from utils.constants import CLIENT_TIMEOUT, IGNORED_HEADERS
|
||||
from utils.explorer import push_trace
|
||||
import routes.open_ai as open_ai
|
||||
import litellm
|
||||
|
||||
proxy = APIRouter()
|
||||
|
||||
@@ -86,6 +88,24 @@ async def anthropic_v1_messages_proxy(
|
||||
response, dataset_name, request_body_json, invariant_authorization
|
||||
)
|
||||
|
||||
async def push_to_explorer_with_openai_format(
|
||||
dataset_name: str,
|
||||
merged_response: dict[str, Any],
|
||||
request_body: dict[str, Any],
|
||||
invariant_authorization: str,
|
||||
) -> None:
|
||||
"""Pushes the full trace to the Invariant Explorer"""
|
||||
|
||||
# Combine the messages from the request body and the choices from the OpenAI response
|
||||
messages = request_body.get("messages", [])
|
||||
print("request body: ", request_body.get("messages", []))
|
||||
messages += [choice["message"] for choice in merged_response.get("choices", [])]
|
||||
|
||||
_ = await push_trace(
|
||||
dataset_name=dataset_name,
|
||||
messages=[messages],
|
||||
invariant_authorization=invariant_authorization,
|
||||
)
|
||||
|
||||
async def push_to_explorer(
|
||||
dataset_name: str,
|
||||
@@ -97,7 +117,6 @@ async def push_to_explorer(
|
||||
# Combine the messages from the request body and Anthropic response
|
||||
messages = request_body.get("messages", [])
|
||||
messages += [merged_response]
|
||||
|
||||
transformed_messages = connvert_anthropic_to_invariant_message_format(messages)
|
||||
_ = await push_trace(
|
||||
dataset_name=dataset_name,
|
||||
@@ -127,10 +146,12 @@ async def handle_non_streaming_response(
|
||||
)
|
||||
# Only push the trace to explorer if the last message is an end turn message
|
||||
if dataset_name:
|
||||
await push_to_explorer(
|
||||
request_messages_in_openai_format = connvert_anthropic_to_invariant_message_format(request_body_json.get("messages", []))
|
||||
response_in_openai_format = await convert_to_litellm_messages(response, request_body_json)
|
||||
await open_ai.push_to_explorer(
|
||||
dataset_name,
|
||||
json_response,
|
||||
request_body_json,
|
||||
response_in_openai_format.json(),
|
||||
{"messages": request_messages_in_openai_format},
|
||||
invariant_authorization,
|
||||
)
|
||||
return Response(
|
||||
@@ -140,6 +161,71 @@ async def handle_non_streaming_response(
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
|
||||
async def convert_to_litellm_messages(
|
||||
raw_response: httpx.Response,
|
||||
request_body_json: dict[str, Any],
|
||||
) -> list[dict]:
|
||||
import uuid
|
||||
import time
|
||||
import os
|
||||
import tiktoken
|
||||
from litellm import ProviderConfigManager
|
||||
from litellm.utils import ModelResponse, LlmProviders
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
|
||||
model = request_body_json.get("model")
|
||||
model_response = ModelResponse()
|
||||
setattr(model_response, "usage", litellm.Usage())
|
||||
messages = request_body_json.get("messages")
|
||||
data = request_body_json
|
||||
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
|
||||
call_type = "completion"
|
||||
stream = False
|
||||
custom_llm_provider="anthropic"
|
||||
|
||||
litellm_call_id = str(uuid.uuid4())
|
||||
function_id = "None"
|
||||
optional_params = {}
|
||||
litellm_params = {}
|
||||
start_time = time.time()
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
json_mode = False
|
||||
|
||||
logging_obj = Logging(
|
||||
function_id=function_id,
|
||||
model=model,
|
||||
messages=messages,
|
||||
call_type=call_type,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
litellm_call_id=litellm_call_id,
|
||||
)
|
||||
|
||||
logging_obj.update_environment_variables(
|
||||
litellm_params=litellm_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model,
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
response = config.transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
return response
|
||||
|
||||
async def handle_streaming_response(
|
||||
client: httpx.AsyncClient,
|
||||
|
||||
Reference in New Issue
Block a user