mirror of
https://github.com/dongdongunique/EvoSynth.git
synced 2026-03-20 10:23:31 +00:00
- Add loop mechanism to reduce history_depth if output exceeds max_total_chars - Prevents context overflow when accessing run context history - Prints reduction info and returns depth_reduced field in result
286 lines
12 KiB
Python
286 lines
12 KiB
Python
from ..base_model import BaseModel
|
|
from ...core.registry import model_registry
|
|
import openai
|
|
import time
|
|
import base64
|
|
from io import BytesIO
|
|
from typing import Any, Optional, List, Dict, Union
|
|
from PIL import Image
|
|
|
|
@model_registry.register("openai")
|
|
class OpenAIModel(BaseModel):
|
|
"""
|
|
OpenAI API model wrapper for jailbreak toolbox.
|
|
Supports various OpenAI models like gpt-3.5-turbo and gpt-4.
|
|
"""
|
|
def __init__(self,
|
|
api_key: str,
|
|
base_url: Optional[str] = None,
|
|
model_name: str = "gpt-3.5-turbo",
|
|
temperature: float = 0.7,
|
|
max_tokens: int = None,
|
|
retry_attempts: int = 3,
|
|
retry_delay: float = 2.0,
|
|
system_message: str = "You are a helpful assistant.",
|
|
embedding_model: str = "text-embedding-3-small",
|
|
**kwargs):
|
|
"""
|
|
Initialize the OpenAI model.
|
|
|
|
Args:
|
|
api_key: OpenAI API key
|
|
base_url: Optional base URL for OpenAI API
|
|
model_name: Name of the OpenAI model to use (e.g., "gpt-3.5-turbo", "gpt-4")
|
|
temperature: Sampling temperature (0.0-2.0)
|
|
max_tokens: Maximum tokens in the response
|
|
retry_attempts: Number of retry attempts for API calls
|
|
retry_delay: Delay between retry attempts in seconds
|
|
system_message: System message to set the assistant's behavior
|
|
embedding_model: Model to use for generating embeddings
|
|
"""
|
|
super().__init__(**kwargs)
|
|
self.model_name = model_name
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.retry_attempts = retry_attempts
|
|
self.retry_delay = retry_delay
|
|
self.system_message = system_message
|
|
self.embedding_model = embedding_model
|
|
|
|
# Initialize conversation history
|
|
self.conversation_history = [{"role": "system", "content": system_message}]
|
|
|
|
# Filter kwargs for different OpenAI API functions
|
|
self.client_kwargs = self._filter_client_kwargs(kwargs)
|
|
self.chat_kwargs = self._filter_chat_kwargs(kwargs)
|
|
self.embedding_kwargs = self._filter_embedding_kwargs(kwargs)
|
|
|
|
# Configure OpenAI client
|
|
openai.api_key = api_key
|
|
if base_url:
|
|
openai.base_url = base_url
|
|
self.client = openai.OpenAI(api_key=api_key, base_url=base_url, **self.client_kwargs)
|
|
print(f"Initialized OpenAI model: {model_name}")
|
|
|
|
def _filter_client_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Filter kwargs suitable for OpenAI client initialization"""
|
|
# Valid parameters for OpenAI() client
|
|
valid_client_params = {
|
|
'timeout', 'max_retries', 'default_headers', 'default_query',
|
|
'http_client', 'api_key', 'base_url', 'organization', 'project'
|
|
}
|
|
return {k: v for k, v in kwargs.items() if k in valid_client_params}
|
|
|
|
def _filter_chat_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Filter kwargs suitable for chat.completions.create()"""
|
|
# Valid parameters for chat completions
|
|
valid_chat_params = {
|
|
'frequency_penalty', 'logit_bias', 'logprobs', 'top_logprobs',
|
|
'max_tokens', 'n', 'presence_penalty', 'response_format',
|
|
'seed', 'stop', 'stream', 'temperature', 'top_p', 'tools', 'tool_choice',
|
|
'parallel_tool_calls', 'user'
|
|
}
|
|
return {k: v for k, v in kwargs.items() if k in valid_chat_params}
|
|
|
|
def _filter_embedding_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Filter kwargs suitable for embeddings.create()"""
|
|
# Valid parameters for embeddings
|
|
valid_embedding_params = {
|
|
'encoding_format', 'dimensions', 'user'
|
|
}
|
|
return {k: v for k, v in kwargs.items() if k in valid_embedding_params}
|
|
|
|
def _encode_image_to_base64(self, image_input: Union[str, Any]) -> str:
|
|
"""Encode an image to base64 string. Supports both file paths and PIL Image objects.
|
|
|
|
Args:
|
|
image_input: Path to image file or PIL Image object
|
|
|
|
Returns:
|
|
Base64 encoded string of the image
|
|
"""
|
|
# Check if it's a file path (string)
|
|
if isinstance(image_input, str):
|
|
with open(image_input, "rb") as image_file:
|
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
|
|
|
# Check if it's a PIL Image object
|
|
if isinstance(image_input, Image.Image):
|
|
buffered = BytesIO()
|
|
image_input.save(buffered, format="JPEG")
|
|
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
|
|
raise ValueError("image_input must be either a file path (string) or a PIL Image object")
|
|
|
|
def query(self, text_input: Union[str, List[Dict]] = "", image_input: Any = None, maintain_history: bool = False) -> str:
|
|
"""
|
|
Send a query to the OpenAI API and return the response.
|
|
|
|
Args:
|
|
text_input: The prompt text to send (can be string or list of message dicts)
|
|
image_input: Path to image file or PIL Image object for vision models
|
|
maintain_history: Whether to add this exchange to conversation history
|
|
|
|
Returns:
|
|
The model's response as a string
|
|
"""
|
|
#print("text input: ",text_input)
|
|
messages = []
|
|
|
|
# Handle image input
|
|
if image_input is not None:
|
|
try:
|
|
image_base64 = self._encode_image_to_base64(image_input)
|
|
if isinstance(text_input, list):
|
|
# If text_input is already a list of messages, use it directly
|
|
messages = text_input
|
|
elif isinstance(text_input, str):
|
|
# Create a message with both text and image
|
|
messages = [{"role": "system", "content": self.system_message}] if not maintain_history else []
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": text_input
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
},
|
|
},
|
|
],
|
|
})
|
|
except Exception as e:
|
|
print(f"Warning: Failed to process image input: {str(e)}")
|
|
# Fall back to text-only input
|
|
image_input = None
|
|
|
|
# Handle text-only input or fallback from image processing failure
|
|
if image_input is None:
|
|
if isinstance(text_input, list):
|
|
# If text_input is already a list of messages, use it directly
|
|
messages = text_input
|
|
elif text_input is not None:
|
|
text_input = str(text_input)
|
|
# Add user message to history if maintaining history
|
|
if maintain_history:
|
|
self.add_user_message(text_input)
|
|
messages = self.conversation_history
|
|
else:
|
|
# For single-turn interactions without affecting history
|
|
messages = [{"role": "system", "content": self.system_message},
|
|
{"role": "user", "content": text_input}]
|
|
#print(type(text_input))
|
|
|
|
for attempt in range(self.retry_attempts):
|
|
try:
|
|
#print(messages)
|
|
# Prepare chat completion parameters
|
|
chat_params = {
|
|
'model': self.model_name,
|
|
'messages': messages,
|
|
**self.chat_kwargs
|
|
}
|
|
|
|
# Add explicit parameters if not in kwargs
|
|
if self.temperature is not None:
|
|
chat_params['temperature'] = self.temperature
|
|
if self.max_tokens is not None:
|
|
chat_params['max_tokens'] = self.max_tokens
|
|
|
|
response = self.client.chat.completions.create(**chat_params)
|
|
|
|
response_text = response.choices[0].message.content
|
|
|
|
# Add assistant response to history if maintaining history
|
|
if maintain_history:
|
|
self.add_assistant_message(response_text)
|
|
|
|
return response_text
|
|
|
|
except Exception as e:
|
|
print(f"Unexpected error: {str(e)}")
|
|
if attempt == self.retry_attempts - 1:
|
|
return f"Error: {str(e)}"
|
|
# Note: time.sleep removed to prevent blocking in asyncio environments
|
|
# For retry delays, the calling async code should handle timing
|
|
|
|
return "Error: Failed to get response from model"
|
|
|
|
def add_user_message(self, content: Union[str, List[Dict]]) -> None:
|
|
"""Add a user message to the conversation history."""
|
|
self.conversation_history.append({"role": "user", "content": content})
|
|
|
|
def add_assistant_message(self, content: Union[str, List[Dict]]) -> None:
|
|
"""Add an assistant message to the conversation history."""
|
|
self.conversation_history.append({"role": "assistant", "content": content})
|
|
|
|
def add_system_message(self, content: str) -> None:
|
|
"""Add a system message to the conversation history."""
|
|
self.conversation_history.append({"role": "system", "content": content})
|
|
|
|
def remove_last_turn(self) -> None:
|
|
"""
|
|
Remove the last turn of conversation (last user message and its corresponding assistant message).
|
|
Assumes conversation is stored sequentially as messages.
|
|
"""
|
|
if not self.conversation_history:
|
|
return
|
|
|
|
for idx in range(len(self.conversation_history) - 1, -1, -1):
|
|
if self.conversation_history[idx]["role"] == "user":
|
|
self.conversation_history = self.conversation_history[:idx]
|
|
break
|
|
|
|
def reset_conversation(self) -> None:
|
|
"""Reset the conversation history to only include the initial system message."""
|
|
self.conversation_history = [{"role": "system", "content": self.system_message}]
|
|
|
|
def get_conversation_history(self) -> List[Dict[str, str]]:
|
|
"""Get the current conversation history."""
|
|
return self.conversation_history
|
|
|
|
def set_system_message(self, system_message: str) -> None:
|
|
"""Set a new system message and reset the conversation."""
|
|
self.system_message = system_message
|
|
self.reset_conversation()
|
|
|
|
def reset_system_message(self) -> None:
|
|
"""Reset the system message to the default."""
|
|
self.system_message = "You are a helpful assistant."
|
|
self.reset_conversation()
|
|
|
|
def get_embedding(self, text_input: str, model: str = None) -> list[float]:
|
|
"""
|
|
Get embedding vector for the given text using OpenAI embedding model.
|
|
|
|
Args:
|
|
text_input: The input text to embed
|
|
model: The embedding model to use (default: uses self.embedding_model)
|
|
|
|
Returns:
|
|
A list of floats representing the embedding vector
|
|
"""
|
|
try:
|
|
import torch
|
|
clean_text = text_input.replace("\n", " ")
|
|
|
|
# Use specified model or default embedding model
|
|
embedding_model = model or self.embedding_model
|
|
|
|
# Prepare embedding parameters
|
|
embedding_params = {
|
|
'input': clean_text,
|
|
'model': embedding_model,
|
|
**self.embedding_kwargs
|
|
}
|
|
|
|
response = self.client.embeddings.create(**embedding_params)
|
|
embedding = response.data[0].embedding
|
|
return torch.tensor(embedding)
|
|
|
|
except Exception as e:
|
|
print(f"Error while generating embedding: {str(e)}")
|
|
return [] |