This commit is contained in:
LanWeiFRJ
2025-10-29 14:02:04 +08:00
parent c6b8c4bbd8
commit e334c4f764
52 changed files with 12251 additions and 0 deletions

16
src/managers/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
# Export all modules
from .image_builder import *
from .log import *
from .llm_api import *
from .prompts import *
from .loop import *
from .decorators import *
__all__ = [
"image_builder",
"log",
"llm_api",
"prompts",
"loop",
"decorators"
]

View File

@@ -0,0 +1,13 @@
import threading
def singleton(cls):
instances = {}
lock = threading.Lock()
def get_instance(*args, **kwargs):
if cls not in instances:
with lock:
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance

View File

@@ -0,0 +1,8 @@
# 导出全部模块
from .build_image import *
from .dockerfiles import *
__all__ = [
"build_image",
"dockerfiles"
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,8 @@
# User image Dockerfile template
# image_key: instance_image_key corresponding to instance_id
# apt_packages: split by ' ', load multi packages needed to be installed in ubuntu
_DOCKERFILE_USER_IMAGE_PY = r"""
FROM {image_key}
RUN apt-get update && apt-get install -y --no-install-recommends {apt_packages}
"""

View File

@@ -0,0 +1,128 @@
"""
User defined logger for swebench
"""
import logging
from typing import Union, Optional
from swebench.harness.docker_utils import remove_image
from src.managers.log.logger import Logger as CustomLogger
# 导入原始的 swebench 函数
from swebench.harness.docker_build import build_instance_image, run_threadpool
def patched_build_instance_images(
client,
dataset: list,
force_rebuild: bool = False,
max_workers: int = 4,
namespace: str = None,
tag: str = None,
env_image_tag: str = None,
custom_logger: Optional[Union[logging.Logger, CustomLogger]] = None, # 新增参数
):
"""
Monkey patched version of build_instance_images that supports custom logger
Args:
client: Docker client
dataset: List of test specs or dataset to build images for
force_rebuild: Whether to force rebuild the images even if they already exist
max_workers: Maximum number of worker threads
namespace: Namespace for images
tag: Tag for images
env_image_tag: Environment image tag
custom_logger: Custom logger to use (instead of creating new ones)
"""
from swebench.harness.docker_build import make_test_spec, build_env_images
test_specs = []
for instance in dataset:
spec = make_test_spec(
instance,
namespace=namespace,
instance_image_tag=tag,
env_image_tag=env_image_tag,
)
test_specs.append(spec)
if force_rebuild:
for spec in test_specs:
remove_image(client, spec.instance_image_key, "quiet")
_, env_failed = build_env_images(client, test_specs, force_rebuild, max_workers)
if len(env_failed) > 0:
# Don't build images for instances that depend on failed-to-build env images
dont_run_specs = [
spec for spec in test_specs if spec.env_image_key in env_failed
]
test_specs = [
spec for spec in test_specs if spec.env_image_key not in env_failed
]
if custom_logger:
custom_logger.info(
f"Skipping {len(dont_run_specs)} instances - due to failed env image builds"
)
else:
print(f"Skipping {len(dont_run_specs)} instances - due to failed env image builds")
if custom_logger:
custom_logger.info(f"Building instance images for {len(test_specs)} instances")
else:
print(f"Building instance images for {len(test_specs)} instances")
successful, failed = [], []
if custom_logger:
payloads = [(spec, client, custom_logger, False) for spec in test_specs]
else:
payloads = [(spec, client, None, False) for spec in test_specs]
successful, failed = run_threadpool(build_instance_image, payloads, max_workers)
if len(failed) == 0:
if custom_logger:
custom_logger.info("All instance images built successfully.")
else:
print("All instance images built successfully.")
else:
if custom_logger:
custom_logger.warning(f"{len(failed)} instance images failed to build.")
else:
print(f"{len(failed)} instance images failed to build.")
return successful, failed
def apply_logger_patch():
"""应用 logger patch"""
import swebench.harness.docker_build as docker_build_module
original_build_instance_images = docker_build_module.build_instance_images
docker_build_module.build_instance_images = patched_build_instance_images
return original_build_instance_images
def restore_logger_patch(original_function):
"""Recover original logger actions"""
import swebench.harness.docker_build as docker_build_module
docker_build_module.build_instance_images = original_function
# 上下文管理器版本
class LoggerPatch:
"""Context manager"""
def __init__(self, logger: Optional[Union[logging.Logger, CustomLogger]] = None):
self.logger = logger
self.original_function = None
def __enter__(self):
self.original_function = apply_logger_patch()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.original_function:
restore_logger_patch(self.original_function)

View File

@@ -0,0 +1,132 @@
"""
Redirect print output of a third party repo to a specific log
"""
import sys
import logging
import re
from typing import Union, Set, Pattern
from typing import Optional, TextIO
from pathlib import Path
from src.managers.log.logger import Logger as CustomLogger
class StreamWrapper:
"""Simple stream wrapperfor redirecting stdout/stderr"""
def __init__(self, original_stream, write_func):
self.original_stream = original_stream
self.write_func = write_func
self.buffer = ""
def write(self, text):
self.buffer += text
while '\n' in self.buffer:
line, self.buffer = self.buffer.split('\n', 1)
if line.strip():
self.write_func(line + '\n')
return len(text)
def flush(self):
if self.buffer:
if self.buffer.strip():
self.write_func(self.buffer)
self.buffer = ""
if hasattr(self.original_stream, 'flush'):
self.original_stream.flush()
def __getattr__(self, name):
return getattr(self.original_stream, name)
class PrintRedirector:
"""Redirect print output of a third party repo to a specific log"""
TRACEBACK_START_PATTERN = re.compile(r'Traceback\s*\(most recent call last\):', re.IGNORECASE)
EXCEPTION_END_PATTERN = re.compile(
r'\b(Error|Exception|KeyError|ValueError|TypeError|AttributeError|'
r'ImportError|BuildError|DockerError|IndentationError|SyntaxError|'
r'RuntimeError|OSError|FileNotFoundError|PermissionError)\s*:',
re.IGNORECASE
)
ERROR_KEYWORDS = {'error', 'failed', 'exception', 'fatal', 'critical'}
WARNING_KEYWORDS = {'warning', 'warn', 'deprecated'}
SKIP_KEYWORDS = {'skipping', 'skip', 'ignoring', 'ignore'}
def __init__(self, logger: Union[logging.Logger, CustomLogger]):
self.logger = logger
self.original_print = None
self.original_stdout = None
self.original_stderr = None
self.traceback_buffer = []
self.in_traceback = False
def __enter__(self):
self.start_redirect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_redirect()
def start_redirect(self):
if isinstance(__builtins__, dict):
self.original_print = __builtins__['print']
else:
self.original_print = getattr(__builtins__, 'print')
if isinstance(__builtins__, dict):
__builtins__['print'] = self._redirected_print
else:
setattr(__builtins__, 'print', self._redirected_print)
def stop_redirect(self):
if self.original_print:
if isinstance(__builtins__, dict):
__builtins__['print'] = self.original_print
else:
setattr(__builtins__, 'print', self.original_print)
def _redirected_print(self, *args, **kwargs):
if not args:
return
output = ' '.join(str(arg) for arg in args)
if self.TRACEBACK_START_PATTERN.search(output):
self.in_traceback = True
self.traceback_buffer = [output]
return
if self.in_traceback:
self.traceback_buffer.append(output)
if self.EXCEPTION_END_PATTERN.search(output):
self._log_traceback_and_reset()
return
self._log_by_level(output)
def _log_traceback_and_reset(self):
full_traceback = '\n'.join(self.traceback_buffer)
self.logger.error(f"[Third party repo exception stack]\n{full_traceback}")
self.in_traceback = False
self.traceback_buffer.clear()
def _log_by_level(self, output: str):
output_lower = output.lower()
if any(keyword in output_lower for keyword in self.ERROR_KEYWORDS):
self.logger.error(f"[Third party] {output}")
elif any(keyword in output_lower for keyword in self.WARNING_KEYWORDS):
self.logger.warning(f"[Third party] {output}")
elif any(keyword in output_lower for keyword in self.SKIP_KEYWORDS):
self.logger.info(f"[Third party] {output}")
else:
self.logger.info(f"[Third party] {output}")
def redirect_swebench_prints(logger: Union[logging.Logger, CustomLogger]):
return PrintRedirector(logger)

View File

@@ -0,0 +1,98 @@
"""
LLM API Management Module
Provide a standard OpenAI format LLM API interface that supports:
- Unified Chat Completions endpoint
- Synchronous and asynchronous operations
- Streaming responses
- Tool Calling
- Error handling and retry mechanism(s)
Supported providers:
- OpenAI: OpenAI API and compatible services
- Anthropic: Claude models
- DeepSeek: DeepSeek models
- Private: Private modelsvLLM、TGI、Ollama etc.
Usage example::
# 1: use general manager (suggested)
from llm_api import LLMAPIManager
# create manager
manager = LLMAPIManager(
client_name="openai",
model_name="gpt-3.5-turbo",
stream=False
)
# sendf messages
response = manager.chat("Hello world!")
print(response)
# 2: Use client directly
from llm_api import OpenAIClient, ChatMessage, MessageRole
# creat client
client = OpenAIClient(api_key="your-api-key")
# create4 message
messages = [
ChatMessage(role=MessageRole.USER, content="你好,世界!")
]
# send request
request = client.create_request(messages=messages, model="gpt-3.5-turbo")
response = client.chat_completions_create(request)
print(response.choices[0].message.content)
"""
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
MessageRole,
Choice,
Usage,
)
# 客户端实现
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
# API 管理器
from src.managers.llm_api.api_manager import (
LLMAPIManager,
create_manager,
create_common_manager,
COMMON_CONFIGS,
)
__all__ = [
"BaseLLMAPI",
"ChatMessage",
"ChatCompletionRequest",
"ChatCompletionResponse",
"ChatCompletionChunk",
"MessageRole",
"Choice",
"Usage",
"OpenAIClient",
"AnthropicClient",
"DeepSeekClient",
"OpenRouterClient",
"PrivateModelClient",
"LLMAPIManager",
"create_manager",
"create_common_manager",
"COMMON_CONFIGS",
]
__version__ = "1.0.0"
__author__ = "Tokfinity Team"
__description__ = "标准 OpenAI 格式的 LLM API 基类库"

View File

