mirror of
https://github.com/lightbroker/llmsecops-research.git
synced 2026-03-19 16:54:05 +00:00
single instance of LLM service objects
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
|
||||
|
||||
from src.llm.llm import Phi3LanguageModel
|
||||
from src.llm.llm_rag import Phi3LanguageModelWithRag
|
||||
|
||||
@@ -9,6 +11,8 @@ class ApiController:
|
||||
self.routes = {}
|
||||
# Register routes
|
||||
self.register_routes()
|
||||
self.llm_svc = Phi3LanguageModel() # TODO: rename this as a service
|
||||
self.llm_rag_svc = Phi3LanguageModelWithRag()
|
||||
|
||||
def register_routes(self):
|
||||
"""Register all API routes"""
|
||||
@@ -21,13 +25,11 @@ class ApiController:
|
||||
return [json.dumps({'error': 'Unsupported Content-Type'}).encode('utf-8')]
|
||||
|
||||
def get_service_response(self, prompt):
|
||||
service = Phi3LanguageModel()
|
||||
response = service.invoke(user_input=prompt)
|
||||
response = self.llm_svc.invoke(user_input=prompt)
|
||||
return response
|
||||
|
||||
def get_service_response_with_rag(self, prompt):
|
||||
service = Phi3LanguageModelWithRag()
|
||||
response = service.invoke(user_input=prompt)
|
||||
response = self.llm_rag_svc.invoke(user_input=prompt)
|
||||
return response
|
||||
|
||||
def format_response(self, data):
|
||||
|
||||
@@ -22,9 +22,6 @@ from langchain_core.runnables import RunnablePassthrough
|
||||
from optimum.onnxruntime import ORTModelForCausalLM
|
||||
from transformers import AutoTokenizer, pipeline
|
||||
|
||||
# ------------------------------------------------------
|
||||
# 1. LOAD THE LOCAL PHI-3 MODEL
|
||||
# ------------------------------------------------------
|
||||
|
||||
class Phi3LanguageModel:
|
||||
|
||||
@@ -34,14 +31,10 @@ class Phi3LanguageModel:
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(handler)
|
||||
self.logger = logger
|
||||
self.configure_model()
|
||||
|
||||
def extract_assistant_response(self, text):
|
||||
if "<|assistant|>" in text:
|
||||
return text.split("<|assistant|>")[-1].strip()
|
||||
return text
|
||||
def configure_model(self):
|
||||
|
||||
|
||||
def invoke(self, user_input: str) -> str:
|
||||
# Set up paths to the local model
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(base_dir, "cpu_and_mobile", "cpu-int4-rtn-block-32-acc-level-4")
|
||||
@@ -70,6 +63,7 @@ class Phi3LanguageModel:
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.1,
|
||||
use_fast=True,
|
||||
do_sample=True
|
||||
)
|
||||
|
||||
@@ -86,18 +80,25 @@ class Phi3LanguageModel:
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
|
||||
# Create a chain using LCEL
|
||||
chain = (
|
||||
self.chain = (
|
||||
{"question": RunnablePassthrough()}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
| self.extract_assistant_response
|
||||
)
|
||||
|
||||
|
||||
def extract_assistant_response(self, text):
|
||||
if "<|assistant|>" in text:
|
||||
return text.split("<|assistant|>")[-1].strip()
|
||||
return text
|
||||
|
||||
|
||||
def invoke(self, user_input: str) -> str:
|
||||
try:
|
||||
# Get response from the chain
|
||||
self.logger.debug(f'===Prompt: {user_input}\n\n')
|
||||
response = chain.invoke(user_input)
|
||||
response = self.chain.invoke(user_input)
|
||||
# Print the answer
|
||||
self.logger.debug(f'===Response: {response}\n\n')
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user