Files
llmsecops-research/src/text_generation/services/guidelines/base_security_guidelines_service.py
T
2025-07-18 12:33:51 -06:00

49 lines
2.4 KiB
Python

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from src.text_generation.common.constants import Constants
from src.text_generation.domain.abstract_guidelines_processed_completion import AbstractGuidelinesProcessedCompletion
from src.text_generation.domain.guidelines_result import GuidelinesResult
from src.text_generation.ports.abstract_foundation_model import AbstractFoundationModel
from src.text_generation.services.guidelines.abstract_security_guidelines_service import AbstractSecurityGuidelinesService
from src.text_generation.services.nlp.abstract_prompt_template_service import AbstractPromptTemplateService
from src.text_generation.services.utilities.abstract_response_processing_service import AbstractResponseProcessingService
class BaseSecurityGuidelinesService(AbstractSecurityGuidelinesService):
"""Base service for security guidelines implementations."""
def __init__(
self,
foundation_model: AbstractFoundationModel,
response_processing_service: AbstractResponseProcessingService,
prompt_template_service: AbstractPromptTemplateService):
super().__init__()
self.constants = Constants()
self.foundation_model_pipeline = foundation_model.create_pipeline()
self.response_processing_service = response_processing_service
self.prompt_template_service = prompt_template_service
def _create_chain(self, prompt_template: PromptTemplate):
return (
{ f"{self.constants.INPUT_VARIABLE_TOKEN}": RunnablePassthrough() }
| prompt_template
| self.foundation_model_pipeline
| StrOutputParser()
| self.response_processing_service.process_text_generation_output
)
def _get_template_id(self) -> str:
raise NotImplementedError("Subclasses must implement _get_template_id()")
def apply_guidelines(self, user_prompt: str) -> AbstractGuidelinesProcessedCompletion:
if not user_prompt:
raise ValueError(f"Parameter 'user_prompt' cannot be empty or None")
try:
template_id = self._get_template_id()
prompt_template: PromptTemplate = self.prompt_template_service.get(id=template_id)
chain = self._create_chain(prompt_template)
return chain.invoke(user_prompt)
except Exception as e:
raise e