@@ -0,0 +1,479 @@
"""
LLM API manager
"""
import os
import time
from typing import Dict, List, Any, Optional, Union, Generator
from dotenv import load_dotenv
import yaml
from traceback import format_exc
from .base_client import (
BaseLLMAPI,
ChatMessage,
MessageRole,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
EmbeddingRequest,
EmbeddingResponse,
)
# 导入所有客户端
from .clients.openai.openai_client import OpenAIClient
from .clients.anthropic.anthropic_client import AnthropicClient
from .clients.deepseek.deepseek_client import DeepSeekClient
from .clients.openrouter.openrouter_client import OpenRouterClient
from .clients.private.private_client import PrivateModelClient
class LLMAPIManager:
SUPPORTED_CLIENTS = {
"openai": OpenAIClient,
"anthropic": AnthropicClient,
"deepseek": DeepSeekClient,
"openrouter": OpenRouterClient,
"private": PrivateModelClient,
}
DEFAULT_CONFIGS = {
"openai": {"base_url": None, "api_key_env": "OPENAI_API_KEY"},
"anthropic": {"base_url": None, "api_key_env": "ANTHROPIC_API_KEY"},
"deepseek": {"base_url": None, "api_key_env": "DEEPSEEK_API_KEY"},
"openrouter": {
"base_url": None,
"api_key_env": "OPENROUTER_API_KEY",
"extra_config": {
"app_name": "tokfinity-llm-client",
"site_url": "https://github.com/your-repo",
},
},
"private": {
"base_url": "http://localhost:8000/v1",
"api_key_env": "PRIVATE_API_KEY",
"extra_config": {"deployment_type": "vllm"},
},
}
def __init__(
self,
client_name: Optional[str] = None,
stream: bool = False,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
auto_load_env: bool = True,
logger: Optional[Any] = None,
config: Optional[Dict[str, Any]] = None,
**kwargs,
):
# if client not provided, get first provider from default providers in config
if client_name is None:
default_client, default_model = (
self._load_default_client_and_model_from_config()
)
self.client_name = default_client
self.default_model = default_model
else:
self.client_name = client_name.lower()
self.default_model = None
self.stream = stream
self.timeout = timeout
self.max_retries = max_retries
self.logger = logger
self.logger.info(f"[LLMAPIManager]: Using client: {self.client_name}, model: {self.default_model}.")
self.config = config
if auto_load_env:
self._load_environment()
if self.client_name not in self.SUPPORTED_CLIENTS:
raise ValueError(
f"Unsupported client: {client_name}"
f"Support client: {list(self.SUPPORTED_CLIENTS.keys())}"
)
self.client = self._create_client(api_key, base_url, logger, **kwargs)
def _load_environment(self) -> None:
"""load environment variables"""
env_paths = [".env", "../.env", "../../.env", "../../../.env"]
for env_path in env_paths:
if os.path.exists(env_path):
load_dotenv(env_path)
break
def _create_client(
self, api_key: Optional[str] = None, base_url: Optional[str] = None, logger: Optional[Any] = None, **kwargs
) -> BaseLLMAPI:
client_class = self.SUPPORTED_CLIENTS[self.client_name]
config = self.DEFAULT_CONFIGS[self.client_name]
if api_key is None:
api_key = os.getenv(config["api_key_env"])
if not api_key:
if self.client_name == "private":
api_key = "EMPTY" # private mode may not need a key
else:
raise ValueError(
f"Fail to find env variable, please set: {config['api_key_env']} "
f"or upload ai_key parameter when initialize"
)
if base_url is None:
env_key = f"{self.client_name.upper()}_BASE_URL"
if self.client_name == "private":
env_key = "PRIVATE_URL"
base_url = os.getenv(env_key)
if base_url is None:
base_url = config.get("base_url")
client_kwargs = {
"api_key": api_key,
"timeout": self.timeout,
"max_retries": self.max_retries,
}
if base_url:
client_kwargs["base_url"] = base_url
if logger is not None:
client_kwargs["logger"] = logger
extra_config = config.get("extra_config", {})
client_kwargs.update(extra_config)
client_kwargs.update(kwargs)
if self.client_name == "openrouter":
client_kwargs.setdefault("app_name", "tokfinity-llm-client")
elif self.client_name == "private":
client_kwargs.setdefault("deployment_type", "vllm")
return client_class(**client_kwargs)
def _load_default_client_and_model_from_config(self) -> (str, Optional[str]):
"""
Get first item from providers from config/config.yaml as the default client
And take the first model as default model
"""
# Parse the root of config file relative to the project root
# This file is in src/managers/llm_api/api_manager.py
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
config_path = os.path.join(base_dir, "config", "config.yaml")
if not os.path.exists(config_path):
raise FileNotFoundError(f"No config file: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
providers = cfg.get("providers")
if not isinstance(providers, dict) or len(providers) == 0:
raise ValueError("config.yaml lack of providers config or format error")
first_provider_name = next(iter(providers.keys()))
models = providers.get(first_provider_name) or []
first_model = (
models[0] if isinstance(models, list) and len(models) > 0 else None
)
client_key = first_provider_name.strip().lower()
if client_key not in self.SUPPORTED_CLIENTS:
raise ValueError(
f"Default provider '{first_provider_name}' in config not registered in SUPPORTED_CLIENTS"
)
return client_key, first_model
def chat(
self,
messages: List[Union[Dict[str, Any], ChatMessage]],
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
timeout: Optional[int] = None,
retry: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
**kwargs,
) -> Union[
ChatCompletionResponse, Generator[ChatCompletionChunk, None, None], None
]:
"""
Send chat message and get response
Args:
model: model name (optional, default self.default_model)
messages: Concatenated message list.
- If list of dict: [{"role": "system|user|assistant|tool", "content": "..."}]
- Or `ChatMessage` list
temperature: Temperature
max_tokens: Max tokens
timeout: Request timeout, use the value from initialization if not specified
retry: Max retries, use the value from initialization if not specified
tools: Tools description list (OpenAI tools format)
tool_choice: Tool choice strategy"auto" | "none" | {"type":..., ...}
**kwargs: Other request args
Returns:
Union[ChatCompletionResponse, Generator[ChatCompletionChunk, None, None], None]:
- Non-streaming: Return the complete ChatCompletionResponse object
- Stream: Return a streaming response generator
- Fail: Return None
"""
actual_timeout = timeout if timeout is not None else self.timeout
actual_retry = retry if retry is not None else self.max_retries
actual_model = model if model is not None else self.default_model
if not actual_model:
raise ValueError("Model unprovided, and cannot get default model from config")
normalized_messages: List[ChatMessage] = []
for msg in messages:
if isinstance(msg, ChatMessage):
if msg.content is None:
msg.content = ""
normalized_messages.append(msg)
else:
role_value = msg.get("role")
content_value = msg.get("content")
if content_value is None:
content_value = ""
role_enum = MessageRole(role_value)
normalized_messages.append(
ChatMessage(
role=role_enum,
content=content_value,
name=msg.get("name"),
tool_calls=msg.get("tool_calls"),
tool_call_id=msg.get("tool_call_id"),
)
)
last_exception = None
for attempt in range(actual_retry + 1):
try:
request = self.client.create_request(
messages=normalized_messages,
model=actual_model,
temperature=temperature,
max_tokens=max_tokens,
stream=self.stream,
tools=tools,
tool_choice=tool_choice,
config=self.config,
**kwargs,
)
response = self.client.chat_completions_create(request)
if self.stream:
# Streaming response: has not processed
return response
else:
# Non-streaming response: return complete ChatCompletionResponse
return response
except Exception as e:
last_exception = e
if attempt < actual_retry:
delay = min(2**attempt, 30)
if self.logger:
self.logger.warning(
f"{attempt + 1}th try failedretry after {delay} seconds: {str(e)}, traceback: {format_exc()}."
)
time.sleep(delay)
else:
if self.logger:
self.logger.error(
f"All {actual_retry + 1} tries failed: {str(e)}"
)
return None
def chat_simple(
self,
model: str,
messages: List[Union[Dict[str, Any], ChatMessage]],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
timeout: Optional[int] = None,
retry: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
**kwargs,
) -> Optional[str]:
# Ban stream
original_stream = self.stream
self.stream = False
try:
response = self.chat(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
retry=retry,
tools=tools,
tool_choice=tool_choice,
**kwargs,
)
if response and hasattr(response, "choices") and response.choices:
return response.choices[0].message.content
return None
finally:
# Recover original stream settings
self.stream = original_stream
def create_embeddings(
self,
input_text: Union[str, List[str]],
model: str,
encoding_format: str = "float",
dimensions: Optional[int] = None,
user: Optional[str] = None,
timeout: Optional[int] = None,
retry: Optional[int] = None,
**kwargs,
) -> Optional[EmbeddingResponse]:
actual_timeout = timeout if timeout is not None else self.timeout
actual_retry = retry if retry is not None else self.max_retries
if isinstance(input_text, str):
text_list = [input_text]
single_input = True
else:
text_list = input_text
single_input = False
if self.logger:
self.logger.debug(
f"Start embedding request - client: {self.client_name}, model: {model}, text num: {len(text_list)}"
)
if not hasattr(self.client, "create_embeddings"):
error_msg = f"Client {self.client_name} not support embedding"
if self.logger:
self.logger.error(error_msg)
return None
last_exception = None
for attempt in range(actual_retry + 1):
try:
if self.logger:
self.logger.debug(f"{attempt + 1}th try to create embedding vector")
response = self.client.create_embeddings(
input_text=text_list,
model=model,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
timeout=actual_timeout,
max_retries=1,
**kwargs,
)
if response:
return response
else:
raise Exception("Client return empty response")
except Exception as e:
last_exception = e
if attempt < actual_retry:
delay = min(2**attempt, 30)
if self.logger:
self.logger.warning(
f"{attempt + 1}th embedding request failedretry after {delay}s: {str(e)}, traceback: {format_exc()}."
)
time.sleep(delay)
else:
if self.logger:
self.logger.error(
f"All {actual_retry + 1} embedding tries failed: {str(e)}, traceback: {format_exc()}."
)
return None
def get_client_info(self) -> Dict[str, Any]:
return {
"client_name": self.client_name,
"stream": self.stream,
"client_info": (
self.client.get_model_info()
if hasattr(self.client, "get_model_info")
else {}
),
}
def get_model_name(self) -> str:
return self.default_model or "unknown"
def close(self) -> None:
if hasattr(self.client, "close"):
self.client.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __str__(self) -> str:
return f"LLMAPIManager(client={self.client_name}, stream={self.stream})"
def __repr__(self) -> str:
return (
f"LLMAPIManager("
f"client_name='{self.client_name}', "
f"stream={self.stream})"
)
# 便捷函数
def create_manager(
client_name: str, stream: bool = False, logger: Optional[Any] = None, **kwargs
) -> LLMAPIManager:
return LLMAPIManager(
client_name=client_name, stream=stream, logger=logger, **kwargs
)
COMMON_CONFIGS = {
"openai_gpt4": {"client_name": "openai", "model_name": "gpt-4o"},
"openai_gpt35": {"client_name": "openai", "model_name": "gpt-3.5-turbo"},
"claude_sonnet": {
"client_name": "anthropic",
"model_name": "claude-3-5-sonnet-20241022",
},
"claude_haiku": {
"client_name": "anthropic",
"model_name": "claude-3-haiku-20240307",
},
"deepseek_chat": {"client_name": "deepseek", "model_name": "deepseek-chat"},
"deepseek_coder": {"client_name": "deepseek", "model_name": "deepseek-coder"},
}
def create_common_manager(
config_name: str, stream: bool = False, logger: Optional[Any] = None, **kwargs
) -> LLMAPIManager:
if config_name not in COMMON_CONFIGS:
raise ValueError(
f"Unknown config: {config_name}. Available configs: {list(COMMON_CONFIGS.keys())}"
)
config = COMMON_CONFIGS[config_name]
return LLMAPIManager(
client_name=config["client_name"], stream=stream, logger=logger, **kwargs
)

View File

@@ -0,0 +1,594 @@
"""
LLM API base class - standard OpenAI format
Multi LLM providers' universal API supported
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Union, AsyncGenerator, Generator
from dataclasses import dataclass, field
from enum import Enum
import json
import time
import logging
import requests
import asyncio
from traceback import format_exc
from src.managers.log.logger import Logger
class MessageRole(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@dataclass
class ChatMessage:
role: MessageRole
content: Union[str, List[Dict[str, Any]]]
name: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None
@dataclass
class ChatCompletionRequest:
messages: List[ChatMessage]
model: str
temperature: float = 0.7
max_tokens: Optional[int] = None
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
response_format: Optional[Dict[str, Any]] = None
seed: Optional[int] = None
user: Optional[str] = None
@dataclass
class Usage:
prompt_tokens: int
completion_tokens: int
total_tokens: int
@dataclass
class Choice:
index: int
message: ChatMessage
finish_reason: Optional[str] = None
logprobs: Optional[Dict[str, Any]] = None
@dataclass
class ChatCompletionResponse:
id: str
object: str
created: int
model: str
choices: List[Choice]
usage: Optional[Usage] = None
system_fingerprint: Optional[str] = None
@dataclass
class ChatCompletionChunk:
id: str
object: str
created: int
model: str
choices: List[Dict[str, Any]]
system_fingerprint: Optional[str] = None
@dataclass
class EmbeddingRequest:
input: Union[str, List[str]]
model: str
encoding_format: str = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
@dataclass
class EmbeddingData:
object: str
embedding: List[float]
index: int
@dataclass
class EmbeddingUsage:
prompt_tokens: int
total_tokens: int
@dataclass
class EmbeddingResponse:
object: str
data: List[EmbeddingData]
model: str
usage: EmbeddingUsage
class BaseLLMAPI(ABC):
"""
LLM API base class
Provide standard OpenAI format API, support:
- Synchronous/Asynchronous Chat Completions
- Streaming Responses
- Tool Calling
- Error handling and Retry
- Usage Statics
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
logger: Optional[Logger] = None,
**kwargs,
):
self.api_key = api_key
self.base_url = base_url
self.timeout = timeout
self.max_retries = max_retries
self.retry_delay = retry_delay
self.extra_config = kwargs
self.logger = logger
self.session: Optional[requests.Session] = None
self._initialize_client()
def _create_http_clients(self, headers: Dict[str, str]) -> None:
self.session = requests.Session()
self.session.headers.update(headers)
self.session.timeout = self.timeout
@abstractmethod
def _initialize_client(self) -> None:
pass
@abstractmethod
def _get_chat_endpoint(self) -> str:
pass
@abstractmethod
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
pass
@abstractmethod
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
pass
@abstractmethod
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
pass
def chat_completions_create(
self, request: ChatCompletionRequest
) -> Union[ChatCompletionResponse, Generator[ChatCompletionChunk, None, None]]:
self.validate_request(request)
def _make_request():
payload = self._build_request_payload(request)
endpoint = self._get_chat_endpoint()
if request.stream:
return self._stream_chat_completion(payload, endpoint)
else:
full_url = self.base_url.rstrip("/") + endpoint
headers = {}
if self.session and self.session.headers:
headers.update(self.session.headers)
else:
headers = {
"Content-Type": "application/json",
}
if self.api_key and self.api_key != "EMPTY":
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(
full_url, json=payload, headers=headers, timeout=self.timeout
)
if response.status_code != 200:
error_msg = f"API Error {response.status_code}: {response.text}"
print(error_msg)
raise requests.exceptions.HTTPError(error_msg, response=response)
response.raise_for_status()
return self._parse_response(response.json())
return self._retry_with_backoff(_make_request)
async def achat_completions_create(
self, request: ChatCompletionRequest
) -> Union[ChatCompletionResponse, AsyncGenerator[ChatCompletionChunk, None]]:
self.validate_request(request)
def _make_async_request():
payload = self._build_request_payload(request)
endpoint = self._get_chat_endpoint()
if request.stream:
return self._stream_chat_completion(
payload, endpoint
)
else:
full_url = self.base_url.rstrip("/") + endpoint
headers = {}
if self.session and self.session.headers:
headers.update(self.session.headers)
else:
headers = {
"Content-Type": "application/json",
}
if self.api_key and self.api_key != "EMPTY":
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(
full_url, json=payload, headers=headers, timeout=self.timeout
)
response.raise_for_status()
return self._parse_response(response.json())
import asyncio
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _make_async_request)
def create_message(self, role: MessageRole, content: str, config: Optional[Dict[str, Any]] = None, **kwargs) -> ChatMessage:
return ChatMessage(role=role, content=content, **kwargs)
def _make_token_cache_request(self, messages: List[ChatMessage], model: str, config: Optional[Dict[str, Any]] = None) -> list[ChatMessage]:
"""
Insert cache_control block to Claude/Anthropic model according to config, forward compatibility
- Only activate when model is "Claude/Anthropic"
- Preserve the content as is for other providers to avoid breaking compatibility.
- If token_cache.role_turns is configured, inject it matching the turn and role; otherwise, only inject it into the first system message.
"""
role_turns = self._load_role_turns(config)
#if self.logger:
# self.logger.debug(f"In _make_token_cache_request, got role_turns: {role_turns}.")
if not self._is_claude_like(model):
return messages
turn_to_roles: Dict[int, List[str]] = {}
try:
for rt in role_turns or []:
try:
turn = int(rt.get("turn"))
role = (rt.get("role") or "").strip().lower()
if role:
if turn not in turn_to_roles:
turn_to_roles[turn] = []
turn_to_roles[turn].append(role)
except Exception:
continue
except Exception:
turn_to_roles = {}
current_turn = 0
result_messages: List[ChatMessage] = []
for msg in messages:
msg_blocks = self._to_claude_blocks(msg.content)
# 0th turn rule: If met user, only add cache to its own content and put it into result.
if current_turn == 0 and msg.role == MessageRole.USER:
cached_blocks = self._add_cache_flag(msg_blocks)
result_messages.append(
ChatMessage(
role=msg.role,
content=cached_blocks,
name=msg.name,
tool_calls=msg.tool_calls,
tool_call_id=msg.tool_call_id,
)
)
#if self.logger:
# self.logger.debug(
# f"Applied cache to initial user message at turn {current_turn}."
# )
continue
# Other messages: just add it to result
result_messages.append(
ChatMessage(
role=msg.role,
content=msg_blocks,
name=msg.name,
tool_calls=msg.tool_calls,
tool_call_id=msg.tool_call_id,
)
)
# Hit anchor point: When the current turn's configuration includes the assistant and the current message is from the assistant
# add an extra system turn.
roles_for_turn = turn_to_roles.get(current_turn, [])
if msg.role == MessageRole.ASSISTANT and roles_for_turn and ("assistant" in roles_for_turn):
refresh_text = f"refresh cache tokens."
refresh_blocks = self._to_claude_blocks(refresh_text)
refresh_blocks = self._add_cache_flag(refresh_blocks)
result_messages.append(
ChatMessage(
role=MessageRole.SYSTEM,
content=refresh_blocks,
)
)
#if self.logger:
# self.logger.debug(
# f"Appended system refresh after assistant at turn {current_turn}, refresh_blocks: {refresh_blocks}."
# )
# assistant message make the round +1
if msg.role == MessageRole.ASSISTANT:
current_turn += 1
return result_messages
def _to_claude_blocks(self, content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
if isinstance(content, list):
return content
return [{"type": "text", "text": content}]
def _add_cache_flag(self, blocks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
out: List[Dict[str, Any]] = []
added = False
for blk in blocks:
if not added and isinstance(blk, dict) and blk.get("type") == "text":
out.append({**blk, "cache_control": {"type": "ephemeral"}})
added = True
else:
out.append(blk)
return out
def _is_claude_like(self, model: str) -> bool:
try:
model_lc = (model or "").lower()
return ("claude" in model_lc) or ("anthropic" in model_lc)
except Exception:
return False
def _load_role_turns(self, config: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""Load token_cache.role_turnsPriorityInvoker config > default value"""
try:
role_turns = None
# 1) Read from invoker's config
if config and isinstance(config, dict):
token_cache_cfg = config.get("token_cache")
if isinstance(token_cache_cfg, dict):
role_turns = token_cache_cfg.get("role_turns")
# 2) If still lack, use default value in current example config
if not role_turns:
role_turns = [
{"turn": 0, "role": "user"},
{"turn": 25, "role": "assistant"},
{"turn": 50, "role": "assistant"},
{"turn": 75, "role": "assistant"},
]
return role_turns
except Exception as e:
if self.logger:
self.logger.warning(f"Read token_cache.role_turns failuse default value: {e}")
return [
{"turn": 0, "role": "user"},
{"turn": 25, "role": "assistant"},
{"turn": 50, "role": "assistant"},
{"turn": 75, "role": "assistant"},
]
def create_request(
self, messages: List[ChatMessage], model: str, config: Optional[Dict[str, Any]] = None, **kwargs
) -> ChatCompletionRequest:
messages = self._make_token_cache_request(messages, model, config)
return ChatCompletionRequest(messages=messages, model=model, **kwargs)
def _handle_error(self, error: Exception) -> None:
if self.logger:
self.logger.error(f"API request failed: {error}")
else:
print(f"API request failed: {error}")
raise error
def _retry_with_backoff(self, func, *args, **kwargs):
last_exception = None
for attempt in range(self.max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < self.max_retries:
delay = self.retry_delay * (2**attempt) # 指数退避
if self.logger:
self.logger.warning(
f"{attempt + 1}th try failedretry after {delay}s: {e}, traceback: {format_exc()}."
)
time.sleep(delay)
else:
if self.logger:
self.logger.error(f"All retries failed: {e}, traceback: {format_exc()}.")
raise last_exception
def get_model_info(self) -> Dict[str, Any]:
return {
"provider": self.__class__.__name__,
"base_url": self.base_url,
"timeout": self.timeout,
"max_retries": self.max_retries,
}
def validate_request(self, request: ChatCompletionRequest) -> bool:
if not request.messages:
raise ValueError("Message list is empty")
if not request.model:
raise ValueError("Model name is empty")
for idx, message in enumerate(request.messages):
if not message.content and not message.tool_calls:
try:
msg_info = {
"index": idx,
"role": getattr(message.role, "value", str(message.role)),
"content_len": (
len(message.content)
if isinstance(message.content, str)
else (len(message.content) if isinstance(message.content, list) else 0)
),
"has_tool_calls": bool(message.tool_calls),
"tool_calls": message.tool_calls,
}
if self.logger:
self.logger.warning(
f"Request validation failed: Invalid message exists (lacking both content and tool_calls): {json.dumps(msg_info, ensure_ascii=False)}"
)
except Exception as e:
if self.logger:
self.logger.warning(
f"Request validation failed: Invalid message exists (lacking both content and tool_calls), index={idx}, error: {e}, traceback: {format_exc()}."
)
raise ValueError("Cannot lacking both content and tool_calls")
return True
def format_messages_for_api(
self, messages: List[ChatMessage]
) -> List[Dict[str, Any]]:
formatted_messages = []
for message in messages:
msg_dict = {"role": message.role.value, "content": message.content}
if message.name:
msg_dict["name"] = message.name
if message.tool_calls:
msg_dict["tool_calls"] = message.tool_calls
if message.tool_call_id:
msg_dict["tool_call_id"] = message.tool_call_id
formatted_messages.append(msg_dict)
return formatted_messages
def _stream_chat_completion(
self, payload: Dict[str, Any], endpoint: str
) -> Generator[ChatCompletionChunk, None, None]:
full_url = self.base_url.rstrip("/") + endpoint
headers = {}
if self.session and self.session.headers:
headers.update(self.session.headers)
else:
headers = {
"Content-Type": "application/json",
}
if self.api_key and self.api_key != "EMPTY":
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(
full_url, json=payload, headers=headers, timeout=self.timeout, stream=True
)
response.raise_for_status()
try:
line_count = 0
for line in response.iter_lines():
if line:
line_count += 1
line_str = line.decode("utf-8")
chunk = self._parse_stream_line(line_str)
if chunk:
yield chunk
else:
print(f"Jump invalid line")
print(f"Streaming request process done, {line_count} processed")
finally:
response.close()
async def _astream_chat_completion(
self, payload: Dict[str, Any], endpoint: str
) -> AsyncGenerator[ChatCompletionChunk, None]:
import asyncio
def _sync_stream():
return list(self._stream_chat_completion(payload, endpoint))
loop = asyncio.get_event_loop()
chunks = await loop.run_in_executor(None, _sync_stream)
for chunk in chunks:
yield chunk
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
# Process standard SSE format
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
return None
try:
chunk_data = json.loads(data)
return self._parse_stream_chunk(chunk_data)
except json.JSONDecodeError:
self.logger.warning(f"Unable to parse streaming data: {data}")
return None
return None
def close(self) -> None:
if self.session:
self.session.close()
async def aclose(self) -> None:
if self.session:
self.session.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
def __str__(self) -> str:
return f"{self.__class__.__name__}(base_url={self.base_url})"
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"base_url={self.base_url}, "
f"timeout={self.timeout}, "
f"max_retries={self.max_retries})"
)

View File

@@ -0,0 +1,23 @@
"""
LLM Client implement module
Diverse LLM provider
- OpenAI: OpenAI API and compatible service
- Anthropic: Claude models
- DeepSeek: DeepSeek models
- Private: Private modelsvLLM、TGI、Ollama etc.
"""
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
__all__ = [
"OpenAIClient",
"AnthropicClient",
"DeepSeekClient",
"OpenRouterClient",
"PrivateModelClient",
]

View File

@@ -0,0 +1,7 @@
"""
Anthropic Claude Client Module
"""
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
__all__ = ["AnthropicClient"]

View File

@@ -0,0 +1,288 @@
"""
Anthropic Claude Client Module
"""
import json
import time
from typing import Dict, List, Any, Optional, Union
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
)
class AnthropicClient(BaseLLMAPI):
"""
Anthropic Claude API Client
Anthropic Claude models supported
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
anthropic_version: str = "2023-06-01",
logger: Optional[Logger] = None,
**kwargs
):
"""
Initialize Anthropic client
Args:
api_key: Anthropic API key
base_url: API basic URLdefault Anthropic API
timeout: timeout in seconds
max_retries: max retries
retry_delay: retry delay in seconds
anthropic_version: API version
**kwargs: other args
"""
self.anthropic_version = anthropic_version
if base_url is None:
base_url = "https://api.anthropic.com"
super().__init__(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs
)
def _initialize_client(self) -> None:
"""Initialize HTTP client"""
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
"anthropic-version": self.anthropic_version,
"User-Agent": "tokfinity-llm-client/1.0",
}
self._create_http_clients(headers)
def _get_chat_endpoint(self) -> str:
return "/v1/messages"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
"""Build Anthropic API request payload"""
messages, system_prompt = self._convert_messages_to_anthropic_format(
request.messages
)
payload = {
"model": request.model,
"messages": messages,
"max_tokens": request.max_tokens or 1024,
"temperature": request.temperature,
"top_p": request.top_p,
"stream": request.stream,
}
if system_prompt:
payload["system"] = system_prompt
if request.stop is not None:
payload["stop_sequences"] = (
request.stop if isinstance(request.stop, list) else [request.stop]
)
if request.tools is not None:
payload["tools"] = self._convert_tools_to_anthropic_format(request.tools)
if request.tool_choice is not None:
payload["tool_choice"] = self._convert_tool_choice_to_anthropic_format(
request.tool_choice
)
return payload
def _convert_messages_to_anthropic_format(
self, messages: List[ChatMessage]
) -> tuple[List[Dict[str, Any]], Optional[str]]:
"""Convert messages to anthropic format"""
anthropic_messages = []
system_prompt = None
for message in messages:
if message.role == MessageRole.SYSTEM:
system_prompt = message.content
elif message.role in [MessageRole.USER, MessageRole.ASSISTANT]:
anthropic_messages.append(
{"role": message.role.value, "content": message.content}
)
return anthropic_messages, system_prompt
def _convert_tools_to_anthropic_format(
self, tools: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Convert OpenAI format tools to anthropic format"""
anthropic_tools = []
for tool in tools:
if tool.get("type") == "function":
function_def = tool.get("function", {})
anthropic_tool = {
"name": function_def.get("name", ""),
"description": function_def.get("description", ""),
"input_schema": function_def.get("parameters", {}),
}
anthropic_tools.append(anthropic_tool)
return anthropic_tools
def _convert_tool_choice_to_anthropic_format(
self, tool_choice: Union[str, Dict[str, Any]]
) -> Union[str, Dict[str, Any]]:
"""Convert OpenAI format tool choice to anthropic format"""
if isinstance(tool_choice, str):
if tool_choice == "auto":
return "auto"
elif tool_choice == "none":
return "none"
else:
return {"type": "tool", "name": tool_choice}
elif isinstance(tool_choice, dict):
if tool_choice.get("type") == "function":
return {
"type": "tool",
"name": tool_choice.get("function", {}).get("name", ""),
}
elif tool_choice.get("type") == "tool":
return tool_choice
return "auto"
def _convert_anthropic_tool_calls(
self, content_list: List[Dict[str, Any]]
) -> Optional[List[Dict[str, Any]]]:
"""Convert anthropic tool calls to OpenAI format"""
tool_calls = []
for item in content_list:
if item.get("type") == "tool_use":
tool_call = {
"id": item.get("id", ""),
"type": "function",
"function": {
"name": item.get("name", ""),
"arguments": item.get("input", {}),
},
}
tool_calls.append(tool_call)
return tool_calls if tool_calls else None
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
"""Convert anthropic API response to OpenAI format"""
content = ""
tool_calls = None
if response_data.get("content"):
content_data = (
response_data["content"][0] if response_data["content"] else {}
)
content = content_data.get("text", "")
if content_data.get("type") == "tool_use":
tool_calls = self._convert_anthropic_tool_calls(
response_data.get("content", [])
)
message = ChatMessage(
role=MessageRole.ASSISTANT, content=content, tool_calls=tool_calls
)
choice = Choice(
index=0,
message=message,
finish_reason=self._convert_stop_reason(response_data.get("stop_reason")),
)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("input_tokens", 0),
completion_tokens=usage_data.get("output_tokens", 0),
total_tokens=usage_data.get("input_tokens", 0)
+ usage_data.get("output_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object="chat.completion",
created=int(time.time()),
model=response_data.get("model", ""),
choices=[choice],
usage=usage,
)
def _convert_stop_reason(self, stop_reason: Optional[str]) -> Optional[str]:
"""Convert anthropic stop reason to OpenAI format"""
if stop_reason == "end_turn":
return "stop"
elif stop_reason == "max_tokens":
return "length"
elif stop_reason == "stop_sequence":
return "stop"
else:
return stop_reason
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
"""Convert anthropic steam response to OpenAI format"""
event_type = chunk_data.get("type")
if event_type == "content_block_delta":
delta = chunk_data.get("delta", {})
text = delta.get("text", "")
if text:
choices = [
{"index": 0, "delta": {"content": text}, "finish_reason": None}
]
return ChatCompletionChunk(
id=chunk_data.get("message", {}).get("id", ""),
object="chat.completion.chunk",
created=int(time.time()),
model=chunk_data.get("message", {}).get("model", ""),
choices=choices,
)
elif event_type == "message_stop":
choices = [{"index": 0, "delta": {}, "finish_reason": "stop"}]
return ChatCompletionChunk(
id=chunk_data.get("message", {}).get("id", ""),
object="chat.completion.chunk",
created=int(time.time()),
model=chunk_data.get("message", {}).get("model", ""),
choices=choices,
)
return None
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
try:
chunk_data = json.loads(line)
return self._parse_stream_chunk(chunk_data)
except json.JSONDecodeError:
# if not json, try SSE
return super()._parse_stream_line(line)

View File

@@ -0,0 +1,7 @@
"""
DeepSeek Client Module
"""
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
__all__ = ["DeepSeekClient"]

View File

@@ -0,0 +1,164 @@
"""
DeepSeek Client Module
"""
import time
from typing import Dict, List, Any, Optional
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
)
class DeepSeekClient(BaseLLMAPI):
"""
DeepSeek API Client
DeepSeek models supported, including DeepSeek-Coder、DeepSeek-Chat etc.
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
logger: Optional[Logger] = None,
**kwargs,
):
if base_url is None:
base_url = "https://api.deepseek.com/v1"
super().__init__(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs,
)
def _initialize_client(self) -> None:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"User-Agent": "tokfinity-llm-client/1.0",
}
self._create_http_clients(headers)
def _get_chat_endpoint(self) -> str:
return "/chat/completions"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"temperature": request.temperature,
"top_p": request.top_p,
"stream": request.stream,
}
# optional
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.stop is not None:
payload["stop"] = request.stop
if request.frequency_penalty != 0.0:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty != 0.0:
payload["presence_penalty"] = request.presence_penalty
if request.tools is not None:
payload["tools"] = request.tools
if request.tool_choice is not None:
payload["tool_choice"] = request.tool_choice
if request.response_format is not None:
payload["response_format"] = request.response_format
if request.seed is not None:
payload["seed"] = request.seed
if request.user is not None:
payload["user"] = request.user
return payload
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
choices = []
for choice_data in response_data.get("choices", []):
message_data = choice_data.get("message", {})
tool_calls = message_data.get("tool_calls")
message = ChatMessage(
role=MessageRole(message_data.get("role", "assistant")),
content=message_data.get("content", ""),
tool_calls=tool_calls,
)
choice = Choice(
index=choice_data.get("index", 0),
message=message,
finish_reason=choice_data.get("finish_reason"),
logprobs=choice_data.get("logprobs"),
)
choices.append(choice)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object=response_data.get("object", "chat.completion"),
created=response_data.get("created", int(time.time())),
model=response_data.get("model", ""),
choices=choices,
usage=usage,
system_fingerprint=response_data.get("system_fingerprint"),
)
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
"""Parse stream chunk"""
return ChatCompletionChunk(
id=chunk_data.get("id", ""),
object=chunk_data.get("object", "chat.completion.chunk"),
created=chunk_data.get("created", int(time.time())),
model=chunk_data.get("model", ""),
choices=chunk_data.get("choices", []),
system_fingerprint=chunk_data.get("system_fingerprint"),
)
def list_models(self) -> Dict[str, Any]:
"""Get available models"""
response = self.client.get("/models")
response.raise_for_status()
return response.json()
async def alist_models(self) -> Dict[str, Any]:
response = await self.async_client.get("/models")
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,7 @@
"""
OpenAI Client Module
"""
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
__all__ = ["OpenAIClient"]

