mirror of
https://github.com/lightbroker/llmsecops-research.git
synced 2026-06-07 15:33:56 +02:00
49 lines
2.4 KiB
Python
49 lines
2.4 KiB
Python
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.prompts import PromptTemplate
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
|
|
from src.text_generation.common.constants import Constants
|
|
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
|
|
from src.text_generation.domain.guidelines_result import GuidelinesResult
|
|
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
|
|
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesService
|
|
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
|
|
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
|
|
|
|
class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
|
|
"""Base service for security guidelines implementations."""
|
|
|
|
def __init__(
|
|
self,
|
|
foundation_model: AbstractFoundationModel,
|
|
response_processing_service: AbstractResponseProcessingService,
|
|
prompt_template_service: AbstractPromptTemplateService):
|
|
super().__init__()
|
|
self.constants = Constants()
|
|
self.foundation_model_pipeline = foundation_model.create_pipeline()
|
|
self.response_processing_service = response_processing_service
|
|
self.prompt_template_service = prompt_template_service
|
|
|
|
def _create_chain(self, prompt_template: PromptTemplate):
|
|
return (
|
|
{ f"{self.constants.INPUT_VARIABLE_TOKEN}": RunnablePassthrough() }
|
|
| prompt_template
|
|
| self.foundation_model_pipeline
|
|
| StrOutputParser()
|
|
| self.response_processing_service.process_text_generation_output
|
|
)
|
|
|
|
def _get_template_id(self) -> str:
|
|
raise NotImplementedError("Subclasses must implement _get_template_id()")
|
|
|
|
def apply_guidelines(self, user_prompt: str) -> AbstractGuidelinesProcessedCompletion:
|
|
if not user_prompt:
|
|
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
|
|
|
|
try:
|
|
template_id = self._get_template_id()
|
|
prompt_template: PromptTemplate = self.prompt_template_service.get(id=template_id)
|
|
chain = self._create_chain(prompt_template)
|
|
return chain.invoke(user_prompt)
|
|
except Exception as e:
|
|
raise e |