diff --git a/src/text_generation/adapters/llm/llm.py b/src/text_generation/adapters/llm/language_model.py similarity index 70% rename from src/text_generation/adapters/llm/llm.py rename to src/text_generation/adapters/llm/language_model.py index d4bbb386e..617b8e29f 100644 --- a/src/text_generation/adapters/llm/llm.py +++ b/src/text_generation/adapters/llm/language_model.py @@ -1,18 +1,14 @@ -""" -RAG implementation with local Phi-3-mini-4k-instruct-onnx and embeddings -""" - import logging import sys -# LangChain imports from langchain.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough -from src.text_generation.adapters.llm.text_generation_model import TextGenerationFoundationModel +from src.text_generation.adapters.llm.abstract_language_model import AbstractLanguageModel +from src.text_generation.adapters.llm.text_generation_foundation_model import TextGenerationFoundationModel -class Phi3LanguageModel: +class LanguageModel(AbstractLanguageModel): def __init__(self): logger = logging.getLogger() @@ -20,9 +16,14 @@ class Phi3LanguageModel: handler = logging.StreamHandler(sys.stdout) logger.addHandler(handler) self.logger = logger - self.configure_model() + self._configure_model() - def configure_model(self): + def _extract_assistant_response(self, text): + if "<|assistant|>" in text: + return text.split("<|assistant|>")[-1].strip() + return text + + def _configure_model(self): # Create the LangChain LLM llm = TextGenerationFoundationModel().build() @@ -42,21 +43,15 @@ class Phi3LanguageModel: | prompt | llm | StrOutputParser() - | self.extract_assistant_response + | self._extract_assistant_response ) - def extract_assistant_response(self, text): - if "<|assistant|>" in text: - return text.split("<|assistant|>")[-1].strip() - return text - - - def invoke(self, user_input: str) -> str: + def invoke(self, user_prompt: str) -> str: try: # Get response from the chain - response = self.chain.invoke(user_input) + response = self.chain.invoke(user_prompt) return response except Exception as e: self.logger.error(f"Failed: {e}") - return e + raise e diff --git a/src/text_generation/adapters/llm/llm_rag.py b/src/text_generation/adapters/llm/language_model_with_rag.py similarity index 94% rename from src/text_generation/adapters/llm/llm_rag.py rename to src/text_generation/adapters/llm/language_model_with_rag.py index 77f1639e8..a92e87eee 100644 --- a/src/text_generation/adapters/llm/llm_rag.py +++ b/src/text_generation/adapters/llm/language_model_with_rag.py @@ -1,11 +1,6 @@ -""" -RAG implementation with local Phi-3-mini-4k-instruct-onnx and embeddings -""" - import logging import sys -# LangChain imports from langchain_huggingface import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS @@ -17,10 +12,11 @@ from langchain_core.output_parsers import StrOutputParser from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from langchain.schema import Document -from src.text_generation.adapters.llm.text_generation_model import TextGenerationFoundationModel +from src.text_generation.adapters.llm.abstract_language_model import AbstractLanguageModel +from src.text_generation.adapters.llm.text_generation_foundation_model import TextGenerationFoundationModel -class Phi3LanguageModelWithRag: +class Phi3LanguageModelWithRag(AbstractLanguageModel): def __init__(self): logger = logging.getLogger() @@ -131,9 +127,9 @@ class Phi3LanguageModelWithRag: return raw_answer.strip() - def invoke(self, user_input: str) -> str: + def invoke(self, user_prompt: str) -> str: - context_docs = self.vectorstore.as_retriever(search_kwargs={"k": 3}).invoke(user_input) + context_docs = self.vectorstore.as_retriever(search_kwargs={"k": 3}).invoke(user_prompt) context = self.format_docs(context_docs) # PROMPT_TEMPLATE = """<|system|> @@ -173,14 +169,10 @@ class Phi3LanguageModelWithRag: chain = prompt | self.llm | StrOutputParser() raw_answer = chain.invoke({ "context": context, - "question": user_input + "question": user_prompt }) # Clean up the answer (remove any remaining template artifacts) assistant_answer = self.parse_assistant_answer(raw_answer) - return { - # "question": user_input, - # "context": context, - "answer": assistant_answer - } + return assistant_answer diff --git a/src/text_generation/adapters/llm/text_generation_model.py b/src/text_generation/adapters/llm/text_generation_foundation_model.py similarity index 90% rename from src/text_generation/adapters/llm/text_generation_model.py rename to src/text_generation/adapters/llm/text_generation_foundation_model.py index 12fce00e9..f6e451336 100644 --- a/src/text_generation/adapters/llm/text_generation_model.py +++ b/src/text_generation/adapters/llm/text_generation_foundation_model.py @@ -1,15 +1,8 @@ -""" -RAG implementation with local Phi-3-mini-4k-instruct-onnx and embeddings -""" - import logging import os import sys -# LangChain imports from langchain_huggingface import HuggingFacePipeline - -# HuggingFace and ONNX imports from optimum.onnxruntime import ORTModelForCausalLM from transformers import AutoTokenizer, pipeline @@ -35,7 +28,7 @@ class TextGenerationFoundationModel: self.logger.debug(f'model_base_dir: {model_base_dir}') self.logger.debug(f'model_cpu_dir: {model_cpu_dir}') - self.logger.debug(f"Loading Phi-3 model from: {model_path}") + self.logger.debug(f'Loading Phi-3 model from: {model_path}') # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained( diff --git a/src/text_generation/entrypoints/http_api_controller.py b/src/text_generation/entrypoints/http_api_controller.py index 15a1dcb31..0f16b470b 100644 --- a/src/text_generation/entrypoints/http_api_controller.py +++ b/src/text_generation/entrypoints/http_api_controller.py @@ -1,8 +1,8 @@ import json import traceback -from src.text_generation.adapters.llm.llm import Phi3LanguageModel -from src.text_generation.adapters.llm.llm_rag import Phi3LanguageModelWithRag +from src.text_generation.services.language_models.language_model_response_service import TextGenerationResponseService +from src.text_generation.adapters.llm.language_model_with_rag import Phi3LanguageModelWithRag from src.text_generation.services.logging.file_logging_service import FileLoggingService class HttpApiController: @@ -11,7 +11,7 @@ class HttpApiController: self.routes = {} # Register routes self.register_routes() - self.llm_svc = Phi3LanguageModel() # TODO: rename this as a service + self.text_generation_svc = TextGenerationResponseService() self.llm_rag_svc = Phi3LanguageModelWithRag() def register_routes(self): @@ -25,14 +25,6 @@ class HttpApiController: start_response('415 Unsupported Media Type', response_headers) return [json.dumps({'error': 'Unsupported Content-Type'}).encode('utf-8')] - def get_service_response(self, prompt): - response = self.llm_svc.invoke(user_input=prompt) - return response - - def get_service_response_with_rag(self, prompt): - response = self.llm_rag_svc.invoke(user_input=prompt) - return response - def format_response(self, data): """Format response data as JSON with 'response' key""" response_data = {'response': data} @@ -66,8 +58,8 @@ class HttpApiController: start_response('400 Bad Request', response_headers) return [response_body] - data = self.get_service_response(prompt) - response_body = self.format_response(data) + response_text = self.text_generation_svc.invoke(user_prompt=prompt) + response_body = self.format_response(response_text) http_status_code = 200 # make enum response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] @@ -92,8 +84,8 @@ class HttpApiController: start_response('400 Bad Request', response_headers) return [response_body] - data = self.get_service_response_with_rag(prompt) - response_body = self.format_response(data) + response_text = self.text_generation_svc.invoke(user_prompt=prompt) + response_body = self.format_response(response_text) http_status_code = 200 # make enum response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))] @@ -101,7 +93,7 @@ class HttpApiController: self.logger.info('RAG response', request_body, http_status_code, response_body) return [response_body] - def __http_200_ok(self, env, start_response): + def _http_200_ok(self, env, start_response): """Default handler for other routes""" try: request_body_size = int(env.get('CONTENT_LENGTH', 0)) @@ -124,7 +116,7 @@ class HttpApiController: path = env.get('PATH_INFO') try: - handler = self.routes.get((method, path), self.__http_200_ok) + handler = self.routes.get((method, path), self._http_200_ok) return handler(env, start_response) except json.JSONDecodeError as e: response_body = json.dumps({'error': f"Invalid JSON: {e.msg}"}).encode('utf-8') diff --git a/src/text_generation/services/language_models/language_model_response_service.py b/src/text_generation/services/language_models/language_model_response_service.py index 891fca818..dab85f1e2 100644 --- a/src/text_generation/services/language_models/language_model_response_service.py +++ b/src/text_generation/services/language_models/language_model_response_service.py @@ -1,10 +1,14 @@ -import abc +from src.text_generation.services.language_models.abstract_language_model_response_service import AbstractLanguageModelResponseService +from src.text_generation.adapters.llm.language_model import LanguageModel -class AbstractLanguageModelResponseService(abc.ABC): - @abc.abstractmethod - def invoke(self, user_input: str) -> str: - raise NotImplementedError -class LanguageModelResponseService(AbstractLanguageModelResponseService): - def __call__(self, *args, **kwds): - pass \ No newline at end of file +class TextGenerationResponseService(AbstractLanguageModelResponseService): + + def invoke(self, user_prompt: str) -> str: + + if not user_prompt: + raise ValueError(f"Parameter 'user_prompt' cannot be empty or None") + + llm = LanguageModel() + response = llm.invoke(user_prompt=user_prompt) + return response \ No newline at end of file