diff --git a/proxy/requirements.txt b/proxy/requirements.txt index a564b13..3127132 100644 --- a/proxy/requirements.txt +++ b/proxy/requirements.txt @@ -2,4 +2,5 @@ fastapi==0.115.7 httpx==0.28.1 invariant-sdk>=0.0.10 starlette-compress==1.4.0 -uvicorn==0.34.0 \ No newline at end of file +uvicorn==0.34.0 +litellm \ No newline at end of file diff --git a/proxy/routes/anthropic.py b/proxy/routes/anthropic.py index 182485c..d552c37 100644 --- a/proxy/routes/anthropic.py +++ b/proxy/routes/anthropic.py @@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from starlette.responses import StreamingResponse from utils.constants import CLIENT_TIMEOUT, IGNORED_HEADERS from utils.explorer import push_trace +import routes.open_ai as open_ai +import litellm proxy = APIRouter() @@ -86,6 +88,24 @@ async def anthropic_v1_messages_proxy( response, dataset_name, request_body_json, invariant_authorization ) +async def push_to_explorer_with_openai_format( + dataset_name: str, + merged_response: dict[str, Any], + request_body: dict[str, Any], + invariant_authorization: str, +) -> None: + """Pushes the full trace to the Invariant Explorer""" + + # Combine the messages from the request body and the choices from the OpenAI response + messages = request_body.get("messages", []) + print("request body: ", request_body.get("messages", [])) + messages += [choice["message"] for choice in merged_response.get("choices", [])] + + _ = await push_trace( + dataset_name=dataset_name, + messages=[messages], + invariant_authorization=invariant_authorization, + ) async def push_to_explorer( dataset_name: str, @@ -97,7 +117,6 @@ async def push_to_explorer( # Combine the messages from the request body and Anthropic response messages = request_body.get("messages", []) messages += [merged_response] - transformed_messages = connvert_anthropic_to_invariant_message_format(messages) _ = await push_trace( dataset_name=dataset_name, @@ -127,10 +146,12 @@ async def handle_non_streaming_response( ) # Only push the trace to explorer if the last message is an end turn message if dataset_name: - await push_to_explorer( + request_messages_in_openai_format = connvert_anthropic_to_invariant_message_format(request_body_json.get("messages", [])) + response_in_openai_format = await convert_to_litellm_messages(response, request_body_json) + await open_ai.push_to_explorer( dataset_name, - json_response, - request_body_json, + response_in_openai_format.json(), + {"messages": request_messages_in_openai_format}, invariant_authorization, ) return Response( @@ -140,6 +161,71 @@ async def handle_non_streaming_response( headers=dict(response.headers), ) +async def convert_to_litellm_messages( + raw_response: httpx.Response, + request_body_json: dict[str, Any], + ) -> list[dict]: + import uuid + import time + import os + import tiktoken + from litellm import ProviderConfigManager + from litellm.utils import ModelResponse, LlmProviders + from litellm.litellm_core_utils.litellm_logging import Logging + + model = request_body_json.get("model") + model_response = ModelResponse() + setattr(model_response, "usage", litellm.Usage()) + messages = request_body_json.get("messages") + data = request_body_json + api_key = os.environ.get("ANTHROPIC_API_KEY") + + call_type = "completion" + stream = False + custom_llm_provider="anthropic" + + litellm_call_id = str(uuid.uuid4()) + function_id = "None" + optional_params = {} + litellm_params = {} + start_time = time.time() + encoding = tiktoken.get_encoding("cl100k_base") + json_mode = False + + logging_obj = Logging( + function_id=function_id, + model=model, + messages=messages, + call_type=call_type, + stream=stream, + start_time=start_time, + litellm_call_id=litellm_call_id, + ) + + logging_obj.update_environment_variables( + litellm_params=litellm_params, + optional_params=optional_params, + ) + + config = ProviderConfigManager.get_provider_chat_config( + model=model, + provider=LlmProviders(custom_llm_provider), + ) + + response = config.transform_response( + model=model, + raw_response=raw_response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + json_mode=json_mode, + ) + return response async def handle_streaming_response( client: httpx.AsyncClient,