diff --git a/.github/scripts/cleanup.sh b/.github/scripts/cleanup.sh index f6e131094..7ed9736eb 100755 --- a/.github/scripts/cleanup.sh +++ b/.github/scripts/cleanup.sh @@ -1,5 +1,7 @@ #!/bin/bash +cd $GITHUB_WORKSPACE + echo "Cleaning up processes..." # Kill the monitoring process if it exists diff --git a/.github/scripts/health_check.sh b/.github/scripts/health_check.sh index eeea6fbb5..0bef6f6f3 100755 --- a/.github/scripts/health_check.sh +++ b/.github/scripts/health_check.sh @@ -1,6 +1,8 @@ #!/bin/bash set -e # Exit on error +cd $GITHUB_WORKSPACE + echo "Waiting for API to be ready..." max_attempts=10 attempt=1 @@ -8,7 +10,7 @@ attempt=1 while [ $attempt -le $max_attempts ]; do echo "Health check attempt $attempt of $max_attempts..." if curl -s -f -i http://localhost:9999/ > logs/health_check_$attempt.log 2>&1; then - echo "Health check succeeded!" + echo "Health check succeeded" break else echo "Health check failed, waiting 5 seconds..." diff --git a/.github/scripts/run_garak.sh b/.github/scripts/run_garak.sh index 0e5d98f66..8f551264b 100755 --- a/.github/scripts/run_garak.sh +++ b/.github/scripts/run_garak.sh @@ -1,6 +1,8 @@ #!/bin/bash # Don't use set -e here as we want to capture and handle errors ourselves +cd $GITHUB_WORKSPACE + # Make sure garak report directory exists GARAK_REPORTS_DIR="/home/runner/.local/share/garak/garak_runs" mkdir -p $GARAK_REPORTS_DIR diff --git a/.github/scripts/start_api.sh b/.github/scripts/start_api.sh index 0f42950dc..60859b70c 100755 --- a/.github/scripts/start_api.sh +++ b/.github/scripts/start_api.sh @@ -1,6 +1,8 @@ #!/bin/bash set -e # Exit on error +cd $GITHUB_WORKSPACE + echo "Starting API server with logging..." nohup python -m src.api.server > logs/api.log 2>&1 & API_PID=$! diff --git a/.github/scripts/start_monitoring.sh b/.github/scripts/start_monitoring.sh index c919a1031..983510973 100755 --- a/.github/scripts/start_monitoring.sh +++ b/.github/scripts/start_monitoring.sh @@ -2,6 +2,8 @@ echo "Starting system monitoring..." +cd $GITHUB_WORKSPACE + # Read API PID from file API_PID=$(cat api_pid.txt) echo "Monitoring API process with PID: $API_PID" diff --git a/.github/scripts/test_api.sh b/.github/scripts/test_api.sh index 6de9c1d70..84a2ebe76 100755 --- a/.github/scripts/test_api.sh +++ b/.github/scripts/test_api.sh @@ -1,16 +1,18 @@ #!/bin/bash set -e # Exit on error +cd $GITHUB_WORKSPACE + echo "Making API request..." curl -X POST -i http://localhost:9999/api/conversations \ -d '{ "prompt": "describe a random planet in our solar system in 10 words or less" }' \ -H "Content-Type: application/json" > logs/test_request.log 2>&1 if [ $? -ne 0 ]; then - echo "Test API request failed!" + echo "Test API request failed" cat logs/test_request.log exit 1 else - echo "Test API request succeeded!" + echo "Test API request succeeded" cat logs/test_request.log fi \ No newline at end of file diff --git a/src/api/controller.backup.py b/src/api/controller.backup.py deleted file mode 100644 index c67d16c9f..000000000 --- a/src/api/controller.backup.py +++ /dev/null @@ -1,133 +0,0 @@ -import json -import traceback - -from src.llm.llm import Phi3LanguageModel -from src.llm.llm_rag import Phi3LanguageModelWithRag - -class ApiController: - def __init__(self): - self.routes = {} - # Register routes - self.register_routes() - - def register_routes(self): - """Register all API routes""" - self.routes[('POST', '/api/conversations')] = self.handle_conversations - self.routes[('POST', '/api/rag_conversations')] = self.handle_conversations_with_rag - - def __http_415_notsupported(self, env, start_response): - response_headers = [('Content-Type', 'application/json')] - start_response('415 Unsupported Media Type', response_headers) - return [json.dumps({'error': 'Unsupported Content-Type'}).encode('utf-8')] - - def get_service_response(self, prompt): - service = Phi3LanguageModel() - response = service.invoke(user_input=prompt) - return response - - def get_service_response_with_rag(self, prompt): - service = Phi3LanguageModelWithRag() - response = service.invoke(user_input=prompt) - return response - - def format_response(self, data): - """Format response data as JSON with 'response' key""" - response_data = {'response': data} - try: - response_body = json.dumps(response_data).encode('utf-8') - except: - # If serialization fails, convert data to string first - response_body = json.dumps({'response': str(data)}).encode('utf-8') - return response_body - - def handle_conversations(self, env, start_response): - """Handle POST requests to /api/conversations""" - try: - request_body_size = int(env.get('CONTENT_LENGTH', 0)) - except ValueError: - request_body_size = 0 - - request_body = env['wsgi.input'].read(request_body_size) - request_json = json.loads(request_body.decode('utf-8')) - prompt = request_json.get('prompt') - - if not prompt: - response_body = json.dumps({'error': 'Missing prompt in request body'}).encode('utf-8') - response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] - start_response('400 Bad Request', response_headers) - return [response_body] - - data = self.get_service_response(prompt) - response_body = self.format_response(data) - - response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] - start_response('200 OK', response_headers) - return [response_body] - - def handle_conversations_with_rag(self, env, start_response): - """Handle POST requests to /api/rag_conversations with RAG functionality""" - try: - request_body_size = int(env.get('CONTENT_LENGTH', 0)) - except ValueError: - request_body_size = 0 - - request_body = env['wsgi.input'].read(request_body_size) - request_json = json.loads(request_body.decode('utf-8')) - prompt = request_json.get('prompt') - - if not prompt: - response_body = json.dumps({'error': 'Missing prompt in request body'}).encode('utf-8') - response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] - start_response('400 Bad Request', response_headers) - return [response_body] - - data = self.get_service_response_with_rag(prompt) - response_body = self.format_response(data) - - response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] - start_response('200 OK', response_headers) - return [response_body] - - def __http_200_ok(self, env, start_response): - """Default handler for other routes""" - try: - request_body_size = int(env.get('CONTENT_LENGTH', 0)) - except (ValueError): - request_body_size = 0 - - request_body = env['wsgi.input'].read(request_body_size) - request_json = json.loads(request_body.decode('utf-8')) - prompt = request_json.get('prompt') - - data = self.get_service_response(prompt) - response_body = self.format_response(data) - - response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] - start_response('200 OK', response_headers) - return [response_body] - - def __call__(self, env, start_response): - method = env.get('REQUEST_METHOD').upper() - path = env.get('PATH_INFO') - - if method != 'POST': - return self.__http_415_notsupported(env, start_response) - - try: - handler = self.routes.get((method, path), self.__http_200_ok) - return handler(env, start_response) - except json.JSONDecodeError as e: - response_body = json.dumps({'error': f"Invalid JSON: {e.msg}"}).encode('utf-8') - response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] - start_response('400 Bad Request', response_headers) - return [response_body] - except Exception as e: - # Log to stdout so it shows in GitHub Actions - print("Exception occurred:") - traceback.print_exc() - - # Return more detailed error response (would not do this in Production) - error_response = json.dumps({'error': f"Internal Server Error: {str(e)}"}).encode('utf-8') - response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(error_response)))] - start_response('500 Internal Server Error', response_headers) - return [error_response] \ No newline at end of file diff --git a/src/api/controller.flask.py b/src/api/controller.flask.py new file mode 100644 index 000000000..3ff759964 --- /dev/null +++ b/src/api/controller.flask.py @@ -0,0 +1,26 @@ +import logging +from flask import Flask, jsonify, request +from waitress import serve +from src.llm.llm import Phi3LanguageModel +from src.llm.llm_rag import Phi3LanguageModelWithRag + +app = Flask(__name__) + +@app.route('/', methods=['GET']) +def health_check(): + return f"Server is running\n", 200 + +@app.route('/api/conversations', methods=['POST']) +def get_llm_response(): + prompt = request.json['prompt'] + service = Phi3LanguageModel() + response = service.invoke(user_input=prompt) + return jsonify({'response': response}), 201 + +if __name__ == '__main__': + logger = logging.Logger(name='Flask API', level=logging.DEBUG) + print('test') + logger.debug('running...') + + # TODO set up port # as env var + serve(app, host='0.0.0.0', port=9999) \ No newline at end of file diff --git a/src/api/controller.py b/src/api/controller.py index 3ff759964..c67d16c9f 100644 --- a/src/api/controller.py +++ b/src/api/controller.py @@ -1,26 +1,133 @@ -import logging -from flask import Flask, jsonify, request -from waitress import serve +import json +import traceback + from src.llm.llm import Phi3LanguageModel from src.llm.llm_rag import Phi3LanguageModelWithRag -app = Flask(__name__) +class ApiController: + def __init__(self): + self.routes = {} + # Register routes + self.register_routes() -@app.route('/', methods=['GET']) -def health_check(): - return f"Server is running\n", 200 + def register_routes(self): + """Register all API routes""" + self.routes[('POST', '/api/conversations')] = self.handle_conversations + self.routes[('POST', '/api/rag_conversations')] = self.handle_conversations_with_rag -@app.route('/api/conversations', methods=['POST']) -def get_llm_response(): - prompt = request.json['prompt'] - service = Phi3LanguageModel() - response = service.invoke(user_input=prompt) - return jsonify({'response': response}), 201 + def __http_415_notsupported(self, env, start_response): + response_headers = [('Content-Type', 'application/json')] + start_response('415 Unsupported Media Type', response_headers) + return [json.dumps({'error': 'Unsupported Content-Type'}).encode('utf-8')] -if __name__ == '__main__': - logger = logging.Logger(name='Flask API', level=logging.DEBUG) - print('test') - logger.debug('running...') + def get_service_response(self, prompt): + service = Phi3LanguageModel() + response = service.invoke(user_input=prompt) + return response + + def get_service_response_with_rag(self, prompt): + service = Phi3LanguageModelWithRag() + response = service.invoke(user_input=prompt) + return response - # TODO set up port # as env var - serve(app, host='0.0.0.0', port=9999) \ No newline at end of file + def format_response(self, data): + """Format response data as JSON with 'response' key""" + response_data = {'response': data} + try: + response_body = json.dumps(response_data).encode('utf-8') + except: + # If serialization fails, convert data to string first + response_body = json.dumps({'response': str(data)}).encode('utf-8') + return response_body + + def handle_conversations(self, env, start_response): + """Handle POST requests to /api/conversations""" + try: + request_body_size = int(env.get('CONTENT_LENGTH', 0)) + except ValueError: + request_body_size = 0 + + request_body = env['wsgi.input'].read(request_body_size) + request_json = json.loads(request_body.decode('utf-8')) + prompt = request_json.get('prompt') + + if not prompt: + response_body = json.dumps({'error': 'Missing prompt in request body'}).encode('utf-8') + response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] + start_response('400 Bad Request', response_headers) + return [response_body] + + data = self.get_service_response(prompt) + response_body = self.format_response(data) + + response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] + start_response('200 OK', response_headers) + return [response_body] + + def handle_conversations_with_rag(self, env, start_response): + """Handle POST requests to /api/rag_conversations with RAG functionality""" + try: + request_body_size = int(env.get('CONTENT_LENGTH', 0)) + except ValueError: + request_body_size = 0 + + request_body = env['wsgi.input'].read(request_body_size) + request_json = json.loads(request_body.decode('utf-8')) + prompt = request_json.get('prompt') + + if not prompt: + response_body = json.dumps({'error': 'Missing prompt in request body'}).encode('utf-8') + response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] + start_response('400 Bad Request', response_headers) + return [response_body] + + data = self.get_service_response_with_rag(prompt) + response_body = self.format_response(data) + + response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] + start_response('200 OK', response_headers) + return [response_body] + + def __http_200_ok(self, env, start_response): + """Default handler for other routes""" + try: + request_body_size = int(env.get('CONTENT_LENGTH', 0)) + except (ValueError): + request_body_size = 0 + + request_body = env['wsgi.input'].read(request_body_size) + request_json = json.loads(request_body.decode('utf-8')) + prompt = request_json.get('prompt') + + data = self.get_service_response(prompt) + response_body = self.format_response(data) + + response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] + start_response('200 OK', response_headers) + return [response_body] + + def __call__(self, env, start_response): + method = env.get('REQUEST_METHOD').upper() + path = env.get('PATH_INFO') + + if method != 'POST': + return self.__http_415_notsupported(env, start_response) + + try: + handler = self.routes.get((method, path), self.__http_200_ok) + return handler(env, start_response) + except json.JSONDecodeError as e: + response_body = json.dumps({'error': f"Invalid JSON: {e.msg}"}).encode('utf-8') + response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] + start_response('400 Bad Request', response_headers) + return [response_body] + except Exception as e: + # Log to stdout so it shows in GitHub Actions + print("Exception occurred:") + traceback.print_exc() + + # Return more detailed error response (would not do this in Production) + error_response = json.dumps({'error': f"Internal Server Error: {str(e)}"}).encode('utf-8') + response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(error_response)))] + start_response('500 Internal Server Error', response_headers) + return [error_response] \ No newline at end of file