From 55db93c8d339c56bc003a5b3100230eb5cea6604 Mon Sep 17 00:00:00 2001 From: Luca Beurer-Kellner Date: Fri, 28 Mar 2025 19:41:38 +0100 Subject: [PATCH] simplify request instrumentation --- .../{guardails.py => guardrails.py} | 241 ++++------ gateway/routes/anthropic.py | 2 +- gateway/routes/gemini.py | 2 +- gateway/routes/open_ai.py | 450 ++++++++++-------- gateway/serve.py | 13 + 5 files changed, 359 insertions(+), 349 deletions(-) rename gateway/integrations/{guardails.py => guardrails.py} (56%) diff --git a/gateway/integrations/guardails.py b/gateway/integrations/guardrails.py similarity index 56% rename from gateway/integrations/guardails.py rename to gateway/integrations/guardrails.py index feabfb8..6b3719a 100644 --- a/gateway/integrations/guardails.py +++ b/gateway/integrations/guardrails.py @@ -6,7 +6,13 @@ import time from typing import Any, Dict, List from functools import wraps +from fastapi.responses import StreamingResponse import httpx +<<<<<<< HEAD:gateway/integrations/guardails.py +======= +from zmq import IO_THREADS +from common.request_context_data import RequestContextData +>>>>>>> 91684ce (simplify request instrumentation):gateway/integrations/guardrails.py DEFAULT_API_URL = "https://explorer.invariantlabs.ai" @@ -99,102 +105,59 @@ async def preload_guardrails(context: "RequestContextData") -> None: print(f"Error scheduling preload_guardrails task: {e}") -class YieldException(Exception): +class ExtraItem: """ - Raise this exception in stream instrumentor listeners to - end the stream early, or to emit additional items in a stream. + Return this class in a instrumented stream callback, to yield an extra item in the resulting stream. """ def __init__(self, value, end_of_stream=False): - super().__init__(value) self.value = value self.end_of_stream = end_of_stream def __str__(self): - return f"YieldException: {self.value}" + return f"" -class StreamInstrumentor: - """ - A class to instrument async iterables with hooks for processing - chunks, before processing, and on completion. - - Use `@on('chunk')`, `@on('start')`, and `@on('end')` decorators - to register listeners for different events. - - Listeners can simply process data, or alternatively raise a designated - YieldException to yield additional values or stop the stream. - - Example usage: - - ``` - instrumentor = StreamInstrumentor() - - @instrumentor.on('chunk') - async def process_chunk(chunk): - # Process the chunk - print(f"Processing chunk: {chunk}") - - if some_condition: - # Yield an additional value that will be interleaved in the stream - # Pass `end_of_stream=True` to stop the stream after yielding - # Pass `end_of_stream=False` to continue the stream after the interleaved value - raise YieldException("Extra value", end_of_stream=True) - ``` - """ - +class InstrumentedStreamingResponse: def __init__(self): - # called on every chunk (async) - self.on_chunk_listeners = [] - # called once before the first chunk is processed, or even earlier (async) - self.before_listeners = [] - # called once on stream completion (async) - self.on_complete_listeners = [] - + # request statistics self.stat_token_times = [] self.stat_before_time = None self.stat_after_time = None self.stat_first_item_time = None - # decorator - def on(self, event: str): + async def on_chunk(self, chunk: Any) -> ExtraItem | None: """ - Decorator to register listeners for different events. + This called will be called on every chunk (async). + """ + pass + + async def on_start(self) -> ExtraItem | None: + """ + Decorator to register a listener for start events. + """ + pass + + async def on_end(self) -> ExtraItem | None: + """ + Decorator to register a listener for end events. + """ + pass + + async def event_generator(self): + """ + Streams the async iterable and invokes all instrumented hooks. Args: - event (str): The event to listen for. Can be 'on_chunk', - 'before', or 'on_complete'. + async_iterable: An async iterable to stream. - Returns: - Callable: A decorator to register the listener. + Yields: + The streamed data. """ + raise NotImplementedError("This method should be implemented in a subclass.") - def decorator(func): - assert asyncio.iscoroutinefunction( - func - ), "Listener must be an async function" - - if event == "chunk": - if self.on_chunk_listeners is None: - self.on_chunk_listeners = [] - self.on_chunk_listeners.append(func) - elif event == "start": - if self.before_listeners is None: - self.before_listeners = [] - self.before_listeners.append(func) - elif event == "end": - if self.on_complete_listeners is None: - self.on_complete_listeners = [] - self.on_complete_listeners.append(func) - else: - raise ValueError("Invalid event type. Use 'chunk', 'before', or 'end'.") - - return func - - return decorator - - async def stream(self, async_iterable): + async def instrumented_event_generator(self): """ Streams the async iterable and invokes all instrumented hooks. @@ -207,14 +170,11 @@ class StreamInstrumentor: try: start = time.time() - # schedule all before listeners which can be run concurrently - before_tasks = [ - asyncio.create_task(listener(), name="instrumentor:start") - for listener in self.before_listeners - ] + # schedule on_start which can be run concurrently + start_task = asyncio.create_task(self.on_start(), name="instrumentor:start") # create async iterator from async_iterable - aiterable = aiter(async_iterable) + aiterable = aiter(self.event_generator()) # [STAT] capture start time of first item start_first_item_request = time.time() @@ -233,31 +193,23 @@ class StreamInstrumentor: wait_for_first_item(), name="instrumentor:next:first" ) - # wait for all before listeners to finish - has_end_of_stream = False - for before_task in before_tasks: - try: - await before_task - except YieldException as e: - # yield extra value before any real items - yield e.value - # stop the stream if end_of_stream is True - if e.end_of_stream: - # if first item is already available - if not next_item_task.done(): - # cancel the task - next_item_task.cancel() - # [STAT] capture time to first item to be now +0.01 - if self.stat_first_item_time is None: - self.stat_first_item_time = ( - time.time() - start_first_item_request - ) + 0.01 - has_end_of_stream = True - - # don't wait for the first item if end_of stream is True - if has_end_of_stream: - # if end_of_stream is True, stop the stream - return + # check if 'start_task' yields an extra item + if extra_item := await start_task: + # yield extra value before any real items + yield extra_item.value + # stop the stream if end_of_stream is True + if extra_item.end_of_stream: + # if first item is already available + if not next_item_task.done(): + # cancel the task + next_item_task.cancel() + # [STAT] capture time to first item to be now +0.01 + if self.stat_first_item_time is None: + self.stat_first_item_time = ( + time.time() - start_first_item_request + ) + 0.01 + # don't wait for the first item if end_of stream is True + return # [STAT] capture before time stamp self.stat_before_time = time.time() - start @@ -282,35 +234,20 @@ class StreamInstrumentor: time.time() - start - sum(self.stat_token_times) ) - # invoke on_chunk listeners - any_end_of_stream = False - for listener in self.on_chunk_listeners: - try: - await listener(item) - except YieldException as e: - yield e.value - # if end_of_stream is True, stop the stream - if e.end_of_stream: - any_end_of_stream = True - - # if end_of_stream is True, stop the stream - if any_end_of_stream: - return + if extra_item := await self.on_chunk(item): + yield extra_item.value + # if end_of_stream is True, stop the stream + if extra_item.end_of_stream: + return # yield item yield item - on_complete_tasks = [ - asyncio.create_task(listener(), name="instrumentor:end") - for listener in self.on_complete_listeners - ] - for result in asyncio.as_completed(on_complete_tasks): - try: - await result - except YieldException as e: - # yield extra value before any real items - yield e.value - # we ignore end_of_stream here, because we are already at the end + # run on_end, before closing the stream (may yield an extra value) + if extra_item := await self.on_end(): + # yield extra value before any real items + yield extra_item.value + # we ignore end_of_stream here, because we are already at the end # [STAT] capture after time stamp self.stat_after_time = time.time() - start @@ -344,29 +281,35 @@ class StreamInstrumentor: print(f" [total: {time.time() - start:.2f}s]") -class RequestInstrumentor(StreamInstrumentor): +class InstrumentedResponse(InstrumentedStreamingResponse): """ - Like 'StreamInstrumentor', but for non-streaming requests. - - Supports similar 'start', 'end' events, but not 'chunk', since everything is assumed - to be processed in one chunk (i.e., the request). + A class to instrument an async request with hooks for concurrent + pre-processing and post-processing (input and output guardrailing). """ - def on(self, event): - assert event in [ - "start", - "end", - ], "RequestInstrumentor does not support 'chunk' events" - return super().on(event) + async def event_generator(self): + """ + We implement the 'event_generator' as a single item stream, + where the item is the full result of the request. + """ + yield await self.request() - async def execute(self, request_task): - async def wrapped_request_task(): - yield await request_task + async def request(self): + """ + This method should be implemented in a subclass to perform the actual request. + """ + raise NotImplementedError("This method should be implemented in a subclass.") - # pretend the 'request_task' is an async iterable with a single item - result = [item async for item in self.stream(wrapped_request_task())] - assert len(result) >= 1, "RequestInstrumentor must yield at least one item" - return result[-1] + async def instrumented_request(self): + """ + Returns the 'Response' object of the request, after applying all instrumented hooks. + """ + results = [r async for r in self.instrumented_event_generator()] + assert len(results) >= 1, "InstrumentedResponse must yield at least one item" + + # we return the last item, in case the end callback yields an extra item. Then, + # don't return the actual result but the 'end' result, e.g. for output guardrailing. + return results[-1] async def check_guardrails( @@ -395,6 +338,10 @@ async def check_guardrails( "Accept": "application/json", }, ) + if not result.is_success: + raise Exception( + f"Guardrails check failed: {result.status_code} - {result.text}" + ) print(f"Guardrail check response: {result.json()}") return result.json() except Exception as e: diff --git a/gateway/routes/anthropic.py b/gateway/routes/anthropic.py index 01004c6..807be91 100644 --- a/gateway/routes/anthropic.py +++ b/gateway/routes/anthropic.py @@ -18,7 +18,7 @@ from converters.anthropic_to_invariant import ( ) from common.authorization import extract_authorization_from_headers from common.request_context_data import RequestContextData -from integrations.guardails import check_guardrails, preload_guardrails +from integrations.guardrails import check_guardrails, preload_guardrails gateway = APIRouter() diff --git a/gateway/routes/gemini.py b/gateway/routes/gemini.py index 4042b4e..25f94e4 100644 --- a/gateway/routes/gemini.py +++ b/gateway/routes/gemini.py @@ -16,7 +16,7 @@ from common.authorization import extract_authorization_from_headers from common.request_context_data import RequestContextData from converters.gemini_to_invariant import convert_request, convert_response from integrations.explorer import create_annotations_from_guardrails_errors, push_trace -from integrations.guardails import check_guardrails, preload_guardrails +from integrations.guardrails import check_guardrails, preload_guardrails gateway = APIRouter() diff --git a/gateway/routes/open_ai.py b/gateway/routes/open_ai.py index 0418e2d..08f10e4 100644 --- a/gateway/routes/open_ai.py +++ b/gateway/routes/open_ai.py @@ -13,10 +13,10 @@ from common.constants import ( IGNORED_HEADERS, ) from integrations.explorer import create_annotations_from_guardrails_errors, push_trace -from integrations.guardails import ( - RequestInstrumentor, - StreamInstrumentor, - YieldException, +from integrations.guardrails import ( + ExtraItem, + InstrumentedResponse, + InstrumentedStreamingResponse, check_guardrails, preload_guardrails, ) @@ -89,23 +89,131 @@ async def openai_chat_completions_gateway( return await handle_non_streaming_response(context, client, open_ai_request) -async def stream_response( - context: RequestContextData, - client: httpx.AsyncClient, - open_ai_request: httpx.Request, -) -> Response: - """ - Handles streaming the OpenAI response to the client while building a merged_response - The chunks are returned to the caller immediately - The merged_response is built from the chunks as they are received - It is sent to the Invariant Explorer at the end of the stream - """ +class InstrumentedOpenAIStreamResponse(InstrumentedStreamingResponse): + def __init__( + self, + context: RequestContextData, + client: httpx.AsyncClient, + open_ai_request: httpx.Request, + ): + super().__init__() - async def request_and_stream(): + # request parameters + self.context: RequestContextData = context + self.client: httpx.AsyncClient = client + self.open_ai_request: httpx.Request = open_ai_request + + # guardrailing output (if any) + self.guardrails_execution_result: Optional[dict] = None + + # merged_response will be updated with the data from the chunks in the stream + # At the end of the stream, this will be sent to the explorer + self.merged_response = { + "id": None, + "object": "chat.completion", + "created": None, + "model": None, + "choices": [], + "usage": None, + } + + # Each chunk in the stream contains a list called "choices" each entry in the list + # has an index. + # A choice has a field called "delta" which may contain a list called "tool_calls". + # Maps the choice index in the stream to the index in the merged_response["choices"] list + self.choice_mapping_by_index = {} + # Combines the choice index and tool call index to uniquely identify a tool call + self.tool_call_mapping_by_index = {} + + async def on_start(self): + # check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing) + if self.context.config and self.context.config.guardrails: + # Block on the guardrails check + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, self.merged_response + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "error": { + "message": "[Invariant] The request did not pass the guardrails", + "details": self.guardrails_execution_result, + } + } + ) + + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.merged_response, + self.guardrails_execution_result, + ) + ) + + # if we find something, we end the stream prematurely (end_of_stream=True) + # and yield an error chunk instead of actually beginning the stream + return ExtraItem( + f"data: {error_chunk}\n\n".encode(), end_of_stream=True + ) + + async def on_chunk(self, chunk): + # process and check each chunk + chunk_text = chunk.decode().strip() + if not chunk_text: + return + + # Process the chunk + # This will update merged_response with the data from the chunk + process_chunk_text( + chunk_text, + self.merged_response, + self.choice_mapping_by_index, + self.tool_call_mapping_by_index, + ) + + # check guardrails at the end of the stream (on the '[DONE]' SSE chunk.) + if ( + "data: [DONE]" in chunk_text + and self.context.config + and self.context.config.guardrails + ): + # Block on the guardrails check + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, self.merged_response + ) + if self.guardrails_execution_result.get("errors", []): + error_chunk = json.dumps( + { + "error": { + "message": "[Invariant] The response did not pass the guardrails", + "details": self.guardrails_execution_result, + } + } + ) + + # yield an extra error chunk (without preventing the original chunk to go through after) + return ExtraItem(f"data: {error_chunk}\n\n".encode()) + + # push will happen in on_end + + async def on_end(self): + # Send full merged response to the explorer + # Don't block on the response from explorer + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, self.merged_response, self.guardrails_execution_result + ) + ) + + async def event_generator(self): """ - Sets off the request and then streams the result. + Actual OpenAI stream response. """ - response = await client.send(open_ai_request, stream=True) + + response = await self.client.send(self.open_ai_request, stream=True) if response.status_code != 200: error_content = await response.aread() try: @@ -119,123 +227,28 @@ async def stream_response( async for chunk in response.aiter_bytes(): yield chunk - async def event_generator() -> Any: - # merged_response will be updated with the data from the chunks in the stream - # At the end of the stream, this will be sent to the explorer - merged_response = { - "id": None, - "object": "chat.completion", - "created": None, - "model": None, - "choices": [], - "usage": None, - } - # Each chunk in the stream contains a list called "choices" each entry in the list - # has an index. - # A choice has a field called "delta" which may contain a list called "tool_calls". - # Maps the choice index in the stream to the index in the merged_response["choices"] list - choice_mapping_by_index = {} - # Combines the choice index and tool call index to uniquely identify a tool call - tool_call_mapping_by_index = {} +async def stream_response( + context: RequestContextData, + client: httpx.AsyncClient, + open_ai_request: httpx.Request, +) -> Response: + """ + Handles streaming the OpenAI response to the client while building a merged_response + The chunks are returned to the caller immediately + The merged_response is built from the chunks as they are received + It is sent to the Invariant Explorer at the end of the stream + """ - # prepare stream instrumentor - instrumentor = StreamInstrumentor() + response = InstrumentedOpenAIStreamResponse( + context, + client, + open_ai_request, + ) - @instrumentor.on("start") - async def precheck_guardrails() -> None: - # check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing) - if context.config and context.config.guardrails: - # Block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result( - context, merged_response - ) - if guardrails_execution_result.get("errors", []): - error_chunk = json.dumps( - { - "error": { - "message": "[Invariant] The request did not pass the guardrails", - "details": guardrails_execution_result, - } - } - ) - - # Push annotated trace to the explorer - don't block on its response - if context.dataset_name: - asyncio.create_task( - push_to_explorer( - context, - merged_response, - guardrails_execution_result, - ) - ) - - # if we find something, we end the stream prematurely (end_of_stream=True) - # and yield an error chunk instead of actually beginning the stream - raise YieldException( - f"data: {error_chunk}\n\n".encode(), end_of_stream=True - ) - - @instrumentor.on("chunk") - async def process_chunk(chunk: bytes) -> None: - # process and check each chunk - chunk_text = chunk.decode().strip() - if not chunk_text: - return - - # Process the chunk - # This will update merged_response with the data from the chunk - process_chunk_text( - chunk_text, - merged_response, - choice_mapping_by_index, - tool_call_mapping_by_index, - ) - - # check guardrails at the end of the stream (on the '[DONE]' SSE chunk.) - if ( - "data: [DONE]" in chunk_text - and context.config - and context.config.guardrails - ): - # Block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result( - context, merged_response - ) - if guardrails_execution_result.get("errors", []): - error_chunk = json.dumps( - { - "error": { - "message": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, - } - } - ) - # Push annotated trace to the explorer - don't block on its response - if context.dataset_name: - asyncio.create_task( - push_to_explorer( - context, - merged_response, - guardrails_execution_result, - ) - ) - - # yield an extra error chunk (without preventing the original chunk to go through after) - raise YieldException(f"data: {error_chunk}\n\n".encode()) - - @instrumentor.on("end") - async def send_to_explorer() -> None: - # Send full merged response to the explorer - # Don't block on the response from explorer - if context.dataset_name: - asyncio.create_task(push_to_explorer(context, merged_response)) - - async for chunk in instrumentor.stream(request_and_stream()): - # Yield chunk to the client - yield chunk - - return StreamingResponse(event_generator(), media_type="text/event-stream") + return StreamingResponse( + response.instrumented_event_generator(), media_type="text/event-stream" + ) def initialize_merged_response() -> dict[str, Any]: @@ -436,73 +449,51 @@ async def get_guardrails_check_result( return guardrails_execution_result -async def handle_non_streaming_response( - context: RequestContextData, - client: httpx.AsyncClient, - open_ai_request: httpx.Request, -) -> Response: - """Handles non-streaming OpenAI responses""" +class InstrumentedOpenAIResponse(InstrumentedResponse): + def __init__( + self, + context: RequestContextData, + client: httpx.AsyncClient, + open_ai_request: httpx.Request, + ): + super().__init__() - instrumentor = RequestInstrumentor() + # request parameters + self.context: RequestContextData = context + self.client: httpx.AsyncClient = client + self.open_ai_request: httpx.Request = open_ai_request - # respond we get and its JSON decoded version - # available once the 'send_request' function has progressed to the point of - # being able to call 'response.json()' - response = None - json_response = None + # request outputs + self.response: Optional[httpx.Response] = None + self.json_response: Optional[dict[str, Any]] = None - async def send_request(): - nonlocal response, json_response + self.guardrails_execution_result: Optional[dict] = None - response = await client.send(open_ai_request) - - try: - json_response = response.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=response.status_code, - detail="Invalid JSON response received from OpenAI API", - ) from e - if response.status_code != 200: - raise HTTPException( - status_code=response.status_code, - detail=json_response.get("error", "Unknown error from OpenAI API"), - ) - - response_string = json.dumps(json_response) - response_code = response.status_code - - return Response( - content=response_string, - status_code=response_code, - media_type="application/json", - headers=dict(response.headers), - ) - - @instrumentor.on("start") - async def precheck_guardrails() -> None: + async def on_start(self): # check guardrails in a pipelined fashion, before processing the first chunk (for input guardrailing) - if context.config and context.config.guardrails: + if self.context.config and self.context.config.guardrails: # block on the guardrails check - guardrails_execution_result = await get_guardrails_check_result(context) - if guardrails_execution_result.get("errors", []): + self.guardrails_execution_result = await get_guardrails_check_result( + self.context + ) + if self.guardrails_execution_result.get("errors", []): # Push annotated trace to the explorer - don't block on its response - if context.dataset_name: + if self.context.dataset_name: asyncio.create_task( push_to_explorer( - context, + self.context, {}, - guardrails_execution_result, + self.guardrails_execution_result, ) ) # replace the response with the error message - raise YieldException( + return ExtraItem( Response( content=json.dumps( { "error": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, + "details": self.guardrails_execution_result, } ), status_code=400, @@ -511,38 +502,77 @@ async def handle_non_streaming_response( end_of_stream=True, ) - @instrumentor.on("end") - async def postprocess_guardrails() -> None: + async def request(self): + """ + Actual OpenAI request. + """ + self.response = await self.client.send(self.open_ai_request) + + try: + self.json_response = self.response.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=self.response.status_code, + detail="Invalid JSON response received from OpenAI API", + ) from e + if self.response.status_code != 200: + raise HTTPException( + status_code=self.response.status_code, + detail=self.json_response.get("error", "Unknown error from OpenAI API"), + ) + + response_string = json.dumps(self.json_response) + response_code = self.response.status_code + + return Response( + content=response_string, + status_code=response_code, + media_type="application/json", + headers=dict(self.response.headers), + ) + + async def on_end(self): + """ + Postprocess the OpenAI response and potentially replace it with a guardrails error. + """ + # these two are guaranteed to be set by the time we reach this point (after self.request() was executed) + assert ( + self.response is not None + ), "on_end called before 'self.response' was available" + assert ( + self.json_response is not None + ), "on_end called before 'self.json_response' was available" + # at this point, we are guaranteed that 'send_request' has already been executed successfully - response_code = response.status_code + response_code = self.response.status_code # if we have guardrails, check the response - if context.config and context.config.guardrails: + if self.context.config and self.context.config.guardrails: # run guardrails again, this time on request + response - guardrails_execution_result = await get_guardrails_check_result( - context, json_response + self.guardrails_execution_result = await get_guardrails_check_result( + self.context, self.json_response ) - if guardrails_execution_result.get("errors", []): + if self.guardrails_execution_result.get("errors", []): response_string = json.dumps( { "error": "[Invariant] The response did not pass the guardrails", - "details": guardrails_execution_result, + "details": self.guardrails_execution_result, } ) response_code = 400 # Push annotated trace to the explorer - don't block on its response - if context.dataset_name: + if self.context.dataset_name: asyncio.create_task( push_to_explorer( - context, - json_response, - guardrails_execution_result, + self.context, + self.json_response, + self.guardrails_execution_result, ) ) # replace the response with the error message - raise YieldException( + return ExtraItem( Response( content=response_string, status_code=response_code, @@ -550,10 +580,30 @@ async def handle_non_streaming_response( ), ) - # if we don't have guardrails or if the response passed the guardrails (only then, we reach this point) - if context.dataset_name: - # Push to Explorer - don't block on its response - asyncio.create_task(push_to_explorer(context, json_response)) + # Push annotated trace to the explorer - don't block on its response + if self.context.dataset_name: + asyncio.create_task( + push_to_explorer( + self.context, + self.json_response, + self.guardrails_execution_result, + ) + ) - # execute instrumented request - return await instrumentor.execute(send_request()) + +async def handle_non_streaming_response( + context: RequestContextData, + client: httpx.AsyncClient, + open_ai_request: httpx.Request, +) -> Response: + """Handles non-streaming OpenAI responses""" + + # # execute instrumented request + # return await instrumentor.execute(send_request()) + response = InstrumentedOpenAIResponse( + context, + client, + open_ai_request, + ) + + return await response.instrumented_request() diff --git a/gateway/serve.py b/gateway/serve.py index e3683f9..0d73096 100644 --- a/gateway/serve.py +++ b/gateway/serve.py @@ -2,15 +2,28 @@ import fastapi import uvicorn +from common.config_manager import GatewayConfigManager from routes.anthropic import gateway as anthropic_gateway from routes.gemini import gateway as gemini_gateway from routes.open_ai import gateway as open_ai_gateway from starlette_compress import CompressMiddleware +from contextlib import asynccontextmanager + + +@asynccontextmanager +async def lifespan(app: fastapi.FastAPI): + """Lifespan event to load the config manager""" + gateway_config = GatewayConfigManager.get_config() + yield + # Cleanup if needed + del gateway_config + app = fastapi.app = fastapi.FastAPI( docs_url="/api/v1/gateway/docs", redoc_url="/api/v1/gateway/redoc", openapi_url="/api/v1/gateway/openapi.json", + lifespan=lifespan, ) app.add_middleware(CompressMiddleware)