From ffacc738b7546457b2df41b36064121eaac0185e Mon Sep 17 00:00:00 2001 From: Adam Wilson Date: Tue, 20 May 2025 20:06:07 -0600 Subject: [PATCH] single instance of LLM service objects --- src/api/controller.py | 10 ++++++---- src/llm/llm.py | 25 +++++++++++++------------ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/api/controller.py b/src/api/controller.py index c67d16c9f..e0723ebc2 100644 --- a/src/api/controller.py +++ b/src/api/controller.py @@ -1,6 +1,8 @@ import json +import time import traceback + from src.llm.llm import Phi3LanguageModel from src.llm.llm_rag import Phi3LanguageModelWithRag @@ -9,6 +11,8 @@ class ApiController: self.routes = {} # Register routes self.register_routes() + self.llm_svc = Phi3LanguageModel() # TODO: rename this as a service + self.llm_rag_svc = Phi3LanguageModelWithRag() def register_routes(self): """Register all API routes""" @@ -21,13 +25,11 @@ class ApiController: return [json.dumps({'error': 'Unsupported Content-Type'}).encode('utf-8')] def get_service_response(self, prompt): - service = Phi3LanguageModel() - response = service.invoke(user_input=prompt) + response = self.llm_svc.invoke(user_input=prompt) return response def get_service_response_with_rag(self, prompt): - service = Phi3LanguageModelWithRag() - response = service.invoke(user_input=prompt) + response = self.llm_rag_svc.invoke(user_input=prompt) return response def format_response(self, data): diff --git a/src/llm/llm.py b/src/llm/llm.py index 9dca789a1..30a9cb108 100644 --- a/src/llm/llm.py +++ b/src/llm/llm.py @@ -22,9 +22,6 @@ from langchain_core.runnables import RunnablePassthrough from optimum.onnxruntime import ORTModelForCausalLM from transformers import AutoTokenizer, pipeline -# ------------------------------------------------------ -# 1. LOAD THE LOCAL PHI-3 MODEL -# ------------------------------------------------------ class Phi3LanguageModel: @@ -34,14 +31,10 @@ class Phi3LanguageModel: handler = logging.StreamHandler(sys.stdout) logger.addHandler(handler) self.logger = logger + self.configure_model() - def extract_assistant_response(self, text): - if "<|assistant|>" in text: - return text.split("<|assistant|>")[-1].strip() - return text + def configure_model(self): - - def invoke(self, user_input: str) -> str: # Set up paths to the local model base_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(base_dir, "cpu_and_mobile", "cpu-int4-rtn-block-32-acc-level-4") @@ -70,6 +63,7 @@ class Phi3LanguageModel: temperature=0.7, top_p=0.9, repetition_penalty=1.1, + use_fast=True, do_sample=True ) @@ -86,18 +80,25 @@ class Phi3LanguageModel: prompt = PromptTemplate.from_template(template) # Create a chain using LCEL - chain = ( + self.chain = ( {"question": RunnablePassthrough()} | prompt | llm | StrOutputParser() | self.extract_assistant_response ) - + + def extract_assistant_response(self, text): + if "<|assistant|>" in text: + return text.split("<|assistant|>")[-1].strip() + return text + + + def invoke(self, user_input: str) -> str: try: # Get response from the chain self.logger.debug(f'===Prompt: {user_input}\n\n') - response = chain.invoke(user_input) + response = self.chain.invoke(user_input) # Print the answer self.logger.debug(f'===Response: {response}\n\n') return response