mypy fixes

This commit is contained in:
Adam Wilson
2025-08-19 15:39:16 -06:00
parent 16dac722f3
commit 31cd6fe1fc
9 changed files with 279 additions and 282 deletions
@@ -8,5 +8,5 @@ class MetaLlamaConfig(BaseModelConfig):
"""meta-llama/Llama-3.2-3B-Instruct configuration"""
use_flash_attention: bool = False
rope_scaling: Optional[Dict[str, Any]] = None
trust_remote_code: True,
torch_dtype: "auto"
trust_remote_code = True
torch_dtype = "auto"
@@ -26,7 +26,7 @@ class TextGenerationCompletionResult(AbstractTextGenerationCompletionResult):
self.alternate_result = alternate_result
self.final_completion_text = ''
def finalize_completion_text(self) -> str:
def finalize_completion_text(self):
"""
Returns the current completion text based on priority order:
1. guardrails_result.completion_text (if not empty)
@@ -74,12 +74,12 @@ class HttpApiController:
configured_service = service_configurator(self.text_generation_response_service)
result: TextGenerationCompletionResult = configured_service.invoke(user_prompt=prompt)
response_body = self.format_response(result.final)
response_body = self.format_response(result.final_completion_text)
http_status_code = 200
response_headers = [('Content-Type', 'application/json'), ('Content-Length', str(len(response_body)))]
start_response(f'{http_status_code} OK', response_headers)
self.logging_service.log_request_response(request=prompt, response=result.final)
self.logging_service.log_request_response(request=prompt, response=result.final_completion_text)
return [response_body]
def handle_conversations(self, env, start_response):
@@ -13,5 +13,5 @@ class AbstractSecurityGuidelinesService(abc.ABC):
class AbstractSecurityGuidelinesConfigurationBuilder(abc.ABC):
@abc.abstractmethod
def get_prompt_template(self, template_id: str, user_prompt: str) -> StringPromptTemplate:
def get_prompt_template(self, template_id: str, user_prompt: str):
raise NotImplementedError
@@ -20,7 +20,7 @@ class RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder(
prompt_injection_example_repository: AbstractPromptInjectionExampleRepository
):
self.constants = Constants()
self.embedding_model: EmbeddingModel = embedding_model
self.embedding_model = embedding_model
self.prompt_template_service = prompt_template_service
self.prompt_injection_example_repository = prompt_injection_example_repository
self.vectorstore = self._init_vectorstore()
@@ -60,7 +60,7 @@ class RetrievalAugmentedGenerationSecurityGuidelinesConfigurationBuilder(
context_parts.append(f"Example {i}:\n{doc.page_content}")
return "\n\n".join(context_parts)
def get_prompt_template(self, template_id: str, user_prompt: str) -> PromptTemplate:
def get_prompt_template(self, template_id: str, user_prompt: str):
"""Get the base template from the template service and fill in RAG context"""
# Get the base template from the template service
base_template = self.prompt_template_service.get(id=template_id)