mirror of
https://github.com/lightbroker/llmsecops-research.git
synced 2026-03-16 23:37:35 +00:00
call services from API instead of adapters
This commit is contained in:
@@ -16,7 +16,7 @@ from src.text_generation.adapters.llm.abstract_language_model import AbstractLan
|
||||
from src.text_generation.adapters.llm.text_generation_foundation_model import TextGenerationFoundationModel
|
||||
|
||||
|
||||
class Phi3LanguageModelWithRag(AbstractLanguageModel):
|
||||
class LanguageModelWithRag(AbstractLanguageModel):
|
||||
|
||||
def __init__(self):
|
||||
logger = logging.getLogger()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from src.text_generation.services.language_models.language_model_response_service import TextGenerationResponseService
|
||||
from src.text_generation.adapters.llm.language_model_with_rag import Phi3LanguageModelWithRag
|
||||
from src.text_generation.services.language_models.text_generation_response_service import TextGenerationResponseService
|
||||
from src.text_generation.services.language_models.retrieval_augmented_generation_response_service import RetrievalAugmentedGenerationResponseService
|
||||
from src.text_generation.services.logging.file_logging_service import FileLoggingService
|
||||
|
||||
class HttpApiController:
|
||||
@@ -12,7 +12,7 @@ class HttpApiController:
|
||||
# Register routes
|
||||
self.register_routes()
|
||||
self.text_generation_svc = TextGenerationResponseService()
|
||||
self.llm_rag_svc = Phi3LanguageModelWithRag()
|
||||
self.rag_svc = RetrievalAugmentedGenerationResponseService()
|
||||
|
||||
def register_routes(self):
|
||||
"""Register all API routes"""
|
||||
@@ -84,7 +84,7 @@ class HttpApiController:
|
||||
start_response('400 Bad Request', response_headers)
|
||||
return [response_body]
|
||||
|
||||
response_text = self.text_generation_svc.invoke(user_prompt=prompt)
|
||||
response_text = self.rag_svc.invoke(user_prompt=prompt)
|
||||
response_body = self.format_response(response_text)
|
||||
|
||||
http_status_code = 200 # make enum
|
||||
|
||||
@@ -5,7 +5,8 @@ from wsgiref.simple_server import make_server
|
||||
|
||||
class RestApiServer:
|
||||
def __init__(self):
|
||||
self.logger = FileLoggingService(filename='text_generation.server.log')
|
||||
logging_service = FileLoggingService(filename='text_generation.server.log')
|
||||
self.logger = logging_service.logger
|
||||
|
||||
def listen(self):
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
from src.text_generation.adapters.llm.language_model_with_rag import LanguageModelWithRag
|
||||
from src.text_generation.services.language_models.abstract_language_model_response_service import AbstractLanguageModelResponseService
|
||||
|
||||
|
||||
class RetrievalAugmentedGenerationResponseService(AbstractLanguageModelResponseService):
|
||||
|
||||
def invoke(self, user_prompt: str) -> str:
|
||||
|
||||
if not user_prompt:
|
||||
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
|
||||
|
||||
rag = LanguageModelWithRag()
|
||||
response = rag.invoke(user_prompt=user_prompt)
|
||||
return response
|
||||
Reference in New Issue
Block a user