View File

@@ -0,0 +1,279 @@
"""
OpenAI Client
"""
import time
from typing import Dict, List, Any, Optional, Union
from traceback import format_exc
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingData,
EmbeddingUsage,
)
class OpenAIClient(BaseLLMAPI):
"""
OpenAI API Client
OpenAI API supported, and other API services compatible with the OpenAI format
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
organization: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
logger: Optional[Logger] = None,
**kwargs,
):
self.organization = organization
if base_url is None:
base_url = "https://api.openai.com/v1"
super().__init__(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs,
)
def _initialize_client(self) -> None:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"User-Agent": "tokfinity-llm-client/1.0",
}
if self.organization:
headers["OpenAI-Organization"] = self.organization
self._create_http_clients(headers)
return "/chat/completions"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"temperature": request.temperature,
"top_p": request.top_p,
"frequency_penalty": request.frequency_penalty,
"presence_penalty": request.presence_penalty,
"stream": request.stream,
}
# Optional
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.stop is not None:
payload["stop"] = request.stop
if request.tools is not None:
payload["tools"] = request.tools
if request.tool_choice is not None:
payload["tool_choice"] = request.tool_choice
if request.response_format is not None:
payload["response_format"] = request.response_format
if request.seed is not None:
payload["seed"] = request.seed
if request.user is not None:
payload["user"] = request.user
return payload
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
choices = []
for choice_data in response_data.get("choices", []):
message_data = choice_data.get("message", {})
tool_calls = message_data.get("tool_calls")
message = ChatMessage(
role=MessageRole(message_data.get("role", "assistant")),
content=message_data.get("content", ""),
tool_calls=tool_calls,
)
choice = Choice(
index=choice_data.get("index", 0),
message=message,
finish_reason=choice_data.get("finish_reason"),
logprobs=choice_data.get("logprobs"),
)
choices.append(choice)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object=response_data.get("object", "chat.completion"),
created=response_data.get("created", int(time.time())),
model=response_data.get("model", ""),
choices=choices,
usage=usage,
system_fingerprint=response_data.get("system_fingerprint"),
)
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
return ChatCompletionChunk(
id=chunk_data.get("id", ""),
object=chunk_data.get("object", "chat.completion.chunk"),
created=chunk_data.get("created", int(time.time())),
model=chunk_data.get("model", ""),
choices=chunk_data.get("choices", []),
system_fingerprint=chunk_data.get("system_fingerprint"),
)
def list_models(self) -> Dict[str, Any]:
response = self.client.get("/models")
response.raise_for_status()
return response.json()
async def alist_models(self) -> Dict[str, Any]:
response = await self.async_client.get("/models")
response.raise_for_status()
return response.json()
def create_embeddings(
self,
input_text: Union[str, List[str]],
model: str,
encoding_format: str = "float",
dimensions: Optional[int] = None,
user: Optional[str] = None,
timeout: Optional[int] = None,
max_retries: Optional[int] = None,
retry_delay: Optional[float] = None,
) -> EmbeddingResponse:
actual_timeout = timeout if timeout is not None else self.timeout
actual_max_retries = (
max_retries if max_retries is not None else self.max_retries
)
actual_retry_delay = (
retry_delay if retry_delay is not None else self.retry_delay
)
request = EmbeddingRequest(
input=input_text,
model=model,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
payload = self._build_embedding_request_payload(request)
for attempt in range(actual_max_retries + 1):
try:
print(f"Debug: Payload: {payload}")
response = self.session.post(
f"{self.base_url}/embeddings", json=payload, timeout=actual_timeout
)
if response.status_code == 200:
response_data = response.json()
return self._parse_embedding_response(response_data)
else:
error_msg = f"embedding failed (try {attempt + 1}): HTTP {response.status_code}"
if hasattr(response, "text"):
error_msg += f" - {response.text}"
print(f"Debug: {error_msg}")
if attempt < actual_max_retries:
print(f"Debug: wait {actual_retry_delay} and retry...")
time.sleep(actual_retry_delay)
continue
else:
raise Exception(f"All retries failed: {error_msg}")
except Exception as e:
error_msg = f"Embedding request failed (try {attempt + 1}): {str(e)}, traceback: {format_exc()}."
print(f"Debug: {error_msg}")
if attempt < actual_max_retries:
print(f"Debug: wait {actual_retry_delay} and retry...")
time.sleep(actual_retry_delay)
continue
else:
raise Exception(f"All retries failed: {str(e)}, traceback: {format_exc()}.")
raise Exception("Unknown error")
def _build_embedding_request_payload(
self, request: EmbeddingRequest
) -> Dict[str, Any]:
payload = {
"input": request.input,
"model": request.model,
"encoding_format": request.encoding_format,
}
if request.dimensions is not None:
payload["dimensions"] = request.dimensions
if request.user is not None:
payload["user"] = request.user
return payload
def _parse_embedding_response(
self, response_data: Dict[str, Any]
) -> EmbeddingResponse:
embedding_data_list = []
for data_item in response_data.get("data", []):
embedding_data = EmbeddingData(
object=data_item.get("object", "embedding"),
embedding=data_item.get("embedding", []),
index=data_item.get("index", 0),
)
embedding_data_list.append(embedding_data)
usage_data = response_data.get("usage", {})
usage = EmbeddingUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return EmbeddingResponse(
object=response_data.get("object", "list"),
data=embedding_data_list,
model=response_data.get("model", ""),
usage=usage,
)

View File

@@ -0,0 +1,8 @@
"""
OpenRouter Client
"""
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
__all__ = ["OpenRouterClient"]

View File

@@ -0,0 +1,329 @@
"""
OpenRouter Client
"""
import time
from typing import Dict, List, Any, Optional
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
)
class OpenRouterClient(BaseLLMAPI):
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
app_name: Optional[str] = None,
site_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
logger: Optional[Logger] = None,
**kwargs,
):
"""
Initialize OpenRouter Client
Args:
api_key: OpenRouter API key
base_url: API base URL, default to OpenRouter official API
app_name: Application name (optional, for statistics)
site_url: Website URL (optional, for statistics)
timeout: Request timeout
max_retries: Maximum number of retries
retry_delay: Retry delay
**kwargs: Other configuration parameters
"""
self.app_name = app_name or "tokfinity-llm-client"
self.site_url = site_url
if base_url is None:
base_url = "https://openrouter.ai/api/v1"
super().__init__(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs,
)
def _initialize_client(self) -> None:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"User-Agent": "tokfinity-llm-client/1.0",
"X-Title": self.app_name,
}
if self.site_url:
headers["HTTP-Referer"] = self.site_url
self._create_http_clients(headers)
def _get_chat_endpoint(self) -> str:
return "/chat/completions"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"temperature": request.temperature,
"top_p": request.top_p,
"stream": request.stream,
}
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.stop is not None:
payload["stop"] = request.stop
if request.frequency_penalty != 0.0:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty != 0.0:
payload["presence_penalty"] = request.presence_penalty
if request.tools is not None:
payload["tools"] = request.tools
if request.tool_choice is not None:
payload["tool_choice"] = request.tool_choice
if request.response_format is not None:
payload["response_format"] = request.response_format
if request.seed is not None:
payload["seed"] = request.seed
if request.user is not None:
payload["user"] = request.user
return payload
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
"""
Parse OpenRouter API response
Args:
response_data: API response data
Returns:
ChatCompletionResponse: Parsed response object
"""
choices = []
for choice_data in response_data.get("choices", []):
message_data = choice_data.get("message", {})
tool_calls = message_data.get("tool_calls")
message = ChatMessage(
role=MessageRole(message_data.get("role", "assistant")),
content=message_data.get("content", ""),
tool_calls=tool_calls,
)
choice = Choice(
index=choice_data.get("index", 0),
message=message,
finish_reason=choice_data.get("finish_reason"),
logprobs=choice_data.get("logprobs"),
)
choices.append(choice)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object=response_data.get("object", "chat.completion"),
created=response_data.get("created", int(time.time())),
model=response_data.get("model", ""),
choices=choices,
usage=usage,
system_fingerprint=response_data.get("system_fingerprint"),
)
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
"""
Parse stream chunk
Args:
chunk_data: Raw chunk data
Returns:
Optional[ChatCompletionChunk]: Parsed chunk
"""
return ChatCompletionChunk(
id=chunk_data.get("id", ""),
object=chunk_data.get("object", "chat.completion.chunk"),
created=chunk_data.get("created", int(time.time())),
model=chunk_data.get("model", ""),
choices=chunk_data.get("choices", []),
system_fingerprint=chunk_data.get("system_fingerprint"),
)
def list_models(self) -> Dict[str, Any]:
"""
Get available models
Returns:
Dict[str, Any]: Model list response
"""
response = self.client.get("/models")
response.raise_for_status()
return response.json()
async def alist_models(self) -> Dict[str, Any]:
"""
Async get available models
Returns:
Dict[str, Any]: Model list response
"""
response = await self.async_client.get("/models")
response.raise_for_status()
return response.json()
def get_generation_info(self, generation_id: str) -> Dict[str, Any]:
"""
Get specific generation request details
Args:
generation_id: Generation request ID
Returns:
Dict[str, Any]: Generation information
"""
response = self.client.get(f"/generation?id={generation_id}")
response.raise_for_status()
return response.json()
async def aget_generation_info(self, generation_id: str) -> Dict[str, Any]:
"""
Async get specific generation request details
Args:
generation_id: Generation request ID
Returns:
Dict[str, Any]: Generation information
"""
response = await self.async_client.get(f"/generation?id={generation_id}")
response.raise_for_status()
return response.json()
def get_account_credits(self) -> Dict[str, Any]:
"""
Get account credits
Returns:
Dict[str, Any]: Account credits information
"""
response = self.client.get("/auth/key")
response.raise_for_status()
return response.json()
async def aget_account_credits(self) -> Dict[str, Any]:
"""
Async get account credits
Returns:
Dict[str, Any]: Account credits information
"""
response = await self.async_client.get("/auth/key")
response.raise_for_status()
return response.json()
@staticmethod
def get_popular_models() -> List[str]:
"""
Get popular models
Returns:
List[str]: Popular model names
"""
return [
# OpenAI
"openai/gpt-4-turbo-preview",
"openai/gpt-4",
"openai/gpt-3.5-turbo",
# Anthropic
"anthropic/claude-3-opus",
"anthropic/claude-3-sonnet",
"anthropic/claude-3-haiku",
# Google
"google/gemini-pro",
"google/gemini-pro-vision",
# Meta
"meta-llama/llama-2-70b-chat",
"meta-llama/llama-2-13b-chat",
# Mistral
"mistralai/mixtral-8x7b-instruct",
"mistralai/mistral-7b-instruct",
# Open Source
"microsoft/wizardlm-2-8x22b",
"databricks/dbrx-instruct",
"cohere/command-r-plus",
]
def get_model_info(self, model_name: str) -> Dict[str, Any]:
"""
Get specific model details
Args:
model_name: Model name
Returns:
Dict[str, Any]: Model information
"""
models_response = self.list_models()
models = models_response.get("data", [])
for model in models:
if model.get("id") == model_name:
return model
raise ValueError(f"Model {model_name} not found")
async def aget_model_info(self, model_name: str) -> Dict[str, Any]:
"""
Async get specific model details
Args:
model_name: Model name
Returns:
Dict[str, Any]: Model information
"""
models_response = await self.alist_models()
models = models_response.get("data", [])
for model in models:
if model.get("id") == model_name:
return model
raise ValueError(f"Model {model_name} not found")

View File

@@ -0,0 +1,7 @@
"""
Private model client
"""
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
__all__ = ["PrivateModelClient"]

View File

@@ -0,0 +1,321 @@
"""
Private model client
Implement of private model API based on the BaseLLMAPI
vLLM, Text Generation Inference, Ollama supported
"""
import json
import time
from typing import Dict, List, Any, Optional, Union, AsyncGenerator, Generator
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
)
class PrivateModelClient(BaseLLMAPI):
def __init__(
self,
api_key: Optional[str] = None,
base_url: str = "http://localhost:8000/v1",
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
deployment_type: str = "vllm",
custom_headers: Optional[Dict[str, str]] = None,
supports_tools: bool = True,
logger: Optional[Logger] = None,
**kwargs,
):
self.deployment_type = deployment_type.lower()
self.custom_headers = custom_headers or {}
self.supports_tools = supports_tools
super().__init__(
api_key=api_key or "EMPTY",
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs,
)
def _initialize_client(self) -> None:
headers = {
"Content-Type": "application/json",
"User-Agent": "tokfinity-llm-client/1.0",
}
if self.api_key and self.api_key != "EMPTY":
headers["Authorization"] = f"Bearer {self.api_key}"
headers.update(self.custom_headers)
if self.deployment_type == "ollama":
pass
elif self.deployment_type == "tgi":
pass
self._create_http_clients(headers)
def _get_chat_endpoint(self) -> str:
if self.deployment_type == "ollama":
return "/api/chat"
elif self.deployment_type == "tgi":
return "/generate"
else:
return "/chat/completions"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
if self.deployment_type == "ollama":
return self._build_ollama_payload(request)
elif self.deployment_type == "tgi":
return self._build_tgi_payload(request)
else:
return self._build_openai_payload(request)
def _build_openai_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"temperature": request.temperature,
"top_p": request.top_p,
"stream": request.stream,
}
# Optional
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.stop is not None:
payload["stop"] = request.stop
if request.frequency_penalty != 0.0:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty != 0.0:
payload["presence_penalty"] = request.presence_penalty
if self.supports_tools and request.tools is not None:
payload["tools"] = request.tools
if self.supports_tools and request.tool_choice is not None:
payload["tool_choice"] = request.tool_choice
if request.response_format is not None:
payload["response_format"] = request.response_format
return payload
def _build_ollama_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"stream": request.stream,
"options": {
"temperature": request.temperature,
"top_p": request.top_p,
},
}
if request.max_tokens is not None:
payload["options"]["num_predict"] = request.max_tokens
if request.stop is not None:
payload["options"]["stop"] = (
request.stop if isinstance(request.stop, list) else [request.stop]
)
return payload
def _build_tgi_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
prompt = self._messages_to_prompt(request.messages)
payload = {
"inputs": prompt,
"parameters": {
"temperature": request.temperature,
"top_p": request.top_p,
"do_sample": True,
"stream": request.stream,
},
}
if request.max_tokens is not None:
payload["parameters"]["max_new_tokens"] = request.max_tokens
if request.stop is not None:
payload["parameters"]["stop_sequences"] = (
request.stop if isinstance(request.stop, list) else [request.stop]
)
return payload
def _messages_to_prompt(self, messages: List[ChatMessage]) -> str:
prompt_parts = []
for message in messages:
if message.role == MessageRole.SYSTEM:
prompt_parts.append(f"System: {message.content}")
elif message.role == MessageRole.USER:
prompt_parts.append(f"User: {message.content}")
elif message.role == MessageRole.ASSISTANT:
prompt_parts.append(f"Assistant: {message.content}")
prompt_parts.append("Assistant:")
return "\n\n".join(prompt_parts)
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
if self.deployment_type == "ollama":
return self._parse_ollama_response(response_data)
elif self.deployment_type == "tgi":
return self._parse_tgi_response(response_data)
else:
return self._parse_openai_response(response_data)
def _parse_openai_response(
self, response_data: Dict[str, Any]
) -> ChatCompletionResponse:
choices = []
for choice_data in response_data.get("choices", []):
message_data = choice_data.get("message", {})
message = ChatMessage(
role=MessageRole(message_data.get("role", "assistant")),
content=message_data.get("content", ""),
tool_calls=message_data.get("tool_calls"),
)
choice = Choice(
index=choice_data.get("index", 0),
message=message,
finish_reason=choice_data.get("finish_reason"),
)
choices.append(choice)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object=response_data.get("object", "chat.completion"),
created=response_data.get("created", int(time.time())),
model=response_data.get("model", ""),
choices=choices,
usage=usage,
)
def _parse_ollama_response(
self, response_data: Dict[str, Any]
) -> ChatCompletionResponse:
content = response_data.get("message", {}).get("content", "")
message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
choice = Choice(
index=0,
message=message,
finish_reason="stop" if response_data.get("done", False) else None,
)
return ChatCompletionResponse(
id=f"ollama-{int(time.time())}",
object="chat.completion",
created=int(time.time()),
model=response_data.get("model", ""),
choices=[choice],
)
def _parse_tgi_response(
self, response_data: Dict[str, Any]
) -> ChatCompletionResponse:
if isinstance(response_data, list) and len(response_data) > 0:
response_data = response_data[0]
content = response_data.get("generated_text", "")
message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
choice = Choice(
index=0,
message=message,
finish_reason=response_data.get("finish_reason", "stop"),
)
return ChatCompletionResponse(
id=f"tgi-{int(time.time())}",
object="chat.completion",
created=int(time.time()),
model="tgi-model",
choices=[choice],
)
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
if self.deployment_type == "ollama":
try:
chunk_data = json.loads(line)
return self._parse_ollama_chunk(chunk_data)
except json.JSONDecodeError:
return None
else:
# Standard SSE format
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
return None
try:
chunk_data = json.loads(data)
return self._parse_stream_chunk(chunk_data)
except json.JSONDecodeError:
return None
return None
def _parse_ollama_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
content = chunk_data.get("message", {}).get("content", "")
done = chunk_data.get("done", False)
choices = [
{
"index": 0,
"delta": {"content": content} if content else {},
"finish_reason": "stop" if done else None,
}
]
return ChatCompletionChunk(
id=f"ollama-{int(time.time())}",
object="chat.completion.chunk",
created=int(time.time()),
model=chunk_data.get("model", ""),
choices=choices,
)
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
return ChatCompletionChunk(
id=chunk_data.get("id", ""),
object=chunk_data.get("object", "chat.completion.chunk"),
created=chunk_data.get("created", int(time.time())),
model=chunk_data.get("model", ""),
choices=chunk_data.get("choices", []),
)

View File

@@ -0,0 +1,43 @@
"""
Log manage Module
Provide log management functionality categorized by timestamp and level
Usage example:
from src.managers.log import Logger, create_logger
# Basic Usage
logger = Logger("logs", "my_app")
logger.info("This is an info")
logger.error("This is an error")
logger.close()
# Use the context manager
with Logger("logs", "my_app") as logger:
logger.info("Automatically manage resources")
# Convenient function
logger = create_logger("logs", "my_app")
logger.info("Convenient create")
logger.close()
"""
from .logger import (
Logger,
create_logger,
init_global_logger,
get_global_logger,
set_global_logger,
)
__all__ = [
"Logger",
"create_logger",
"init_global_logger",
"get_global_logger",
"set_global_logger",
]
__version__ = "1.0.0"
__author__ = "Tokfinity Team"
__description__ = "Timestamp:Directory and Level-Based Log Manager "

357
src/managers/log/logger.py Normal file
View File

@@ -0,0 +1,357 @@
"""
Log Manager
Create dir according to timestamp, store level-based log files
"""
import os
import logging
from datetime import datetime
from typing import Optional
from pathlib import Path
"""
Module-level Self-defined NOTICE level log (between INFO and WARNING)
"""
NOTICE_LEVEL = 25
if not hasattr(logging, "NOTICE"):
logging.addLevelName(NOTICE_LEVEL, "NOTICE")
def notice(self, message, *args, **kwargs):
if self.isEnabledFor(NOTICE_LEVEL):
self._log(NOTICE_LEVEL, message, args, **kwargs)
logging.Logger.notice = notice # type: ignore[attr-defined]
class ImageBuilderLogger:
"""
ImageBuilder log manager
Function:
- Create sub dir based on timestamp
- Create debug.log, info.log and error.log
- Provide standard log format, able to print logs two levels up and the code location
- Console and file outputs
"""
def __init__(self, log_base_path: str, console_output: bool = True):
self.log_base_path = Path(log_base_path)
self.console_output = console_output
self.log_dir = self._create_log_dir()
self.file_handlers = {}
self.logger = self._setup_logger()
def _create_log_dir(self) -> Path:
timestamp = datetime.now().strftime("%Y%m%d%H%M")
log_dir = self.log_base_path / f"build_images/{timestamp}"
log_dir.mkdir(parents=True, exist_ok=True)
return log_dir
def _setup_logger(self) -> logging.Logger:
logger = logging.getLogger("image_builder_logger")
logger.setLevel(logging.DEBUG)
logger.handlers.clear()
log_format = (
"%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - "
"%(pathname)s:%(lineno)d - %(funcName)s - %(message)s"
)
formatter = logging.Formatter(log_format)
if self.console_output:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
self._add_file_handlers(logger, formatter)
return logger
def _add_file_handlers(self, logger: logging.Logger, formatter: logging.Formatter):
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
}
for level_name, level_value in log_levels.items():
log_file = self.log_dir / f"{level_name}.log"
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
if level_name == "debug":
file_handler.addFilter(lambda record: record.levelno == logging.DEBUG)
elif level_name == "info":
file_handler.addFilter(lambda record: record.levelno == logging.INFO)
elif level_name == "warning":
file_handler.addFilter(lambda record: record.levelno == logging.WARNING)
elif level_name == "error":
file_handler.addFilter(lambda record: record.levelno == logging.ERROR)
logger.addHandler(file_handler)
self.file_handlers[level_name] = file_handler
all_log_file = self.log_dir / "all.log"
all_file_handler = logging.FileHandler(all_log_file, encoding="utf-8")
all_file_handler.setFormatter(formatter)
all_file_handler.setLevel(logging.DEBUG)
logger.addHandler(all_file_handler)
self.file_handlers["all"] = all_file_handler
def debug(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.debug(message, *args, **kwargs)
def info(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.info(message, *args, **kwargs)
def warning(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.warning(message, *args, **kwargs)
def error(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.error(message, *args, **kwargs)
@property
def log_file(self) -> str:
return str(self.log_dir / "all.log")
def get_log_dir(self) -> str:
return str(self.log_dir.absolute())
def get_log_files(self) -> dict:
return {
"debug": str(self.log_dir / "debug.log"),
"info": str(self.log_dir / "info.log"),
"warning": str(self.log_dir / "warning.log"),
"error": str(self.log_dir / "error.log"),
"all": str(self.log_dir / "all.log"),
}
def close(self):
for handler in self.file_handlers.values():
handler.close()
for handler in self.logger.handlers[:]:
self.logger.removeHandler(handler)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __str__(self) -> str:
return f"ImageBuilderLogger(dir={self.log_dir})"
def __repr__(self) -> str:
return (
f"ImageBuilderLogger("
f"base_path='{self.log_base_path}', "
f"log_dir='{self.log_dir}', "
f"console_output={self.console_output})"
)
class Logger:
"""
Log manager
Function:
- Create sub dir based on timestamp
- Create debug.log, info.log, notice.log, warning.log and error.log
- Provide standard log format
- Console and file outputs
"""
def __init__(
self,
log_base_path: str,
logger_name: str = "tokfinity_logger",
console_output: bool = True,
log_format: Optional[str] = None,
instance_id: Optional[str] = None,
):
self.log_base_path = Path(log_base_path)
self.logger_name = logger_name
self.console_output = console_output
self.instance_id = instance_id
self.log_format = log_format or (
"%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - "
"%(pathname)s:%(lineno)d - %(funcName)s - %(message)s"
)
self.log_dir = self._create_log_dir()
self.file_handlers = {}
self.logger = self._setup_logger()
def _create_log_dir(self) -> Path:
if self.instance_id:
log_dir = self.log_base_path / self.instance_id
else:
timestamp = datetime.now().strftime("%Y%m%d%H%M")
log_dir = self.log_base_path / timestamp
log_dir.mkdir(parents=True, exist_ok=True)
return log_dir
def _setup_logger(self) -> logging.Logger:
logger = logging.getLogger(self.logger_name)
logger.setLevel(logging.DEBUG)
logger.handlers.clear()
formatter = logging.Formatter(self.log_format)
if self.console_output:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
self._add_file_handlers(logger, formatter)
return logger
def _add_file_handlers(self, logger: logging.Logger, formatter: logging.Formatter):
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"notice": NOTICE_LEVEL,
"warning": logging.WARNING,
"error": logging.ERROR,
}
for level_name, level_value in log_levels.items():
log_file = self.log_dir / f"{level_name}.log"
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
if level_name == "debug":
file_handler.addFilter(lambda record: record.levelno == logging.DEBUG)
elif level_name == "info":
file_handler.addFilter(lambda record: record.levelno == logging.INFO)
elif level_name == "notice":
file_handler.addFilter(lambda record: record.levelno == NOTICE_LEVEL)
elif level_name == "warning":
file_handler.addFilter(lambda record: record.levelno == logging.WARNING)
elif level_name == "error":
file_handler.addFilter(lambda record: record.levelno == logging.ERROR)
logger.addHandler(file_handler)
self.file_handlers[level_name] = file_handler
all_log_file = self.log_dir / "all.log"
all_file_handler = logging.FileHandler(all_log_file, encoding="utf-8")
all_file_handler.setFormatter(formatter)
all_file_handler.setLevel(logging.DEBUG)
logger.addHandler(all_file_handler)
self.file_handlers["all"] = all_file_handler
def debug(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.debug(message, *args, **kwargs)
def info(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.info(message, *args, **kwargs)
def notice(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.log(NOTICE_LEVEL, message, *args, **kwargs)
def warning(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.warning(message, *args, **kwargs)
def error(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.error(message, *args, **kwargs)
def get_log_dir(self) -> str:
return str(self.log_dir.absolute())
def get_log_files(self) -> dict:
return {
"debug": str(self.log_dir / "debug.log"),
"info": str(self.log_dir / "info.log"),
"notice": str(self.log_dir / "notice.log"),
"warning": str(self.log_dir / "warning.log"),
"error": str(self.log_dir / "error.log"),
"all": str(self.log_dir / "all.log"),
}
@property
def log_file(self) -> str:
return str(self.log_dir / "all.log")
def close(self):
for handler in self.file_handlers.values():
handler.close()
for handler in self.logger.handlers[:]:
self.logger.removeHandler(handler)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __str__(self) -> str:
return f"Logger(name={self.logger_name}, dir={self.log_dir})"
def __repr__(self) -> str:
return (
f"Logger("
f"name='{self.logger_name}', "
f"base_path='{self.log_base_path}', "
f"log_dir='{self.log_dir}', "
f"console_output={self.console_output})"
)
def create_logger(
log_base_path: str,
logger_name: str = "tokfinity_logger",
console_output: bool = True,
) -> Logger:
return Logger(
log_base_path=log_base_path,
logger_name=logger_name,
console_output=console_output,
)
# Global log manager instance (optional)
_global_logger: Optional[Logger] = None
def get_global_logger() -> Optional[Logger]:
return _global_logger
def set_global_logger(logger: Logger):
global _global_logger
_global_logger = logger
def init_global_logger(
log_base_path: str, logger_name: str = "global_logger"
) -> Logger:
global _global_logger
_global_logger = Logger(log_base_path, logger_name)
return _global_logger

275
src/managers/loop/base.py Normal file
View File

@@ -0,0 +1,275 @@
from typing import Any, Dict, List
import json
from traceback import format_exc
from src.managers.log.logger import Logger
from src.managers.llm_api.api_manager import LLMAPIManager
from src.managers.prompts.prompts_manager import PromptsManager
from src.tools.base import (
ToolExecutor,
BASH_TOOL_NAME,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
SEARCH_TOOL_NAME,
SUBMIT_RESULT_TOOL_NAME,
)
from src.managers.loop.types import ToolStats, LLMUsage
class BaseLoop:
def __init__(self, instance_id: str, instance_data: Dict[str, Any], logger: Logger, prompts_manager: PromptsManager | None, llm_manager: LLMAPIManager | None, tool_executor: ToolExecutor, config: Dict[str, Any] | None = None):
self.instance_id = instance_id
self.instance_data = instance_data
self.logger = logger
self.prompts_manager = prompts_manager
self.llm_manager = llm_manager
self.tool_executor = tool_executor
self.config = config or {}
self.component_name = self.__class__.__name__
def _make_assistant(
self, content: str | None, tool_calls: Any, messages: List[Dict[str, Any]]
) -> bool:
"""
Construct an assistant message based on the current content and tool calls, and append it to the messages.
"""
safe_content = content or ""
if not safe_content and not tool_calls:
self.logger.warning(
f"[{self.component_name}] Assistant returned an empty message with no tool calls; skipping this message and prompting to continue"
)
messages.append(
{"role": "user", "content": "请继续分析问题并使用工具来解决问题。"}
)
return False
assistant_message: Dict[str, Any] = {"role": "assistant"}
if tool_calls and not safe_content:
assistant_message["content"] = ""
elif safe_content:
assistant_message["content"] = safe_content
if tool_calls:
assistant_message["tool_calls"] = tool_calls
messages.append(assistant_message)
return True
def _make_tool_response(
self, tool_results: List[Any], messages: List[Dict[str, Any]]
) -> None:
"""Convert tool execution results into standard tool messages (role=tool) and append them to the messages.
- Generate content per result: use prompts_manager.tool_response_prompts([{...}]) to produce the content
- Set tool_call_id: prefer ToolResult.id; fallback to ToolResult.call_id
"""
if not tool_results:
return
for result in tool_results:
single_dict = [
{
"name": getattr(result, "name", "unknown"),
"success": getattr(result, "success", False),
"result": getattr(result, "result", None) or "",
"error": getattr(result, "error", None) or "",
}
]
content_text = (
self.prompts_manager.tool_response_prompts(single_dict)
if self.prompts_manager
else ""
)
tool_call_id = getattr(result, "id", None) or getattr(
result, "call_id", None
)
messages.append(
{
"role": "tool",
"content": content_text,
"tool_call_id": tool_call_id,
}
)
def _response_log(
self, response: Any, first_content: str, first_tool_calls: Any, total_turns: int
) -> None:
"""notice log for the current turn's LLM output"""
try:
response_log: Dict[str, Any] = {}
if hasattr(response, "usage") and response.usage:
response_log["usage"] = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", None),
"completion_tokens": getattr(response.usage, "completion_tokens", None),
"total_tokens": getattr(response.usage, "total_tokens", None),
}
if hasattr(response, "choices") and response.choices:
response_log["choice"] = {
"message": {
"content": first_content,
"tool_calls": first_tool_calls,
}
}
if response_log:
self.logger.notice(
f"[{self.component_name}] The {total_turns}th turn output: {json.dumps(response_log, ensure_ascii=False)}"
)
else:
self.logger.notice(
f"[{self.component_name}] The {total_turns}th turn output: {str(response)}"
)
except Exception:
self.logger.notice(
f"[{self.component_name}] 第 {total_turns} 轮: LLM 输出序列化失败,使用字符串表示: {str(response)}, traceback: {format_exc()}."
)
def _debug_messages(
self, turn: int, messages: List[Dict[str, Any]], prefix_len: int = 300
) -> None:
"""debug log for the messages to be sent to the model"""
try:
self.logger.debug(f"[{self.component_name}] msg:")
recent_messages = messages[-2:] if len(messages) > 2 else messages
base_index = len(messages) - len(recent_messages)
for offset, msg in enumerate[Dict[str, Any]](recent_messages):
idx = base_index + offset
role = msg.get("role")
content = msg.get("content")
content_str = content if isinstance(content, str) else ""
preview = content_str[:prefix_len]
content_len = len(content_str)
extra = ""
if role == "assistant":
tool_calls = msg.get("tool_calls")
has_tool = tool_calls is not None and tool_calls != []
try:
tool_calls_json = json.dumps(tool_calls, ensure_ascii=False)
except Exception:
self.logger.warning(
f"[{self.component_name}] In debug_messages function, fail: {format_exc()}, tool calls: {tool_calls}."
)
tool_calls_json = str(tool_calls)
extra = f", has_tool_calls={has_tool}, tool_calls={tool_calls_json}"
elif role == "tool":
tool_call_id = msg.get("tool_call_id")
extra = f", tool_call_id={tool_call_id}"
self.logger.debug(
f"[{self.component_name}] {turn+1}th, msg#{idx}: role={role}, content_len={content_len}, content_preview={json.dumps(preview, ensure_ascii=False)}{extra}"
)
except Exception:
self.logger.warning(
f"[{self.component_name}] In debug_messages function, fail msg: {format_exc()}."
)
def _debug_last_message(
self, turn: int, messages: List[Dict[str, Any]], prefix_len: int = 300
) -> None:
"""debug last turn msg"""
try:
if not messages:
return
last_assistant_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") == "assistant":
last_assistant_idx = i
break
if last_assistant_idx is None:
return
msg = messages[last_assistant_idx]
content = msg.get("content")
content_str = content if isinstance(content, str) else ""
preview = content_str[:prefix_len]
content_len = len(content_str)
tool_calls = msg.get("tool_calls")
has_tool = tool_calls is not None and tool_calls != []
try:
tool_calls_json = json.dumps(tool_calls, ensure_ascii=False)
except Exception:
self.logger.warning(
f"[{self.component_name}] In debug_last_message function, fail: {format_exc()}, tool calls: {tool_calls}."
)
tool_calls_json = str(tool_calls)
self.logger.debug(
f"[{self.component_name}] {turn+1}th turn, output_preview: role=assistant, content_len={content_len}, content_preview={json.dumps(preview, ensure_ascii=False)}, has_tool_calls={has_tool}, tool_calls={tool_calls_json}"
)
except Exception:
self.logger.warning(
f"[{self.component_name}] In debug_last_message function, last turn fail: {format_exc()}."
)
def _debug_tools(self, tools: List[Dict[str, Any]]) -> None:
"""debug tools msg"""
try:
self.logger.debug(f"[{self.component_name}] tools num: {len(tools)}")
for i, tool in enumerate(tools):
try:
tool_json = json.dumps(tool, ensure_ascii=False)
self.logger.debug(f"[{self.component_name}] tool #{i+1}: {tool_json}")
except Exception:
self.logger.debug(
f"[{self.component_name}] tool #{i+1} fail: {format_exc()}, string: {str(tool)}."
)
except Exception:
try:
self.logger.warning(
f"[{self.component_name}] fail; traceback: {format_exc()}."
)
self.logger.warning(f"[{self.component_name}] tools string: {str(tools)}")
except Exception:
pass
def _get_tools(self) -> List[Dict[str, Any]]:
pass
def _is_bash_tool(self, tool_name: str) -> bool:
return BASH_TOOL_NAME in tool_name
def _is_edit_tool(self, tool_name: str) -> bool:
return "edit" in tool_name or "str_replace" in tool_name or STR_REPLACE_BASED_EDIT_TOOL_NAME in tool_name
def _is_search_tool(self, tool_name: str) -> bool:
return SEARCH_TOOL_NAME in tool_name or "search" in tool_name
def _is_submit_result_tool(self, tool_name: str) -> bool:
return SUBMIT_RESULT_TOOL_NAME in tool_name
def _update_usage(self, response: Any, usage_stats: LLMUsage) -> None:
if hasattr(response, "usage") and response.usage:
usage_stats.prompt_tokens += int(getattr(response.usage, "prompt_tokens", 0) or 0)
usage_stats.completion_tokens += int(
getattr(response.usage, "completion_tokens", 0) or 0
)
usage_stats.total_tokens += int(getattr(response.usage, "total_tokens", 0) or 0)
def _init_usage_stats(self) -> LLMUsage:
return LLMUsage()
def _init_tools_stats(self) -> ToolStats:
return ToolStats()
def _update_tool_call_statistic(
self, tool_results: List[Any], tool_stats: ToolStats
) -> None:
for result in tool_results:
try:
tool_name = getattr(result, "name", "")
tool_name = tool_name.lower() if isinstance(tool_name, str) else ""
success = bool(getattr(result, "success", False))
if self._is_bash_tool(tool_name):
tool_stats.bash["count"] += 1
if not success:
tool_stats.bash["failed"] += 1
elif self._is_edit_tool(tool_name):
tool_stats.edit["count"] += 1
if not success:
tool_stats.edit["failed"] += 1
elif self._is_search_tool(tool_name):
tool_stats.search["count"] += 1
if not success:
tool_stats.search["failed"] += 1
elif self._is_submit_result_tool(tool_name):
tool_stats.submit_result["count"] += 1
if not success:
tool_stats.submit_result["failed"] += 1
except Exception:
continue

