refactoring LLM adapters and service layer

This commit is contained in:
Adam Wilson
2025-06-12 20:25:18 -06:00
parent 0cdce03879
commit 3aca7df000
5 changed files with 43 additions and 67 deletions
@@ -1,18 +1,14 @@
"""
RAG implementation with local Phi-3-mini-4k-instruct-onnx and embeddings
"""
import logging
import sys
# LangChain imports
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from src.text_generation.adapters.llm.text_generation_model import TextGenerationFoundationModel
from src.text_generation.adapters.llm.abstract_language_model import AbstractLanguageModel
from src.text_generation.adapters.llm.text_generation_foundation_model import TextGenerationFoundationModel
class Phi3LanguageModel:
class LanguageModel(AbstractLanguageModel):
def __init__(self):
logger = logging.getLogger()
@@ -20,9 +16,14 @@ class Phi3LanguageModel:
handler = logging.StreamHandler(sys.stdout)
logger.addHandler(handler)
self.logger = logger
self.configure_model()
self._configure_model()
def configure_model(self):
def _extract_assistant_response(self, text):
if "<|assistant|>" in text:
return text.split("<|assistant|>")[-1].strip()
return text
def _configure_model(self):
# Create the LangChain LLM
llm = TextGenerationFoundationModel().build()
@@ -42,21 +43,15 @@ class Phi3LanguageModel:
| prompt
| llm
| StrOutputParser()
| self.extract_assistant_response
| self._extract_assistant_response
)
def extract_assistant_response(self, text):
if "<|assistant|>" in text:
return text.split("<|assistant|>")[-1].strip()
return text
def invoke(self, user_input: str) -> str:
def invoke(self, user_prompt: str) -> str:
try:
# Get response from the chain
response = self.chain.invoke(user_input)
response = self.chain.invoke(user_prompt)
return response
except Exception as e:
self.logger.error(f"Failed: {e}")
return e
raise e
@@ -1,11 +1,6 @@
"""
RAG implementation with local Phi-3-mini-4k-instruct-onnx and embeddings
"""
import logging
import sys
# LangChain imports
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
@@ -17,10 +12,11 @@ from langchain_core.output_parsers import StrOutputParser
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.schema import Document
from src.text_generation.adapters.llm.text_generation_model import TextGenerationFoundationModel
from src.text_generation.adapters.llm.abstract_language_model import AbstractLanguageModel
from src.text_generation.adapters.llm.text_generation_foundation_model import TextGenerationFoundationModel
class Phi3LanguageModelWithRag:
class Phi3LanguageModelWithRag(AbstractLanguageModel):
def __init__(self):
logger = logging.getLogger()
@@ -131,9 +127,9 @@ class Phi3LanguageModelWithRag:
return raw_answer.strip()
def invoke(self, user_input: str) -> str:
def invoke(self, user_prompt: str) -> str:
context_docs = self.vectorstore.as_retriever(search_kwargs={"k": 3}).invoke(user_input)
context_docs = self.vectorstore.as_retriever(search_kwargs={"k": 3}).invoke(user_prompt)
context = self.format_docs(context_docs)
# PROMPT_TEMPLATE = """<|system|>
@@ -173,14 +169,10 @@ class Phi3LanguageModelWithRag:
chain = prompt | self.llm | StrOutputParser()
raw_answer = chain.invoke({
"context": context,
"question": user_input
"question": user_prompt
})
# Clean up the answer (remove any remaining template artifacts)
assistant_answer = self.parse_assistant_answer(raw_answer)
return {
# "question": user_input,
# "context": context,
"answer": assistant_answer
}
return assistant_answer
@@ -1,15 +1,8 @@
"""
RAG implementation with local Phi-3-mini-4k-instruct-onnx and embeddings
"""
import logging
import os
import sys
# LangChain imports
from langchain_huggingface import HuggingFacePipeline
# HuggingFace and ONNX imports
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer, pipeline
@@ -35,7 +28,7 @@ class TextGenerationFoundationModel:
self.logger.debug(f'model_base_dir: {model_base_dir}')
self.logger.debug(f'model_cpu_dir: {model_cpu_dir}')
self.logger.debug(f"Loading Phi-3 model from: {model_path}")
self.logger.debug(f'Loading Phi-3 model from: {model_path}')
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(
@@ -1,8 +1,8 @@
import json
import traceback
from src.text_generation.adapters.llm.llm import Phi3LanguageModel
from src.text_generation.adapters.llm.llm_rag import Phi3LanguageModelWithRag
from src.text_generation.services.language_models.language_model_response_service import TextGenerationResponseService
from src.text_generation.adapters.llm.language_model_with_rag import Phi3LanguageModelWithRag
from src.text_generation.services.logging.file_logging_service import FileLoggingService
class HttpApiController:
@@ -11,7 +11,7 @@ class HttpApiController:
self.routes = {}
# Register routes
self.register_routes()
self.llm_svc = Phi3LanguageModel() # TODO: rename this as a service
self.text_generation_svc = TextGenerationResponseService()
self.llm_rag_svc = Phi3LanguageModelWithRag()
def register_routes(self):
@@ -25,14 +25,6 @@ class HttpApiController:
start_response('415 Unsupported Media Type', response_headers)
return [json.dumps({'error': 'Unsupported Content-Type'}).encode('utf-8')]
def get_service_response(self, prompt):
response = self.llm_svc.invoke(user_input=prompt)
return response
def get_service_response_with_rag(self, prompt):
response = self.llm_rag_svc.invoke(user_input=prompt)
return response
def format_response(self, data):
"""Format response data as JSON with 'response' key"""
response_data = {'response': data}
@@ -66,8 +58,8 @@ class HttpApiController:
start_response('400 Bad Request', response_headers)
return [response_body]
data = self.get_service_response(prompt)
response_body = self.format_response(data)
response_text = self.text_generation_svc.invoke(user_prompt=prompt)
response_body = self.format_response(response_text)
http_status_code = 200 # make enum
response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))]
@@ -92,8 +84,8 @@ class HttpApiController:
start_response('400 Bad Request', response_headers)
return [response_body]
data = self.get_service_response_with_rag(prompt)
response_body = self.format_response(data)
response_text = self.text_generation_svc.invoke(user_prompt=prompt)
response_body = self.format_response(response_text)
http_status_code = 200 # make enum
response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))]
@@ -101,7 +93,7 @@ class HttpApiController:
self.logger.info('RAG response', request_body, http_status_code, response_body)
return [response_body]
def __http_200_ok(self, env, start_response):
def _http_200_ok(self, env, start_response):
"""Default handler for other routes"""
try:
request_body_size = int(env.get('CONTENT_LENGTH', 0))
@@ -124,7 +116,7 @@ class HttpApiController:
path = env.get('PATH_INFO')
try:
handler = self.routes.get((method, path), self.__http_200_ok)
handler = self.routes.get((method, path), self._http_200_ok)
return handler(env, start_response)
except json.JSONDecodeError as e:
response_body = json.dumps({'error': f"Invalid JSON: {e.msg}"}).encode('utf-8')
@@ -1,10 +1,14 @@
import abc
from src.text_generation.services.language_models.abstract_language_model_response_service import AbstractLanguageModelResponseService
from src.text_generation.adapters.llm.language_model import LanguageModel
class AbstractLanguageModelResponseService(abc.ABC):
@abc.abstractmethod
def invoke(self, user_input: str) -> str:
raise NotImplementedError
class LanguageModelResponseService(AbstractLanguageModelResponseService):
def __call__(self, *args, **kwds):
pass
class TextGenerationResponseService(AbstractLanguageModelResponseService):
def invoke(self, user_prompt: str) -> str:
if not user_prompt:
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
llm = LanguageModel()
response = llm.invoke(user_prompt=user_prompt)
return response