Merge pull request #310 from zhanz5/fix/cost-calculation-model-aware

fix: make cost calculation model-aware instead of hardcoded to deepseek-chat
This commit is contained in:
Alexander Myasoedov
2026-06-05 10:12:41 +03:00
committed by GitHub
6 changed files with 38 additions and 5 deletions
+13
View File
@@ -1,4 +1,5 @@
import base64
import json
from enum import Enum
from urllib.parse import urlparse
@@ -145,6 +146,18 @@ class LLMSpec(BaseModel):
fn = probe
@property
def model_name(self) -> str:
"""Extract the model name from the request body (JSON).
Returns the value of the 'model' field if present, otherwise 'unknown'.
"""
try:
body_json = json.loads(self.body)
return body_json.get("model", "unknown")
except (json.JSONDecodeError, TypeError):
return "unknown"
@property
def modality(self) -> Modality:
if self.has_image:
+1 -1
View File
@@ -42,7 +42,7 @@ class Scan(BaseModel):
class ScanResult(BaseModel):
module: str
tokens: float | int
cost: float
cost: float | None
progress: float
status: bool = False
failureRate: float = 0.0
+10 -2
View File
@@ -1,3 +1,5 @@
from agentic_security.logutils import logger
# API pricing, USD per token. Values are dollars per 1M tokens / 1_000_000.
# Verified against vendor pricing pages on 2026-06-03.
PRICING = {
@@ -21,13 +23,19 @@ PRICING = {
DEFAULT_MODEL = "claude-sonnet"
def calculate_cost(tokens: int, model: str = DEFAULT_MODEL) -> float:
def calculate_cost(tokens: int, model: str = DEFAULT_MODEL) -> float | None:
"""Calculate API cost in USD for a total token count.
Assumes a 1:1 input/output split, since callers only track a combined total.
Returns:
float | None: Cost in USD, or None if the model pricing is unknown.
"""
if model not in PRICING:
raise ValueError(f"Unknown model: {model}")
logger.warning(
f"Unknown model '{model}': pricing not available, cost will not be estimated."
)
return None
half = max(tokens, 0) / 2
rates = PRICING[model]
+6 -2
View File
@@ -273,7 +273,9 @@ async def scan_module(
failure_rate = module_failures / max(module_prompts, 1)
failure_rates.append(failure_rate)
cost = calculate_cost(tokens)
cost = calculate_cost(
tokens, model=getattr(request_factory, "model_name", "unknown")
)
response_text = fuzzer_state.get_last_output(prompt) or ""
@@ -557,7 +559,9 @@ async def perform_many_shot_scan(
failure_rate = module_failures / max(processed_prompts, 1)
failure_rates.append(failure_rate)
cost = calculate_cost(tokens)
cost = calculate_cost(
tokens, model=getattr(request_factory, "model_name", "unknown")
)
yield ScanResult(
module=module.dataset_name,
@@ -131,6 +131,10 @@ class RequestAdapter:
if not llm_spec.has_audio:
raise ValueError("LLMSpec must have an image")
@property
def model_name(self) -> str:
return self.llm_spec.model_name
async def probe(
self, prompt: str, encoded_image: str = "", encoded_audio: str = "", files={}
) -> httpx.Response:
@@ -131,6 +131,10 @@ class RequestAdapter:
if not llm_spec.has_image:
raise ValueError("LLMSpec must have an image")
@property
def model_name(self) -> str:
return self.llm_spec.model_name
async def probe(
self, prompt: str, encoded_image: str = "", encoded_audio: str = "", files={}
) -> httpx.Response: