Files
OmniSafeBench-MM/models/base_model.py
Liao, Jie 04a1cbe8d1 refactor: consolidate duplicate code and fix deadlock bug
- Extract duplicate attack/defense config merging into _merge_component_configs()
- Extract duplicate lazy loading logic into _get_component()
- Move content policy detection to BaseModel base class
- Fix BatchSaveManager deadlock by splitting flush logic
- Add TypeError to ValueError conversion for consistent config errors
- Move _determine_load_model() to BaseComponent (explicit field only)
2026-01-23 12:55:34 +08:00

410 lines
14 KiB
Python

from abc import ABC, abstractmethod
from typing import List, Optional, Union, Dict, Any
import inspect
import backoff
from core.base_classes import BaseModel as CoreBaseModel
from core.data_formats import TestCase, ModelResponse
class BaseModel(CoreBaseModel):
"""Base class for all model implementations."""
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"
API_CONTENT_REJECTION_OUTPUT = (
"[ERROR] Prompt detected as harmful content, refusing to answer"
)
API_QUERY_SLEEP = 0.5
API_MAX_RETRY = 3
API_TIMEOUT = 600
# Content policy detection keywords (common across providers)
CONTENT_POLICY_KEYWORDS = [
"content policy",
"safety",
"harmful",
"unsafe",
"violation",
"moderation",
"data_inspection_failed",
"inappropriate content",
]
# Provider-specific additional keywords (subclasses can override)
PROVIDER_SPECIFIC_KEYWORDS = []
def __init__(self, model_name: str, api_key: str = None, base_url: str = None):
# Call parent class __init__, pass empty configuration
super().__init__(config={})
self.model_name = model_name
self.api_key = api_key
self.base_url = base_url
self.model_type = self._determine_model_type()
@classmethod
def from_config(cls, name: str, config: Dict[str, Any] | None = None):
"""
Unified factory method: build model instance from a config dict.
This is intentionally signature-aware: it only passes kwargs that the concrete
model class' __init__ accepts, so providers with different constructor
signatures won't break.
"""
config = config or {}
# model_name in config takes priority; fall back to the alias name
model_name = config.get("model_name", name)
api_key = config.get("api_key", "")
base_url = config.get("base_url", None)
init_kwargs = {
"model_name": model_name,
"api_key": api_key,
"base_url": base_url,
}
# Filter kwargs by the concrete __init__ signature
try:
sig = inspect.signature(cls.__init__)
accepted = {
p.name
for p in sig.parameters.values()
if p.name != "self"
and p.kind
in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
}
filtered_kwargs = {k: v for k, v in init_kwargs.items() if k in accepted}
return cls(**filtered_kwargs)
except Exception:
# Fallback: best-effort positional init
return cls(model_name=model_name, api_key=api_key, base_url=base_url)
def _determine_model_type(self):
"""Determine model type: api (API call) or local (local loading)"""
# Default implementation: determine based on whether api_key exists
# Subclasses can override this method
if self.api_key or self.base_url:
return "api"
else:
return "local"
def _is_content_policy_rejection(self, error: Exception) -> bool:
"""Check if an exception represents a content policy rejection.
Args:
error: The exception to check
Returns:
True if the error indicates content policy rejection
"""
error_str = str(error).lower()
# Combine common and provider-specific keywords
all_keywords = self.CONTENT_POLICY_KEYWORDS + self.PROVIDER_SPECIFIC_KEYWORDS
return any(keyword in error_str for keyword in all_keywords)
def _handle_content_rejection(self) -> str:
"""Return the standard content rejection output."""
return self.API_CONTENT_REJECTION_OUTPUT
def _handle_content_rejection_stream(self):
"""Return a generator that yields the content rejection placeholder."""
def error_generator():
yield self.API_CONTENT_REJECTION_OUTPUT
return error_generator()
def _handle_api_error(self) -> str:
"""Return the standard API error output."""
return self.API_ERROR_OUTPUT
def _handle_api_error_stream(self):
"""Return a generator that yields the API error placeholder."""
def error_generator():
yield self.API_ERROR_OUTPUT
return error_generator()
def _retry_with_backoff(self, func, *args, **kwargs):
"""Execute function with retry logic and exponential backoff using backoff library."""
@backoff.on_exception(
backoff.expo,
Exception,
max_tries=self.API_MAX_RETRY,
max_time=self.API_TIMEOUT * 2,
on_backoff=lambda details: print(
f"Attempt {details['tries']} failed: {details['exception'].__class__.__name__}: {details['exception']}"
),
on_giveup=lambda details: print(
f"Final attempt failed after {details['tries']} tries: {details['exception'].__class__.__name__}: {details['exception']}"
),
)
def _execute():
return func(*args, **kwargs)
try:
return _execute()
except Exception as e:
raise
@abstractmethod
def _generate_single(
self,
messages: List[dict],
**kwargs,
) -> str:
"""Generate response for a single prompt."""
pass
@abstractmethod
def _generate_stream(
self,
messages: List[dict],
**kwargs,
):
"""Generate streaming response for a single prompt."""
pass
def generate_response(self, test_case: TestCase, **kwargs) -> ModelResponse:
"""
Generate model response
Args:
test_case: Test case
**kwargs: Other parameters
Returns:
ModelResponse: Model response
"""
# Convert TestCase to message list
messages = self._test_case_to_messages(test_case)
# Generate response
response = self.generate(messages, **kwargs)
response_text = self._extract_text(response)
return ModelResponse(
test_case_id=test_case.test_case_id,
model_response=response_text,
model_name=self.model_name,
metadata=test_case.metadata,
)
def _extract_text(self, response_obj: Any) -> str:
"""
Extract plain text from various provider SDK response objects.
Why this exists:
- Some providers return OpenAI-like objects with `.choices[0].message.content`
- Some return `.text` (e.g., certain Gemini SDK objects)
- Some return `content` blocks (e.g., Anthropic)
- Our retry/error handling may return a string placeholder directly
"""
if response_obj is None:
return ""
# Placeholder / already-a-text
if isinstance(response_obj, str):
return response_obj
# Dict-like responses
if isinstance(response_obj, dict):
# OpenAI-like dict
if "choices" in response_obj and response_obj["choices"]:
try:
choice0 = response_obj["choices"][0]
# message.content
if isinstance(choice0, dict):
msg = choice0.get("message")
if isinstance(msg, dict) and isinstance(msg.get("content"), str):
return msg["content"]
delta = choice0.get("delta")
if isinstance(delta, dict) and isinstance(delta.get("content"), str):
return delta["content"]
except Exception:
pass
if isinstance(response_obj.get("text"), str):
return response_obj["text"]
content = response_obj.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
texts = []
for block in content:
if isinstance(block, str):
texts.append(block)
elif isinstance(block, dict) and isinstance(block.get("text"), str):
texts.append(block["text"])
if texts:
return "".join(texts)
return str(response_obj)
# OpenAI-like objects: response.choices[0].message.content
try:
choices = getattr(response_obj, "choices", None)
if choices and len(choices) > 0:
choice0 = choices[0]
msg = getattr(choice0, "message", None)
if msg is not None:
content = getattr(msg, "content", None)
if isinstance(content, str):
return content
# Some SDKs expose delta for streaming chunks
delta = getattr(choice0, "delta", None)
if delta is not None:
delta_content = getattr(delta, "content", None)
if isinstance(delta_content, str):
return delta_content
except Exception:
pass
# Gemini-like objects: response.text
try:
text = getattr(response_obj, "text", None)
if isinstance(text, str):
return text
except Exception:
pass
# Anthropic-like objects: response.content is a list of blocks with .text
try:
content = getattr(response_obj, "content", None)
if isinstance(content, str):
return content
if isinstance(content, list):
texts = []
for block in content:
if isinstance(block, str):
texts.append(block)
else:
block_text = getattr(block, "text", None)
if isinstance(block_text, str):
texts.append(block_text)
elif isinstance(block, dict) and isinstance(block.get("text"), str):
texts.append(block["text"])
if texts:
return "".join(texts)
except Exception:
pass
# Last resort: stringify
return str(response_obj)
def generate_responses_batch(
self, test_cases: List[TestCase], **kwargs
) -> List[ModelResponse]:
"""
Batch generate model responses (for locally loaded models)
Args:
test_cases: List of test cases
**kwargs: Other parameters
Returns:
List[ModelResponse]: List of model responses
"""
# Default implementation: loop calling single generation
# Local models should override this method to implement true batch inference
responses = []
for test_case in test_cases:
response = self.generate_response(test_case, **kwargs)
responses.append(response)
return responses
def _test_case_to_messages(self, test_case: TestCase) -> List[Dict[str, Any]]:
"""Convert TestCase to message list"""
messages = []
# If there is an image, add image message
if test_case.image_path:
try:
# Check if image file exists
import os
from pathlib import Path
image_path = Path(test_case.image_path)
# Load image and encode as base64
from PIL import Image
import base64
from io import BytesIO
# Open image
image = Image.open(image_path)
# Encode image as base64
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Create data URL
data_url = f"data:image/png;base64,{img_str}"
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": test_case.prompt},
{
"type": "image_url",
"image_url": {"url": data_url},
},
],
}
)
except Exception as e:
self.logger.error(f"Error processing image: {e}")
# Fallback to plain text on error
messages.append({"role": "user", "content": test_case.prompt})
else:
# Plain text message
messages.append({"role": "user", "content": test_case.prompt})
return messages
def generate(
self,
messages: Union[List[dict], List[List[dict]]],
use_tqdm: bool = False,
stream: bool = False,
**kwargs,
) -> Union[str, List[str]]:
"""Generate responses for multiple prompts.
Args:
messages: Single message list or list of message lists
use_tqdm: Whether to show progress bar
stream: Whether to use streaming output
**kwargs: Additional model-specific parameters
Returns:
Single response string or list of generated responses
"""
if isinstance(messages, list) and all(
isinstance(msg, dict) for msg in messages
):
if stream:
return self._generate_stream(messages, **kwargs)
else:
return self._generate_single(messages, **kwargs)
if use_tqdm:
from tqdm import tqdm
messages = tqdm(messages)
if stream:
# For multiple messages with streaming, return a list of generators
return [self._generate_stream(msg_set, **kwargs) for msg_set in messages]
else:
return [self._generate_single(msg_set, **kwargs) for msg_set in messages]