mirror of
https://github.com/lightbroker/llmsecops-research.git
synced 2026-03-18 16:23:37 +00:00
support testing malicious prompts with no guidelines
This commit is contained in:
@@ -1,30 +1,24 @@
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from src.text_generation.services.nlp.text_generation_response_service import TextGenerationResponseService
|
||||
from src.text_generation.services.nlp.retrieval_augmented_generation_response_service import RetrievalAugmentedGenerationResponseService
|
||||
from src.text_generation.services.logging.file_logging_service import FileLoggingService
|
||||
from src.text_generation.services.guardrails.generated_text_guardrail_service import GeneratedTextGuardrailService
|
||||
from src.text_generation.services.logging.abstract_web_traffic_logging_service import AbstractWebTrafficLoggingService
|
||||
from src.text_generation.services.nlp.abstract_language_model_response_service import AbstractLanguageModelResponseService
|
||||
from src.text_generation.services.guardrails.abstract_generated_text_guardrail_service import AbstractGeneratedTextGuardrailService
|
||||
|
||||
|
||||
|
||||
class HttpApiController:
|
||||
def __init__(
|
||||
self,
|
||||
logging_service: FileLoggingService,
|
||||
text_generation_response_service: TextGenerationResponseService,
|
||||
rag_response_service: RetrievalAugmentedGenerationResponseService,
|
||||
generated_text_guardrail_service: GeneratedTextGuardrailService
|
||||
logging_service: AbstractWebTrafficLoggingService,
|
||||
text_generation_response_service: AbstractLanguageModelResponseService,
|
||||
rag_response_service: AbstractLanguageModelResponseService,
|
||||
generated_text_guardrail_service: AbstractGeneratedTextGuardrailService
|
||||
):
|
||||
self.logger = logging_service.logger
|
||||
|
||||
# TODO: temp debug
|
||||
self.original_info = self.logger.info
|
||||
self.logger.info = self.debug_info
|
||||
|
||||
self.logging_service = logging_service
|
||||
self.text_generation_response_service = text_generation_response_service
|
||||
self.rag_response_service = rag_response_service
|
||||
self.generated_text_guardrail_service = generated_text_guardrail_service
|
||||
|
||||
self.routes = {}
|
||||
self.register_routes()
|
||||
|
||||
@@ -39,23 +33,15 @@ class HttpApiController:
|
||||
|
||||
|
||||
def register_routes(self):
|
||||
"""Register all API routes"""
|
||||
self.routes[('GET', '/')] = self.health_check
|
||||
self.routes[('POST', '/api/conversations')] = self.handle_conversations
|
||||
self.routes[('POST', '/api/rag_conversations')] = self.handle_conversations_with_rag
|
||||
|
||||
def __http_415_notsupported(self, env, start_response):
|
||||
response_headers = [('Content-Type', 'application/json')]
|
||||
start_response('415 Unsupported Media Type', response_headers)
|
||||
return [json.dumps({'error': 'Unsupported Content-Type'}).encode('utf-8')]
|
||||
self.routes[('POST', '/api/completions')] = self.handle_conversations
|
||||
self.routes[('POST', '/api/completions/rag-guided')] = self.handle_conversations_with_rag
|
||||
|
||||
def format_response(self, data):
|
||||
"""Format response data as JSON with 'response' key"""
|
||||
response_data = {'response': data}
|
||||
try:
|
||||
response_body = json.dumps(response_data).encode('utf-8')
|
||||
except:
|
||||
# If serialization fails, convert data to string first
|
||||
response_body = json.dumps({'response': str(data)}).encode('utf-8')
|
||||
return response_body
|
||||
|
||||
@@ -66,7 +52,7 @@ class HttpApiController:
|
||||
return [response_body]
|
||||
|
||||
def handle_conversations(self, env, start_response):
|
||||
"""Handle POST requests to /api/conversations"""
|
||||
"""POST /api/completions"""
|
||||
try:
|
||||
request_body_size = int(env.get('CONTENT_LENGTH', 0))
|
||||
except ValueError:
|
||||
@@ -89,11 +75,11 @@ class HttpApiController:
|
||||
http_status_code = 200 # make enum
|
||||
response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))]
|
||||
start_response(f'{http_status_code} OK', response_headers)
|
||||
self.logger.info('non-RAG response: request body %s | status: %s | response: %s', request_body, http_status_code, response_body)
|
||||
self.logging_service.log_request_response(request=prompt, response=response_text)
|
||||
return [response_body]
|
||||
|
||||
def handle_conversations_with_rag(self, env, start_response):
|
||||
"""Handle POST requests to /api/rag_conversations with RAG functionality"""
|
||||
"""POST /api/completions/rag-guided"""
|
||||
try:
|
||||
request_body_size = int(env.get('CONTENT_LENGTH', 0))
|
||||
except ValueError:
|
||||
@@ -116,7 +102,7 @@ class HttpApiController:
|
||||
http_status_code = 200 # make enum
|
||||
response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))]
|
||||
start_response(f'{http_status_code} OK', response_headers)
|
||||
self.logger.info('RAG response: request body %s | status: %s | response: %s', request_body, http_status_code, response_body)
|
||||
self.logging_service.log_request_response(request=prompt, response=response_text)
|
||||
return [response_body]
|
||||
|
||||
def _http_200_ok(self, env, start_response):
|
||||
|
||||
Reference in New Issue
Block a user