support testing malicious prompts with no guidelines

This commit is contained in:
Adam Wilson
2025-06-28 12:18:35 -06:00
parent 036d36bf4f
commit cb1be6746f
18 changed files with 447 additions and 76 deletions

View File

@@ -1,30 +1,24 @@
import json
import traceback
from src.text_generation.services.nlp.text_generation_response_service import TextGenerationResponseService
from src.text_generation.services.nlp.retrieval_augmented_generation_response_service import RetrievalAugmentedGenerationResponseService
from src.text_generation.services.logging.file_logging_service import FileLoggingService
from src.text_generation.services.guardrails.generated_text_guardrail_service import GeneratedTextGuardrailService
from src.text_generation.services.logging.abstract_web_traffic_logging_service import AbstractWebTrafficLoggingService
from src.text_generation.services.nlp.abstract_language_model_response_service import AbstractLanguageModelResponseService
from src.text_generation.services.guardrails.abstract_generated_text_guardrail_service import AbstractGeneratedTextGuardrailService
class HttpApiController:
def __init__(
self,
logging_service: FileLoggingService,
text_generation_response_service: TextGenerationResponseService,
rag_response_service: RetrievalAugmentedGenerationResponseService,
generated_text_guardrail_service: GeneratedTextGuardrailService
logging_service: AbstractWebTrafficLoggingService,
text_generation_response_service: AbstractLanguageModelResponseService,
rag_response_service: AbstractLanguageModelResponseService,
generated_text_guardrail_service: AbstractGeneratedTextGuardrailService
):
self.logger = logging_service.logger
# TODO: temp debug
self.original_info = self.logger.info
self.logger.info = self.debug_info
self.logging_service = logging_service
self.text_generation_response_service = text_generation_response_service
self.rag_response_service = rag_response_service
self.generated_text_guardrail_service = generated_text_guardrail_service
self.routes = {}
self.register_routes()
@@ -39,23 +33,15 @@ class HttpApiController:
def register_routes(self):
"""Register all API routes"""
self.routes[('GET', '/')] = self.health_check
self.routes[('POST', '/api/conversations')] = self.handle_conversations
self.routes[('POST', '/api/rag_conversations')] = self.handle_conversations_with_rag
def __http_415_notsupported(self, env, start_response):
response_headers = [('Content-Type', 'application/json')]
start_response('415 Unsupported Media Type', response_headers)
return [json.dumps({'error': 'Unsupported Content-Type'}).encode('utf-8')]
self.routes[('POST', '/api/completions')] = self.handle_conversations
self.routes[('POST', '/api/completions/rag-guided')] = self.handle_conversations_with_rag
def format_response(self, data):
"""Format response data as JSON with 'response' key"""
response_data = {'response': data}
try:
response_body = json.dumps(response_data).encode('utf-8')
except:
# If serialization fails, convert data to string first
response_body = json.dumps({'response': str(data)}).encode('utf-8')
return response_body
@@ -66,7 +52,7 @@ class HttpApiController:
return [response_body]
def handle_conversations(self, env, start_response):
"""Handle POST requests to /api/conversations"""
"""POST /api/completions"""
try:
request_body_size = int(env.get('CONTENT_LENGTH', 0))
except ValueError:
@@ -89,11 +75,11 @@ class HttpApiController:
http_status_code = 200 # make enum
response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))]
start_response(f'{http_status_code} OK', response_headers)
self.logger.info('non-RAG response: request body %s | status: %s | response: %s', request_body, http_status_code, response_body)
self.logging_service.log_request_response(request=prompt, response=response_text)
return [response_body]
def handle_conversations_with_rag(self, env, start_response):
"""Handle POST requests to /api/rag_conversations with RAG functionality"""
"""POST /api/completions/rag-guided"""
try:
request_body_size = int(env.get('CONTENT_LENGTH', 0))
except ValueError:
@@ -116,7 +102,7 @@ class HttpApiController:
http_status_code = 200 # make enum
response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))]
start_response(f'{http_status_code} OK', response_headers)
self.logger.info('RAG response: request body %s | status: %s | response: %s', request_body, http_status_code, response_body)
self.logging_service.log_request_response(request=prompt, response=response_text)
return [response_body]
def _http_200_ok(self, env, start_response):