feat: Refactor model discovery and selection, filter embedding models, update report summary, and optimize prompt sending with max_tokens.

This commit is contained in:
shiva108
2026-01-27 10:59:44 +01:00
parent 1723769ee0
commit 23b94c5038
5 changed files with 40 additions and 26 deletions

View File

@@ -359,7 +359,7 @@ class TargetConfig:
auth_type: str = "bearer" # bearer, api_key, none
auth_token: str = ""
auth_header: str = "Authorization"
timeout: int = 30
timeout: int = 60
rate_limit: float = 1.0 # requests per second
headers: dict[str, str] = field(default_factory=dict)
verify_ssl: bool = True

View File

@@ -193,10 +193,11 @@ class InjectionTester:
)
)
# Probe for capabilities
# Probe for capabilities (quick probe with limited tokens)
client = self._require_client()
probe_response, _ = await client.send_prompt(
"What capabilities do you have? Can you access tools or plugins?"
"What capabilities do you have? Can you access tools or plugins?",
max_tokens=50
)
# Check for tool/plugin mentions
@@ -376,8 +377,11 @@ class InjectionTester:
client = self._require_client()
try:
# Send the payload
response, raw_response = await client.send_prompt(payload.content)
# Send the payload (with limited tokens for faster testing)
response, raw_response = await client.send_prompt(
payload.content,
max_tokens=150
)
result.response = response
result.response_raw = raw_response

View File

@@ -15,7 +15,7 @@ class TargetConfig(BaseModel):
token: Optional[str] = Field(None, description="Authentication token")
model: Optional[str] = Field(None, description="Model identifier (e.g., gpt-4, llama3:latest)")
api_type: str = Field(default="openai", description="API type (openai, anthropic, custom)")
timeout: int = Field(default=30, ge=1, le=300, description="Request timeout in seconds")
timeout: int = Field(default=60, ge=1, le=300, description="Request timeout in seconds")
@field_validator("url")
@classmethod

View File

@@ -221,12 +221,8 @@ class DiscoveryPhase(Phase):
if resp.status in (200, 404, 405): # Service responding
# For /v1/models, 200 means OpenAI-compatible
# For chat endpoints, 405 (method not allowed) is ok
base_url = f"{scheme}://{hostname}:{port}"
if "/v1/" in path:
return f"{base_url}/v1/chat/completions"
elif "/api/chat" in path:
return f"{base_url}/api/chat"
return base_url
# Return base URL only - LLMClient will add the appropriate path
return f"{scheme}://{hostname}:{port}"
except:
continue
@@ -259,7 +255,10 @@ class DiscoveryPhase(Phase):
if "data" in data:
for model_obj in data["data"]:
if isinstance(model_obj, dict) and "id" in model_obj:
models.append(model_obj["id"])
model_id = model_obj["id"]
# Filter out embedding models
if "embedding" not in model_id.lower():
models.append(model_id)
return models
except:
pass
@@ -317,13 +316,19 @@ class AttackPhase(Phase):
# Check if we have discovered models to test
models_to_test = []
if hasattr(context, 'discovered_models') and context.discovered_models:
# Priority: specified model > discovered models > fail
if context.config.target.model:
# User specified a model - use only that
models_to_test = [context.config.target.model]
elif hasattr(context, 'discovered_models') and context.discovered_models:
# Test top 3 discovered models
models_to_test = context.discovered_models[:3]
console.print(f"[cyan]Testing {len(models_to_test)} model(s): {', '.join(models_to_test)}[/cyan]")
else:
# Use the configured or default model
models_to_test = [context.config.target.model or "default"]
# No model specified and none discovered - this will likely fail
console.print("[yellow]⚠ No model specified. Attacks may fail.[/yellow]")
models_to_test = ["gpt-3.5-turbo"] # Generic fallback
console.print(f"[cyan]Loaded {len(patterns)} attack pattern(s)[/cyan]")
@@ -716,17 +721,22 @@ class ReportingPhase(Phase):
report: Report data
report_path: Path to saved report
"""
from pit.ui.tables import create_summary_panel
from pit.ui.tables import print_summary_panel
summary = report["summary"]
panel = create_summary_panel(
total_tests=summary["total_tests"],
successful_attacks=summary["successful_attacks"],
success_rate=summary["success_rate"],
vulnerabilities_by_severity={},
report_path=str(report_path),
)
# Calculate duration if available
duration = report.get("metadata", {}).get("duration_seconds", 0.0)
# Calculate failed tests
total = summary.get("total_tests", 0)
successful = summary.get("successful_attacks", 0)
failed = total - successful
console.print()
console.print(panel)
print_summary_panel(
total=total,
successful=successful,
failed=failed,
duration=duration,
)

View File

@@ -74,7 +74,7 @@ class AsyncHTTPClient:
self,
base_url: str = "",
headers: dict[str, str] | None = None,
timeout: int = 30,
timeout: int = 60,
rate_limit: float = 1.0,
retry_count: int = 3,
retry_delay: float = 1.0,