single instance of LLM service objects

This commit is contained in:
Adam Wilson
2025-05-20 20:06:07 -06:00
parent d8a6609b8b
commit ffacc738b7
2 changed files with 19 additions and 16 deletions

View File

@@ -1,6 +1,8 @@
import json
import time
import traceback
from src.llm.llm import Phi3LanguageModel
from src.llm.llm_rag import Phi3LanguageModelWithRag
@@ -9,6 +11,8 @@ class ApiController:
self.routes = {}
# Register routes
self.register_routes()
self.llm_svc = Phi3LanguageModel() # TODO: rename this as a service
self.llm_rag_svc = Phi3LanguageModelWithRag()
def register_routes(self):
"""Register all API routes"""
@@ -21,13 +25,11 @@ class ApiController:
return [json.dumps({'error': 'Unsupported Content-Type'}).encode('utf-8')]
def get_service_response(self, prompt):
service = Phi3LanguageModel()
response = service.invoke(user_input=prompt)
response = self.llm_svc.invoke(user_input=prompt)
return response
def get_service_response_with_rag(self, prompt):
service = Phi3LanguageModelWithRag()
response = service.invoke(user_input=prompt)
response = self.llm_rag_svc.invoke(user_input=prompt)
return response
def format_response(self, data):

View File

@@ -22,9 +22,6 @@ from langchain_core.runnables import RunnablePassthrough
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer, pipeline
# ------------------------------------------------------
# 1. LOAD THE LOCAL PHI-3 MODEL
# ------------------------------------------------------
class Phi3LanguageModel:
@@ -34,14 +31,10 @@ class Phi3LanguageModel:
handler = logging.StreamHandler(sys.stdout)
logger.addHandler(handler)
self.logger = logger
self.configure_model()
def extract_assistant_response(self, text):
if "<|assistant|>" in text:
return text.split("<|assistant|>")[-1].strip()
return text
def configure_model(self):
def invoke(self, user_input: str) -> str:
# Set up paths to the local model
base_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(base_dir, "cpu_and_mobile", "cpu-int4-rtn-block-32-acc-level-4")
@@ -70,6 +63,7 @@ class Phi3LanguageModel:
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
use_fast=True,
do_sample=True
)
@@ -86,18 +80,25 @@ class Phi3LanguageModel:
prompt = PromptTemplate.from_template(template)
# Create a chain using LCEL
chain = (
self.chain = (
{"question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
| self.extract_assistant_response
)
def extract_assistant_response(self, text):
if "<|assistant|>" in text:
return text.split("<|assistant|>")[-1].strip()
return text
def invoke(self, user_input: str) -> str:
try:
# Get response from the chain
self.logger.debug(f'===Prompt: {user_input}\n\n')
response = chain.invoke(user_input)
response = self.chain.invoke(user_input)
# Print the answer
self.logger.debug(f'===Response: {response}\n\n')
return response