Add type annotations to functions and methods for improved clarity and maintainabiliy

This commit is contained in:
nemanjaASE
2025-03-09 13:50:18 +01:00
parent 21180b53e5
commit 71787c6ec9
4 changed files with 32 additions and 21 deletions
+13 -6
View File
@@ -4,11 +4,18 @@ from asyncio import Event, Queue
from fastapi import FastAPI
from fastapi.responses import ORJSONResponse
from agentic_security.http_spec import LLMSpec
from typing import Any, Dict
tools_inbox: Queue = Queue()
stop_event: Event = Event()
current_run: str = {"spec": "", "id": ""}
_secrets = {}
_secrets: dict[str, str] = {}
current_run: Dict[str, int | LLMSpec] = {
"spec": "",
"id": ""
}
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
@@ -26,29 +33,29 @@ def get_stop_event() -> Event:
return stop_event
def get_current_run() -> str:
def get_current_run() -> Dict[str, int | LLMSpec]:
"""Get the current run id."""
return current_run
def set_current_run(spec):
def set_current_run(spec : LLMSpec) -> Dict[str, int | LLMSpec]:
"""Set the current run id."""
current_run["id"] = hash(id(spec))
current_run["spec"] = spec
return current_run
def get_secrets():
def get_secrets() -> dict[str, str]:
return _secrets
def set_secrets(secrets):
def set_secrets(secrets : dict[str, str]) -> dict[str, str]:
_secrets.update(secrets)
expand_secrets(_secrets)
return _secrets
def expand_secrets(secrets):
def expand_secrets(secrets : dict[str, str]) -> None:
for key in secrets:
val = secrets[key]
if val.startswith("$"):
+9 -8
View File
@@ -1,5 +1,6 @@
from pyfiglet import Figlet, FontNotFound
from termcolor import colored
from typing import Optional
try:
from importlib.metadata import version
@@ -8,14 +9,14 @@ except ImportError:
def generate_banner(
title="Agentic Security",
font="slant",
version="v2.1.0",
tagline="Proactive Threat Detection & Automated Security Protocols",
author="Developed by: [Security Team]",
website="Website: https://github.com/msoedov/agentic_security",
warning="",
):
title: str = "Agentic Security",
font: str = "slant",
version: str = "v2.1.0",
tagline: str = "Proactive Threat Detection & Automated Security Protocols",
author: str = "Developed by: [Security Team]",
website: str = "Website: https://github.com/msoedov/agentic_security",
warning: Optional[str] = "", # Using Optional for warning since it might be None
) -> str:
"""Generate a visually enhanced banner with dynamic width and borders."""
# Define the text elements
+4 -2
View File
@@ -1,5 +1,6 @@
import io
import string
from typing import List
import matplotlib.pyplot as plt
import numpy as np
@@ -7,8 +8,9 @@ import pandas as pd
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize
from .primitives import Table
def plot_security_report(table):
def plot_security_report(table: Table) -> io.BytesIO:
# Data preprocessing
data = pd.DataFrame(table)
@@ -141,7 +143,7 @@ def plot_security_report(table):
return buf
def generate_identifiers(data):
def generate_identifiers(data : pd.DataFrame) -> List[str]:
data_length = len(data)
alphabet = string.ascii_uppercase
num_letters = len(alphabet)
+6 -5
View File
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any, Generator
from fastapi import (
APIRouter,
@@ -24,7 +25,7 @@ router = APIRouter()
@router.post("/verify")
async def verify(
info: LLMInfo, secrets: InMemorySecrets = Depends(get_in_memory_secrets)
):
) -> dict[str, int | str | float]:
spec = LLMSpec.from_string(info.spec)
try:
r = await spec.verify()
@@ -42,7 +43,7 @@ async def verify(
)
def streaming_response_generator(scan_parameters: Scan):
def streaming_response_generator(scan_parameters: Scan) -> Generator[str, Any, None]:
request_factory = LLMSpec.from_string(scan_parameters.llmSpec)
set_current_run(request_factory)
@@ -63,7 +64,7 @@ async def scan(
scan_parameters: Scan,
background_tasks: BackgroundTasks,
secrets: InMemorySecrets = Depends(get_in_memory_secrets),
):
) -> StreamingResponse:
scan_parameters.with_secrets(secrets)
return StreamingResponse(
streaming_response_generator(scan_parameters), media_type="application/json"
@@ -71,7 +72,7 @@ async def scan(
@router.post("/stop")
async def stop_scan():
async def stop_scan() -> dict[str, str]:
get_stop_event().set()
return {"status": "Scan stopped"}
@@ -85,7 +86,7 @@ async def scan_csv(
maxBudget: int = Query(10_000),
enableMultiStepAttack: bool = Query(False),
secrets: InMemorySecrets = Depends(get_in_memory_secrets),
):
) -> StreamingResponse:
# TODO: content dataset to fuzzer
content = await file.read() # noqa
llm_spec = await llmSpec.read()