mirror of
https://github.com/lightbroker/llmsecops-research.git
synced 2026-06-28 17:39:55 +02:00
refactoring LLM adapters and service layer
This commit is contained in:
+14
-19
@@ -1,18 +1,14 @@
|
||||
"""
|
||||
RAG implementation with local Phi-3-mini-4k-instruct-onnx and embeddings
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
# LangChain imports
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from src.text_generation.adapters.llm.text_generation_model import TextGenerationFoundationModel
|
||||
from src.text_generation.adapters.llm.abstract_language_model import AbstractLanguageModel
|
||||
from src.text_generation.adapters.llm.text_generation_foundation_model import TextGenerationFoundationModel
|
||||
|
||||
|
||||
class Phi3LanguageModel:
|
||||
class LanguageModel(AbstractLanguageModel):
|
||||
|
||||
def __init__(self):
|
||||
logger = logging.getLogger()
|
||||
@@ -20,9 +16,14 @@ class Phi3LanguageModel:
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(handler)
|
||||
self.logger = logger
|
||||
self.configure_model()
|
||||
self._configure_model()
|
||||
|
||||
def configure_model(self):
|
||||
def _extract_assistant_response(self, text):
|
||||
if "<|assistant|>" in text:
|
||||
return text.split("<|assistant|>")[-1].strip()
|
||||
return text
|
||||
|
||||
def _configure_model(self):
|
||||
|
||||
# Create the LangChain LLM
|
||||
llm = TextGenerationFoundationModel().build()
|
||||
@@ -42,21 +43,15 @@ class Phi3LanguageModel:
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
| self.extract_assistant_response
|
||||
| self._extract_assistant_response
|
||||
)
|
||||
|
||||
def extract_assistant_response(self, text):
|
||||
if "<|assistant|>" in text:
|
||||
return text.split("<|assistant|>")[-1].strip()
|
||||
return text
|
||||
|
||||
|
||||
def invoke(self, user_input: str) -> str:
|
||||
def invoke(self, user_prompt: str) -> str:
|
||||
try:
|
||||
# Get response from the chain
|
||||
response = self.chain.invoke(user_input)
|
||||
response = self.chain.invoke(user_prompt)
|
||||
return response
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed: {e}")
|
||||
return e
|
||||
raise e
|
||||
|
||||
+7
-15
@@ -1,11 +1,6 @@
|
||||
"""
|
||||
RAG implementation with local Phi-3-mini-4k-instruct-onnx and embeddings
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
# LangChain imports
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.vectorstores import FAISS
|
||||
@@ -17,10 +12,11 @@ from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain.chains import RetrievalQA
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import Document
|
||||
from src.text_generation.adapters.llm.text_generation_model import TextGenerationFoundationModel
|
||||
from src.text_generation.adapters.llm.abstract_language_model import AbstractLanguageModel
|
||||
from src.text_generation.adapters.llm.text_generation_foundation_model import TextGenerationFoundationModel
|
||||
|
||||
|
||||
class Phi3LanguageModelWithRag:
|
||||
class Phi3LanguageModelWithRag(AbstractLanguageModel):
|
||||
|
||||
def __init__(self):
|
||||
logger = logging.getLogger()
|
||||
@@ -131,9 +127,9 @@ class Phi3LanguageModelWithRag:
|
||||
return raw_answer.strip()
|
||||
|
||||
|
||||
def invoke(self, user_input: str) -> str:
|
||||
def invoke(self, user_prompt: str) -> str:
|
||||
|
||||
context_docs = self.vectorstore.as_retriever(search_kwargs={"k": 3}).invoke(user_input)
|
||||
context_docs = self.vectorstore.as_retriever(search_kwargs={"k": 3}).invoke(user_prompt)
|
||||
context = self.format_docs(context_docs)
|
||||
|
||||
# PROMPT_TEMPLATE = """<|system|>
|
||||
@@ -173,14 +169,10 @@ class Phi3LanguageModelWithRag:
|
||||
chain = prompt | self.llm | StrOutputParser()
|
||||
raw_answer = chain.invoke({
|
||||
"context": context,
|
||||
"question": user_input
|
||||
"question": user_prompt
|
||||
})
|
||||
|
||||
# Clean up the answer (remove any remaining template artifacts)
|
||||
assistant_answer = self.parse_assistant_answer(raw_answer)
|
||||
|
||||
return {
|
||||
# "question": user_input,
|
||||
# "context": context,
|
||||
"answer": assistant_answer
|
||||
}
|
||||
return assistant_answer
|
||||
+1
-8
@@ -1,15 +1,8 @@
|
||||
"""
|
||||
RAG implementation with local Phi-3-mini-4k-instruct-onnx and embeddings
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
# LangChain imports
|
||||
from langchain_huggingface import HuggingFacePipeline
|
||||
|
||||
# HuggingFace and ONNX imports
|
||||
from optimum.onnxruntime import ORTModelForCausalLM
|
||||
from transformers import AutoTokenizer, pipeline
|
||||
|
||||
@@ -35,7 +28,7 @@ class TextGenerationFoundationModel:
|
||||
|
||||
self.logger.debug(f'model_base_dir: {model_base_dir}')
|
||||
self.logger.debug(f'model_cpu_dir: {model_cpu_dir}')
|
||||
self.logger.debug(f"Loading Phi-3 model from: {model_path}")
|
||||
self.logger.debug(f'Loading Phi-3 model from: {model_path}')
|
||||
|
||||
# Load the tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from src.text_generation.adapters.llm.llm import Phi3LanguageModel
|
||||
from src.text_generation.adapters.llm.llm_rag import Phi3LanguageModelWithRag
|
||||
from src.text_generation.services.language_models.language_model_response_service import TextGenerationResponseService
|
||||
from src.text_generation.adapters.llm.language_model_with_rag import Phi3LanguageModelWithRag
|
||||
from src.text_generation.services.logging.file_logging_service import FileLoggingService
|
||||
|
||||
class HttpApiController:
|
||||
@@ -11,7 +11,7 @@ class HttpApiController:
|
||||
self.routes = {}
|
||||
# Register routes
|
||||
self.register_routes()
|
||||
self.llm_svc = Phi3LanguageModel() # TODO: rename this as a service
|
||||
self.text_generation_svc = TextGenerationResponseService()
|
||||
self.llm_rag_svc = Phi3LanguageModelWithRag()
|
||||
|
||||
def register_routes(self):
|
||||
@@ -25,14 +25,6 @@ class HttpApiController:
|
||||
start_response('415 Unsupported Media Type', response_headers)
|
||||
return [json.dumps({'error': 'Unsupported Content-Type'}).encode('utf-8')]
|
||||
|
||||
def get_service_response(self, prompt):
|
||||
response = self.llm_svc.invoke(user_input=prompt)
|
||||
return response
|
||||
|
||||
def get_service_response_with_rag(self, prompt):
|
||||
response = self.llm_rag_svc.invoke(user_input=prompt)
|
||||
return response
|
||||
|
||||
def format_response(self, data):
|
||||
"""Format response data as JSON with 'response' key"""
|
||||
response_data = {'response': data}
|
||||
@@ -66,8 +58,8 @@ class HttpApiController:
|
||||
start_response('400 Bad Request', response_headers)
|
||||
return [response_body]
|
||||
|
||||
data = self.get_service_response(prompt)
|
||||
response_body = self.format_response(data)
|
||||
response_text = self.text_generation_svc.invoke(user_prompt=prompt)
|
||||
response_body = self.format_response(response_text)
|
||||
|
||||
http_status_code = 200 # make enum
|
||||
response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))]
|
||||
@@ -92,8 +84,8 @@ class HttpApiController:
|
||||
start_response('400 Bad Request', response_headers)
|
||||
return [response_body]
|
||||
|
||||
data = self.get_service_response_with_rag(prompt)
|
||||
response_body = self.format_response(data)
|
||||
response_text = self.text_generation_svc.invoke(user_prompt=prompt)
|
||||
response_body = self.format_response(response_text)
|
||||
|
||||
http_status_code = 200 # make enum
|
||||
response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))]
|
||||
@@ -101,7 +93,7 @@ class HttpApiController:
|
||||
self.logger.info('RAG response', request_body, http_status_code, response_body)
|
||||
return [response_body]
|
||||
|
||||
def __http_200_ok(self, env, start_response):
|
||||
def _http_200_ok(self, env, start_response):
|
||||
"""Default handler for other routes"""
|
||||
try:
|
||||
request_body_size = int(env.get('CONTENT_LENGTH', 0))
|
||||
@@ -124,7 +116,7 @@ class HttpApiController:
|
||||
path = env.get('PATH_INFO')
|
||||
|
||||
try:
|
||||
handler = self.routes.get((method, path), self.__http_200_ok)
|
||||
handler = self.routes.get((method, path), self._http_200_ok)
|
||||
return handler(env, start_response)
|
||||
except json.JSONDecodeError as e:
|
||||
response_body = json.dumps({'error': f"Invalid JSON: {e.msg}"}).encode('utf-8')
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import abc
|
||||
from src.text_generation.services.language_models.abstract_language_model_response_service import AbstractLanguageModelResponseService
|
||||
from src.text_generation.adapters.llm.language_model import LanguageModel
|
||||
|
||||
class AbstractLanguageModelResponseService(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def invoke(self, user_input: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
class LanguageModelResponseService(AbstractLanguageModelResponseService):
|
||||
def __call__(self, *args, **kwds):
|
||||
pass
|
||||
class TextGenerationResponseService(AbstractLanguageModelResponseService):
|
||||
|
||||
def invoke(self, user_prompt: str) -> str:
|
||||
|
||||
if not user_prompt:
|
||||
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
|
||||
|
||||
llm = LanguageModel()
|
||||
response = llm.invoke(user_prompt=user_prompt)
|
||||
return response
|
||||
Reference in New Issue
Block a user