add litellm

This commit is contained in:
Zishan
2025-02-28 16:13:55 +01:00
parent 335715111d
commit c125ad141b
2 changed files with 92 additions and 5 deletions
+2 -1
View File
@@ -2,4 +2,5 @@ fastapi==0.115.7
httpx==0.28.1
invariant-sdk>=0.0.10
starlette-compress==1.4.0
uvicorn==0.34.0
uvicorn==0.34.0
litellm
+90 -4
View File
@@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from starlette.responses import StreamingResponse
from utils.constants import CLIENT_TIMEOUT, IGNORED_HEADERS
from utils.explorer import push_trace
import routes.open_ai as open_ai
import litellm
proxy = APIRouter()
@@ -86,6 +88,24 @@ async def anthropic_v1_messages_proxy(
response, dataset_name, request_body_json, invariant_authorization
)
async def push_to_explorer_with_openai_format(
dataset_name: str,
merged_response: dict[str, Any],
request_body: dict[str, Any],
invariant_authorization: str,
) -> None:
"""Pushes the full trace to the Invariant Explorer"""
# Combine the messages from the request body and the choices from the OpenAI response
messages = request_body.get("messages", [])
print("request body: ", request_body.get("messages", []))
messages += [choice["message"] for choice in merged_response.get("choices", [])]
_ = await push_trace(
dataset_name=dataset_name,
messages=[messages],
invariant_authorization=invariant_authorization,
)
async def push_to_explorer(
dataset_name: str,
@@ -97,7 +117,6 @@ async def push_to_explorer(
# Combine the messages from the request body and Anthropic response
messages = request_body.get("messages", [])
messages += [merged_response]
transformed_messages = connvert_anthropic_to_invariant_message_format(messages)
_ = await push_trace(
dataset_name=dataset_name,
@@ -127,10 +146,12 @@ async def handle_non_streaming_response(
)
# Only push the trace to explorer if the last message is an end turn message
if dataset_name:
await push_to_explorer(
request_messages_in_openai_format = connvert_anthropic_to_invariant_message_format(request_body_json.get("messages", []))
response_in_openai_format = await convert_to_litellm_messages(response, request_body_json)
await open_ai.push_to_explorer(
dataset_name,
json_response,
request_body_json,
response_in_openai_format.json(),
{"messages": request_messages_in_openai_format},
invariant_authorization,
)
return Response(
@@ -140,6 +161,71 @@ async def handle_non_streaming_response(
headers=dict(response.headers),
)
async def convert_to_litellm_messages(
raw_response: httpx.Response,
request_body_json: dict[str, Any],
) -> list[dict]:
import uuid
import time
import os
import tiktoken
from litellm import ProviderConfigManager
from litellm.utils import ModelResponse, LlmProviders
from litellm.litellm_core_utils.litellm_logging import Logging
model = request_body_json.get("model")
model_response = ModelResponse()
setattr(model_response, "usage", litellm.Usage())
messages = request_body_json.get("messages")
data = request_body_json
api_key = os.environ.get("ANTHROPIC_API_KEY")
call_type = "completion"
stream = False
custom_llm_provider="anthropic"
litellm_call_id = str(uuid.uuid4())
function_id = "None"
optional_params = {}
litellm_params = {}
start_time = time.time()
encoding = tiktoken.get_encoding("cl100k_base")
json_mode = False
logging_obj = Logging(
function_id=function_id,
model=model,
messages=messages,
call_type=call_type,
stream=stream,
start_time=start_time,
litellm_call_id=litellm_call_id,
)
logging_obj.update_environment_variables(
litellm_params=litellm_params,
optional_params=optional_params,
)
config = ProviderConfigManager.get_provider_chat_config(
model=model,
provider=LlmProviders(custom_llm_provider),
)
response = config.transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
json_mode=json_mode,
)
return response
async def handle_streaming_response(
client: httpx.AsyncClient,