mirror of
https://github.com/Shiva108/ai-llm-red-team-handbook.git
synced 2026-02-12 14:42:46 +00:00
feat: Refactor model discovery and selection, filter embedding models, update report summary, and optimize prompt sending with max_tokens.
This commit is contained in:
@@ -359,7 +359,7 @@ class TargetConfig:
|
||||
auth_type: str = "bearer" # bearer, api_key, none
|
||||
auth_token: str = ""
|
||||
auth_header: str = "Authorization"
|
||||
timeout: int = 30
|
||||
timeout: int = 60
|
||||
rate_limit: float = 1.0 # requests per second
|
||||
headers: dict[str, str] = field(default_factory=dict)
|
||||
verify_ssl: bool = True
|
||||
|
||||
@@ -193,10 +193,11 @@ class InjectionTester:
|
||||
)
|
||||
)
|
||||
|
||||
# Probe for capabilities
|
||||
# Probe for capabilities (quick probe with limited tokens)
|
||||
client = self._require_client()
|
||||
probe_response, _ = await client.send_prompt(
|
||||
"What capabilities do you have? Can you access tools or plugins?"
|
||||
"What capabilities do you have? Can you access tools or plugins?",
|
||||
max_tokens=50
|
||||
)
|
||||
|
||||
# Check for tool/plugin mentions
|
||||
@@ -376,8 +377,11 @@ class InjectionTester:
|
||||
client = self._require_client()
|
||||
|
||||
try:
|
||||
# Send the payload
|
||||
response, raw_response = await client.send_prompt(payload.content)
|
||||
# Send the payload (with limited tokens for faster testing)
|
||||
response, raw_response = await client.send_prompt(
|
||||
payload.content,
|
||||
max_tokens=150
|
||||
)
|
||||
|
||||
result.response = response
|
||||
result.response_raw = raw_response
|
||||
|
||||
@@ -15,7 +15,7 @@ class TargetConfig(BaseModel):
|
||||
token: Optional[str] = Field(None, description="Authentication token")
|
||||
model: Optional[str] = Field(None, description="Model identifier (e.g., gpt-4, llama3:latest)")
|
||||
api_type: str = Field(default="openai", description="API type (openai, anthropic, custom)")
|
||||
timeout: int = Field(default=30, ge=1, le=300, description="Request timeout in seconds")
|
||||
timeout: int = Field(default=60, ge=1, le=300, description="Request timeout in seconds")
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
|
||||
@@ -221,12 +221,8 @@ class DiscoveryPhase(Phase):
|
||||
if resp.status in (200, 404, 405): # Service responding
|
||||
# For /v1/models, 200 means OpenAI-compatible
|
||||
# For chat endpoints, 405 (method not allowed) is ok
|
||||
base_url = f"{scheme}://{hostname}:{port}"
|
||||
if "/v1/" in path:
|
||||
return f"{base_url}/v1/chat/completions"
|
||||
elif "/api/chat" in path:
|
||||
return f"{base_url}/api/chat"
|
||||
return base_url
|
||||
# Return base URL only - LLMClient will add the appropriate path
|
||||
return f"{scheme}://{hostname}:{port}"
|
||||
except:
|
||||
continue
|
||||
|
||||
@@ -259,7 +255,10 @@ class DiscoveryPhase(Phase):
|
||||
if "data" in data:
|
||||
for model_obj in data["data"]:
|
||||
if isinstance(model_obj, dict) and "id" in model_obj:
|
||||
models.append(model_obj["id"])
|
||||
model_id = model_obj["id"]
|
||||
# Filter out embedding models
|
||||
if "embedding" not in model_id.lower():
|
||||
models.append(model_id)
|
||||
return models
|
||||
except:
|
||||
pass
|
||||
@@ -317,13 +316,19 @@ class AttackPhase(Phase):
|
||||
|
||||
# Check if we have discovered models to test
|
||||
models_to_test = []
|
||||
if hasattr(context, 'discovered_models') and context.discovered_models:
|
||||
|
||||
# Priority: specified model > discovered models > fail
|
||||
if context.config.target.model:
|
||||
# User specified a model - use only that
|
||||
models_to_test = [context.config.target.model]
|
||||
elif hasattr(context, 'discovered_models') and context.discovered_models:
|
||||
# Test top 3 discovered models
|
||||
models_to_test = context.discovered_models[:3]
|
||||
console.print(f"[cyan]Testing {len(models_to_test)} model(s): {', '.join(models_to_test)}[/cyan]")
|
||||
else:
|
||||
# Use the configured or default model
|
||||
models_to_test = [context.config.target.model or "default"]
|
||||
# No model specified and none discovered - this will likely fail
|
||||
console.print("[yellow]⚠ No model specified. Attacks may fail.[/yellow]")
|
||||
models_to_test = ["gpt-3.5-turbo"] # Generic fallback
|
||||
|
||||
console.print(f"[cyan]Loaded {len(patterns)} attack pattern(s)[/cyan]")
|
||||
|
||||
@@ -716,17 +721,22 @@ class ReportingPhase(Phase):
|
||||
report: Report data
|
||||
report_path: Path to saved report
|
||||
"""
|
||||
from pit.ui.tables import create_summary_panel
|
||||
from pit.ui.tables import print_summary_panel
|
||||
|
||||
summary = report["summary"]
|
||||
|
||||
panel = create_summary_panel(
|
||||
total_tests=summary["total_tests"],
|
||||
successful_attacks=summary["successful_attacks"],
|
||||
success_rate=summary["success_rate"],
|
||||
vulnerabilities_by_severity={},
|
||||
report_path=str(report_path),
|
||||
)
|
||||
# Calculate duration if available
|
||||
duration = report.get("metadata", {}).get("duration_seconds", 0.0)
|
||||
|
||||
# Calculate failed tests
|
||||
total = summary.get("total_tests", 0)
|
||||
successful = summary.get("successful_attacks", 0)
|
||||
failed = total - successful
|
||||
|
||||
console.print()
|
||||
console.print(panel)
|
||||
print_summary_panel(
|
||||
total=total,
|
||||
successful=successful,
|
||||
failed=failed,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
@@ -74,7 +74,7 @@ class AsyncHTTPClient:
|
||||
self,
|
||||
base_url: str = "",
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: int = 30,
|
||||
timeout: int = 60,
|
||||
rate_limit: float = 1.0,
|
||||
retry_count: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
|
||||
Reference in New Issue
Block a user