mirror of
https://github.com/lightbroker/llmsecops-research.git
synced 2026-06-01 04:31:40 +02:00
mypy fixes
This commit is contained in:
@@ -8,5 +8,5 @@ class MetaLlamaConfig(BaseModelConfig):
|
||||
"""meta-llama/Llama-3.2-3B-Instruct configuration"""
|
||||
use_flash_attention: bool = False
|
||||
rope_scaling: Optional[Dict[str, Any]] = None
|
||||
trust_remote_code: True,
|
||||
torch_dtype: "auto"
|
||||
trust_remote_code = True
|
||||
torch_dtype = "auto"
|
||||
@@ -26,7 +26,7 @@ class TextGenerationCompletionResult(AbstractTextGenerationCompletionResult):
|
||||
self.alternate_result = alternate_result
|
||||
self.final_completion_text = ''
|
||||
|
||||
def finalize_completion_text(self) -> str:
|
||||
def finalize_completion_text(self):
|
||||
"""
|
||||
Returns the current completion text based on priority order:
|
||||
1. guardrails_result.completion_text (if not empty)
|
||||
|
||||
@@ -74,12 +74,12 @@ class HttpApiController:
|
||||
configured_service = service_configurator(self.text_generation_response_service)
|
||||
result: TextGenerationCompletionResult = configured_service.invoke(user_prompt=prompt)
|
||||
|
||||
response_body = self.format_response(result.final)
|
||||
response_body = self.format_response(result.final_completion_text)
|
||||
http_status_code = 200
|
||||
response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))]
|
||||
start_response(f'{http_status_code} OK', response_headers)
|
||||
|
||||
self.logging_service.log_request_response(request=prompt, response=result.final)
|
||||
self.logging_service.log_request_response(request=prompt, response=result.final_completion_text)
|
||||
return [response_body]
|
||||
|
||||
def handle_conversations(self, env, start_response):
|
||||
|
||||
@@ -13,5 +13,5 @@ class AbstractSecurityGuidelinesService(abc.ABC):
|
||||
|
||||
class AbstractSecurityGuidelinesConfigurationBuilder(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def get_prompt_template(self, template_id: str, user_prompt: str) -> StringPromptTemplate:
|
||||
def get_prompt_template(self, template_id: str, user_prompt: str):
|
||||
raise NotImplementedError
|
||||
|
||||
+2
-2
@@ -20,7 +20,7 @@ class RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder(
|
||||
prompt_injection_example_repository: AbstractPromptInjectionExampleRepository
|
||||
):
|
||||
self.constants = Constants()
|
||||
self.embedding_model: EmbeddingModel = embedding_model
|
||||
self.embedding_model = embedding_model
|
||||
self.prompt_template_service = prompt_template_service
|
||||
self.prompt_injection_example_repository = prompt_injection_example_repository
|
||||
self.vectorstore = self._init_vectorstore()
|
||||
@@ -60,7 +60,7 @@ class RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder(
|
||||
context_parts.append(f"Example {i}:\n{doc.page_content}")
|
||||
return "\n\n".join(context_parts)
|
||||
|
||||
def get_prompt_template(self, template_id: str, user_prompt: str) -> PromptTemplate:
|
||||
def get_prompt_template(self, template_id: str, user_prompt: str):
|
||||
"""Get the base template from the template service and fill in RAG context"""
|
||||
# Get the base template from the template service
|
||||
base_template = self.prompt_template_service.get(id=template_id)
|
||||
|
||||
Reference in New Issue
Block a user