View File

@@ -0,0 +1,339 @@
from typing import Any, Dict, List
import json
from traceback import format_exc
from src.managers.log.logger import Logger
from src.managers.llm_api.api_manager import LLMAPIManager
from src.managers.prompts.prompts_manager import PromptsManager
from src.managers.loop.base import BaseLoop
from src.tools.base import (
ToolExecutor,
ToolResult,
SubmitToolResult,
BASH_TOOL_NAME,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
SEARCH_TOOL_NAME,
SUBMIT_RESULT_TOOL_NAME,
)
class PatchGenerator(BaseLoop):
def __init__(
self,
instance_id: str,
instance_data: Dict[str, Any],
logger: Logger,
prompts_manager: PromptsManager | None,
llm_manager: LLMAPIManager | None,
tool_executor: ToolExecutor,
config: Dict[str, Any] | None = None,
) -> None:
super().__init__(instance_id, instance_data, logger, prompts_manager, llm_manager, tool_executor, config)
async def _submit_all_tool_calls(
self, other_tool_calls: List[Dict[str, Any]]
) -> List[Any]:
"""execute tool calls, return tool execution results list"""
if not other_tool_calls:
return []
from src.tools.base import ToolCall
tool_call_objects = []
for tool_call_dict in other_tool_calls:
raw_args = tool_call_dict.get("function", {}).get("arguments", {})
parsed_args = raw_args
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except Exception as e:
self.logger.warning(f"[{self.component_name}] In _submit_all_tool_calls function, fail: {e}, traceback: {format_exc()}, args: {raw_args}.")
parsed_args = {}
tool_call_obj = ToolCall(
name=tool_call_dict.get("function", {}).get("name", ""),
call_id=tool_call_dict.get("id", ""),
arguments=parsed_args,
id=tool_call_dict.get("id", ""),
)
tool_call_objects.append(tool_call_obj)
return await self.tool_executor.container_sequential_tool_call(
tool_call_objects
)
def _process_submit_result_tool_result(
self,
submit_result: ToolResult,
golden_patch: List[Dict[str, Any]],
) -> None:
"""process submit_result tool call, fill golden_patch and log"""
if not submit_result.success or not submit_result.result:
self.logger.warning(f"[{self.component_name}] submit_result failed and no result.")
return
try:
submit_tool_result = SubmitToolResult.from_string(submit_result.result)
if submit_tool_result.output:
patch_info = {
"patch_content": submit_tool_result.output,
"test_status": submit_tool_result.test_status,
"reasoning": submit_tool_result.reasoning,
}
golden_patch.clear()
golden_patch.append(patch_info)
self.logger.info(
f"[{self.component_name}] patch len: {len(submit_tool_result.output)}."
)
self.logger.info(
f"[{self.component_name}] test status: {submit_tool_result.test_status}."
)
self.logger.info(
f"[{self.component_name}] reasoning: {submit_tool_result.reasoning[:100]}..."
)
else:
self.logger.warning(
f"[{self.component_name}] submit_result success but no patch content."
)
except Exception as e:
self.logger.error(f"[{self.component_name}] parse submit_result result fail: {e}, traceback: {format_exc()}.")
def _get_tools(self) -> List[Dict[str, Any]]:
tools = []
#use_openai_format = self._should_use_openai_format()
use_openai_format = True
for tool in self.tool_executor.tools.values():
if use_openai_format:
tool_def = tool._definition_for_openai_fmt()
else:
tool_def = tool._definition_for_claude_fmt()
tools.append(tool_def)
return tools
def _should_use_openai_format(self) -> bool:
if not self.llm_manager or not hasattr(self.llm_manager, "get_model_name"):
return True # openAI format by default
model_name = self.llm_manager.get_model_name().lower()
return "claude" not in model_name
def _get_issue_prompt(self) -> str:
"""generate issue prompt based on instance data"""
if not self.prompts_manager:
self.logger.warning("PromptsManager not initialized, cannot generate issue prompt.")
return ""
#instance_id = self.instance_data.get("instance_id", "")
#repo = self.instance_data.get("repo", "")
created_at = self.instance_data.get("created_at", "")
base_commit = self.instance_data.get("base_commit", "")
environment_setup_commit = self.instance_data.get(
"environment_setup_commit", ""
)
version = self.instance_data.get("version", "")
problem_statement = self.instance_data.get("problem_statement", "")
difficulty = self.instance_data.get("difficulty", "")
return self.prompts_manager.format_issue_prompt(
created_at=created_at,
base_commit=base_commit,
environment_setup_commit=environment_setup_commit,
version=version,
problem_statement=problem_statement,
difficulty=difficulty,
)
async def _generate_patch(self) -> Dict[str, Any] | None:
"""main loop logic for generating candidate patch"""
usage_stats = self._init_usage_stats()
tool_stats = self._init_tools_stats()
if not self.llm_manager or not self.prompts_manager:
self.logger.error(f"[{self.component_name}] LLM manager or prompts manager not initialized.")
return {
"success": False,
"golden_patch": [],
"llm_usage": usage_stats.to_dict(),
"tool_stats": tool_stats.to_dict(),
"total_turns": 0,
}
tools = self._get_tools()
self._debug_tools(tools)
root_path = self.config.get("builder", {}).get("repo_root_path", "")
max_turn = (
self.config.get("runner", {}).get("generator_loop", {}).get("max_turn", 10)
)
temperature = (
self.config.get("runner", {})
.get("generator_loop", {})
.get("temperature", 0.2)
)
issue_prompt = self._get_issue_prompt()
user_prompt = self.prompts_manager.get_generator_user(root_path, issue_prompt)
system_prompt = self.prompts_manager.get_generator_system(root_path)
total_turns = 0
golden_patch = []
try:
self.logger.info(
f"[{self.component_name}] {self.instance_id}: start generating candidate patch, max turn: {max_turn}"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
self.logger.notice(
f"[{self.component_name}]: {json.dumps(messages[0], ensure_ascii=False)}"
)
self.logger.notice(
f"[{self.component_name}]: {json.dumps(messages[1], ensure_ascii=False)}"
)
for turn in range(max_turn):
total_turns = turn + 1
self.logger.info(f"[{self.component_name}] The {total_turns}th turn started.")
try:
current_input_msg = messages[-1] if messages else None
if current_input_msg is not None:
self.logger.notice(
f"[{self.component_name}] The {total_turns}th turn input: {json.dumps(current_input_msg, ensure_ascii=False)}"
)
except Exception as e:
self.logger.warning(
f"[{self.component_name}] {total_turns}th turn: LLM input fail: {messages[-1] if messages else None}, error: {e}, traceback: {format_exc()}."
)
self._debug_messages(turn, messages)
response = self.llm_manager.chat(
messages=messages,
tools=tools,
tool_choice="auto",
temperature=temperature,
)
first_content: str = ""
first_tool_calls: Any = None
if hasattr(response, "choices") and response.choices:
ch0 = response.choices[0]
first_content = (
getattr(getattr(ch0, "message", None), "content", None) or ""
)
first_tool_calls = getattr(
getattr(ch0, "message", None), "tool_calls", None
)
self._response_log(
response, first_content, first_tool_calls, total_turns
)
self._update_usage(response, usage_stats)
if hasattr(response, "choices") and response.choices:
content = first_content
tool_calls = first_tool_calls
if not self._make_assistant(content, tool_calls, messages):
continue
if tool_calls:
self.logger.info(
f"[{self.component_name}] {total_turns}th turn: call {len(tool_calls)} tools."
)
tool_results = await self._submit_all_tool_calls(tool_calls)
self._update_tool_call_statistic(tool_results, tool_stats)
if tool_results:
submit_result = None
other_tool_results = []
for tool_result in tool_results:
tool_name = getattr(tool_result, "name", "")
if tool_name == SUBMIT_RESULT_TOOL_NAME:
submit_result = tool_result
else:
other_tool_results.append(tool_result)
if submit_result:
self.logger.debug(
f"[{self.component_name}] {total_turns}th turn: got submit_result tool call."
)
self.logger.debug(f"[{self.component_name}] {total_turns}th turn: submit_result result: {submit_result}")
self._process_submit_result_tool_result(
submit_result, golden_patch
)
self._debug_last_message(turn, messages)
break
if other_tool_results:
self._make_tool_response(other_tool_results, messages)
else:
messages.append(
{
"role": "user",
"content": "请继续分析问题并使用工具来解决问题。",
}
)
self.logger.debug(f"[{self.component_name}] final golden_patch: {golden_patch}")
success = (
len(golden_patch) > 0 and golden_patch[0].get("patch_content", "") != ""
)
self.logger.info(
f"[{self.component_name}] status={success}, total_turns={total_turns}, tools_stats={tool_stats}"
)
result_payload = {
"success": success,
"golden_patch": golden_patch,
"llm_usage": usage_stats.to_dict(),
"tool_stats": tool_stats.to_dict(),
"total_turns": total_turns,
}
try:
self.logger.notice(
f"[{self.component_name}] final output: {json.dumps(result_payload, ensure_ascii=False)}"
)
except Exception as e:
self.logger.warning(
f"[{self.component_name}] output: {str(result_payload)}, error: {e}, traceback: {format_exc()}."
)
return result_payload
except Exception as e:
self.logger.error(f"[{self.component_name}] fail: {e}, traceback: {format_exc()}.")
result_payload = {
"success": False,
"golden_patch": [],
"llm_usage": usage_stats.to_dict(),
"tool_stats": tool_stats.to_dict(),
"total_turns": total_turns,
}
try:
self.logger.notice(
f"[{self.component_name}] 最终返回数据(失败): {json.dumps(result_payload, ensure_ascii=False)}"
)
except Exception as e:
self.logger.notice(
f"[{self.component_name}] 最终返回数据(失败, 字符串回退): {str(result_payload)}, error: {e}, traceback: {format_exc()}."
)
return result_payload

View File

@@ -0,0 +1,338 @@
from typing import Any, Dict, List
import json
from traceback import format_exc
from src.managers.log.logger import Logger
from src.managers.llm_api.api_manager import LLMAPIManager
from src.managers.prompts.prompts_manager import PromptsManager
from src.managers.loop.types import GeneratorResult, SelectorResult, LLMUsage, ToolStats, PatchInfo
from src.tools.base import ToolExecutor, ToolCall, ToolResult
from src.managers.loop.base import BaseLoop
SELECTOR_SUBMIT_TOOL_NAME = "submit_result"
class PatchSelector(BaseLoop):
def __init__(
self,
instance_id: str,
instance_data: Dict[str, Any],
logger: Logger,
prompts_manager: PromptsManager | None,
llm_manager: LLMAPIManager | None,
tool_executor: ToolExecutor,
config: Dict[str, Any] | None = None,
) -> None:
super().__init__(instance_id, instance_data, logger, prompts_manager, llm_manager, tool_executor, config)
def _get_submit_result_tool_name(self):
return SELECTOR_SUBMIT_TOOL_NAME
def _definition_for_submit_tool(self, use_openai_format: bool) -> Dict[str, Any]:
"""submit_result tool"""
if use_openai_format:
return {
"type": "function",
"function": {
"name": self._get_submit_result_tool_name(),
"description": "Submit the final selected patch index and reasoning.",
"parameters": {
"type": "object",
"properties": {
"index": {
"type": "integer",
"description": "The chosen patch index (0-based).",
},
"reason": {
"type": "string",
"description": "Detailed reasoning for the selection.",
},
},
"required": ["index", "reason"],
},
},
}
return {
"type": "function",
"function": {
"name": self._get_submit_result_tool_name(),
"description": "Submit the final selected patch index and reasoning.",
"parameters": {
"type": "object",
"properties": {
"index": {
"type": "integer",
"description": "The chosen patch index (0-based).",
},
"reason": {
"type": "string",
"description": "Detailed reasoning for the selection.",
},
},
"required": ["index", "reason"],
},
},
}
def _build_user_prompt(self, candidates: List[GeneratorResult], root_path: str) -> str:
if not self.prompts_manager:
return ""
return self.prompts_manager.get_selector_user(self.instance_data, candidates, root_path)
def _get_system_prompt(self, patches_count: int, root_path: str) -> str:
if not self.prompts_manager:
return ""
return self.prompts_manager.get_selector_system(patches_count, root_path)
def _get_tools(self) -> List[Dict[str, Any]]:
tool_defs: List[Dict[str, Any]] = []
try:
for tool in self.tool_executor.tools.values():
try:
tool_defs.append(tool._definition_for_openai_fmt())
except Exception:
continue
except Exception:
pass
tool_defs.append(self._definition_for_submit_tool(True))
return tool_defs
def _extract_submit_choice(self, tool_call: Dict[str, Any]) -> Dict[str, Any] | None:
if not tool_call:
return None
fn = tool_call.get("function", {})
if fn.get("name") != self._get_submit_result_tool_name():
return None
raw_args = fn.get("arguments", {})
try:
args = json.loads(raw_args) if isinstance(raw_args, str) else raw_args
except Exception:
args = {}
index = args.get("index")
reason = args.get("reason")
if isinstance(index, int) and index >= 0:
return {"index": index, "reason": reason or ""}
return None
async def _submit_other_tool_calls(
self, tool_calls: List[Dict[str, Any]]
) -> List[ToolResult]:
if not tool_calls:
return []
to_run: List[ToolCall] = []
for tool_call_dict in tool_calls:
fn = tool_call_dict.get("function", {})
name = fn.get("name", "")
if name == SELECTOR_SUBMIT_TOOL_NAME:
continue
raw_args = fn.get("arguments", {})
parsed_args = raw_args
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except Exception:
parsed_args = {}
to_run.append(
ToolCall(
name=name,
call_id=tool_call_dict.get("id", ""),
arguments=parsed_args,
id=tool_call_dict.get("id", ""),
)
)
if not to_run:
return []
results: List[ToolResult] = await self.tool_executor.container_sequential_tool_call(to_run)
return results
async def _select_patch(self, candidates: List[GeneratorResult]) -> SelectorResult:
if not candidates:
raise ValueError("No candidates provided")
if not self.llm_manager:
raise ValueError("LLM manager is not initialized")
tools = self._get_tools()
self._debug_tools(tools)
root_path = self.config.get("builder", {}).get("repo_root_path", "")
system_prompt = self._get_system_prompt(len(candidates), root_path)
user_prompt = self._build_user_prompt(candidates, root_path)
messages: List[Dict[str, Any]] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
try:
self.logger.notice(
f"[{self.component_name}]: {json.dumps(messages[0], ensure_ascii=False)}"
)
self.logger.notice(
f"[{self.component_name}]: {json.dumps(messages[1], ensure_ascii=False)}"
)
except Exception:
self.logger.warning(
f"[{self.component_name}] Initial fail in selector loop: SP={str(messages[0])}, UP={str(messages[1])}, traceback: {format_exc()}."
)
max_turn = int(
self.config.get("runner", {})
.get("selector_loop", {})
.get("max_turn", 200)
)
temperature = (
self.config.get("runner", {})
.get("selector_loop", {})
.get("temperature", 0.2)
)
usage_stats = self._init_usage_stats()
tool_stats = self._init_tools_stats()
total_turns = 0
chosen_index: int | None = None
select_reason: str = ""
for turn in range(max_turn):
try:
try:
current_input_msg = messages[-1] if messages else None
if current_input_msg is not None:
self.logger.notice(
f"[{self.component_name}] The {turn+1}th turn input: {json.dumps(current_input_msg, ensure_ascii=False)}"
)
except Exception:
self.logger.warning(
f"[{self.component_name}] {turn+1}th turn fail: {messages[-1] if messages else None}, traceback: {format_exc()}."
)
self._debug_messages(turn, messages)
response = self.llm_manager.chat(
messages=messages,
tools=tools,
tool_choice="auto",
temperature=temperature,
)
first_tool_calls = None
if hasattr(response, "choices") and response.choices:
ch0 = response.choices[0]
first_tool_calls = getattr(getattr(ch0, "message", None), "tool_calls", None)
first_content = getattr(getattr(ch0, "message", None), "content", None) or ""
else:
first_content = ""
total_turns = turn + 1
self._response_log(response, first_content, first_tool_calls, turn + 1)
self._update_usage(response, usage_stats)
if first_tool_calls:
if not self._make_assistant(first_content, first_tool_calls, messages):
messages.append(
{
"role": "user",
"content": "请完成分析并调用 submit_result 工具给出最终选择与理由。",
}
)
continue
submit_found = False
for tc in first_tool_calls:
choice = self._extract_submit_choice(tc)
if choice is not None:
chosen_index = choice["index"]
reason = choice.get("reason", "")
self.logger.info(
f"[{self.component_name}] choose: index={chosen_index}, reason={reason}"
)
select_reason = reason or ""
submit_found = True
self._debug_last_message(turn, messages)
break
if not submit_found:
results = await self._submit_other_tool_calls(first_tool_calls)
self._make_tool_response(results, messages)
self._update_tool_call_statistic(results, tool_stats)
else:
messages.append(
{
"role": "user",
"content": "请完成分析并调用 submit_result 工具给出最终选择与理由。",
}
)
if chosen_index is not None:
break
except Exception as e:
self.logger.warning(
f"[{self.component_name}] fail: {e}, traceback: {format_exc()}"
)
break
if chosen_index is None:
# If the model provides no choice, fallback: pick the first successful one; otherwise the first
for i, r in enumerate(candidates):
try:
if r.success:
chosen_index = i
break
except Exception:
continue
if chosen_index is None:
chosen_index = 0
if not (0 <= chosen_index < len(candidates)):
chosen_index = 0
selected = candidates[chosen_index]
try:
gp = selected.golden_patch[0] if selected.golden_patch else None
if gp is None:
patch_info = PatchInfo(patch_content="", test_status="", reasoning="")
else:
patch_info = PatchInfo(
patch_content=gp.patch_content,
test_status=gp.test_status,
reasoning=gp.reasoning,
)
except Exception:
patch_info = PatchInfo(patch_content="", test_status="", reasoning="")
selector_result = SelectorResult(
instance_id=selected.instance_id,
generator_id=selected.generator_id,
image=selected.image,
success=True,
golden_patch=patch_info,
llm_usage=usage_stats,
tool_stats=tool_stats,
total_turns=total_turns,
select_reason=select_reason,
error=None,
)
return selector_result

254
src/managers/loop/types.py Normal file
View File

@@ -0,0 +1,254 @@
"""
This module defines the GeneratorResult data structure for patch generation results.
"""
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional
from src.tools.base import (
BASH_TOOL_NAME,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
SEARCH_TOOL_NAME,
SUBMIT_RESULT_TOOL_NAME,
)
@dataclass
class LLMUsage:
"""LLM usage statistics."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
def to_dict(self) -> Dict[str, int]:
"""Serialize LLMUsage to a plain dictionary."""
return {
"prompt_tokens": int(self.prompt_tokens),
"completion_tokens": int(self.completion_tokens),
"total_tokens": int(self.total_tokens),
}
@dataclass
class ToolStats:
"""Tool usage statistics per tool.
Each tool is represented by a small map with two fields:
- count: total invocation count
- failed: failed invocation count
"""
bash: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
edit: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
search: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
submit_result: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
def to_dict(self) -> Dict[str, Dict[str, int]]:
"""Serialize ToolStats to a plain dictionary."""
return {
BASH_TOOL_NAME: {"count": int(self.bash.get("count", 0)), "failed": int(self.bash.get("failed", 0))},
STR_REPLACE_BASED_EDIT_TOOL_NAME: {"count": int(self.edit.get("count", 0)), "failed": int(self.edit.get("failed", 0))},
SEARCH_TOOL_NAME: {"count": int(self.search.get("count", 0)), "failed": int(self.search.get("failed", 0))},
SUBMIT_RESULT_TOOL_NAME: {"count": int(self.submit_result.get("count", 0)), "failed": int(self.submit_result.get("failed", 0))},
}
@dataclass
class PatchInfo:
"""Information about a generated patch."""
patch_content: str
test_status: str
reasoning: str
@dataclass
class GeneratorResult:
"""Result from a patch generator."""
instance_id: str
generator_id: int
image: str
success: bool
golden_patch: List[
PatchInfo
]
llm_usage: LLMUsage
tool_stats: ToolStats
total_turns: int
error: Optional[str] = None
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GeneratorResult":
"""Create GeneratorResult from dictionary."""
# Handle golden_patch conversion
golden_patch = []
if data.get("golden_patch"):
for patch_data in data["golden_patch"]:
if isinstance(patch_data, dict):
golden_patch.append(
PatchInfo(
patch_content=patch_data.get("patch_content", ""),
test_status=patch_data.get("test_status", ""),
reasoning=patch_data.get("reasoning", ""),
)
)
else:
# Legacy format: just patch content string
golden_patch.append(
PatchInfo(
patch_content=str(patch_data), test_status="", reasoning=""
)
)
# Handle LLM usage
llm_usage_data = data.get("llm_usage", {})
llm_usage = LLMUsage(
prompt_tokens=llm_usage_data.get("prompt_tokens", 0),
completion_tokens=llm_usage_data.get("completion_tokens", 0),
total_tokens=llm_usage_data.get("total_tokens", 0),
)
# Handle tool stats
tool_stats_data = data.get("tool_stats", {})
tool_stats = ToolStats(
bash=tool_stats_data.get(BASH_TOOL_NAME, 0),
edit=tool_stats_data.get(STR_REPLACE_BASED_EDIT_TOOL_NAME, 0),
search=tool_stats_data.get(SEARCH_TOOL_NAME, 0),
submit_result=tool_stats_data.get(SUBMIT_RESULT_TOOL_NAME, 0),
)
return cls(
instance_id=data.get("instance_id", ""),
generator_id=data.get("generator_id", 0),
image=data.get("image", ""),
success=data.get("success", False),
golden_patch=golden_patch,
llm_usage=llm_usage,
tool_stats=tool_stats,
total_turns=data.get("total_turns", 0),
error=data.get("error"),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert GeneratorResult to dictionary."""
return {
"instance_id": self.instance_id,
"generator_id": self.generator_id,
"image": self.image,
"success": self.success,
"golden_patch": [
{
"patch_content": patch.patch_content,
"test_status": patch.test_status,
"reasoning": patch.reasoning,
}
for patch in self.golden_patch
],
"llm_usage": {
"prompt_tokens": self.llm_usage.prompt_tokens,
"completion_tokens": self.llm_usage.completion_tokens,
"total_tokens": self.llm_usage.total_tokens,
},
"tool_stats": {
BASH_TOOL_NAME: self.tool_stats.bash,
STR_REPLACE_BASED_EDIT_TOOL_NAME: self.tool_stats.edit,
SEARCH_TOOL_NAME: self.tool_stats.search,
SUBMIT_RESULT_TOOL_NAME: self.tool_stats.submit_result,
},
"total_turns": self.total_turns,
"error": self.error,
}
@dataclass
class SelectorResult:
"""Result from a patch selector.
"""
instance_id: str
generator_id: int
image: str
success: bool
golden_patch: PatchInfo
llm_usage: LLMUsage
tool_stats: ToolStats
total_turns: int
select_reason: str
error: Optional[str] = None
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SelectorResult":
"""Create SelectorResult from dictionary."""
gp_data = data.get("golden_patch", {})
if isinstance(gp_data, dict):
golden_patch = PatchInfo(
patch_content=gp_data.get("patch_content", ""),
test_status=gp_data.get("test_status", ""),
reasoning=gp_data.get("reasoning", ""),
)
else:
golden_patch = PatchInfo(
patch_content=str(gp_data) if gp_data is not None else "",
test_status="",
reasoning="",
)
# LLM usage
llm_usage_data = data.get("llm_usage", {})
llm_usage = LLMUsage(
prompt_tokens=llm_usage_data.get("prompt_tokens", 0),
completion_tokens=llm_usage_data.get("completion_tokens", 0),
total_tokens=llm_usage_data.get("total_tokens", 0),
)
# Tool stats
tool_stats_data = data.get("tool_stats", {})
tool_stats = ToolStats(
bash=tool_stats_data.get(BASH_TOOL_NAME, 0),
edit=tool_stats_data.get(STR_REPLACE_BASED_EDIT_TOOL_NAME, 0),
search=tool_stats_data.get(SEARCH_TOOL_NAME, 0),
submit_result=tool_stats_data.get(SUBMIT_RESULT_TOOL_NAME, 0),
)
return cls(
instance_id=data.get("instance_id", ""),
generator_id=data.get("generator_id", 0),
image=data.get("image", ""),
success=data.get("success", False),
golden_patch=golden_patch,
llm_usage=llm_usage,
tool_stats=tool_stats,
total_turns=data.get("total_turns", 0),
select_reason=data.get("select_reason", ""),
error=data.get("error"),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert SelectorResult to dictionary."""
return {
"instance_id": self.instance_id,
"generator_id": self.generator_id,
"image": self.image,
"success": self.success,
"golden_patch": {
"patch_content": self.golden_patch.patch_content,
"test_status": self.golden_patch.test_status,
"reasoning": self.golden_patch.reasoning,
},
"llm_usage": {
"prompt_tokens": self.llm_usage.prompt_tokens,
"completion_tokens": self.llm_usage.completion_tokens,
"total_tokens": self.llm_usage.total_tokens,
},
"tool_stats": {
BASH_TOOL_NAME: self.tool_stats.bash,
STR_REPLACE_BASED_EDIT_TOOL_NAME: self.tool_stats.edit,
SEARCH_TOOL_NAME: self.tool_stats.search,
SUBMIT_RESULT_TOOL_NAME: self.tool_stats.submit_result,
},
"total_turns": self.total_turns,
"select_reason": self.select_reason,
"error": self.error,
}

View File

@@ -0,0 +1,268 @@
from typing import Any, List, Dict
class PromptsManager:
def __init__(self, config):
self.candidate_length = config.get("runner", {}).get("generator_concurrency", 5)
def get_generator_system(self, root_path: str | None = None):
return f"""
# You are a highly skilled expert in software engineering focused on resolving complex GitHub issues by effectively analyzing codebases, implementing fixes, and ensuring code reliability through rigorous testing.
## Skills
1. Code Analysis and Debugging
- Issue Exploration: Ability to explore and comprehend codebases within repositories.
- Workflow Tracing: Skilled in using debugging techniques to trace issues through code.
- Root Cause Identification: Proficient in pinpointing the underlying causes of software issues.
- Test Creation: Expertise in establishing tests that replicate and validate issues.
2. Solution Implementation and Testing
- Fix Implementation: Experience in crafting precise and minimal code patches.
- Comprehensive Testing: Skilled in running and analyzing both existing and newly created tests.
- Regression Prevention: Ensures all changes maintain overall code stability.
- Continuous Improvement: Iterates on solutions based on test results to achieve optimal functionality.
## Task
Your task is resolve the given GitHub issue by understanding the repository and the issue, implementing a fix, and checking your changes against existing tests and your own test(s).
Write ABSOLUTE PATHS as arguments for tools that take a `file_path`. Combine the project root path `{root_path or "/testbed"}` with the file's path inside the project.
For example, pass `/root/testbed/_run.py` as `file_path` if you need to edit `_run.py` given the root path `/root/test_bed`.
Here's the project root path: `{root_path or "/testbed"}`. The target repository has already been cloned, and I activated the virtual environment for you. You can start analyzing the issue, searching and reading relevant files, and performing necessary fixes directly.
Follow these steps:
1. Problem Analysis:
- Read the issue description carefully to fully grasp the issue and explore the repository (source code, tests, examples) to understand expected behavior of relevant components.
- Identify the full scope. Does the issue mention multiple components, backends, or functions? Your solution must address all of them.
2. Reproduce the issue (IMPORTANT):
- Create a test that reproduces the issue as a baseline for verification.
- Check that the output of your test matches your understanding of the issue in step 1.
3. Identify the root cause:
- Go through relavant files, create debugging scripts with print statements or use other methods if necessary,to trace the workflow and exact cause of the issue.
- Trace the problem to its root cause.** Do not just patch the symptom where the error appears. Trace the data and execution flow upstream to find where the problem originates.
4. Implement a Fix:
- Once you have identified the root cause, develop a precise and targeted fix and then apply it as a minimal patch using the `str_replace_based_edit_tool` tools.
5. Test comprehensively:
- Verify the Fix: Run your initial reproduction script to confirm that the bug is resolved.
- Prevent Regressions:
--Identify the right tests: Once you have verified your fix, identify the most relevant tests within the project's existing test suite that correspond to your code changes.
--Run the tests: Then you **must** run these tests to ensure that your fix does not introduce any new bugs.
--Analyze failures carefully:
---If tests fail, do not immediately assume your fix is wrong. Critically analyze the failure.
---Is it a **regression**? Did your change break existing, valid functionality? If so, you must refine your fix.
---Is it an **unrelated failure**? It could be an environmental issue (e.g., missing dependency, network error) or a pre-existing flaky test. If you suspect this, try to run a more focused test and note the issue in your final reasoning.
---Is the **test now obsolete**? If your fix improves behavior in a way that makes an old test's assertions incorrect, you should **update the test** to match the new, correct behavior and explain why in your reasoning.
- Write New Tests: Create new, specific test cases (e.g., using `pytest`) that cover the original bug scenario.
- Consider Edge Cases: Think about and test potential edge cases related to your changes.
6. Revisit step 1 through 5 if unexpected behavior occurs, then call `submit_result` to submit the reliable and verified solution patch after successful testing and validation.
**Mandatory Workflow** As a senior engineer, ensure solution correctness and safety. Upon successful verification, immediately conclude the task by calling `submit_result`.
"""
def format_issue_prompt(
self,
created_at: str,
base_commit: str,
environment_setup_commit: str,
version: str,
problem_statement: str,
difficulty: str,
) -> str:
template = f"""
[📝 Issue Description]
**Created at**: {created_at}
**Base commit**: {base_commit}
---
### 📌 Problem Statement
{problem_statement}
---
### ⚙️ Difficulty Level
{difficulty}
---
"""
return template.strip()
def get_generator_user(self, root_path: str, issue_text: str):
return (
f"""
[Project root path]:
{root_path}
[Issue Information]:
{issue_text}
"""
+ self.get_generator_notice()
)
def get_generator_notice(self):
return """
[notice]
1. Use the available tools to locate the root cause.
2. Prioritize using the `search_tool` to retrieve and locate the precise location of key information in the project.
3. Collect supporting evidence: stack traces, logs, configs, recent changes, related modules.
"""
def get_selector_system(self, patches_count: int, root_path: str):
return f"""
# ROLE:
*You are a highly proficient software engineer tasked with evaluating and selecting optimal code patches to resolve specific issues within a given project.
*You colleagus worked on {patches_count} potential patches for an github issue. Select ONE correct patch to solve the issue.
*Here's the project root path: `{root_path or "/testbed"}`. The target repository has already been cloned, and the virtual environment has been activated for you. You can start analyzing the issue, searching and reading relevant files, and performing necessary fixes directly.
*Write ABSOLUTE PATHS as arguments for tools that take a `file_path`. Combine the project root path `{root_path or "/testbed"}` with the file's path inside the project. For instance, pass `/root/testbed/_run.py` as `file_path` if you need to edit `_run.py` given the root path `/root/test_bed`.
# WORKFLOWS:
*Follow these steps without any skipping:
1.Problem Analysis:
- Read the issue description and the current code that needs to be fixed. Explore the repository (source code, tests, examples) to understand expected behavior of relevant components, and gather comprehensive information about the problem area
2.Conduct a thorough review of each patch:
- Scrutinize all code modifications.
- Decipher the core logic and problem-solving methodology.
- Evaluate potential edge cases and unintended consequences.
- Validate that each patch fully addresses the initial issue specifications.
3.Verify Your Analysis
- Use available tools to verify your analysis works of this issue.
- Test your conclusions against relevant code sections.
- Ensure full contextual understanding.
4.Proceed with Your Decision
- Upon completion of the preceding three steps, utilize the `submit_result` tool with your detailed reasoning.
#RULES:
1.It is MANDATORY to utilize both available tools prior to finalizing any selectio:
-- Start with `bash` to explore the codebase structure;
-- Employ the str_replace_based_edit_tool to inspect the current code;
-- Use `search_tool` to search related code and file;
2.You MUST first explore the codebase before using the `submit_result` tool.
3.Substantiate your reasoning with evidence from your analysis.
4.Only selections made after employing the tools will be accepted.
#FINAL DECISION:
Upon completion of your tool-based analysis, finalize the process by submitting your choice via the `submit_result` tool.
#NOTICE:
1. Tool usage is MANDATORY - do not skip this step.
2. Without making a decision after completing analysis is not permitted.
3. Never generate new patches by your own, just make the selection.
4. Always provide detailed reasoning for the selection based on your tool-based investigation
"""
def get_selector_user(
self, instance_data: Dict[str, Any] | None = None, candidates: List[Any] | None = None, root_path: str | None = None
) -> str:
"""
Generate user prompt of selector, including issue information and the first golden patch of each candidate.
- instance_data: Current instance metadata (issue description etc.)
- candidates: Candidates list (.to_dict() supported), only get golden_patch[0].patch_content
"""
if not instance_data or not candidates:
return ""
created_at = instance_data.get("created_at", "")
base_commit = instance_data.get("base_commit", "")
environment_setup_commit = instance_data.get("environment_setup_commit", "")
version = instance_data.get("version", "")
problem_statement = instance_data.get("problem_statement", "")
difficulty = instance_data.get("difficulty", "")
issue_block = self.format_issue_prompt(
created_at=created_at,
base_commit=base_commit,
environment_setup_commit=environment_setup_commit,
version=version,
problem_statement=problem_statement,
difficulty=difficulty,
)
root_path_block = f"""
[Project root path]:
{root_path or "/testbed"}
"""
parts: List[str] = [root_path_block, issue_block, "\n[🔎 Candidates]\n"]
for idx, r in enumerate(candidates):
try:
data = r.to_dict() if hasattr(r, "to_dict") else {}
except Exception:
data = {}
golden_patch = data.get("golden_patch", [])
patch_content = golden_patch[0].get("patch_content", "") if golden_patch else ""
test_status = golden_patch[0].get("test_status", "") if golden_patch else ""
reasoning = golden_patch[0].get("reasoning", "") if golden_patch else ""
parts.append(self.format_selector_candidate(idx, patch_content, test_status, reasoning))
parts.append(
"\nPlease analyze the candidates, then call the submit_result tool with the final index and reasoning."
)
return "\n".join(parts)
def get_terminal_response(self, exit_code: int, output: str, timeout_status: bool):
if timeout_status == True:
return f"""[Terminal response]
Exit code: {exit_code}
Output: {output}"""
else:
return f"""[Terminal response]
Terminal time out."""
def tool_response_prompts(self, tool_results: list) -> str:
if not tool_results:
return ""
response_parts = ["[tool_response]"]
for i, result in enumerate(tool_results, 1):
tool_name = result.get("name", "unknown")
success = result.get("success", False)
output = result.get("result", "")
error = result.get("error", "")
response_parts.append(f"Tool {i}: {tool_name}")
response_parts.append(f"Success: {success}")
if success and output:
response_parts.append(f"Output:\n{output}")
elif error:
response_parts.append(f"Error: {error}")
else:
response_parts.append("No output")
response_parts.append("") # 空行分隔
return "\n".join(response_parts)
def format_selector_candidate(self, index: int, patch_content: str, test_status: str, reasoning: str) -> str:
"""
Generate description of selector candidate items, including key information of the first golden_patch
- index: Candidate index(0-based)
- patch_content: golden_patch[0].patch_content
- test_status: golden_patch[0].test_status test status in generating stage
- reasoning: golden_patch[0].reasoning model reasoning in generating stage
"""
header = f"- Candidate #{index}:"
patch_block = patch_content or ""
status_block = test_status or ""
reasoning_block = reasoning or ""
return (
f"--{header}\n"
f"--Patch content (the proposed fix):\n{patch_block}\n\n"
f"--Test status during generation: {status_block}\n\n"
f"--Reasoning during generation (model's logic):\n{reasoning_block}"
)

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict
import json
class ResultBuilder:
"""
Build preds.json
Iterate through JSON files named by instance_id in the directory specified by `runner.selector_result_dump_path` in the configuration.
- Parse the `golden_patch` field from each JSON file and extract the patch text as `model_patch`.
- Read the first top-level field from providers and the first model under it, and concatenate them to form `model_name_or_path`.
- The output location is `{workspace.path}/{result.preds.path}/preds.json`.
"""
def __init__(self, config: Dict[str, Any]):
self.config = config or {}
def _get_selector_dump_dir(self) -> Path:
runner_cfg = self.config.get("runner", {}) if isinstance(self.config, dict) else {}
dump_dir_str = runner_cfg.get(
"selector_result_dump_path", "workspace/selector_result_dump"
)
return Path(dump_dir_str)
def _get_preds_output_dir(self) -> Path:
workspace_cfg = self.config.get("workspace", {}) if isinstance(self.config, dict) else {}
result_cfg = self.config.get("result", {}) if isinstance(self.config, dict) else {}
preds_cfg = result_cfg.get("preds", {}) if isinstance(result_cfg, dict) else {}
workspace_path = workspace_cfg.get("path", "workspace")
preds_path = preds_cfg.get("path", "result")
return Path(workspace_path) / preds_path
def _get_model_name_or_path(self) -> str:
providers = self.config.get("providers", {}) if isinstance(self.config, dict) else {}
if not isinstance(providers, dict) or not providers:
return ""
first_provider_name = next(iter(providers.keys()))
first_models = providers.get(first_provider_name, [])
if isinstance(first_models, list) and first_models:
first_model = first_models[0]
else:
first_model = ""
return f"{first_provider_name}/{first_model}" if first_provider_name and first_model else ""
@staticmethod
def _extract_model_patch(golden_patch: Any) -> str:
"""
Extract patch content from golden_patch
Forms supported:
- dict: prioritize extract 'patch_content', then attempt `model_patch`
- string: Directly return
- other: return empty string
"""
if isinstance(golden_patch, dict):
if "patch_content" in golden_patch and isinstance(golden_patch["patch_content"], str):
return golden_patch["patch_content"]
if "model_patch" in golden_patch and isinstance(golden_patch["model_patch"], str):
return golden_patch["model_patch"]
return ""
if isinstance(golden_patch, str):
return golden_patch
return ""
def build_preds(self) -> Path:
dump_dir = self._get_selector_dump_dir()
output_dir = self._get_preds_output_dir()
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / "preds.json"
model_name_or_path = self._get_model_name_or_path()
# SWE-bench evaluation expects: list[dict], each element includes instance_id / model_patch / model
predictions: list[dict[str, str]] = []
if dump_dir.exists() and dump_dir.is_dir():
for path in sorted(dump_dir.glob("*.json")):
try:
instance_id = path.stem
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
golden_patch = data.get("golden_patch", {}) if isinstance(data, dict) else {}
model_patch = self._extract_model_patch(golden_patch)
predictions.append(
{
"instance_id": instance_id,
"model_patch": model_patch,
"model_name_or_path": model_name_or_path,
}
)
except Exception:
continue
with open(output_file, "w", encoding="utf-8") as f:
json.dump(predictions, f, ensure_ascii=False, indent=2)
return output_file