From b074b388beb4fb88dba887d9a7606cb8a468692d Mon Sep 17 00:00:00 2001 From: Hemang Date: Thu, 30 Jan 2025 10:07:05 +0100 Subject: [PATCH] Add validation on LLM provider param. --- proxy/routes/proxy.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/proxy/routes/proxy.py b/proxy/routes/proxy.py index 1af476f..b460a06 100644 --- a/proxy/routes/proxy.py +++ b/proxy/routes/proxy.py @@ -1,11 +1,32 @@ """Proxy service to forward requests to the appropriate language model provider""" -from fastapi import APIRouter +from enum import Enum + +from fastapi import APIRouter, HTTPException proxy = APIRouter() +class LLMProvider(str, Enum): + """Supported language model providers""" + + OPEN_AI = "openai" + ANTHROPIC = "anthropic" + + @classmethod + def is_valid(cls, provider: str) -> bool: + """Check if a provider is a valid LLM provider""" + return provider in {provider.value for provider in cls} + + @proxy.post("/{username}/{dataset_name}/{llm_provider}") async def chat_completion(username: str, dataset_name: str, llm_provider: str): """Proxy call to a language model provider""" + + if not LLMProvider.is_valid(llm_provider): + raise HTTPException( + status_code=400, + detail=f"Unsupported LLM provider '{llm_provider}'.", + ) + return {"message": f"Upload {dataset_name} for {username} to {llm_provider}"}