This commit is contained in:
LanWeiFRJ
2025-10-29 14:02:04 +08:00
parent c6b8c4bbd8
commit e334c4f764
52 changed files with 12251 additions and 0 deletions

1
src/__init__.py Normal file
View File

@@ -0,0 +1 @@
# src directory marker file

16
src/managers/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
# Export all modules
from .image_builder import *
from .log import *
from .llm_api import *
from .prompts import *
from .loop import *
from .decorators import *
__all__ = [
"image_builder",
"log",
"llm_api",
"prompts",
"loop",
"decorators"
]

View File

@@ -0,0 +1,13 @@
import threading
def singleton(cls):
instances = {}
lock = threading.Lock()
def get_instance(*args, **kwargs):
if cls not in instances:
with lock:
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance

View File

@@ -0,0 +1,8 @@
# 导出全部模块
from .build_image import *
from .dockerfiles import *
__all__ = [
"build_image",
"dockerfiles"
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,8 @@
# User image Dockerfile template
# image_key: instance_image_key corresponding to instance_id
# apt_packages: split by ' ', load multi packages needed to be installed in ubuntu
_DOCKERFILE_USER_IMAGE_PY = r"""
FROM {image_key}
RUN apt-get update && apt-get install -y --no-install-recommends {apt_packages}
"""

View File

@@ -0,0 +1,128 @@
"""
User defined logger for swebench
"""
import logging
from typing import Union, Optional
from swebench.harness.docker_utils import remove_image
from src.managers.log.logger import Logger as CustomLogger
# 导入原始的 swebench 函数
from swebench.harness.docker_build import build_instance_image, run_threadpool
def patched_build_instance_images(
client,
dataset: list,
force_rebuild: bool = False,
max_workers: int = 4,
namespace: str = None,
tag: str = None,
env_image_tag: str = None,
custom_logger: Optional[Union[logging.Logger, CustomLogger]] = None, # 新增参数
):
"""
Monkey patched version of build_instance_images that supports custom logger
Args:
client: Docker client
dataset: List of test specs or dataset to build images for
force_rebuild: Whether to force rebuild the images even if they already exist
max_workers: Maximum number of worker threads
namespace: Namespace for images
tag: Tag for images
env_image_tag: Environment image tag
custom_logger: Custom logger to use (instead of creating new ones)
"""
from swebench.harness.docker_build import make_test_spec, build_env_images
test_specs = []
for instance in dataset:
spec = make_test_spec(
instance,
namespace=namespace,
instance_image_tag=tag,
env_image_tag=env_image_tag,
)
test_specs.append(spec)
if force_rebuild:
for spec in test_specs:
remove_image(client, spec.instance_image_key, "quiet")
_, env_failed = build_env_images(client, test_specs, force_rebuild, max_workers)
if len(env_failed) > 0:
# Don't build images for instances that depend on failed-to-build env images
dont_run_specs = [
spec for spec in test_specs if spec.env_image_key in env_failed
]
test_specs = [
spec for spec in test_specs if spec.env_image_key not in env_failed
]
if custom_logger:
custom_logger.info(
f"Skipping {len(dont_run_specs)} instances - due to failed env image builds"
)
else:
print(f"Skipping {len(dont_run_specs)} instances - due to failed env image builds")
if custom_logger:
custom_logger.info(f"Building instance images for {len(test_specs)} instances")
else:
print(f"Building instance images for {len(test_specs)} instances")
successful, failed = [], []
if custom_logger:
payloads = [(spec, client, custom_logger, False) for spec in test_specs]
else:
payloads = [(spec, client, None, False) for spec in test_specs]
successful, failed = run_threadpool(build_instance_image, payloads, max_workers)
if len(failed) == 0:
if custom_logger:
custom_logger.info("All instance images built successfully.")
else:
print("All instance images built successfully.")
else:
if custom_logger:
custom_logger.warning(f"{len(failed)} instance images failed to build.")
else:
print(f"{len(failed)} instance images failed to build.")
return successful, failed
def apply_logger_patch():
"""应用 logger patch"""
import swebench.harness.docker_build as docker_build_module
original_build_instance_images = docker_build_module.build_instance_images
docker_build_module.build_instance_images = patched_build_instance_images
return original_build_instance_images
def restore_logger_patch(original_function):
"""Recover original logger actions"""
import swebench.harness.docker_build as docker_build_module
docker_build_module.build_instance_images = original_function
# 上下文管理器版本
class LoggerPatch:
"""Context manager"""
def __init__(self, logger: Optional[Union[logging.Logger, CustomLogger]] = None):
self.logger = logger
self.original_function = None
def __enter__(self):
self.original_function = apply_logger_patch()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.original_function:
restore_logger_patch(self.original_function)

View File

@@ -0,0 +1,132 @@
"""
Redirect print output of a third party repo to a specific log
"""
import sys
import logging
import re
from typing import Union, Set, Pattern
from typing import Optional, TextIO
from pathlib import Path
from src.managers.log.logger import Logger as CustomLogger
class StreamWrapper:
"""Simple stream wrapperfor redirecting stdout/stderr"""
def __init__(self, original_stream, write_func):
self.original_stream = original_stream
self.write_func = write_func
self.buffer = ""
def write(self, text):
self.buffer += text
while '\n' in self.buffer:
line, self.buffer = self.buffer.split('\n', 1)
if line.strip():
self.write_func(line + '\n')
return len(text)
def flush(self):
if self.buffer:
if self.buffer.strip():
self.write_func(self.buffer)
self.buffer = ""
if hasattr(self.original_stream, 'flush'):
self.original_stream.flush()
def __getattr__(self, name):
return getattr(self.original_stream, name)
class PrintRedirector:
"""Redirect print output of a third party repo to a specific log"""
TRACEBACK_START_PATTERN = re.compile(r'Traceback\s*\(most recent call last\):', re.IGNORECASE)
EXCEPTION_END_PATTERN = re.compile(
r'\b(Error|Exception|KeyError|ValueError|TypeError|AttributeError|'
r'ImportError|BuildError|DockerError|IndentationError|SyntaxError|'
r'RuntimeError|OSError|FileNotFoundError|PermissionError)\s*:',
re.IGNORECASE
)
ERROR_KEYWORDS = {'error', 'failed', 'exception', 'fatal', 'critical'}
WARNING_KEYWORDS = {'warning', 'warn', 'deprecated'}
SKIP_KEYWORDS = {'skipping', 'skip', 'ignoring', 'ignore'}
def __init__(self, logger: Union[logging.Logger, CustomLogger]):
self.logger = logger
self.original_print = None
self.original_stdout = None
self.original_stderr = None
self.traceback_buffer = []
self.in_traceback = False
def __enter__(self):
self.start_redirect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_redirect()
def start_redirect(self):
if isinstance(__builtins__, dict):
self.original_print = __builtins__['print']
else:
self.original_print = getattr(__builtins__, 'print')
if isinstance(__builtins__, dict):
__builtins__['print'] = self._redirected_print
else:
setattr(__builtins__, 'print', self._redirected_print)
def stop_redirect(self):
if self.original_print:
if isinstance(__builtins__, dict):
__builtins__['print'] = self.original_print
else:
setattr(__builtins__, 'print', self.original_print)
def _redirected_print(self, *args, **kwargs):
if not args:
return
output = ' '.join(str(arg) for arg in args)
if self.TRACEBACK_START_PATTERN.search(output):
self.in_traceback = True
self.traceback_buffer = [output]
return
if self.in_traceback:
self.traceback_buffer.append(output)
if self.EXCEPTION_END_PATTERN.search(output):
self._log_traceback_and_reset()
return
self._log_by_level(output)
def _log_traceback_and_reset(self):
full_traceback = '\n'.join(self.traceback_buffer)
self.logger.error(f"[Third party repo exception stack]\n{full_traceback}")
self.in_traceback = False
self.traceback_buffer.clear()
def _log_by_level(self, output: str):
output_lower = output.lower()
if any(keyword in output_lower for keyword in self.ERROR_KEYWORDS):
self.logger.error(f"[Third party] {output}")
elif any(keyword in output_lower for keyword in self.WARNING_KEYWORDS):
self.logger.warning(f"[Third party] {output}")
elif any(keyword in output_lower for keyword in self.SKIP_KEYWORDS):
self.logger.info(f"[Third party] {output}")
else:
self.logger.info(f"[Third party] {output}")
def redirect_swebench_prints(logger: Union[logging.Logger, CustomLogger]):
return PrintRedirector(logger)

View File

@@ -0,0 +1,98 @@
"""
LLM API Management Module
Provide a standard OpenAI format LLM API interface that supports:
- Unified Chat Completions endpoint
- Synchronous and asynchronous operations
- Streaming responses
- Tool Calling
- Error handling and retry mechanism(s)
Supported providers:
- OpenAI: OpenAI API and compatible services
- Anthropic: Claude models
- DeepSeek: DeepSeek models
- Private: Private modelsvLLM、TGI、Ollama etc.
Usage example::
# 1: use general manager (suggested)
from llm_api import LLMAPIManager
# create manager
manager = LLMAPIManager(
client_name="openai",
model_name="gpt-3.5-turbo",
stream=False
)
# sendf messages
response = manager.chat("Hello world!")
print(response)
# 2: Use client directly
from llm_api import OpenAIClient, ChatMessage, MessageRole
# creat client
client = OpenAIClient(api_key="your-api-key")
# create4 message
messages = [
ChatMessage(role=MessageRole.USER, content="你好,世界!")
]
# send request
request = client.create_request(messages=messages, model="gpt-3.5-turbo")
response = client.chat_completions_create(request)
print(response.choices[0].message.content)
"""
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
MessageRole,
Choice,
Usage,
)
# 客户端实现
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
# API 管理器
from src.managers.llm_api.api_manager import (
LLMAPIManager,
create_manager,
create_common_manager,
COMMON_CONFIGS,
)
__all__ = [
"BaseLLMAPI",
"ChatMessage",
"ChatCompletionRequest",
"ChatCompletionResponse",
"ChatCompletionChunk",
"MessageRole",
"Choice",
"Usage",
"OpenAIClient",
"AnthropicClient",
"DeepSeekClient",
"OpenRouterClient",
"PrivateModelClient",
"LLMAPIManager",
"create_manager",
"create_common_manager",
"COMMON_CONFIGS",
]
__version__ = "1.0.0"
__author__ = "Tokfinity Team"
__description__ = "标准 OpenAI 格式的 LLM API 基类库"

View File

@@ -0,0 +1,479 @@
"""
LLM API manager
"""
import os
import time
from typing import Dict, List, Any, Optional, Union, Generator
from dotenv import load_dotenv
import yaml
from traceback import format_exc
from .base_client import (
BaseLLMAPI,
ChatMessage,
MessageRole,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
EmbeddingRequest,
EmbeddingResponse,
)
# 导入所有客户端
from .clients.openai.openai_client import OpenAIClient
from .clients.anthropic.anthropic_client import AnthropicClient
from .clients.deepseek.deepseek_client import DeepSeekClient
from .clients.openrouter.openrouter_client import OpenRouterClient
from .clients.private.private_client import PrivateModelClient
class LLMAPIManager:
SUPPORTED_CLIENTS = {
"openai": OpenAIClient,
"anthropic": AnthropicClient,
"deepseek": DeepSeekClient,
"openrouter": OpenRouterClient,
"private": PrivateModelClient,
}
DEFAULT_CONFIGS = {
"openai": {"base_url": None, "api_key_env": "OPENAI_API_KEY"},
"anthropic": {"base_url": None, "api_key_env": "ANTHROPIC_API_KEY"},
"deepseek": {"base_url": None, "api_key_env": "DEEPSEEK_API_KEY"},
"openrouter": {
"base_url": None,
"api_key_env": "OPENROUTER_API_KEY",
"extra_config": {
"app_name": "tokfinity-llm-client",
"site_url": "https://github.com/your-repo",
},
},
"private": {
"base_url": "http://localhost:8000/v1",
"api_key_env": "PRIVATE_API_KEY",
"extra_config": {"deployment_type": "vllm"},
},
}
def __init__(
self,
client_name: Optional[str] = None,
stream: bool = False,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
auto_load_env: bool = True,
logger: Optional[Any] = None,
config: Optional[Dict[str, Any]] = None,
**kwargs,
):
# if client not provided, get first provider from default providers in config
if client_name is None:
default_client, default_model = (
self._load_default_client_and_model_from_config()
)
self.client_name = default_client
self.default_model = default_model
else:
self.client_name = client_name.lower()
self.default_model = None
self.stream = stream
self.timeout = timeout
self.max_retries = max_retries
self.logger = logger
self.logger.info(f"[LLMAPIManager]: Using client: {self.client_name}, model: {self.default_model}.")
self.config = config
if auto_load_env:
self._load_environment()
if self.client_name not in self.SUPPORTED_CLIENTS:
raise ValueError(
f"Unsupported client: {client_name}"
f"Support client: {list(self.SUPPORTED_CLIENTS.keys())}"
)
self.client = self._create_client(api_key, base_url, logger, **kwargs)
def _load_environment(self) -> None:
"""load environment variables"""
env_paths = [".env", "../.env", "../../.env", "../../../.env"]
for env_path in env_paths:
if os.path.exists(env_path):
load_dotenv(env_path)
break
def _create_client(
self, api_key: Optional[str] = None, base_url: Optional[str] = None, logger: Optional[Any] = None, **kwargs
) -> BaseLLMAPI:
client_class = self.SUPPORTED_CLIENTS[self.client_name]
config = self.DEFAULT_CONFIGS[self.client_name]
if api_key is None:
api_key = os.getenv(config["api_key_env"])
if not api_key:
if self.client_name == "private":
api_key = "EMPTY" # private mode may not need a key
else:
raise ValueError(
f"Fail to find env variable, please set: {config['api_key_env']} "
f"or upload ai_key parameter when initialize"
)
if base_url is None:
env_key = f"{self.client_name.upper()}_BASE_URL"
if self.client_name == "private":
env_key = "PRIVATE_URL"
base_url = os.getenv(env_key)
if base_url is None:
base_url = config.get("base_url")
client_kwargs = {
"api_key": api_key,
"timeout": self.timeout,
"max_retries": self.max_retries,
}
if base_url:
client_kwargs["base_url"] = base_url
if logger is not None:
client_kwargs["logger"] = logger
extra_config = config.get("extra_config", {})
client_kwargs.update(extra_config)
client_kwargs.update(kwargs)
if self.client_name == "openrouter":
client_kwargs.setdefault("app_name", "tokfinity-llm-client")
elif self.client_name == "private":
client_kwargs.setdefault("deployment_type", "vllm")
return client_class(**client_kwargs)
def _load_default_client_and_model_from_config(self) -> (str, Optional[str]):
"""
Get first item from providers from config/config.yaml as the default client
And take the first model as default model
"""
# Parse the root of config file relative to the project root
# This file is in src/managers/llm_api/api_manager.py
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
config_path = os.path.join(base_dir, "config", "config.yaml")
if not os.path.exists(config_path):
raise FileNotFoundError(f"No config file: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
providers = cfg.get("providers")
if not isinstance(providers, dict) or len(providers) == 0:
raise ValueError("config.yaml lack of providers config or format error")
first_provider_name = next(iter(providers.keys()))
models = providers.get(first_provider_name) or []
first_model = (
models[0] if isinstance(models, list) and len(models) > 0 else None
)
client_key = first_provider_name.strip().lower()
if client_key not in self.SUPPORTED_CLIENTS:
raise ValueError(
f"Default provider '{first_provider_name}' in config not registered in SUPPORTED_CLIENTS"
)
return client_key, first_model
def chat(
self,
messages: List[Union[Dict[str, Any], ChatMessage]],
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
timeout: Optional[int] = None,
retry: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
**kwargs,
) -> Union[
ChatCompletionResponse, Generator[ChatCompletionChunk, None, None], None
]:
"""
Send chat message and get response
Args:
model: model name (optional, default self.default_model)
messages: Concatenated message list.
- If list of dict: [{"role": "system|user|assistant|tool", "content": "..."}]
- Or `ChatMessage` list
temperature: Temperature
max_tokens: Max tokens
timeout: Request timeout, use the value from initialization if not specified
retry: Max retries, use the value from initialization if not specified
tools: Tools description list (OpenAI tools format)
tool_choice: Tool choice strategy"auto" | "none" | {"type":..., ...}
**kwargs: Other request args
Returns:
Union[ChatCompletionResponse, Generator[ChatCompletionChunk, None, None], None]:
- Non-streaming: Return the complete ChatCompletionResponse object
- Stream: Return a streaming response generator
- Fail: Return None
"""
actual_timeout = timeout if timeout is not None else self.timeout
actual_retry = retry if retry is not None else self.max_retries
actual_model = model if model is not None else self.default_model
if not actual_model:
raise ValueError("Model unprovided, and cannot get default model from config")
normalized_messages: List[ChatMessage] = []
for msg in messages:
if isinstance(msg, ChatMessage):
if msg.content is None:
msg.content = ""
normalized_messages.append(msg)
else:
role_value = msg.get("role")
content_value = msg.get("content")
if content_value is None:
content_value = ""
role_enum = MessageRole(role_value)
normalized_messages.append(
ChatMessage(
role=role_enum,
content=content_value,
name=msg.get("name"),
tool_calls=msg.get("tool_calls"),
tool_call_id=msg.get("tool_call_id"),
)
)
last_exception = None
for attempt in range(actual_retry + 1):
try:
request = self.client.create_request(
messages=normalized_messages,
model=actual_model,
temperature=temperature,
max_tokens=max_tokens,
stream=self.stream,
tools=tools,
tool_choice=tool_choice,
config=self.config,
**kwargs,
)
response = self.client.chat_completions_create(request)
if self.stream:
# Streaming response: has not processed
return response
else:
# Non-streaming response: return complete ChatCompletionResponse
return response
except Exception as e:
last_exception = e
if attempt < actual_retry:
delay = min(2**attempt, 30)
if self.logger:
self.logger.warning(
f"{attempt + 1}th try failedretry after {delay} seconds: {str(e)}, traceback: {format_exc()}."
)
time.sleep(delay)
else:
if self.logger:
self.logger.error(
f"All {actual_retry + 1} tries failed: {str(e)}"
)
return None
def chat_simple(
self,
model: str,
messages: List[Union[Dict[str, Any], ChatMessage]],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
timeout: Optional[int] = None,
retry: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
**kwargs,
) -> Optional[str]:
# Ban stream
original_stream = self.stream
self.stream = False
try:
response = self.chat(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
retry=retry,
tools=tools,
tool_choice=tool_choice,
**kwargs,
)
if response and hasattr(response, "choices") and response.choices:
return response.choices[0].message.content
return None
finally:
# Recover original stream settings
self.stream = original_stream
def create_embeddings(
self,
input_text: Union[str, List[str]],
model: str,
encoding_format: str = "float",
dimensions: Optional[int] = None,
user: Optional[str] = None,
timeout: Optional[int] = None,
retry: Optional[int] = None,
**kwargs,
) -> Optional[EmbeddingResponse]:
actual_timeout = timeout if timeout is not None else self.timeout
actual_retry = retry if retry is not None else self.max_retries
if isinstance(input_text, str):
text_list = [input_text]
single_input = True
else:
text_list = input_text
single_input = False
if self.logger:
self.logger.debug(
f"Start embedding request - client: {self.client_name}, model: {model}, text num: {len(text_list)}"
)
if not hasattr(self.client, "create_embeddings"):
error_msg = f"Client {self.client_name} not support embedding"
if self.logger:
self.logger.error(error_msg)
return None
last_exception = None
for attempt in range(actual_retry + 1):
try:
if self.logger:
self.logger.debug(f"{attempt + 1}th try to create embedding vector")
response = self.client.create_embeddings(
input_text=text_list,
model=model,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
timeout=actual_timeout,
max_retries=1,
**kwargs,
)
if response:
return response
else:
raise Exception("Client return empty response")
except Exception as e:
last_exception = e
if attempt < actual_retry:
delay = min(2**attempt, 30)
if self.logger:
self.logger.warning(
f"{attempt + 1}th embedding request failedretry after {delay}s: {str(e)}, traceback: {format_exc()}."
)
time.sleep(delay)
else:
if self.logger:
self.logger.error(
f"All {actual_retry + 1} embedding tries failed: {str(e)}, traceback: {format_exc()}."
)
return None
def get_client_info(self) -> Dict[str, Any]:
return {
"client_name": self.client_name,
"stream": self.stream,
"client_info": (
self.client.get_model_info()
if hasattr(self.client, "get_model_info")
else {}
),
}
def get_model_name(self) -> str:
return self.default_model or "unknown"
def close(self) -> None:
if hasattr(self.client, "close"):
self.client.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __str__(self) -> str:
return f"LLMAPIManager(client={self.client_name}, stream={self.stream})"
def __repr__(self) -> str:
return (
f"LLMAPIManager("
f"client_name='{self.client_name}', "
f"stream={self.stream})"
)
# 便捷函数
def create_manager(
client_name: str, stream: bool = False, logger: Optional[Any] = None, **kwargs
) -> LLMAPIManager:
return LLMAPIManager(
client_name=client_name, stream=stream, logger=logger, **kwargs
)
COMMON_CONFIGS = {
"openai_gpt4": {"client_name": "openai", "model_name": "gpt-4o"},
"openai_gpt35": {"client_name": "openai", "model_name": "gpt-3.5-turbo"},
"claude_sonnet": {
"client_name": "anthropic",
"model_name": "claude-3-5-sonnet-20241022",
},
"claude_haiku": {
"client_name": "anthropic",
"model_name": "claude-3-haiku-20240307",
},
"deepseek_chat": {"client_name": "deepseek", "model_name": "deepseek-chat"},
"deepseek_coder": {"client_name": "deepseek", "model_name": "deepseek-coder"},
}
def create_common_manager(
config_name: str, stream: bool = False, logger: Optional[Any] = None, **kwargs
) -> LLMAPIManager:
if config_name not in COMMON_CONFIGS:
raise ValueError(
f"Unknown config: {config_name}. Available configs: {list(COMMON_CONFIGS.keys())}"
)
config = COMMON_CONFIGS[config_name]
return LLMAPIManager(
client_name=config["client_name"], stream=stream, logger=logger, **kwargs
)

View File

@@ -0,0 +1,594 @@
"""
LLM API base class - standard OpenAI format
Multi LLM providers' universal API supported
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Union, AsyncGenerator, Generator
from dataclasses import dataclass, field
from enum import Enum
import json
import time
import logging
import requests
import asyncio
from traceback import format_exc
from src.managers.log.logger import Logger
class MessageRole(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@dataclass
class ChatMessage:
role: MessageRole
content: Union[str, List[Dict[str, Any]]]
name: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None
@dataclass
class ChatCompletionRequest:
messages: List[ChatMessage]
model: str
temperature: float = 0.7
max_tokens: Optional[int] = None
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
response_format: Optional[Dict[str, Any]] = None
seed: Optional[int] = None
user: Optional[str] = None
@dataclass
class Usage:
prompt_tokens: int
completion_tokens: int
total_tokens: int
@dataclass
class Choice:
index: int
message: ChatMessage
finish_reason: Optional[str] = None
logprobs: Optional[Dict[str, Any]] = None
@dataclass
class ChatCompletionResponse:
id: str
object: str
created: int
model: str
choices: List[Choice]
usage: Optional[Usage] = None
system_fingerprint: Optional[str] = None
@dataclass
class ChatCompletionChunk:
id: str
object: str
created: int
model: str
choices: List[Dict[str, Any]]
system_fingerprint: Optional[str] = None
@dataclass
class EmbeddingRequest:
input: Union[str, List[str]]
model: str
encoding_format: str = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
@dataclass
class EmbeddingData:
object: str
embedding: List[float]
index: int
@dataclass
class EmbeddingUsage:
prompt_tokens: int
total_tokens: int
@dataclass
class EmbeddingResponse:
object: str
data: List[EmbeddingData]
model: str
usage: EmbeddingUsage
class BaseLLMAPI(ABC):
"""
LLM API base class
Provide standard OpenAI format API, support:
- Synchronous/Asynchronous Chat Completions
- Streaming Responses
- Tool Calling
- Error handling and Retry
- Usage Statics
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
logger: Optional[Logger] = None,
**kwargs,
):
self.api_key = api_key
self.base_url = base_url
self.timeout = timeout
self.max_retries = max_retries
self.retry_delay = retry_delay
self.extra_config = kwargs
self.logger = logger
self.session: Optional[requests.Session] = None
self._initialize_client()
def _create_http_clients(self, headers: Dict[str, str]) -> None:
self.session = requests.Session()
self.session.headers.update(headers)
self.session.timeout = self.timeout
@abstractmethod
def _initialize_client(self) -> None:
pass
@abstractmethod
def _get_chat_endpoint(self) -> str:
pass
@abstractmethod
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
pass
@abstractmethod
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
pass
@abstractmethod
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
pass
def chat_completions_create(
self, request: ChatCompletionRequest
) -> Union[ChatCompletionResponse, Generator[ChatCompletionChunk, None, None]]:
self.validate_request(request)
def _make_request():
payload = self._build_request_payload(request)
endpoint = self._get_chat_endpoint()
if request.stream:
return self._stream_chat_completion(payload, endpoint)
else:
full_url = self.base_url.rstrip("/") + endpoint
headers = {}
if self.session and self.session.headers:
headers.update(self.session.headers)
else:
headers = {
"Content-Type": "application/json",
}
if self.api_key and self.api_key != "EMPTY":
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(
full_url, json=payload, headers=headers, timeout=self.timeout
)
if response.status_code != 200:
error_msg = f"API Error {response.status_code}: {response.text}"
print(error_msg)
raise requests.exceptions.HTTPError(error_msg, response=response)
response.raise_for_status()
return self._parse_response(response.json())
return self._retry_with_backoff(_make_request)
async def achat_completions_create(
self, request: ChatCompletionRequest
) -> Union[ChatCompletionResponse, AsyncGenerator[ChatCompletionChunk, None]]:
self.validate_request(request)
def _make_async_request():
payload = self._build_request_payload(request)
endpoint = self._get_chat_endpoint()
if request.stream:
return self._stream_chat_completion(
payload, endpoint
)
else:
full_url = self.base_url.rstrip("/") + endpoint
headers = {}
if self.session and self.session.headers:
headers.update(self.session.headers)
else:
headers = {
"Content-Type": "application/json",
}
if self.api_key and self.api_key != "EMPTY":
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(
full_url, json=payload, headers=headers, timeout=self.timeout
)
response.raise_for_status()
return self._parse_response(response.json())
import asyncio
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _make_async_request)
def create_message(self, role: MessageRole, content: str, config: Optional[Dict[str, Any]] = None, **kwargs) -> ChatMessage:
return ChatMessage(role=role, content=content, **kwargs)
def _make_token_cache_request(self, messages: List[ChatMessage], model: str, config: Optional[Dict[str, Any]] = None) -> list[ChatMessage]:
"""
Insert cache_control block to Claude/Anthropic model according to config, forward compatibility
- Only activate when model is "Claude/Anthropic"
- Preserve the content as is for other providers to avoid breaking compatibility.
- If token_cache.role_turns is configured, inject it matching the turn and role; otherwise, only inject it into the first system message.
"""
role_turns = self._load_role_turns(config)
#if self.logger:
# self.logger.debug(f"In _make_token_cache_request, got role_turns: {role_turns}.")
if not self._is_claude_like(model):
return messages
turn_to_roles: Dict[int, List[str]] = {}
try:
for rt in role_turns or []:
try:
turn = int(rt.get("turn"))
role = (rt.get("role") or "").strip().lower()
if role:
if turn not in turn_to_roles:
turn_to_roles[turn] = []
turn_to_roles[turn].append(role)
except Exception:
continue
except Exception:
turn_to_roles = {}
current_turn = 0
result_messages: List[ChatMessage] = []
for msg in messages:
msg_blocks = self._to_claude_blocks(msg.content)
# 0th turn rule: If met user, only add cache to its own content and put it into result.
if current_turn == 0 and msg.role == MessageRole.USER:
cached_blocks = self._add_cache_flag(msg_blocks)
result_messages.append(
ChatMessage(
role=msg.role,
content=cached_blocks,
name=msg.name,
tool_calls=msg.tool_calls,
tool_call_id=msg.tool_call_id,
)
)
#if self.logger:
# self.logger.debug(
# f"Applied cache to initial user message at turn {current_turn}."
# )
continue
# Other messages: just add it to result
result_messages.append(
ChatMessage(
role=msg.role,
content=msg_blocks,
name=msg.name,
tool_calls=msg.tool_calls,
tool_call_id=msg.tool_call_id,
)
)
# Hit anchor point: When the current turn's configuration includes the assistant and the current message is from the assistant
# add an extra system turn.
roles_for_turn = turn_to_roles.get(current_turn, [])
if msg.role == MessageRole.ASSISTANT and roles_for_turn and ("assistant" in roles_for_turn):
refresh_text = f"refresh cache tokens."
refresh_blocks = self._to_claude_blocks(refresh_text)
refresh_blocks = self._add_cache_flag(refresh_blocks)
result_messages.append(
ChatMessage(
role=MessageRole.SYSTEM,
content=refresh_blocks,
)
)
#if self.logger:
# self.logger.debug(
# f"Appended system refresh after assistant at turn {current_turn}, refresh_blocks: {refresh_blocks}."
# )
# assistant message make the round +1
if msg.role == MessageRole.ASSISTANT:
current_turn += 1
return result_messages
def _to_claude_blocks(self, content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
if isinstance(content, list):
return content
return [{"type": "text", "text": content}]
def _add_cache_flag(self, blocks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
out: List[Dict[str, Any]] = []
added = False
for blk in blocks:
if not added and isinstance(blk, dict) and blk.get("type") == "text":
out.append({**blk, "cache_control": {"type": "ephemeral"}})
added = True
else:
out.append(blk)
return out
def _is_claude_like(self, model: str) -> bool:
try:
model_lc = (model or "").lower()
return ("claude" in model_lc) or ("anthropic" in model_lc)
except Exception:
return False
def _load_role_turns(self, config: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""Load token_cache.role_turnsPriorityInvoker config > default value"""
try:
role_turns = None
# 1) Read from invoker's config
if config and isinstance(config, dict):
token_cache_cfg = config.get("token_cache")
if isinstance(token_cache_cfg, dict):
role_turns = token_cache_cfg.get("role_turns")
# 2) If still lack, use default value in current example config
if not role_turns:
role_turns = [
{"turn": 0, "role": "user"},
{"turn": 25, "role": "assistant"},
{"turn": 50, "role": "assistant"},
{"turn": 75, "role": "assistant"},
]
return role_turns
except Exception as e:
if self.logger:
self.logger.warning(f"Read token_cache.role_turns failuse default value: {e}")
return [
{"turn": 0, "role": "user"},
{"turn": 25, "role": "assistant"},
{"turn": 50, "role": "assistant"},
{"turn": 75, "role": "assistant"},
]
def create_request(
self, messages: List[ChatMessage], model: str, config: Optional[Dict[str, Any]] = None, **kwargs
) -> ChatCompletionRequest:
messages = self._make_token_cache_request(messages, model, config)
return ChatCompletionRequest(messages=messages, model=model, **kwargs)
def _handle_error(self, error: Exception) -> None:
if self.logger:
self.logger.error(f"API request failed: {error}")
else:
print(f"API request failed: {error}")
raise error
def _retry_with_backoff(self, func, *args, **kwargs):
last_exception = None
for attempt in range(self.max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < self.max_retries:
delay = self.retry_delay * (2**attempt) # 指数退避
if self.logger:
self.logger.warning(
f"{attempt + 1}th try failedretry after {delay}s: {e}, traceback: {format_exc()}."
)
time.sleep(delay)
else:
if self.logger:
self.logger.error(f"All retries failed: {e}, traceback: {format_exc()}.")
raise last_exception
def get_model_info(self) -> Dict[str, Any]:
return {
"provider": self.__class__.__name__,
"base_url": self.base_url,
"timeout": self.timeout,
"max_retries": self.max_retries,
}
def validate_request(self, request: ChatCompletionRequest) -> bool:
if not request.messages:
raise ValueError("Message list is empty")
if not request.model:
raise ValueError("Model name is empty")
for idx, message in enumerate(request.messages):
if not message.content and not message.tool_calls:
try:
msg_info = {
"index": idx,
"role": getattr(message.role, "value", str(message.role)),
"content_len": (
len(message.content)
if isinstance(message.content, str)
else (len(message.content) if isinstance(message.content, list) else 0)
),
"has_tool_calls": bool(message.tool_calls),
"tool_calls": message.tool_calls,
}
if self.logger:
self.logger.warning(
f"Request validation failed: Invalid message exists (lacking both content and tool_calls): {json.dumps(msg_info, ensure_ascii=False)}"
)
except Exception as e:
if self.logger:
self.logger.warning(
f"Request validation failed: Invalid message exists (lacking both content and tool_calls), index={idx}, error: {e}, traceback: {format_exc()}."
)
raise ValueError("Cannot lacking both content and tool_calls")
return True
def format_messages_for_api(
self, messages: List[ChatMessage]
) -> List[Dict[str, Any]]:
formatted_messages = []
for message in messages:
msg_dict = {"role": message.role.value, "content": message.content}
if message.name:
msg_dict["name"] = message.name
if message.tool_calls:
msg_dict["tool_calls"] = message.tool_calls
if message.tool_call_id:
msg_dict["tool_call_id"] = message.tool_call_id
formatted_messages.append(msg_dict)
return formatted_messages
def _stream_chat_completion(
self, payload: Dict[str, Any], endpoint: str
) -> Generator[ChatCompletionChunk, None, None]:
full_url = self.base_url.rstrip("/") + endpoint
headers = {}
if self.session and self.session.headers:
headers.update(self.session.headers)
else:
headers = {
"Content-Type": "application/json",
}
if self.api_key and self.api_key != "EMPTY":
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(
full_url, json=payload, headers=headers, timeout=self.timeout, stream=True
)
response.raise_for_status()
try:
line_count = 0
for line in response.iter_lines():
if line:
line_count += 1
line_str = line.decode("utf-8")
chunk = self._parse_stream_line(line_str)
if chunk:
yield chunk
else:
print(f"Jump invalid line")
print(f"Streaming request process done, {line_count} processed")
finally:
response.close()
async def _astream_chat_completion(
self, payload: Dict[str, Any], endpoint: str
) -> AsyncGenerator[ChatCompletionChunk, None]:
import asyncio
def _sync_stream():
return list(self._stream_chat_completion(payload, endpoint))
loop = asyncio.get_event_loop()
chunks = await loop.run_in_executor(None, _sync_stream)
for chunk in chunks:
yield chunk
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
# Process standard SSE format
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
return None
try:
chunk_data = json.loads(data)
return self._parse_stream_chunk(chunk_data)
except json.JSONDecodeError:
self.logger.warning(f"Unable to parse streaming data: {data}")
return None
return None
def close(self) -> None:
if self.session:
self.session.close()
async def aclose(self) -> None:
if self.session:
self.session.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
def __str__(self) -> str:
return f"{self.__class__.__name__}(base_url={self.base_url})"
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"base_url={self.base_url}, "
f"timeout={self.timeout}, "
f"max_retries={self.max_retries})"
)

View File

@@ -0,0 +1,23 @@
"""
LLM Client implement module
Diverse LLM provider
- OpenAI: OpenAI API and compatible service
- Anthropic: Claude models
- DeepSeek: DeepSeek models
- Private: Private modelsvLLM、TGI、Ollama etc.
"""
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
__all__ = [
"OpenAIClient",
"AnthropicClient",
"DeepSeekClient",
"OpenRouterClient",
"PrivateModelClient",
]

View File

@@ -0,0 +1,7 @@
"""
Anthropic Claude Client Module
"""
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
__all__ = ["AnthropicClient"]

View File

@@ -0,0 +1,288 @@
"""
Anthropic Claude Client Module
"""
import json
import time
from typing import Dict, List, Any, Optional, Union
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
)
class AnthropicClient(BaseLLMAPI):
"""
Anthropic Claude API Client
Anthropic Claude models supported
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
anthropic_version: str = "2023-06-01",
logger: Optional[Logger] = None,
**kwargs
):
"""
Initialize Anthropic client
Args:
api_key: Anthropic API key
base_url: API basic URLdefault Anthropic API
timeout: timeout in seconds
max_retries: max retries
retry_delay: retry delay in seconds
anthropic_version: API version
**kwargs: other args
"""
self.anthropic_version = anthropic_version
if base_url is None:
base_url = "https://api.anthropic.com"
super().__init__(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs
)
def _initialize_client(self) -> None:
"""Initialize HTTP client"""
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
"anthropic-version": self.anthropic_version,
"User-Agent": "tokfinity-llm-client/1.0",
}
self._create_http_clients(headers)
def _get_chat_endpoint(self) -> str:
return "/v1/messages"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
"""Build Anthropic API request payload"""
messages, system_prompt = self._convert_messages_to_anthropic_format(
request.messages
)
payload = {
"model": request.model,
"messages": messages,
"max_tokens": request.max_tokens or 1024,
"temperature": request.temperature,
"top_p": request.top_p,
"stream": request.stream,
}
if system_prompt:
payload["system"] = system_prompt
if request.stop is not None:
payload["stop_sequences"] = (
request.stop if isinstance(request.stop, list) else [request.stop]
)
if request.tools is not None:
payload["tools"] = self._convert_tools_to_anthropic_format(request.tools)
if request.tool_choice is not None:
payload["tool_choice"] = self._convert_tool_choice_to_anthropic_format(
request.tool_choice
)
return payload
def _convert_messages_to_anthropic_format(
self, messages: List[ChatMessage]
) -> tuple[List[Dict[str, Any]], Optional[str]]:
"""Convert messages to anthropic format"""
anthropic_messages = []
system_prompt = None
for message in messages:
if message.role == MessageRole.SYSTEM:
system_prompt = message.content
elif message.role in [MessageRole.USER, MessageRole.ASSISTANT]:
anthropic_messages.append(
{"role": message.role.value, "content": message.content}
)
return anthropic_messages, system_prompt
def _convert_tools_to_anthropic_format(
self, tools: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Convert OpenAI format tools to anthropic format"""
anthropic_tools = []
for tool in tools:
if tool.get("type") == "function":
function_def = tool.get("function", {})
anthropic_tool = {
"name": function_def.get("name", ""),
"description": function_def.get("description", ""),
"input_schema": function_def.get("parameters", {}),
}
anthropic_tools.append(anthropic_tool)
return anthropic_tools
def _convert_tool_choice_to_anthropic_format(
self, tool_choice: Union[str, Dict[str, Any]]
) -> Union[str, Dict[str, Any]]:
"""Convert OpenAI format tool choice to anthropic format"""
if isinstance(tool_choice, str):
if tool_choice == "auto":
return "auto"
elif tool_choice == "none":
return "none"
else:
return {"type": "tool", "name": tool_choice}
elif isinstance(tool_choice, dict):
if tool_choice.get("type") == "function":
return {
"type": "tool",
"name": tool_choice.get("function", {}).get("name", ""),
}
elif tool_choice.get("type") == "tool":
return tool_choice
return "auto"
def _convert_anthropic_tool_calls(
self, content_list: List[Dict[str, Any]]
) -> Optional[List[Dict[str, Any]]]:
"""Convert anthropic tool calls to OpenAI format"""
tool_calls = []
for item in content_list:
if item.get("type") == "tool_use":
tool_call = {
"id": item.get("id", ""),
"type": "function",
"function": {
"name": item.get("name", ""),
"arguments": item.get("input", {}),
},
}
tool_calls.append(tool_call)
return tool_calls if tool_calls else None
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
"""Convert anthropic API response to OpenAI format"""
content = ""
tool_calls = None
if response_data.get("content"):
content_data = (
response_data["content"][0] if response_data["content"] else {}
)
content = content_data.get("text", "")
if content_data.get("type") == "tool_use":
tool_calls = self._convert_anthropic_tool_calls(
response_data.get("content", [])
)
message = ChatMessage(
role=MessageRole.ASSISTANT, content=content, tool_calls=tool_calls
)
choice = Choice(
index=0,
message=message,
finish_reason=self._convert_stop_reason(response_data.get("stop_reason")),
)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("input_tokens", 0),
completion_tokens=usage_data.get("output_tokens", 0),
total_tokens=usage_data.get("input_tokens", 0)
+ usage_data.get("output_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object="chat.completion",
created=int(time.time()),
model=response_data.get("model", ""),
choices=[choice],
usage=usage,
)
def _convert_stop_reason(self, stop_reason: Optional[str]) -> Optional[str]:
"""Convert anthropic stop reason to OpenAI format"""
if stop_reason == "end_turn":
return "stop"
elif stop_reason == "max_tokens":
return "length"
elif stop_reason == "stop_sequence":
return "stop"
else:
return stop_reason
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
"""Convert anthropic steam response to OpenAI format"""
event_type = chunk_data.get("type")
if event_type == "content_block_delta":
delta = chunk_data.get("delta", {})
text = delta.get("text", "")
if text:
choices = [
{"index": 0, "delta": {"content": text}, "finish_reason": None}
]
return ChatCompletionChunk(
id=chunk_data.get("message", {}).get("id", ""),
object="chat.completion.chunk",
created=int(time.time()),
model=chunk_data.get("message", {}).get("model", ""),
choices=choices,
)
elif event_type == "message_stop":
choices = [{"index": 0, "delta": {}, "finish_reason": "stop"}]
return ChatCompletionChunk(
id=chunk_data.get("message", {}).get("id", ""),
object="chat.completion.chunk",
created=int(time.time()),
model=chunk_data.get("message", {}).get("model", ""),
choices=choices,
)
return None
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
try:
chunk_data = json.loads(line)
return self._parse_stream_chunk(chunk_data)
except json.JSONDecodeError:
# if not json, try SSE
return super()._parse_stream_line(line)

View File

@@ -0,0 +1,7 @@
"""
DeepSeek Client Module
"""
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
__all__ = ["DeepSeekClient"]

View File

@@ -0,0 +1,164 @@
"""
DeepSeek Client Module
"""
import time
from typing import Dict, List, Any, Optional
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
)
class DeepSeekClient(BaseLLMAPI):
"""
DeepSeek API Client
DeepSeek models supported, including DeepSeek-Coder、DeepSeek-Chat etc.
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
logger: Optional[Logger] = None,
**kwargs,
):
if base_url is None:
base_url = "https://api.deepseek.com/v1"
super().__init__(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs,
)
def _initialize_client(self) -> None:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"User-Agent": "tokfinity-llm-client/1.0",
}
self._create_http_clients(headers)
def _get_chat_endpoint(self) -> str:
return "/chat/completions"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"temperature": request.temperature,
"top_p": request.top_p,
"stream": request.stream,
}
# optional
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.stop is not None:
payload["stop"] = request.stop
if request.frequency_penalty != 0.0:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty != 0.0:
payload["presence_penalty"] = request.presence_penalty
if request.tools is not None:
payload["tools"] = request.tools
if request.tool_choice is not None:
payload["tool_choice"] = request.tool_choice
if request.response_format is not None:
payload["response_format"] = request.response_format
if request.seed is not None:
payload["seed"] = request.seed
if request.user is not None:
payload["user"] = request.user
return payload
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
choices = []
for choice_data in response_data.get("choices", []):
message_data = choice_data.get("message", {})
tool_calls = message_data.get("tool_calls")
message = ChatMessage(
role=MessageRole(message_data.get("role", "assistant")),
content=message_data.get("content", ""),
tool_calls=tool_calls,
)
choice = Choice(
index=choice_data.get("index", 0),
message=message,
finish_reason=choice_data.get("finish_reason"),
logprobs=choice_data.get("logprobs"),
)
choices.append(choice)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object=response_data.get("object", "chat.completion"),
created=response_data.get("created", int(time.time())),
model=response_data.get("model", ""),
choices=choices,
usage=usage,
system_fingerprint=response_data.get("system_fingerprint"),
)
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
"""Parse stream chunk"""
return ChatCompletionChunk(
id=chunk_data.get("id", ""),
object=chunk_data.get("object", "chat.completion.chunk"),
created=chunk_data.get("created", int(time.time())),
model=chunk_data.get("model", ""),
choices=chunk_data.get("choices", []),
system_fingerprint=chunk_data.get("system_fingerprint"),
)
def list_models(self) -> Dict[str, Any]:
"""Get available models"""
response = self.client.get("/models")
response.raise_for_status()
return response.json()
async def alist_models(self) -> Dict[str, Any]:
response = await self.async_client.get("/models")
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,7 @@
"""
OpenAI Client Module
"""
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
__all__ = ["OpenAIClient"]

View File

@@ -0,0 +1,279 @@
"""
OpenAI Client
"""
import time
from typing import Dict, List, Any, Optional, Union
from traceback import format_exc
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingData,
EmbeddingUsage,
)
class OpenAIClient(BaseLLMAPI):
"""
OpenAI API Client
OpenAI API supported, and other API services compatible with the OpenAI format
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
organization: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
logger: Optional[Logger] = None,
**kwargs,
):
self.organization = organization
if base_url is None:
base_url = "https://api.openai.com/v1"
super().__init__(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs,
)
def _initialize_client(self) -> None:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"User-Agent": "tokfinity-llm-client/1.0",
}
if self.organization:
headers["OpenAI-Organization"] = self.organization
self._create_http_clients(headers)
return "/chat/completions"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"temperature": request.temperature,
"top_p": request.top_p,
"frequency_penalty": request.frequency_penalty,
"presence_penalty": request.presence_penalty,
"stream": request.stream,
}
# Optional
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.stop is not None:
payload["stop"] = request.stop
if request.tools is not None:
payload["tools"] = request.tools
if request.tool_choice is not None:
payload["tool_choice"] = request.tool_choice
if request.response_format is not None:
payload["response_format"] = request.response_format
if request.seed is not None:
payload["seed"] = request.seed
if request.user is not None:
payload["user"] = request.user
return payload
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
choices = []
for choice_data in response_data.get("choices", []):
message_data = choice_data.get("message", {})
tool_calls = message_data.get("tool_calls")
message = ChatMessage(
role=MessageRole(message_data.get("role", "assistant")),
content=message_data.get("content", ""),
tool_calls=tool_calls,
)
choice = Choice(
index=choice_data.get("index", 0),
message=message,
finish_reason=choice_data.get("finish_reason"),
logprobs=choice_data.get("logprobs"),
)
choices.append(choice)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object=response_data.get("object", "chat.completion"),
created=response_data.get("created", int(time.time())),
model=response_data.get("model", ""),
choices=choices,
usage=usage,
system_fingerprint=response_data.get("system_fingerprint"),
)
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
return ChatCompletionChunk(
id=chunk_data.get("id", ""),
object=chunk_data.get("object", "chat.completion.chunk"),
created=chunk_data.get("created", int(time.time())),
model=chunk_data.get("model", ""),
choices=chunk_data.get("choices", []),
system_fingerprint=chunk_data.get("system_fingerprint"),
)
def list_models(self) -> Dict[str, Any]:
response = self.client.get("/models")
response.raise_for_status()
return response.json()
async def alist_models(self) -> Dict[str, Any]:
response = await self.async_client.get("/models")
response.raise_for_status()
return response.json()
def create_embeddings(
self,
input_text: Union[str, List[str]],
model: str,
encoding_format: str = "float",
dimensions: Optional[int] = None,
user: Optional[str] = None,
timeout: Optional[int] = None,
max_retries: Optional[int] = None,
retry_delay: Optional[float] = None,
) -> EmbeddingResponse:
actual_timeout = timeout if timeout is not None else self.timeout
actual_max_retries = (
max_retries if max_retries is not None else self.max_retries
)
actual_retry_delay = (
retry_delay if retry_delay is not None else self.retry_delay
)
request = EmbeddingRequest(
input=input_text,
model=model,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
payload = self._build_embedding_request_payload(request)
for attempt in range(actual_max_retries + 1):
try:
print(f"Debug: Payload: {payload}")
response = self.session.post(
f"{self.base_url}/embeddings", json=payload, timeout=actual_timeout
)
if response.status_code == 200:
response_data = response.json()
return self._parse_embedding_response(response_data)
else:
error_msg = f"embedding failed (try {attempt + 1}): HTTP {response.status_code}"
if hasattr(response, "text"):
error_msg += f" - {response.text}"
print(f"Debug: {error_msg}")
if attempt < actual_max_retries:
print(f"Debug: wait {actual_retry_delay} and retry...")
time.sleep(actual_retry_delay)
continue
else:
raise Exception(f"All retries failed: {error_msg}")
except Exception as e:
error_msg = f"Embedding request failed (try {attempt + 1}): {str(e)}, traceback: {format_exc()}."
print(f"Debug: {error_msg}")
if attempt < actual_max_retries:
print(f"Debug: wait {actual_retry_delay} and retry...")
time.sleep(actual_retry_delay)
continue
else:
raise Exception(f"All retries failed: {str(e)}, traceback: {format_exc()}.")
raise Exception("Unknown error")
def _build_embedding_request_payload(
self, request: EmbeddingRequest
) -> Dict[str, Any]:
payload = {
"input": request.input,
"model": request.model,
"encoding_format": request.encoding_format,
}
if request.dimensions is not None:
payload["dimensions"] = request.dimensions
if request.user is not None:
payload["user"] = request.user
return payload
def _parse_embedding_response(
self, response_data: Dict[str, Any]
) -> EmbeddingResponse:
embedding_data_list = []
for data_item in response_data.get("data", []):
embedding_data = EmbeddingData(
object=data_item.get("object", "embedding"),
embedding=data_item.get("embedding", []),
index=data_item.get("index", 0),
)
embedding_data_list.append(embedding_data)
usage_data = response_data.get("usage", {})
usage = EmbeddingUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return EmbeddingResponse(
object=response_data.get("object", "list"),
data=embedding_data_list,
model=response_data.get("model", ""),
usage=usage,
)

View File

@@ -0,0 +1,8 @@
"""
OpenRouter Client
"""
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
__all__ = ["OpenRouterClient"]

View File

@@ -0,0 +1,329 @@
"""
OpenRouter Client
"""
import time
from typing import Dict, List, Any, Optional
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
)
class OpenRouterClient(BaseLLMAPI):
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
app_name: Optional[str] = None,
site_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
logger: Optional[Logger] = None,
**kwargs,
):
"""
Initialize OpenRouter Client
Args:
api_key: OpenRouter API key
base_url: API base URL, default to OpenRouter official API
app_name: Application name (optional, for statistics)
site_url: Website URL (optional, for statistics)
timeout: Request timeout
max_retries: Maximum number of retries
retry_delay: Retry delay
**kwargs: Other configuration parameters
"""
self.app_name = app_name or "tokfinity-llm-client"
self.site_url = site_url
if base_url is None:
base_url = "https://openrouter.ai/api/v1"
super().__init__(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs,
)
def _initialize_client(self) -> None:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"User-Agent": "tokfinity-llm-client/1.0",
"X-Title": self.app_name,
}
if self.site_url:
headers["HTTP-Referer"] = self.site_url
self._create_http_clients(headers)
def _get_chat_endpoint(self) -> str:
return "/chat/completions"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"temperature": request.temperature,
"top_p": request.top_p,
"stream": request.stream,
}
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.stop is not None:
payload["stop"] = request.stop
if request.frequency_penalty != 0.0:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty != 0.0:
payload["presence_penalty"] = request.presence_penalty
if request.tools is not None:
payload["tools"] = request.tools
if request.tool_choice is not None:
payload["tool_choice"] = request.tool_choice
if request.response_format is not None:
payload["response_format"] = request.response_format
if request.seed is not None:
payload["seed"] = request.seed
if request.user is not None:
payload["user"] = request.user
return payload
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
"""
Parse OpenRouter API response
Args:
response_data: API response data
Returns:
ChatCompletionResponse: Parsed response object
"""
choices = []
for choice_data in response_data.get("choices", []):
message_data = choice_data.get("message", {})
tool_calls = message_data.get("tool_calls")
message = ChatMessage(
role=MessageRole(message_data.get("role", "assistant")),
content=message_data.get("content", ""),
tool_calls=tool_calls,
)
choice = Choice(
index=choice_data.get("index", 0),
message=message,
finish_reason=choice_data.get("finish_reason"),
logprobs=choice_data.get("logprobs"),
)
choices.append(choice)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object=response_data.get("object", "chat.completion"),
created=response_data.get("created", int(time.time())),
model=response_data.get("model", ""),
choices=choices,
usage=usage,
system_fingerprint=response_data.get("system_fingerprint"),
)
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
"""
Parse stream chunk
Args:
chunk_data: Raw chunk data
Returns:
Optional[ChatCompletionChunk]: Parsed chunk
"""
return ChatCompletionChunk(
id=chunk_data.get("id", ""),
object=chunk_data.get("object", "chat.completion.chunk"),
created=chunk_data.get("created", int(time.time())),
model=chunk_data.get("model", ""),
choices=chunk_data.get("choices", []),
system_fingerprint=chunk_data.get("system_fingerprint"),
)
def list_models(self) -> Dict[str, Any]:
"""
Get available models
Returns:
Dict[str, Any]: Model list response
"""
response = self.client.get("/models")
response.raise_for_status()
return response.json()
async def alist_models(self) -> Dict[str, Any]:
"""
Async get available models
Returns:
Dict[str, Any]: Model list response
"""
response = await self.async_client.get("/models")
response.raise_for_status()
return response.json()
def get_generation_info(self, generation_id: str) -> Dict[str, Any]:
"""
Get specific generation request details
Args:
generation_id: Generation request ID
Returns:
Dict[str, Any]: Generation information
"""
response = self.client.get(f"/generation?id={generation_id}")
response.raise_for_status()
return response.json()
async def aget_generation_info(self, generation_id: str) -> Dict[str, Any]:
"""
Async get specific generation request details
Args:
generation_id: Generation request ID
Returns:
Dict[str, Any]: Generation information
"""
response = await self.async_client.get(f"/generation?id={generation_id}")
response.raise_for_status()
return response.json()
def get_account_credits(self) -> Dict[str, Any]:
"""
Get account credits
Returns:
Dict[str, Any]: Account credits information
"""
response = self.client.get("/auth/key")
response.raise_for_status()
return response.json()
async def aget_account_credits(self) -> Dict[str, Any]:
"""
Async get account credits
Returns:
Dict[str, Any]: Account credits information
"""
response = await self.async_client.get("/auth/key")
response.raise_for_status()
return response.json()
@staticmethod
def get_popular_models() -> List[str]:
"""
Get popular models
Returns:
List[str]: Popular model names
"""
return [
# OpenAI
"openai/gpt-4-turbo-preview",
"openai/gpt-4",
"openai/gpt-3.5-turbo",
# Anthropic
"anthropic/claude-3-opus",
"anthropic/claude-3-sonnet",
"anthropic/claude-3-haiku",
# Google
"google/gemini-pro",
"google/gemini-pro-vision",
# Meta
"meta-llama/llama-2-70b-chat",
"meta-llama/llama-2-13b-chat",
# Mistral
"mistralai/mixtral-8x7b-instruct",
"mistralai/mistral-7b-instruct",
# Open Source
"microsoft/wizardlm-2-8x22b",
"databricks/dbrx-instruct",
"cohere/command-r-plus",
]
def get_model_info(self, model_name: str) -> Dict[str, Any]:
"""
Get specific model details
Args:
model_name: Model name
Returns:
Dict[str, Any]: Model information
"""
models_response = self.list_models()
models = models_response.get("data", [])
for model in models:
if model.get("id") == model_name:
return model
raise ValueError(f"Model {model_name} not found")
async def aget_model_info(self, model_name: str) -> Dict[str, Any]:
"""
Async get specific model details
Args:
model_name: Model name
Returns:
Dict[str, Any]: Model information
"""
models_response = await self.alist_models()
models = models_response.get("data", [])
for model in models:
if model.get("id") == model_name:
return model
raise ValueError(f"Model {model_name} not found")

View File

@@ -0,0 +1,7 @@
"""
Private model client
"""
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
__all__ = ["PrivateModelClient"]

View File

@@ -0,0 +1,321 @@
"""
Private model client
Implement of private model API based on the BaseLLMAPI
vLLM, Text Generation Inference, Ollama supported
"""
import json
import time
from typing import Dict, List, Any, Optional, Union, AsyncGenerator, Generator
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
)
class PrivateModelClient(BaseLLMAPI):
def __init__(
self,
api_key: Optional[str] = None,
base_url: str = "http://localhost:8000/v1",
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
deployment_type: str = "vllm",
custom_headers: Optional[Dict[str, str]] = None,
supports_tools: bool = True,
logger: Optional[Logger] = None,
**kwargs,
):
self.deployment_type = deployment_type.lower()
self.custom_headers = custom_headers or {}
self.supports_tools = supports_tools
super().__init__(
api_key=api_key or "EMPTY",
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs,
)
def _initialize_client(self) -> None:
headers = {
"Content-Type": "application/json",
"User-Agent": "tokfinity-llm-client/1.0",
}
if self.api_key and self.api_key != "EMPTY":
headers["Authorization"] = f"Bearer {self.api_key}"
headers.update(self.custom_headers)
if self.deployment_type == "ollama":
pass
elif self.deployment_type == "tgi":
pass
self._create_http_clients(headers)
def _get_chat_endpoint(self) -> str:
if self.deployment_type == "ollama":
return "/api/chat"
elif self.deployment_type == "tgi":
return "/generate"
else:
return "/chat/completions"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
if self.deployment_type == "ollama":
return self._build_ollama_payload(request)
elif self.deployment_type == "tgi":
return self._build_tgi_payload(request)
else:
return self._build_openai_payload(request)
def _build_openai_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"temperature": request.temperature,
"top_p": request.top_p,
"stream": request.stream,
}
# Optional
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.stop is not None:
payload["stop"] = request.stop
if request.frequency_penalty != 0.0:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty != 0.0:
payload["presence_penalty"] = request.presence_penalty
if self.supports_tools and request.tools is not None:
payload["tools"] = request.tools
if self.supports_tools and request.tool_choice is not None:
payload["tool_choice"] = request.tool_choice
if request.response_format is not None:
payload["response_format"] = request.response_format
return payload
def _build_ollama_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"stream": request.stream,
"options": {
"temperature": request.temperature,
"top_p": request.top_p,
},
}
if request.max_tokens is not None:
payload["options"]["num_predict"] = request.max_tokens
if request.stop is not None:
payload["options"]["stop"] = (
request.stop if isinstance(request.stop, list) else [request.stop]
)
return payload
def _build_tgi_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
prompt = self._messages_to_prompt(request.messages)
payload = {
"inputs": prompt,
"parameters": {
"temperature": request.temperature,
"top_p": request.top_p,
"do_sample": True,
"stream": request.stream,
},
}
if request.max_tokens is not None:
payload["parameters"]["max_new_tokens"] = request.max_tokens
if request.stop is not None:
payload["parameters"]["stop_sequences"] = (
request.stop if isinstance(request.stop, list) else [request.stop]
)
return payload
def _messages_to_prompt(self, messages: List[ChatMessage]) -> str:
prompt_parts = []
for message in messages:
if message.role == MessageRole.SYSTEM:
prompt_parts.append(f"System: {message.content}")
elif message.role == MessageRole.USER:
prompt_parts.append(f"User: {message.content}")
elif message.role == MessageRole.ASSISTANT:
prompt_parts.append(f"Assistant: {message.content}")
prompt_parts.append("Assistant:")
return "\n\n".join(prompt_parts)
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
if self.deployment_type == "ollama":
return self._parse_ollama_response(response_data)
elif self.deployment_type == "tgi":
return self._parse_tgi_response(response_data)
else:
return self._parse_openai_response(response_data)
def _parse_openai_response(
self, response_data: Dict[str, Any]
) -> ChatCompletionResponse:
choices = []
for choice_data in response_data.get("choices", []):
message_data = choice_data.get("message", {})
message = ChatMessage(
role=MessageRole(message_data.get("role", "assistant")),
content=message_data.get("content", ""),
tool_calls=message_data.get("tool_calls"),
)
choice = Choice(
index=choice_data.get("index", 0),
message=message,
finish_reason=choice_data.get("finish_reason"),
)
choices.append(choice)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object=response_data.get("object", "chat.completion"),
created=response_data.get("created", int(time.time())),
model=response_data.get("model", ""),
choices=choices,
usage=usage,
)
def _parse_ollama_response(
self, response_data: Dict[str, Any]
) -> ChatCompletionResponse:
content = response_data.get("message", {}).get("content", "")
message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
choice = Choice(
index=0,
message=message,
finish_reason="stop" if response_data.get("done", False) else None,
)
return ChatCompletionResponse(
id=f"ollama-{int(time.time())}",
object="chat.completion",
created=int(time.time()),
model=response_data.get("model", ""),
choices=[choice],
)
def _parse_tgi_response(
self, response_data: Dict[str, Any]
) -> ChatCompletionResponse:
if isinstance(response_data, list) and len(response_data) > 0:
response_data = response_data[0]
content = response_data.get("generated_text", "")
message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
choice = Choice(
index=0,
message=message,
finish_reason=response_data.get("finish_reason", "stop"),
)
return ChatCompletionResponse(
id=f"tgi-{int(time.time())}",
object="chat.completion",
created=int(time.time()),
model="tgi-model",
choices=[choice],
)
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
if self.deployment_type == "ollama":
try:
chunk_data = json.loads(line)
return self._parse_ollama_chunk(chunk_data)
except json.JSONDecodeError:
return None
else:
# Standard SSE format
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
return None
try:
chunk_data = json.loads(data)
return self._parse_stream_chunk(chunk_data)
except json.JSONDecodeError:
return None
return None
def _parse_ollama_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
content = chunk_data.get("message", {}).get("content", "")
done = chunk_data.get("done", False)
choices = [
{
"index": 0,
"delta": {"content": content} if content else {},
"finish_reason": "stop" if done else None,
}
]
return ChatCompletionChunk(
id=f"ollama-{int(time.time())}",
object="chat.completion.chunk",
created=int(time.time()),
model=chunk_data.get("model", ""),
choices=choices,
)
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
return ChatCompletionChunk(
id=chunk_data.get("id", ""),
object=chunk_data.get("object", "chat.completion.chunk"),
created=chunk_data.get("created", int(time.time())),
model=chunk_data.get("model", ""),
choices=chunk_data.get("choices", []),
)

View File

@@ -0,0 +1,43 @@
"""
Log manage Module
Provide log management functionality categorized by timestamp and level
Usage example:
from src.managers.log import Logger, create_logger
# Basic Usage
logger = Logger("logs", "my_app")
logger.info("This is an info")
logger.error("This is an error")
logger.close()
# Use the context manager
with Logger("logs", "my_app") as logger:
logger.info("Automatically manage resources")
# Convenient function
logger = create_logger("logs", "my_app")
logger.info("Convenient create")
logger.close()
"""
from .logger import (
Logger,
create_logger,
init_global_logger,
get_global_logger,
set_global_logger,
)
__all__ = [
"Logger",
"create_logger",
"init_global_logger",
"get_global_logger",
"set_global_logger",
]
__version__ = "1.0.0"
__author__ = "Tokfinity Team"
__description__ = "Timestamp:Directory and Level-Based Log Manager "

357
src/managers/log/logger.py Normal file
View File

@@ -0,0 +1,357 @@
"""
Log Manager
Create dir according to timestamp, store level-based log files
"""
import os
import logging
from datetime import datetime
from typing import Optional
from pathlib import Path
"""
Module-level Self-defined NOTICE level log (between INFO and WARNING)
"""
NOTICE_LEVEL = 25
if not hasattr(logging, "NOTICE"):
logging.addLevelName(NOTICE_LEVEL, "NOTICE")
def notice(self, message, *args, **kwargs):
if self.isEnabledFor(NOTICE_LEVEL):
self._log(NOTICE_LEVEL, message, args, **kwargs)
logging.Logger.notice = notice # type: ignore[attr-defined]
class ImageBuilderLogger:
"""
ImageBuilder log manager
Function:
- Create sub dir based on timestamp
- Create debug.log, info.log and error.log
- Provide standard log format, able to print logs two levels up and the code location
- Console and file outputs
"""
def __init__(self, log_base_path: str, console_output: bool = True):
self.log_base_path = Path(log_base_path)
self.console_output = console_output
self.log_dir = self._create_log_dir()
self.file_handlers = {}
self.logger = self._setup_logger()
def _create_log_dir(self) -> Path:
timestamp = datetime.now().strftime("%Y%m%d%H%M")
log_dir = self.log_base_path / f"build_images/{timestamp}"
log_dir.mkdir(parents=True, exist_ok=True)
return log_dir
def _setup_logger(self) -> logging.Logger:
logger = logging.getLogger("image_builder_logger")
logger.setLevel(logging.DEBUG)
logger.handlers.clear()
log_format = (
"%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - "
"%(pathname)s:%(lineno)d - %(funcName)s - %(message)s"
)
formatter = logging.Formatter(log_format)
if self.console_output:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
self._add_file_handlers(logger, formatter)
return logger
def _add_file_handlers(self, logger: logging.Logger, formatter: logging.Formatter):
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
}
for level_name, level_value in log_levels.items():
log_file = self.log_dir / f"{level_name}.log"
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
if level_name == "debug":
file_handler.addFilter(lambda record: record.levelno == logging.DEBUG)
elif level_name == "info":
file_handler.addFilter(lambda record: record.levelno == logging.INFO)
elif level_name == "warning":
file_handler.addFilter(lambda record: record.levelno == logging.WARNING)
elif level_name == "error":
file_handler.addFilter(lambda record: record.levelno == logging.ERROR)
logger.addHandler(file_handler)
self.file_handlers[level_name] = file_handler
all_log_file = self.log_dir / "all.log"
all_file_handler = logging.FileHandler(all_log_file, encoding="utf-8")
all_file_handler.setFormatter(formatter)
all_file_handler.setLevel(logging.DEBUG)
logger.addHandler(all_file_handler)
self.file_handlers["all"] = all_file_handler
def debug(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.debug(message, *args, **kwargs)
def info(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.info(message, *args, **kwargs)
def warning(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.warning(message, *args, **kwargs)
def error(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.error(message, *args, **kwargs)
@property
def log_file(self) -> str:
return str(self.log_dir / "all.log")
def get_log_dir(self) -> str:
return str(self.log_dir.absolute())
def get_log_files(self) -> dict:
return {
"debug": str(self.log_dir / "debug.log"),
"info": str(self.log_dir / "info.log"),
"warning": str(self.log_dir / "warning.log"),
"error": str(self.log_dir / "error.log"),
"all": str(self.log_dir / "all.log"),
}
def close(self):
for handler in self.file_handlers.values():
handler.close()
for handler in self.logger.handlers[:]:
self.logger.removeHandler(handler)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __str__(self) -> str:
return f"ImageBuilderLogger(dir={self.log_dir})"
def __repr__(self) -> str:
return (
f"ImageBuilderLogger("
f"base_path='{self.log_base_path}', "
f"log_dir='{self.log_dir}', "
f"console_output={self.console_output})"
)
class Logger:
"""
Log manager
Function:
- Create sub dir based on timestamp
- Create debug.log, info.log, notice.log, warning.log and error.log
- Provide standard log format
- Console and file outputs
"""
def __init__(
self,
log_base_path: str,
logger_name: str = "tokfinity_logger",
console_output: bool = True,
log_format: Optional[str] = None,
instance_id: Optional[str] = None,
):
self.log_base_path = Path(log_base_path)
self.logger_name = logger_name
self.console_output = console_output
self.instance_id = instance_id
self.log_format = log_format or (
"%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - "
"%(pathname)s:%(lineno)d - %(funcName)s - %(message)s"
)
self.log_dir = self._create_log_dir()
self.file_handlers = {}
self.logger = self._setup_logger()
def _create_log_dir(self) -> Path:
if self.instance_id:
log_dir = self.log_base_path / self.instance_id
else:
timestamp = datetime.now().strftime("%Y%m%d%H%M")
log_dir = self.log_base_path / timestamp
log_dir.mkdir(parents=True, exist_ok=True)
return log_dir
def _setup_logger(self) -> logging.Logger:
logger = logging.getLogger(self.logger_name)
logger.setLevel(logging.DEBUG)
logger.handlers.clear()
formatter = logging.Formatter(self.log_format)
if self.console_output:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
self._add_file_handlers(logger, formatter)
return logger
def _add_file_handlers(self, logger: logging.Logger, formatter: logging.Formatter):
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"notice": NOTICE_LEVEL,
"warning": logging.WARNING,
"error": logging.ERROR,
}
for level_name, level_value in log_levels.items():
log_file = self.log_dir / f"{level_name}.log"
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
if level_name == "debug":
file_handler.addFilter(lambda record: record.levelno == logging.DEBUG)
elif level_name == "info":
file_handler.addFilter(lambda record: record.levelno == logging.INFO)
elif level_name == "notice":
file_handler.addFilter(lambda record: record.levelno == NOTICE_LEVEL)
elif level_name == "warning":
file_handler.addFilter(lambda record: record.levelno == logging.WARNING)
elif level_name == "error":
file_handler.addFilter(lambda record: record.levelno == logging.ERROR)
logger.addHandler(file_handler)
self.file_handlers[level_name] = file_handler
all_log_file = self.log_dir / "all.log"
all_file_handler = logging.FileHandler(all_log_file, encoding="utf-8")
all_file_handler.setFormatter(formatter)
all_file_handler.setLevel(logging.DEBUG)
logger.addHandler(all_file_handler)
self.file_handlers["all"] = all_file_handler
def debug(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.debug(message, *args, **kwargs)
def info(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.info(message, *args, **kwargs)
def notice(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.log(NOTICE_LEVEL, message, *args, **kwargs)
def warning(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.warning(message, *args, **kwargs)
def error(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.error(message, *args, **kwargs)
def get_log_dir(self) -> str:
return str(self.log_dir.absolute())
def get_log_files(self) -> dict:
return {
"debug": str(self.log_dir / "debug.log"),
"info": str(self.log_dir / "info.log"),
"notice": str(self.log_dir / "notice.log"),
"warning": str(self.log_dir / "warning.log"),
"error": str(self.log_dir / "error.log"),
"all": str(self.log_dir / "all.log"),
}
@property
def log_file(self) -> str:
return str(self.log_dir / "all.log")
def close(self):
for handler in self.file_handlers.values():
handler.close()
for handler in self.logger.handlers[:]:
self.logger.removeHandler(handler)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __str__(self) -> str:
return f"Logger(name={self.logger_name}, dir={self.log_dir})"
def __repr__(self) -> str:
return (
f"Logger("
f"name='{self.logger_name}', "
f"base_path='{self.log_base_path}', "
f"log_dir='{self.log_dir}', "
f"console_output={self.console_output})"
)
def create_logger(
log_base_path: str,
logger_name: str = "tokfinity_logger",
console_output: bool = True,
) -> Logger:
return Logger(
log_base_path=log_base_path,
logger_name=logger_name,
console_output=console_output,
)
# Global log manager instance (optional)
_global_logger: Optional[Logger] = None
def get_global_logger() -> Optional[Logger]:
return _global_logger
def set_global_logger(logger: Logger):
global _global_logger
_global_logger = logger
def init_global_logger(
log_base_path: str, logger_name: str = "global_logger"
) -> Logger:
global _global_logger
_global_logger = Logger(log_base_path, logger_name)
return _global_logger

275
src/managers/loop/base.py Normal file
View File

@@ -0,0 +1,275 @@
from typing import Any, Dict, List
import json
from traceback import format_exc
from src.managers.log.logger import Logger
from src.managers.llm_api.api_manager import LLMAPIManager
from src.managers.prompts.prompts_manager import PromptsManager
from src.tools.base import (
ToolExecutor,
BASH_TOOL_NAME,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
SEARCH_TOOL_NAME,
SUBMIT_RESULT_TOOL_NAME,
)
from src.managers.loop.types import ToolStats, LLMUsage
class BaseLoop:
def __init__(self, instance_id: str, instance_data: Dict[str, Any], logger: Logger, prompts_manager: PromptsManager | None, llm_manager: LLMAPIManager | None, tool_executor: ToolExecutor, config: Dict[str, Any] | None = None):
self.instance_id = instance_id
self.instance_data = instance_data
self.logger = logger
self.prompts_manager = prompts_manager
self.llm_manager = llm_manager
self.tool_executor = tool_executor
self.config = config or {}
self.component_name = self.__class__.__name__
def _make_assistant(
self, content: str | None, tool_calls: Any, messages: List[Dict[str, Any]]
) -> bool:
"""
Construct an assistant message based on the current content and tool calls, and append it to the messages.
"""
safe_content = content or ""
if not safe_content and not tool_calls:
self.logger.warning(
f"[{self.component_name}] Assistant returned an empty message with no tool calls; skipping this message and prompting to continue"
)
messages.append(
{"role": "user", "content": "请继续分析问题并使用工具来解决问题。"}
)
return False
assistant_message: Dict[str, Any] = {"role": "assistant"}
if tool_calls and not safe_content:
assistant_message["content"] = ""
elif safe_content:
assistant_message["content"] = safe_content
if tool_calls:
assistant_message["tool_calls"] = tool_calls
messages.append(assistant_message)
return True
def _make_tool_response(
self, tool_results: List[Any], messages: List[Dict[str, Any]]
) -> None:
"""Convert tool execution results into standard tool messages (role=tool) and append them to the messages.
- Generate content per result: use prompts_manager.tool_response_prompts([{...}]) to produce the content
- Set tool_call_id: prefer ToolResult.id; fallback to ToolResult.call_id
"""
if not tool_results:
return
for result in tool_results:
single_dict = [
{
"name": getattr(result, "name", "unknown"),
"success": getattr(result, "success", False),
"result": getattr(result, "result", None) or "",
"error": getattr(result, "error", None) or "",
}
]
content_text = (
self.prompts_manager.tool_response_prompts(single_dict)
if self.prompts_manager
else ""
)
tool_call_id = getattr(result, "id", None) or getattr(
result, "call_id", None
)
messages.append(
{
"role": "tool",
"content": content_text,
"tool_call_id": tool_call_id,
}
)
def _response_log(
self, response: Any, first_content: str, first_tool_calls: Any, total_turns: int
) -> None:
"""notice log for the current turn's LLM output"""
try:
response_log: Dict[str, Any] = {}
if hasattr(response, "usage") and response.usage:
response_log["usage"] = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", None),
"completion_tokens": getattr(response.usage, "completion_tokens", None),
"total_tokens": getattr(response.usage, "total_tokens", None),
}
if hasattr(response, "choices") and response.choices:
response_log["choice"] = {
"message": {
"content": first_content,
"tool_calls": first_tool_calls,
}
}
if response_log:
self.logger.notice(
f"[{self.component_name}] The {total_turns}th turn output: {json.dumps(response_log, ensure_ascii=False)}"
)
else:
self.logger.notice(
f"[{self.component_name}] The {total_turns}th turn output: {str(response)}"
)
except Exception:
self.logger.notice(
f"[{self.component_name}] 第 {total_turns} 轮: LLM 输出序列化失败,使用字符串表示: {str(response)}, traceback: {format_exc()}."
)
def _debug_messages(
self, turn: int, messages: List[Dict[str, Any]], prefix_len: int = 300
) -> None:
"""debug log for the messages to be sent to the model"""
try:
self.logger.debug(f"[{self.component_name}] msg:")
recent_messages = messages[-2:] if len(messages) > 2 else messages
base_index = len(messages) - len(recent_messages)
for offset, msg in enumerate[Dict[str, Any]](recent_messages):
idx = base_index + offset
role = msg.get("role")
content = msg.get("content")
content_str = content if isinstance(content, str) else ""
preview = content_str[:prefix_len]
content_len = len(content_str)
extra = ""
if role == "assistant":
tool_calls = msg.get("tool_calls")
has_tool = tool_calls is not None and tool_calls != []
try:
tool_calls_json = json.dumps(tool_calls, ensure_ascii=False)
except Exception:
self.logger.warning(
f"[{self.component_name}] In debug_messages function, fail: {format_exc()}, tool calls: {tool_calls}."
)
tool_calls_json = str(tool_calls)
extra = f", has_tool_calls={has_tool}, tool_calls={tool_calls_json}"
elif role == "tool":
tool_call_id = msg.get("tool_call_id")
extra = f", tool_call_id={tool_call_id}"
self.logger.debug(
f"[{self.component_name}] {turn+1}th, msg#{idx}: role={role}, content_len={content_len}, content_preview={json.dumps(preview, ensure_ascii=False)}{extra}"
)
except Exception:
self.logger.warning(
f"[{self.component_name}] In debug_messages function, fail msg: {format_exc()}."
)
def _debug_last_message(
self, turn: int, messages: List[Dict[str, Any]], prefix_len: int = 300
) -> None:
"""debug last turn msg"""
try:
if not messages:
return
last_assistant_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") == "assistant":
last_assistant_idx = i
break
if last_assistant_idx is None:
return
msg = messages[last_assistant_idx]
content = msg.get("content")
content_str = content if isinstance(content, str) else ""
preview = content_str[:prefix_len]
content_len = len(content_str)
tool_calls = msg.get("tool_calls")
has_tool = tool_calls is not None and tool_calls != []
try:
tool_calls_json = json.dumps(tool_calls, ensure_ascii=False)
except Exception:
self.logger.warning(
f"[{self.component_name}] In debug_last_message function, fail: {format_exc()}, tool calls: {tool_calls}."
)
tool_calls_json = str(tool_calls)
self.logger.debug(
f"[{self.component_name}] {turn+1}th turn, output_preview: role=assistant, content_len={content_len}, content_preview={json.dumps(preview, ensure_ascii=False)}, has_tool_calls={has_tool}, tool_calls={tool_calls_json}"
)
except Exception:
self.logger.warning(
f"[{self.component_name}] In debug_last_message function, last turn fail: {format_exc()}."
)
def _debug_tools(self, tools: List[Dict[str, Any]]) -> None:
"""debug tools msg"""
try:
self.logger.debug(f"[{self.component_name}] tools num: {len(tools)}")
for i, tool in enumerate(tools):
try:
tool_json = json.dumps(tool, ensure_ascii=False)
self.logger.debug(f"[{self.component_name}] tool #{i+1}: {tool_json}")
except Exception:
self.logger.debug(
f"[{self.component_name}] tool #{i+1} fail: {format_exc()}, string: {str(tool)}."
)
except Exception:
try:
self.logger.warning(
f"[{self.component_name}] fail; traceback: {format_exc()}."
)
self.logger.warning(f"[{self.component_name}] tools string: {str(tools)}")
except Exception:
pass
def _get_tools(self) -> List[Dict[str, Any]]:
pass
def _is_bash_tool(self, tool_name: str) -> bool:
return BASH_TOOL_NAME in tool_name
def _is_edit_tool(self, tool_name: str) -> bool:
return "edit" in tool_name or "str_replace" in tool_name or STR_REPLACE_BASED_EDIT_TOOL_NAME in tool_name
def _is_search_tool(self, tool_name: str) -> bool:
return SEARCH_TOOL_NAME in tool_name or "search" in tool_name
def _is_submit_result_tool(self, tool_name: str) -> bool:
return SUBMIT_RESULT_TOOL_NAME in tool_name
def _update_usage(self, response: Any, usage_stats: LLMUsage) -> None:
if hasattr(response, "usage") and response.usage:
usage_stats.prompt_tokens += int(getattr(response.usage, "prompt_tokens", 0) or 0)
usage_stats.completion_tokens += int(
getattr(response.usage, "completion_tokens", 0) or 0
)
usage_stats.total_tokens += int(getattr(response.usage, "total_tokens", 0) or 0)
def _init_usage_stats(self) -> LLMUsage:
return LLMUsage()
def _init_tools_stats(self) -> ToolStats:
return ToolStats()
def _update_tool_call_statistic(
self, tool_results: List[Any], tool_stats: ToolStats
) -> None:
for result in tool_results:
try:
tool_name = getattr(result, "name", "")
tool_name = tool_name.lower() if isinstance(tool_name, str) else ""
success = bool(getattr(result, "success", False))
if self._is_bash_tool(tool_name):
tool_stats.bash["count"] += 1
if not success:
tool_stats.bash["failed"] += 1
elif self._is_edit_tool(tool_name):
tool_stats.edit["count"] += 1
if not success:
tool_stats.edit["failed"] += 1
elif self._is_search_tool(tool_name):
tool_stats.search["count"] += 1
if not success:
tool_stats.search["failed"] += 1
elif self._is_submit_result_tool(tool_name):
tool_stats.submit_result["count"] += 1
if not success:
tool_stats.submit_result["failed"] += 1
except Exception:
continue

View File

@@ -0,0 +1,339 @@
from typing import Any, Dict, List
import json
from traceback import format_exc
from src.managers.log.logger import Logger
from src.managers.llm_api.api_manager import LLMAPIManager
from src.managers.prompts.prompts_manager import PromptsManager
from src.managers.loop.base import BaseLoop
from src.tools.base import (
ToolExecutor,
ToolResult,
SubmitToolResult,
BASH_TOOL_NAME,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
SEARCH_TOOL_NAME,
SUBMIT_RESULT_TOOL_NAME,
)
class PatchGenerator(BaseLoop):
def __init__(
self,
instance_id: str,
instance_data: Dict[str, Any],
logger: Logger,
prompts_manager: PromptsManager | None,
llm_manager: LLMAPIManager | None,
tool_executor: ToolExecutor,
config: Dict[str, Any] | None = None,
) -> None:
super().__init__(instance_id, instance_data, logger, prompts_manager, llm_manager, tool_executor, config)
async def _submit_all_tool_calls(
self, other_tool_calls: List[Dict[str, Any]]
) -> List[Any]:
"""execute tool calls, return tool execution results list"""
if not other_tool_calls:
return []
from src.tools.base import ToolCall
tool_call_objects = []
for tool_call_dict in other_tool_calls:
raw_args = tool_call_dict.get("function", {}).get("arguments", {})
parsed_args = raw_args
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except Exception as e:
self.logger.warning(f"[{self.component_name}] In _submit_all_tool_calls function, fail: {e}, traceback: {format_exc()}, args: {raw_args}.")
parsed_args = {}
tool_call_obj = ToolCall(
name=tool_call_dict.get("function", {}).get("name", ""),
call_id=tool_call_dict.get("id", ""),
arguments=parsed_args,
id=tool_call_dict.get("id", ""),
)
tool_call_objects.append(tool_call_obj)
return await self.tool_executor.container_sequential_tool_call(
tool_call_objects
)
def _process_submit_result_tool_result(
self,
submit_result: ToolResult,
golden_patch: List[Dict[str, Any]],
) -> None:
"""process submit_result tool call, fill golden_patch and log"""
if not submit_result.success or not submit_result.result:
self.logger.warning(f"[{self.component_name}] submit_result failed and no result.")
return
try:
submit_tool_result = SubmitToolResult.from_string(submit_result.result)
if submit_tool_result.output:
patch_info = {
"patch_content": submit_tool_result.output,
"test_status": submit_tool_result.test_status,
"reasoning": submit_tool_result.reasoning,
}
golden_patch.clear()
golden_patch.append(patch_info)
self.logger.info(
f"[{self.component_name}] patch len: {len(submit_tool_result.output)}."
)
self.logger.info(
f"[{self.component_name}] test status: {submit_tool_result.test_status}."
)
self.logger.info(
f"[{self.component_name}] reasoning: {submit_tool_result.reasoning[:100]}..."
)
else:
self.logger.warning(
f"[{self.component_name}] submit_result success but no patch content."
)
except Exception as e:
self.logger.error(f"[{self.component_name}] parse submit_result result fail: {e}, traceback: {format_exc()}.")
def _get_tools(self) -> List[Dict[str, Any]]:
tools = []
#use_openai_format = self._should_use_openai_format()
use_openai_format = True
for tool in self.tool_executor.tools.values():
if use_openai_format:
tool_def = tool._definition_for_openai_fmt()
else:
tool_def = tool._definition_for_claude_fmt()
tools.append(tool_def)
return tools
def _should_use_openai_format(self) -> bool:
if not self.llm_manager or not hasattr(self.llm_manager, "get_model_name"):
return True # openAI format by default
model_name = self.llm_manager.get_model_name().lower()
return "claude" not in model_name
def _get_issue_prompt(self) -> str:
"""generate issue prompt based on instance data"""
if not self.prompts_manager:
self.logger.warning("PromptsManager not initialized, cannot generate issue prompt.")
return ""
#instance_id = self.instance_data.get("instance_id", "")
#repo = self.instance_data.get("repo", "")
created_at = self.instance_data.get("created_at", "")
base_commit = self.instance_data.get("base_commit", "")
environment_setup_commit = self.instance_data.get(
"environment_setup_commit", ""
)
version = self.instance_data.get("version", "")
problem_statement = self.instance_data.get("problem_statement", "")
difficulty = self.instance_data.get("difficulty", "")
return self.prompts_manager.format_issue_prompt(
created_at=created_at,
base_commit=base_commit,
environment_setup_commit=environment_setup_commit,
version=version,
problem_statement=problem_statement,
difficulty=difficulty,
)
async def _generate_patch(self) -> Dict[str, Any] | None:
"""main loop logic for generating candidate patch"""
usage_stats = self._init_usage_stats()
tool_stats = self._init_tools_stats()
if not self.llm_manager or not self.prompts_manager:
self.logger.error(f"[{self.component_name}] LLM manager or prompts manager not initialized.")
return {
"success": False,
"golden_patch": [],
"llm_usage": usage_stats.to_dict(),
"tool_stats": tool_stats.to_dict(),
"total_turns": 0,
}
tools = self._get_tools()
self._debug_tools(tools)
root_path = self.config.get("builder", {}).get("repo_root_path", "")
max_turn = (
self.config.get("runner", {}).get("generator_loop", {}).get("max_turn", 10)
)
temperature = (
self.config.get("runner", {})
.get("generator_loop", {})
.get("temperature", 0.2)
)
issue_prompt = self._get_issue_prompt()
user_prompt = self.prompts_manager.get_generator_user(root_path, issue_prompt)
system_prompt = self.prompts_manager.get_generator_system(root_path)
total_turns = 0
golden_patch = []
try:
self.logger.info(
f"[{self.component_name}] {self.instance_id}: start generating candidate patch, max turn: {max_turn}"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
self.logger.notice(
f"[{self.component_name}]: {json.dumps(messages[0], ensure_ascii=False)}"
)
self.logger.notice(
f"[{self.component_name}]: {json.dumps(messages[1], ensure_ascii=False)}"
)
for turn in range(max_turn):
total_turns = turn + 1
self.logger.info(f"[{self.component_name}] The {total_turns}th turn started.")
try:
current_input_msg = messages[-1] if messages else None
if current_input_msg is not None:
self.logger.notice(
f"[{self.component_name}] The {total_turns}th turn input: {json.dumps(current_input_msg, ensure_ascii=False)}"
)
except Exception as e:
self.logger.warning(
f"[{self.component_name}] {total_turns}th turn: LLM input fail: {messages[-1] if messages else None}, error: {e}, traceback: {format_exc()}."
)
self._debug_messages(turn, messages)
response = self.llm_manager.chat(
messages=messages,
tools=tools,
tool_choice="auto",
temperature=temperature,
)
first_content: str = ""
first_tool_calls: Any = None
if hasattr(response, "choices") and response.choices:
ch0 = response.choices[0]
first_content = (
getattr(getattr(ch0, "message", None), "content", None) or ""
)
first_tool_calls = getattr(
getattr(ch0, "message", None), "tool_calls", None
)
self._response_log(
response, first_content, first_tool_calls, total_turns
)
self._update_usage(response, usage_stats)
if hasattr(response, "choices") and response.choices:
content = first_content
tool_calls = first_tool_calls
if not self._make_assistant(content, tool_calls, messages):
continue
if tool_calls:
self.logger.info(
f"[{self.component_name}] {total_turns}th turn: call {len(tool_calls)} tools."
)
tool_results = await self._submit_all_tool_calls(tool_calls)
self._update_tool_call_statistic(tool_results, tool_stats)
if tool_results:
submit_result = None
other_tool_results = []
for tool_result in tool_results:
tool_name = getattr(tool_result, "name", "")
if tool_name == SUBMIT_RESULT_TOOL_NAME:
submit_result = tool_result
else:
other_tool_results.append(tool_result)
if submit_result:
self.logger.debug(
f"[{self.component_name}] {total_turns}th turn: got submit_result tool call."
)
self.logger.debug(f"[{self.component_name}] {total_turns}th turn: submit_result result: {submit_result}")
self._process_submit_result_tool_result(
submit_result, golden_patch
)
self._debug_last_message(turn, messages)
break
if other_tool_results:
self._make_tool_response(other_tool_results, messages)
else:
messages.append(
{
"role": "user",
"content": "请继续分析问题并使用工具来解决问题。",
}
)
self.logger.debug(f"[{self.component_name}] final golden_patch: {golden_patch}")
success = (
len(golden_patch) > 0 and golden_patch[0].get("patch_content", "") != ""
)
self.logger.info(
f"[{self.component_name}] status={success}, total_turns={total_turns}, tools_stats={tool_stats}"
)
result_payload = {
"success": success,
"golden_patch": golden_patch,
"llm_usage": usage_stats.to_dict(),
"tool_stats": tool_stats.to_dict(),
"total_turns": total_turns,
}
try:
self.logger.notice(
f"[{self.component_name}] final output: {json.dumps(result_payload, ensure_ascii=False)}"
)
except Exception as e:
self.logger.warning(
f"[{self.component_name}] output: {str(result_payload)}, error: {e}, traceback: {format_exc()}."
)
return result_payload
except Exception as e:
self.logger.error(f"[{self.component_name}] fail: {e}, traceback: {format_exc()}.")
result_payload = {
"success": False,
"golden_patch": [],
"llm_usage": usage_stats.to_dict(),
"tool_stats": tool_stats.to_dict(),
"total_turns": total_turns,
}
try:
self.logger.notice(
f"[{self.component_name}] 最终返回数据(失败): {json.dumps(result_payload, ensure_ascii=False)}"
)
except Exception as e:
self.logger.notice(
f"[{self.component_name}] 最终返回数据(失败, 字符串回退): {str(result_payload)}, error: {e}, traceback: {format_exc()}."
)
return result_payload

View File

@@ -0,0 +1,338 @@
from typing import Any, Dict, List
import json
from traceback import format_exc
from src.managers.log.logger import Logger
from src.managers.llm_api.api_manager import LLMAPIManager
from src.managers.prompts.prompts_manager import PromptsManager
from src.managers.loop.types import GeneratorResult, SelectorResult, LLMUsage, ToolStats, PatchInfo
from src.tools.base import ToolExecutor, ToolCall, ToolResult
from src.managers.loop.base import BaseLoop
SELECTOR_SUBMIT_TOOL_NAME = "submit_result"
class PatchSelector(BaseLoop):
def __init__(
self,
instance_id: str,
instance_data: Dict[str, Any],
logger: Logger,
prompts_manager: PromptsManager | None,
llm_manager: LLMAPIManager | None,
tool_executor: ToolExecutor,
config: Dict[str, Any] | None = None,
) -> None:
super().__init__(instance_id, instance_data, logger, prompts_manager, llm_manager, tool_executor, config)
def _get_submit_result_tool_name(self):
return SELECTOR_SUBMIT_TOOL_NAME
def _definition_for_submit_tool(self, use_openai_format: bool) -> Dict[str, Any]:
"""submit_result tool"""
if use_openai_format:
return {
"type": "function",
"function": {
"name": self._get_submit_result_tool_name(),
"description": "Submit the final selected patch index and reasoning.",
"parameters": {
"type": "object",
"properties": {
"index": {
"type": "integer",
"description": "The chosen patch index (0-based).",
},
"reason": {
"type": "string",
"description": "Detailed reasoning for the selection.",
},
},
"required": ["index", "reason"],
},
},
}
return {
"type": "function",
"function": {
"name": self._get_submit_result_tool_name(),
"description": "Submit the final selected patch index and reasoning.",
"parameters": {
"type": "object",
"properties": {
"index": {
"type": "integer",
"description": "The chosen patch index (0-based).",
},
"reason": {
"type": "string",
"description": "Detailed reasoning for the selection.",
},
},
"required": ["index", "reason"],
},
},
}
def _build_user_prompt(self, candidates: List[GeneratorResult], root_path: str) -> str:
if not self.prompts_manager:
return ""
return self.prompts_manager.get_selector_user(self.instance_data, candidates, root_path)
def _get_system_prompt(self, patches_count: int, root_path: str) -> str:
if not self.prompts_manager:
return ""
return self.prompts_manager.get_selector_system(patches_count, root_path)
def _get_tools(self) -> List[Dict[str, Any]]:
tool_defs: List[Dict[str, Any]] = []
try:
for tool in self.tool_executor.tools.values():
try:
tool_defs.append(tool._definition_for_openai_fmt())
except Exception:
continue
except Exception:
pass
tool_defs.append(self._definition_for_submit_tool(True))
return tool_defs
def _extract_submit_choice(self, tool_call: Dict[str, Any]) -> Dict[str, Any] | None:
if not tool_call:
return None
fn = tool_call.get("function", {})
if fn.get("name") != self._get_submit_result_tool_name():
return None
raw_args = fn.get("arguments", {})
try:
args = json.loads(raw_args) if isinstance(raw_args, str) else raw_args
except Exception:
args = {}
index = args.get("index")
reason = args.get("reason")
if isinstance(index, int) and index >= 0:
return {"index": index, "reason": reason or ""}
return None
async def _submit_other_tool_calls(
self, tool_calls: List[Dict[str, Any]]
) -> List[ToolResult]:
if not tool_calls:
return []
to_run: List[ToolCall] = []
for tool_call_dict in tool_calls:
fn = tool_call_dict.get("function", {})
name = fn.get("name", "")
if name == SELECTOR_SUBMIT_TOOL_NAME:
continue
raw_args = fn.get("arguments", {})
parsed_args = raw_args
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except Exception:
parsed_args = {}
to_run.append(
ToolCall(
name=name,
call_id=tool_call_dict.get("id", ""),
arguments=parsed_args,
id=tool_call_dict.get("id", ""),
)
)
if not to_run:
return []
results: List[ToolResult] = await self.tool_executor.container_sequential_tool_call(to_run)
return results
async def _select_patch(self, candidates: List[GeneratorResult]) -> SelectorResult:
if not candidates:
raise ValueError("No candidates provided")
if not self.llm_manager:
raise ValueError("LLM manager is not initialized")
tools = self._get_tools()
self._debug_tools(tools)
root_path = self.config.get("builder", {}).get("repo_root_path", "")
system_prompt = self._get_system_prompt(len(candidates), root_path)
user_prompt = self._build_user_prompt(candidates, root_path)
messages: List[Dict[str, Any]] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
try:
self.logger.notice(
f"[{self.component_name}]: {json.dumps(messages[0], ensure_ascii=False)}"
)
self.logger.notice(
f"[{self.component_name}]: {json.dumps(messages[1], ensure_ascii=False)}"
)
except Exception:
self.logger.warning(
f"[{self.component_name}] Initial fail in selector loop: SP={str(messages[0])}, UP={str(messages[1])}, traceback: {format_exc()}."
)
max_turn = int(
self.config.get("runner", {})
.get("selector_loop", {})
.get("max_turn", 200)
)
temperature = (
self.config.get("runner", {})
.get("selector_loop", {})
.get("temperature", 0.2)
)
usage_stats = self._init_usage_stats()
tool_stats = self._init_tools_stats()
total_turns = 0
chosen_index: int | None = None
select_reason: str = ""
for turn in range(max_turn):
try:
try:
current_input_msg = messages[-1] if messages else None
if current_input_msg is not None:
self.logger.notice(
f"[{self.component_name}] The {turn+1}th turn input: {json.dumps(current_input_msg, ensure_ascii=False)}"
)
except Exception:
self.logger.warning(
f"[{self.component_name}] {turn+1}th turn fail: {messages[-1] if messages else None}, traceback: {format_exc()}."
)
self._debug_messages(turn, messages)
response = self.llm_manager.chat(
messages=messages,
tools=tools,
tool_choice="auto",
temperature=temperature,
)
first_tool_calls = None
if hasattr(response, "choices") and response.choices:
ch0 = response.choices[0]
first_tool_calls = getattr(getattr(ch0, "message", None), "tool_calls", None)
first_content = getattr(getattr(ch0, "message", None), "content", None) or ""
else:
first_content = ""
total_turns = turn + 1
self._response_log(response, first_content, first_tool_calls, turn + 1)
self._update_usage(response, usage_stats)
if first_tool_calls:
if not self._make_assistant(first_content, first_tool_calls, messages):
messages.append(
{
"role": "user",
"content": "请完成分析并调用 submit_result 工具给出最终选择与理由。",
}
)
continue
submit_found = False
for tc in first_tool_calls:
choice = self._extract_submit_choice(tc)
if choice is not None:
chosen_index = choice["index"]
reason = choice.get("reason", "")
self.logger.info(
f"[{self.component_name}] choose: index={chosen_index}, reason={reason}"
)
select_reason = reason or ""
submit_found = True
self._debug_last_message(turn, messages)
break
if not submit_found:
results = await self._submit_other_tool_calls(first_tool_calls)
self._make_tool_response(results, messages)
self._update_tool_call_statistic(results, tool_stats)
else:
messages.append(
{
"role": "user",
"content": "请完成分析并调用 submit_result 工具给出最终选择与理由。",
}
)
if chosen_index is not None:
break
except Exception as e:
self.logger.warning(
f"[{self.component_name}] fail: {e}, traceback: {format_exc()}"
)
break
if chosen_index is None:
# If the model provides no choice, fallback: pick the first successful one; otherwise the first
for i, r in enumerate(candidates):
try:
if r.success:
chosen_index = i
break
except Exception:
continue
if chosen_index is None:
chosen_index = 0
if not (0 <= chosen_index < len(candidates)):
chosen_index = 0
selected = candidates[chosen_index]
try:
gp = selected.golden_patch[0] if selected.golden_patch else None
if gp is None:
patch_info = PatchInfo(patch_content="", test_status="", reasoning="")
else:
patch_info = PatchInfo(
patch_content=gp.patch_content,
test_status=gp.test_status,
reasoning=gp.reasoning,
)
except Exception:
patch_info = PatchInfo(patch_content="", test_status="", reasoning="")
selector_result = SelectorResult(
instance_id=selected.instance_id,
generator_id=selected.generator_id,
image=selected.image,
success=True,
golden_patch=patch_info,
llm_usage=usage_stats,
tool_stats=tool_stats,
total_turns=total_turns,
select_reason=select_reason,
error=None,
)
return selector_result

254
src/managers/loop/types.py Normal file
View File

@@ -0,0 +1,254 @@
"""
This module defines the GeneratorResult data structure for patch generation results.
"""
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional
from src.tools.base import (
BASH_TOOL_NAME,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
SEARCH_TOOL_NAME,
SUBMIT_RESULT_TOOL_NAME,
)
@dataclass
class LLMUsage:
"""LLM usage statistics."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
def to_dict(self) -> Dict[str, int]:
"""Serialize LLMUsage to a plain dictionary."""
return {
"prompt_tokens": int(self.prompt_tokens),
"completion_tokens": int(self.completion_tokens),
"total_tokens": int(self.total_tokens),
}
@dataclass
class ToolStats:
"""Tool usage statistics per tool.
Each tool is represented by a small map with two fields:
- count: total invocation count
- failed: failed invocation count
"""
bash: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
edit: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
search: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
submit_result: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
def to_dict(self) -> Dict[str, Dict[str, int]]:
"""Serialize ToolStats to a plain dictionary."""
return {
BASH_TOOL_NAME: {"count": int(self.bash.get("count", 0)), "failed": int(self.bash.get("failed", 0))},
STR_REPLACE_BASED_EDIT_TOOL_NAME: {"count": int(self.edit.get("count", 0)), "failed": int(self.edit.get("failed", 0))},
SEARCH_TOOL_NAME: {"count": int(self.search.get("count", 0)), "failed": int(self.search.get("failed", 0))},
SUBMIT_RESULT_TOOL_NAME: {"count": int(self.submit_result.get("count", 0)), "failed": int(self.submit_result.get("failed", 0))},
}
@dataclass
class PatchInfo:
"""Information about a generated patch."""
patch_content: str
test_status: str
reasoning: str
@dataclass
class GeneratorResult:
"""Result from a patch generator."""
instance_id: str
generator_id: int
image: str
success: bool
golden_patch: List[
PatchInfo
]
llm_usage: LLMUsage
tool_stats: ToolStats
total_turns: int
error: Optional[str] = None
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GeneratorResult":
"""Create GeneratorResult from dictionary."""
# Handle golden_patch conversion
golden_patch = []
if data.get("golden_patch"):
for patch_data in data["golden_patch"]:
if isinstance(patch_data, dict):
golden_patch.append(
PatchInfo(
patch_content=patch_data.get("patch_content", ""),
test_status=patch_data.get("test_status", ""),
reasoning=patch_data.get("reasoning", ""),
)
)
else:
# Legacy format: just patch content string
golden_patch.append(
PatchInfo(
patch_content=str(patch_data), test_status="", reasoning=""
)
)
# Handle LLM usage
llm_usage_data = data.get("llm_usage", {})
llm_usage = LLMUsage(
prompt_tokens=llm_usage_data.get("prompt_tokens", 0),
completion_tokens=llm_usage_data.get("completion_tokens", 0),
total_tokens=llm_usage_data.get("total_tokens", 0),
)
# Handle tool stats
tool_stats_data = data.get("tool_stats", {})
tool_stats = ToolStats(
bash=tool_stats_data.get(BASH_TOOL_NAME, 0),
edit=tool_stats_data.get(STR_REPLACE_BASED_EDIT_TOOL_NAME, 0),
search=tool_stats_data.get(SEARCH_TOOL_NAME, 0),
submit_result=tool_stats_data.get(SUBMIT_RESULT_TOOL_NAME, 0),
)
return cls(
instance_id=data.get("instance_id", ""),
generator_id=data.get("generator_id", 0),
image=data.get("image", ""),
success=data.get("success", False),
golden_patch=golden_patch,
llm_usage=llm_usage,
tool_stats=tool_stats,
total_turns=data.get("total_turns", 0),
error=data.get("error"),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert GeneratorResult to dictionary."""
return {
"instance_id": self.instance_id,
"generator_id": self.generator_id,
"image": self.image,
"success": self.success,
"golden_patch": [
{
"patch_content": patch.patch_content,
"test_status": patch.test_status,
"reasoning": patch.reasoning,
}
for patch in self.golden_patch
],
"llm_usage": {
"prompt_tokens": self.llm_usage.prompt_tokens,
"completion_tokens": self.llm_usage.completion_tokens,
"total_tokens": self.llm_usage.total_tokens,
},
"tool_stats": {
BASH_TOOL_NAME: self.tool_stats.bash,
STR_REPLACE_BASED_EDIT_TOOL_NAME: self.tool_stats.edit,
SEARCH_TOOL_NAME: self.tool_stats.search,
SUBMIT_RESULT_TOOL_NAME: self.tool_stats.submit_result,
},
"total_turns": self.total_turns,
"error": self.error,
}
@dataclass
class SelectorResult:
"""Result from a patch selector.
"""
instance_id: str
generator_id: int
image: str
success: bool
golden_patch: PatchInfo
llm_usage: LLMUsage
tool_stats: ToolStats
total_turns: int
select_reason: str
error: Optional[str] = None
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SelectorResult":
"""Create SelectorResult from dictionary."""
gp_data = data.get("golden_patch", {})
if isinstance(gp_data, dict):
golden_patch = PatchInfo(
patch_content=gp_data.get("patch_content", ""),
test_status=gp_data.get("test_status", ""),
reasoning=gp_data.get("reasoning", ""),
)
else:
golden_patch = PatchInfo(
patch_content=str(gp_data) if gp_data is not None else "",
test_status="",
reasoning="",
)
# LLM usage
llm_usage_data = data.get("llm_usage", {})
llm_usage = LLMUsage(
prompt_tokens=llm_usage_data.get("prompt_tokens", 0),
completion_tokens=llm_usage_data.get("completion_tokens", 0),
total_tokens=llm_usage_data.get("total_tokens", 0),
)
# Tool stats
tool_stats_data = data.get("tool_stats", {})
tool_stats = ToolStats(
bash=tool_stats_data.get(BASH_TOOL_NAME, 0),
edit=tool_stats_data.get(STR_REPLACE_BASED_EDIT_TOOL_NAME, 0),
search=tool_stats_data.get(SEARCH_TOOL_NAME, 0),
submit_result=tool_stats_data.get(SUBMIT_RESULT_TOOL_NAME, 0),
)
return cls(
instance_id=data.get("instance_id", ""),
generator_id=data.get("generator_id", 0),
image=data.get("image", ""),
success=data.get("success", False),
golden_patch=golden_patch,
llm_usage=llm_usage,
tool_stats=tool_stats,
total_turns=data.get("total_turns", 0),
select_reason=data.get("select_reason", ""),
error=data.get("error"),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert SelectorResult to dictionary."""
return {
"instance_id": self.instance_id,
"generator_id": self.generator_id,
"image": self.image,
"success": self.success,
"golden_patch": {
"patch_content": self.golden_patch.patch_content,
"test_status": self.golden_patch.test_status,
"reasoning": self.golden_patch.reasoning,
},
"llm_usage": {
"prompt_tokens": self.llm_usage.prompt_tokens,
"completion_tokens": self.llm_usage.completion_tokens,
"total_tokens": self.llm_usage.total_tokens,
},
"tool_stats": {
BASH_TOOL_NAME: self.tool_stats.bash,
STR_REPLACE_BASED_EDIT_TOOL_NAME: self.tool_stats.edit,
SEARCH_TOOL_NAME: self.tool_stats.search,
SUBMIT_RESULT_TOOL_NAME: self.tool_stats.submit_result,
},
"total_turns": self.total_turns,
"select_reason": self.select_reason,
"error": self.error,
}

View File

@@ -0,0 +1,268 @@
from typing import Any, List, Dict
class PromptsManager:
def __init__(self, config):
self.candidate_length = config.get("runner", {}).get("generator_concurrency", 5)
def get_generator_system(self, root_path: str | None = None):
return f"""
# You are a highly skilled expert in software engineering focused on resolving complex GitHub issues by effectively analyzing codebases, implementing fixes, and ensuring code reliability through rigorous testing.
## Skills
1. Code Analysis and Debugging
- Issue Exploration: Ability to explore and comprehend codebases within repositories.
- Workflow Tracing: Skilled in using debugging techniques to trace issues through code.
- Root Cause Identification: Proficient in pinpointing the underlying causes of software issues.
- Test Creation: Expertise in establishing tests that replicate and validate issues.
2. Solution Implementation and Testing
- Fix Implementation: Experience in crafting precise and minimal code patches.
- Comprehensive Testing: Skilled in running and analyzing both existing and newly created tests.
- Regression Prevention: Ensures all changes maintain overall code stability.
- Continuous Improvement: Iterates on solutions based on test results to achieve optimal functionality.
## Task
Your task is resolve the given GitHub issue by understanding the repository and the issue, implementing a fix, and checking your changes against existing tests and your own test(s).
Write ABSOLUTE PATHS as arguments for tools that take a `file_path`. Combine the project root path `{root_path or "/testbed"}` with the file's path inside the project.
For example, pass `/root/testbed/_run.py` as `file_path` if you need to edit `_run.py` given the root path `/root/test_bed`.
Here's the project root path: `{root_path or "/testbed"}`. The target repository has already been cloned, and I activated the virtual environment for you. You can start analyzing the issue, searching and reading relevant files, and performing necessary fixes directly.
Follow these steps:
1. Problem Analysis:
- Read the issue description carefully to fully grasp the issue and explore the repository (source code, tests, examples) to understand expected behavior of relevant components.
- Identify the full scope. Does the issue mention multiple components, backends, or functions? Your solution must address all of them.
2. Reproduce the issue (IMPORTANT):
- Create a test that reproduces the issue as a baseline for verification.
- Check that the output of your test matches your understanding of the issue in step 1.
3. Identify the root cause:
- Go through relavant files, create debugging scripts with print statements or use other methods if necessary,to trace the workflow and exact cause of the issue.
- Trace the problem to its root cause.** Do not just patch the symptom where the error appears. Trace the data and execution flow upstream to find where the problem originates.
4. Implement a Fix:
- Once you have identified the root cause, develop a precise and targeted fix and then apply it as a minimal patch using the `str_replace_based_edit_tool` tools.
5. Test comprehensively:
- Verify the Fix: Run your initial reproduction script to confirm that the bug is resolved.
- Prevent Regressions:
--Identify the right tests: Once you have verified your fix, identify the most relevant tests within the project's existing test suite that correspond to your code changes.
--Run the tests: Then you **must** run these tests to ensure that your fix does not introduce any new bugs.
--Analyze failures carefully:
---If tests fail, do not immediately assume your fix is wrong. Critically analyze the failure.
---Is it a **regression**? Did your change break existing, valid functionality? If so, you must refine your fix.
---Is it an **unrelated failure**? It could be an environmental issue (e.g., missing dependency, network error) or a pre-existing flaky test. If you suspect this, try to run a more focused test and note the issue in your final reasoning.
---Is the **test now obsolete**? If your fix improves behavior in a way that makes an old test's assertions incorrect, you should **update the test** to match the new, correct behavior and explain why in your reasoning.
- Write New Tests: Create new, specific test cases (e.g., using `pytest`) that cover the original bug scenario.
- Consider Edge Cases: Think about and test potential edge cases related to your changes.
6. Revisit step 1 through 5 if unexpected behavior occurs, then call `submit_result` to submit the reliable and verified solution patch after successful testing and validation.
**Mandatory Workflow** As a senior engineer, ensure solution correctness and safety. Upon successful verification, immediately conclude the task by calling `submit_result`.
"""
def format_issue_prompt(
self,
created_at: str,
base_commit: str,
environment_setup_commit: str,
version: str,
problem_statement: str,
difficulty: str,
) -> str:
template = f"""
[📝 Issue Description]
**Created at**: {created_at}
**Base commit**: {base_commit}
---
### 📌 Problem Statement
{problem_statement}
---
### ⚙️ Difficulty Level
{difficulty}
---
"""
return template.strip()
def get_generator_user(self, root_path: str, issue_text: str):
return (
f"""
[Project root path]:
{root_path}
[Issue Information]:
{issue_text}
"""
+ self.get_generator_notice()
)
def get_generator_notice(self):
return """
[notice]
1. Use the available tools to locate the root cause.
2. Prioritize using the `search_tool` to retrieve and locate the precise location of key information in the project.
3. Collect supporting evidence: stack traces, logs, configs, recent changes, related modules.
"""
def get_selector_system(self, patches_count: int, root_path: str):
return f"""
# ROLE:
*You are a highly proficient software engineer tasked with evaluating and selecting optimal code patches to resolve specific issues within a given project.
*You colleagus worked on {patches_count} potential patches for an github issue. Select ONE correct patch to solve the issue.
*Here's the project root path: `{root_path or "/testbed"}`. The target repository has already been cloned, and the virtual environment has been activated for you. You can start analyzing the issue, searching and reading relevant files, and performing necessary fixes directly.
*Write ABSOLUTE PATHS as arguments for tools that take a `file_path`. Combine the project root path `{root_path or "/testbed"}` with the file's path inside the project. For instance, pass `/root/testbed/_run.py` as `file_path` if you need to edit `_run.py` given the root path `/root/test_bed`.
# WORKFLOWS:
*Follow these steps without any skipping:
1.Problem Analysis:
- Read the issue description and the current code that needs to be fixed. Explore the repository (source code, tests, examples) to understand expected behavior of relevant components, and gather comprehensive information about the problem area
2.Conduct a thorough review of each patch:
- Scrutinize all code modifications.
- Decipher the core logic and problem-solving methodology.
- Evaluate potential edge cases and unintended consequences.
- Validate that each patch fully addresses the initial issue specifications.
3.Verify Your Analysis
- Use available tools to verify your analysis works of this issue.
- Test your conclusions against relevant code sections.
- Ensure full contextual understanding.
4.Proceed with Your Decision
- Upon completion of the preceding three steps, utilize the `submit_result` tool with your detailed reasoning.
#RULES:
1.It is MANDATORY to utilize both available tools prior to finalizing any selectio:
-- Start with `bash` to explore the codebase structure;
-- Employ the str_replace_based_edit_tool to inspect the current code;
-- Use `search_tool` to search related code and file;
2.You MUST first explore the codebase before using the `submit_result` tool.
3.Substantiate your reasoning with evidence from your analysis.
4.Only selections made after employing the tools will be accepted.
#FINAL DECISION:
Upon completion of your tool-based analysis, finalize the process by submitting your choice via the `submit_result` tool.
#NOTICE:
1. Tool usage is MANDATORY - do not skip this step.
2. Without making a decision after completing analysis is not permitted.
3. Never generate new patches by your own, just make the selection.
4. Always provide detailed reasoning for the selection based on your tool-based investigation
"""
def get_selector_user(
self, instance_data: Dict[str, Any] | None = None, candidates: List[Any] | None = None, root_path: str | None = None
) -> str:
"""
Generate user prompt of selector, including issue information and the first golden patch of each candidate.
- instance_data: Current instance metadata (issue description etc.)
- candidates: Candidates list (.to_dict() supported), only get golden_patch[0].patch_content
"""
if not instance_data or not candidates:
return ""
created_at = instance_data.get("created_at", "")
base_commit = instance_data.get("base_commit", "")
environment_setup_commit = instance_data.get("environment_setup_commit", "")
version = instance_data.get("version", "")
problem_statement = instance_data.get("problem_statement", "")
difficulty = instance_data.get("difficulty", "")
issue_block = self.format_issue_prompt(
created_at=created_at,
base_commit=base_commit,
environment_setup_commit=environment_setup_commit,
version=version,
problem_statement=problem_statement,
difficulty=difficulty,
)
root_path_block = f"""
[Project root path]:
{root_path or "/testbed"}
"""
parts: List[str] = [root_path_block, issue_block, "\n[🔎 Candidates]\n"]
for idx, r in enumerate(candidates):
try:
data = r.to_dict() if hasattr(r, "to_dict") else {}
except Exception:
data = {}
golden_patch = data.get("golden_patch", [])
patch_content = golden_patch[0].get("patch_content", "") if golden_patch else ""
test_status = golden_patch[0].get("test_status", "") if golden_patch else ""
reasoning = golden_patch[0].get("reasoning", "") if golden_patch else ""
parts.append(self.format_selector_candidate(idx, patch_content, test_status, reasoning))
parts.append(
"\nPlease analyze the candidates, then call the submit_result tool with the final index and reasoning."
)
return "\n".join(parts)
def get_terminal_response(self, exit_code: int, output: str, timeout_status: bool):
if timeout_status == True:
return f"""[Terminal response]
Exit code: {exit_code}
Output: {output}"""
else:
return f"""[Terminal response]
Terminal time out."""
def tool_response_prompts(self, tool_results: list) -> str:
if not tool_results:
return ""
response_parts = ["[tool_response]"]
for i, result in enumerate(tool_results, 1):
tool_name = result.get("name", "unknown")
success = result.get("success", False)
output = result.get("result", "")
error = result.get("error", "")
response_parts.append(f"Tool {i}: {tool_name}")
response_parts.append(f"Success: {success}")
if success and output:
response_parts.append(f"Output:\n{output}")
elif error:
response_parts.append(f"Error: {error}")
else:
response_parts.append("No output")
response_parts.append("") # 空行分隔
return "\n".join(response_parts)
def format_selector_candidate(self, index: int, patch_content: str, test_status: str, reasoning: str) -> str:
"""
Generate description of selector candidate items, including key information of the first golden_patch
- index: Candidate index(0-based)
- patch_content: golden_patch[0].patch_content
- test_status: golden_patch[0].test_status test status in generating stage
- reasoning: golden_patch[0].reasoning model reasoning in generating stage
"""
header = f"- Candidate #{index}:"
patch_block = patch_content or ""
status_block = test_status or ""
reasoning_block = reasoning or ""
return (
f"--{header}\n"
f"--Patch content (the proposed fix):\n{patch_block}\n\n"
f"--Test status during generation: {status_block}\n\n"
f"--Reasoning during generation (model's logic):\n{reasoning_block}"
)

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict
import json
class ResultBuilder:
"""
Build preds.json
Iterate through JSON files named by instance_id in the directory specified by `runner.selector_result_dump_path` in the configuration.
- Parse the `golden_patch` field from each JSON file and extract the patch text as `model_patch`.
- Read the first top-level field from providers and the first model under it, and concatenate them to form `model_name_or_path`.
- The output location is `{workspace.path}/{result.preds.path}/preds.json`.
"""
def __init__(self, config: Dict[str, Any]):
self.config = config or {}
def _get_selector_dump_dir(self) -> Path:
runner_cfg = self.config.get("runner", {}) if isinstance(self.config, dict) else {}
dump_dir_str = runner_cfg.get(
"selector_result_dump_path", "workspace/selector_result_dump"
)
return Path(dump_dir_str)
def _get_preds_output_dir(self) -> Path:
workspace_cfg = self.config.get("workspace", {}) if isinstance(self.config, dict) else {}
result_cfg = self.config.get("result", {}) if isinstance(self.config, dict) else {}
preds_cfg = result_cfg.get("preds", {}) if isinstance(result_cfg, dict) else {}
workspace_path = workspace_cfg.get("path", "workspace")
preds_path = preds_cfg.get("path", "result")
return Path(workspace_path) / preds_path
def _get_model_name_or_path(self) -> str:
providers = self.config.get("providers", {}) if isinstance(self.config, dict) else {}
if not isinstance(providers, dict) or not providers:
return ""
first_provider_name = next(iter(providers.keys()))
first_models = providers.get(first_provider_name, [])
if isinstance(first_models, list) and first_models:
first_model = first_models[0]
else:
first_model = ""
return f"{first_provider_name}/{first_model}" if first_provider_name and first_model else ""
@staticmethod
def _extract_model_patch(golden_patch: Any) -> str:
"""
Extract patch content from golden_patch
Forms supported:
- dict: prioritize extract 'patch_content', then attempt `model_patch`
- string: Directly return
- other: return empty string
"""
if isinstance(golden_patch, dict):
if "patch_content" in golden_patch and isinstance(golden_patch["patch_content"], str):
return golden_patch["patch_content"]
if "model_patch" in golden_patch and isinstance(golden_patch["model_patch"], str):
return golden_patch["model_patch"]
return ""
if isinstance(golden_patch, str):
return golden_patch
return ""
def build_preds(self) -> Path:
dump_dir = self._get_selector_dump_dir()
output_dir = self._get_preds_output_dir()
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / "preds.json"
model_name_or_path = self._get_model_name_or_path()
# SWE-bench evaluation expects: list[dict], each element includes instance_id / model_patch / model
predictions: list[dict[str, str]] = []
if dump_dir.exists() and dump_dir.is_dir():
for path in sorted(dump_dir.glob("*.json")):
try:
instance_id = path.stem
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
golden_patch = data.get("golden_patch", {}) if isinstance(data, dict) else {}
model_patch = self._extract_model_patch(golden_patch)
predictions.append(
{
"instance_id": instance_id,
"model_patch": model_patch,
"model_name_or_path": model_name_or_path,
}
)
except Exception:
continue
with open(output_file, "w", encoding="utf-8") as f:
json.dump(predictions, f, ensure_ascii=False, indent=2)
return output_file

35
src/tools/__init__.py Normal file
View File

@@ -0,0 +1,35 @@
"""Tools module for Code Agent."""
from src.tools.base import (
Tool,
ToolCall,
ToolExecutor,
ToolResult,
BASH_TOOL_NAME,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
SEARCH_TOOL_NAME,
SUBMIT_RESULT_TOOL_NAME,
)
from src.tools.bash_tool import BashTool
from src.tools.edit_tool import TextEditorTool
from src.tools.search_tool import SearchTool
from src.tools.submit_result_tool import SubmitResultTool
__all__ = [
"Tool",
"ToolResult",
"ToolCall",
"ToolExecutor",
"BashTool",
"TextEditorTool",
"JSONEditTool",
"SearchTool",
"SubmitResultTool",
]
tools_registry: dict[str, type[Tool]] = {
BASH_TOOL_NAME: BashTool,
STR_REPLACE_BASED_EDIT_TOOL_NAME: TextEditorTool,
SEARCH_TOOL_NAME: SearchTool,
SUBMIT_RESULT_TOOL_NAME: SubmitResultTool,
}

523
src/tools/base.py Normal file
View File

@@ -0,0 +1,523 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
# Copyright (c) 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
# This file has been modified by Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates. on 27 Oct 2025
#
# Original file was released under MIT License, with the full license text
# available at https://github.com/bytedance/trae-agent/blob/main/LICENSE
#
# This modified file is released under the same license.
"""Base classes for tools and tool calling."""
import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import cached_property
from typing import override
from src.managers.log.logger import Logger
from typing import Dict, Any
from traceback import format_exc
from pathlib import Path
ParamSchemaValue = str | list[str] | bool | dict[str, object]
Property = dict[str, ParamSchemaValue]
BASH_TOOL_NAME = "bash"
STR_REPLACE_BASED_EDIT_TOOL_NAME = "str_replace_based_edit_tool"
SEARCH_TOOL_NAME = "search_tool"
SUBMIT_RESULT_TOOL_NAME = "submit_result"
class ToolError(Exception):
"""Base class for tool errors."""
def __init__(self, message: str):
super().__init__(message)
self.message: str = message
@dataclass
class ToolExecResult:
"""Intermediate result of a tool execution."""
output: str | None = None
error: str | None = None
error_code: int = 0
@dataclass
class ToolResult:
"""Result of a tool execution."""
call_id: str
name: str # Gemini specific field
success: bool
result: str | None = None
error: str | None = None
id: str | None = None # OpenAI-specific field
@dataclass
class SubmitToolResult:
"""Structured result for submit_result tool."""
return_code: int
output: str
is_task_done: bool
test_status: str
reasoning: str
def __str__(self) -> str:
"""Convert to JSON string for output."""
import json
return json.dumps(
{
"return_code": self.return_code,
"output": self.output,
"is_task_done": self.is_task_done,
"test_status": self.test_status,
"reasoning": self.reasoning,
}
)
@classmethod
def from_string(cls, json_str: str) -> "SubmitToolResult":
"""Create SubmitToolResult from JSON string."""
import json
data = json.loads(json_str)
return cls(
return_code=data.get("return_code", 0),
output=data.get("output", ""),
is_task_done=data.get("is_task_done", False),
test_status=data.get("test_status", "error"),
reasoning=data.get("reasoning", ""),
)
ToolCallArguments = dict[
str, str | int | float | dict[str, object] | list[object] | None
]
@dataclass
class ToolCall:
"""Represents a parsed tool call."""
name: str
call_id: str
arguments: ToolCallArguments = field(default_factory=dict)
id: str | None = None
@override
def __str__(self) -> str:
return f"ToolCall(name={self.name}, arguments={self.arguments}, call_id={self.call_id}, id={self.id})"
@dataclass
class ToolParameter:
"""Tool parameter definition."""
name: str
type: str | list[str]
description: str
enum: list[str] | None = None
items: dict[str, object] | None = None
required: bool = True
class Tool(ABC):
"""Base class for all tools."""
def __init__(
self,
model_provider: str | None = None,
logger: Logger | None = None,
config: Dict[str, Any] | None = None,
):
self._model_provider = model_provider
self.logger = logger
self.config = config
@cached_property
def model_provider(self) -> str | None:
return self.get_model_provider()
@cached_property
def name(self) -> str:
return self.get_name()
@cached_property
def description(self) -> str:
return self.get_description()
@cached_property
def parameters(self) -> list[ToolParameter]:
return self.get_parameters()
def get_model_provider(self) -> str | None:
"""Get the model provider."""
return self._model_provider
@abstractmethod
def get_name(self) -> str:
"""Get the tool name."""
pass
@abstractmethod
def get_description(self) -> str:
"""Get the tool description."""
pass
@abstractmethod
def get_parameters(self) -> list[ToolParameter]:
"""Get the tool parameters."""
pass
@abstractmethod
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
"""Execute the tool with given parameters."""
pass
# Optional container execution hooks (to be overridden by tools that support containers)
async def container_execute(self, arguments: ToolCallArguments) -> ToolExecResult:
"""Execute the tool inside a container shell (optional)."""
raise ToolError(
f"Tool '{self.get_name()}' does not support container execution"
)
def container_search(
self, arguments: ToolCallArguments, session_id: str = "0"
) -> ToolExecResult:
"""Execute a search-like operation inside container (optional)."""
raise ToolError(f"Tool '{self.get_name()}' does not support container search")
# Optional container file editing hooks used by edit tools
def container_read_file(self, path) -> str:
"""Read a file inside container (optional)."""
raise ToolError(
f"Tool '{self.get_name()}' does not support container_read_file"
)
def container_write_file(self, path, content: str) -> None:
"""Write a file inside container (optional)."""
raise ToolError(
f"Tool '{self.get_name()}' does not support container_write_file"
)
def container_str_replace(
self, path, old_str: str, new_str: str | None
) -> ToolExecResult:
"""String replace inside a file in container (optional)."""
raise ToolError(
f"Tool '{self.get_name()}' does not support container_str_replace"
)
def container_insert(self, path, insert_line: int, new_str: str) -> ToolExecResult:
"""Insert text into a file in container (optional)."""
raise ToolError(f"Tool '{self.get_name()}' does not support container_insert")
def view_handler_container(
self, arguments: ToolCallArguments, path: Path
) -> ToolExecResult:
"""View handler in container (optional)."""
raise ToolError(f"Tool '{self.get_name()}' does not support view_handler_container")
def json_definition(self) -> dict[str, object]:
"""Default return Claude format (backward compatibility)"""
return self._definition_for_claude_fmt()
def _definition_for_claude_fmt(self) -> dict[str, object]:
"""Return Claude format tool definition (Anthropic Messages API)"""
return {
"name": self.name,
"description": self.description,
"input_schema": self.get_input_schema(),
}
def _definition_for_openai_fmt(self) -> dict[str, object]:
"""Return OpenAI format tool definition"""
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.get_input_schema(),
},
}
def get_input_schema(self) -> dict[str, object]:
"""Get the input schema for the tool."""
schema: dict[str, object] = {
"type": "object",
}
properties: dict[str, Property] = {}
required: list[str] = []
for param in self.parameters:
param_schema: Property = {
"type": param.type,
"description": param.description,
}
# For OpenAI strict mode, all params must be in 'required'.
# Optional params are made "nullable" to be compliant.
if self.model_provider == "openai":
required.append(param.name)
if not param.required:
current_type = param_schema["type"]
if isinstance(current_type, str):
param_schema["type"] = [current_type, "null"]
elif isinstance(current_type, list) and "null" not in current_type:
param_schema["type"] = list(current_type) + ["null"]
elif param.required:
required.append(param.name)
if param.enum:
param_schema["enum"] = param.enum
if param.items:
param_schema["items"] = param.items
# For OpenAI, nested objects also need additionalProperties: false
if self.model_provider == "openai" and param.type == "object":
param_schema["additionalProperties"] = False
properties[param.name] = param_schema
schema["properties"] = properties
if len(required) > 0:
schema["required"] = required
# For OpenAI, the top-level schema needs additionalProperties: false
if self.model_provider == "openai":
schema["additionalProperties"] = False
return schema
async def close(self):
"""Ensure proper tool resource deallocation before task completion."""
return None # Using "pass" will trigger a Ruff check error: B027
class ToolExecutor:
"""Tool executor that manages tool execution."""
def __init__(self, tools: list[Tool], logger: Logger | None = None):
self._tools = tools
self._tool_map: dict[str, Tool] | None = None
self.logger = logger
async def close_tools(self):
"""Ensure all tool resources are properly released."""
tasks = [tool.close() for tool in self._tools if hasattr(tool, "close")]
res = await asyncio.gather(*tasks)
return res
def _normalize_name(self, name: str) -> str:
"""Normalize tool name by making it lowercase and removing underscores."""
return name.lower().replace("_", "")
@property
def tools(self) -> dict[str, Tool]:
if self._tool_map is None:
self._tool_map = {
self._normalize_name(tool.name): tool for tool in self._tools
}
return self._tool_map
async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult:
"""Execute a tool call locally."""
normalized_name = self._normalize_name(tool_call.name)
if normalized_name not in self.tools:
return ToolResult(
name=tool_call.name,
success=False,
error=f"Tool '{tool_call.name}' not found. Available tools: {[tool.name for tool in self._tools]}",
call_id=tool_call.call_id,
id=tool_call.id,
)
tool = self.tools[normalized_name]
try:
tool_exec_result = await tool.execute(tool_call.arguments)
return ToolResult(
name=tool_call.name,
success=tool_exec_result.error_code == 0,
result=tool_exec_result.output,
error=tool_exec_result.error,
call_id=tool_call.call_id,
id=tool_call.id,
)
except Exception as e:
return ToolResult(
name=tool_call.name,
success=False,
error=f"Error executing tool '{tool_call.name}': {str(e)}, traceback: {format_exc()}.",
call_id=tool_call.call_id,
id=tool_call.id,
)
async def container_execute_tool_call(self, tool_call: ToolCall) -> ToolResult:
"""Execute a tool call in container."""
normalized_name = self._normalize_name(tool_call.name)
if normalized_name not in self.tools:
self.logger.warning(
f"[ToolExecutor] '{tool_call.name}' not found. Available tools: {[tool.name for tool in self._tools]}"
)
return ToolResult(
name=tool_call.name,
success=False,
error=f"Tool '{tool_call.name}' not found. Available tools: {[tool.name for tool in self._tools]}",
call_id=tool_call.call_id,
id=tool_call.id,
)
tool = self.tools[normalized_name]
try:
tool_exec_result = await self._container_execute_tool_by_name(
tool, tool_call
)
return ToolResult(
name=tool_call.name,
success=tool_exec_result.error_code == 0,
result=tool_exec_result.output,
error=tool_exec_result.error,
call_id=tool_call.call_id,
id=tool_call.id,
)
except Exception as e:
return ToolResult(
name=tool_call.name,
success=False,
error=f"Error executing tool '{tool_call.name}': {str(e)}, traceback: {format_exc()}.",
call_id=tool_call.call_id,
id=tool_call.id,
)
async def _container_execute_tool_by_name(
self, tool: Tool, tool_call: ToolCall
) -> ToolExecResult:
tool_name = tool.get_name()
if tool_name == BASH_TOOL_NAME:
# BashTool: execute through container
if hasattr(tool, "container_execute"):
return await tool.container_execute(tool_call.arguments)
else:
raise ToolError(
f"Tool '{tool_name}' does not support container execution"
)
elif tool_name == STR_REPLACE_BASED_EDIT_TOOL_NAME:
# TextEditorTool: execute through container
if hasattr(tool, "container_read_file"):
return await self._execute_edit_tool_in_container(
tool, tool_call.arguments
)
else:
raise ToolError(
f"Tool '{tool_name}' does not support container execution"
)
elif tool_name == SEARCH_TOOL_NAME:
# SearchTool: execute through container
if hasattr(tool, "container_search"):
return tool.container_search(tool_call.arguments)
else:
raise ToolError(
f"Tool '{tool_name}' does not support container execution"
)
elif tool_name == SUBMIT_RESULT_TOOL_NAME:
# SubmitResultTool: execute through container
if hasattr(tool, "container_execute"):
return await tool.container_execute(tool_call.arguments)
else:
raise ToolError(
f"Tool '{tool_name}' does not support container execution"
)
else:
# Other toolscontainer execution not supported
raise ToolError(f"Tool '{tool_name}' does not support container execution")
async def _execute_edit_tool_in_container(
self, tool: Tool, arguments: ToolCallArguments
) -> ToolExecResult:
command = str(arguments.get("command", ""))
path_str = str(arguments.get("path", ""))
if not path_str:
return ToolExecResult(
error="No path provided for the edit tool", error_code=-1
)
from pathlib import Path
path = Path(path_str)
try:
if command == "view":
return tool.view_handler_container(arguments, path)
#return ToolExecResult(output=tool._make_output(content, str(path)))
elif command == "create":
file_text = str(arguments.get("file_text", ""))
tool.container_write_file(path, file_text)
return ToolExecResult(output=f"File created successfully at: {path}")
elif command == "str_replace":
old_str = str(arguments.get("old_str", ""))
new_str = arguments.get("new_str")
if new_str is not None:
new_str = str(new_str)
return tool.container_str_replace(path, old_str, new_str)
elif command == "insert":
insert_line = int(arguments.get("insert_line", 0))
new_str = str(arguments.get("new_str", ""))
return tool.container_insert(path, insert_line, new_str)
else:
return ToolExecResult(
error=f"Unsupported command '{command}' for container execution",
error_code=-1,
)
except Exception as e:
return ToolExecResult(
error=f"Container edit tool error: {str(e)}.", error_code=-1
)
async def parallel_tool_call(self, tool_calls: list[ToolCall]) -> list[ToolResult]:
"""Execute tool calls in parallel locally"""
return await asyncio.gather(
*[self.execute_tool_call(call) for call in tool_calls]
)
async def sequential_tool_call(
self, tool_calls: list[ToolCall]
) -> list[ToolResult]:
"""Execute tool calls in sequential locally"""
return [await self.execute_tool_call(call) for call in tool_calls]
async def container_parallel_tool_call(
self, tool_calls: list[ToolCall]
) -> list[ToolResult]:
"""Execute tool calls in parallel in container"""
return await asyncio.gather(
*[self.container_execute_tool_call(call) for call in tool_calls]
)
async def container_sequential_tool_call(
self, tool_calls: list[ToolCall]
) -> list[ToolResult]:
"""Execute tool calls in sequential in container"""
return [await self.container_execute_tool_call(call) for call in tool_calls]

314
src/tools/bash_tool.py Normal file
View File

@@ -0,0 +1,314 @@
# Copyright (c) 2023 Anthropic
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
# Copyright (c) 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
# This file has been modified by Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates. on 27 Oct 2025
#
# Original file was released under MIT License, with the full license text
# available at https://github.com/anthropics/anthropic-quickstarts/blob/main/LICENSE
# and https://github.com/bytedance/trae-agent/blob/main/LICENSE
#
# This modified file is released under the same license.
import asyncio
import os
from typing import override
from src.tools.base import (
Tool,
ToolCallArguments,
ToolError,
ToolExecResult,
ToolParameter,
BASH_TOOL_NAME,
)
from src.tools.executor import Executor
from src.managers.log.logger import Logger
from typing import Dict, Any
from traceback import format_exc
class _BashSession:
"""A session of a bash shell."""
_started: bool
_timed_out: bool
command: str = "/bin/bash"
_output_delay: float = 0.2 # seconds
_timeout: float = 120.0 # seconds
_sentinel: str = (
",,,,bash-command-exit-__ERROR_CODE__-banner,,,," # `__ERROR_CODE__` will be replaced by `$?` or `!errorlevel!` later
)
def __init__(self) -> None:
self._started = False
self._timed_out = False
self._process: asyncio.subprocess.Process | None = None
async def start(self) -> None:
if self._started:
return
# Windows compatibility: os.setsid not available
if os.name != "nt": # Unix-like systems
self._process = await asyncio.create_subprocess_shell(
self.command,
shell=True,
bufsize=0,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
preexec_fn=os.setsid,
)
else:
self._process = await asyncio.create_subprocess_shell(
"cmd.exe /v:on", # enable delayed expansion to allow `echo !errorlevel!`
shell=True,
bufsize=0,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
self._started = True
async def stop(self) -> None:
"""Terminate the bash shell."""
if not self._started:
raise ToolError("Session has not started.")
if self._process is None:
return
if self._process.returncode is not None:
return
self._process.terminate()
# Wait until the process has truly terminated.
stdout, stderr = await self._process.communicate()
async def run(self, command: str) -> ToolExecResult:
"""Execute a command in the bash shell."""
if not self._started or self._process is None:
raise ToolError("Session has not started.")
if self._process.returncode is not None:
return ToolExecResult(
error=f"bash has exited with returncode {self._process.returncode}. tool must be restarted.",
error_code=-1,
)
if self._timed_out:
raise ToolError(
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
)
# we know these are not None because we created the process with PIPEs
assert self._process.stdin
assert self._process.stdout
assert self._process.stderr
error_code = 0
sentinel_before, pivot, sentinel_after = self._sentinel.partition(
"__ERROR_CODE__"
)
assert pivot == "__ERROR_CODE__"
errcode_retriever = "!errorlevel!" if os.name == "nt" else "$?"
command_sep = "&" if os.name == "nt" else ";"
# send command to the process
self._process.stdin.write(
b"(\n"
+ command.encode()
+ f"\n){command_sep} echo {self._sentinel.replace('__ERROR_CODE__', errcode_retriever)}\n".encode()
)
await self._process.stdin.drain()
# read output from the process, until the sentinel is found
try:
async with asyncio.timeout(self._timeout):
while True:
await asyncio.sleep(self._output_delay)
# if we read directly from stdout/stderr, it will wait forever for
# EOF. use the StreamReader buffer directly instead.
output: str = self._process.stdout._buffer.decode() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType, reportUnknownVariableType]
if sentinel_before in output:
# strip the sentinel from output
output, pivot, exit_banner = output.rpartition(sentinel_before)
assert pivot
# get error code inside banner
error_code_str, pivot, _ = exit_banner.partition(sentinel_after)
if not pivot or not error_code_str.isdecimal():
continue
error_code = int(error_code_str)
break
except asyncio.TimeoutError:
self._timed_out = True
raise ToolError(
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
) from None
if output.endswith("\n"): # pyright: ignore[reportUnknownMemberType]
output = output[:-1] # pyright: ignore[reportUnknownVariableType]
error: str = self._process.stderr._buffer.decode() # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType, reportAttributeAccessIssue]
if error.endswith("\n"): # pyright: ignore[reportUnknownMemberType]
error = error[:-1] # pyright: ignore[reportUnknownVariableType]
# clear the buffers so that the next output can be read correctly
self._process.stdout._buffer.clear() # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
self._process.stderr._buffer.clear() # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
return ToolExecResult(
output=output, error=error, error_code=error_code
) # pyright: ignore[reportUnknownArgumentType]
class BashTool(Tool):
"""
A tool that allows the agent to run bash commands.
The tool parameters are defined by Anthropic and are not editable.
"""
def __init__(
self,
model_provider: str | None = None,
executor: Executor | None = None,
logger: Logger | None = None,
config: Dict[str, Any] | None = None,
):
super().__init__(model_provider, logger, config)
self._session: _BashSession | None = None
self.executor = executor
@override
def get_model_provider(self) -> str | None:
return self._model_provider
@override
def get_name(self) -> str:
return BASH_TOOL_NAME
@override
def get_description(self) -> str:
return """Execute commands within a bash shell environment, either on the local system or inside a container.
* When providing the "command" parameter, its contents must be provided as-is without any XML escaping.
* You have access to a mirrored repository of common Linux (via apt) and Python (via pip) packages for installation.
* State is persisted across all command executions and throughout our conversation session.
* Avoid executing commands that are likely to generate excessively large outputs.
* Avoid executing interactive commands that require user input (e.g., password prompts, confirmation messages).
* For Git commands, always prefer non-interactive forms. For example, use git --no-pager diff instead of git diff to prevent opening a pager.
* To inspect a specific range of lines in a file (e.g., lines 5-10), you can use a command like: sed -n '5,10p' /path/to/file
"""
@override
def get_parameters(self) -> list[ToolParameter]:
# For OpenAI models, all parameters must be required=True
# For other providers, optional parameters can have required=False
restart_required = self.model_provider == "openai"
return [
ToolParameter(
name="command",
type="string",
description="The exact bash command string to be executed.",
required=True,
),
ToolParameter(
name="restart",
type="boolean",
description="If true, terminates the current shell session and starts a new one before executing the command. This clears the session state.",
required=restart_required,
),
]
@override
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
if arguments.get("restart"):
if self._session:
await self._session.stop()
self._session = _BashSession()
await self._session.start()
return ToolExecResult(output="tool has been restarted.")
if self._session is None:
try:
self._session = _BashSession()
await self._session.start()
except Exception as e:
return ToolExecResult(
error=f"Error starting bash session: {e}",
error_code=-1,
)
command = str(arguments["command"]) if "command" in arguments else None
if command is None:
return ToolExecResult(
error=f"No command provided for the {self.get_name()} tool",
error_code=-1,
)
try:
return await self._session.run(command)
except Exception as e:
return ToolExecResult(
error=f"Error running bash command: {e}",
error_code=-1,
)
async def container_execute(
self, arguments: ToolCallArguments, session_id: str = "0"
) -> ToolExecResult:
"""Execute a command in a container bash shell."""
if not self.executor:
return ToolExecResult(
error="Container execution requires an executor to be provided during tool initialization",
error_code=-1,
)
if arguments.get("restart"):
# Close the existing session if it exists
self.executor.close_session("0")
# The executor will automatically recreate session '0' when needed
return ToolExecResult(output="Container session has been restarted.")
command = str(arguments["command"]) if "command" in arguments else None
if command is None:
return ToolExecResult(
error=f"No command provided for container execution",
error_code=-1,
)
# command_with_init = f"source /opt/miniconda3/bin/activate && conda activate testbed && {command}"
# Check if the session is alive before executing the command
if not self.executor.check_session():
return ToolExecResult(
error="Container session is not alive and could not be restarted",
error_code=-1,
)
try:
return_code, output = self.executor.execute(session_id, command)
# return_code, output = self.executor.execute_once(command_with_init)
# The executor returns (return_code, output) tuple
# We'll treat any non-zero return code as an error
error = None
if return_code != 0:
error = f"Command failed with exit code {return_code}, output: {output}"
return ToolExecResult(output=output, error=error, error_code=return_code)
except Exception as e:
return ToolExecResult(
error=f"Error running container bash command: {e}", error_code=-1
)
@override
async def close(self):
"""Properly close self._process."""
if self._session:
await self._session.stop()
self._session = None

735
src/tools/edit_tool.py Normal file
View File

@@ -0,0 +1,735 @@
# Copyright (c) 2023 Anthropic
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
# Copyright (c) 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
# This file has been modified by Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates. on 27 Oct 2025
#
# Original file was released under MIT License, with the full license text
# available at https://github.com/anthropics/anthropic-quickstarts/blob/main/LICENSE
# and https://github.com/bytedance/trae-agent/blob/main/LICENSE
#
# This modified file is released under the same license.
import os
from pathlib import Path
import tempfile
from typing import Optional, override
import shlex
from src.tools.base import (
Tool,
ToolCallArguments,
ToolError,
ToolExecResult,
ToolParameter,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
)
from src.tools.run import maybe_truncate, run
from src.tools.executor import Executor
from src.managers.log.logger import Logger
from typing import Dict, Any
from traceback import format_exc
EditToolSubCommands = [
"view",
"create",
"str_replace",
"insert",
]
SNIPPET_LINES: int = 4
class TextEditorTool(Tool):
"""Tool to replace a string in a file."""
def __init__(
self,
model_provider: str | None = None,
executor: Executor | None = None,
logger: Logger | None = None,
config: Dict[str, Any] | None = None,
) -> None:
super().__init__(model_provider, logger, config)
self.executor = executor
@override
def get_model_provider(self) -> str | None:
return self._model_provider
@override
def get_name(self) -> str:
return STR_REPLACE_BASED_EDIT_TOOL_NAME
@override
def get_description(self) -> str:
return """This tool provides capabilities for viewing, creating and editing files
* This tool is stateless. No context is retained between individual command invocations.
* Content Examination `view`:
** For a file: Executing `view` on a file path will output the file's full content with sequential line numbers prefixed (using cat -n).
** For a directory: Executing view on a directory path will recursively list all non-hidden items, displaying contents up to two directory levels deep.
* File Creation `create`:
** The `create` operation is strictly prohibited if a file already exists at the specified `path`.
** Mandatory Pre-action: You must explicitly remove any existing file at the target `path` before proceeding with the creation of a new file.
* Output Handling:
** Should the output generated by a `command` exceed a certain length threshold, it will be automatically shortened and clearly marked with the indicator: <response clipped>.
* String Replacement `str_replace` Operational Rules:
** Precision Targeting: The `old_str` parameter must be an exact, character-for-character match of one or more complete lines from the source file. Special attention must be paid to invisible characters like spaces and tabs.
** Match Uniqueness: The replacement will be canceled if the specified `old_str` pattern is not absolutely unique within the file. To ensure a single match, expand the `old_str` scope to include sufficient preceding or following context lines.
** Content Insertion: The `new_str` parameter defines the complete set of lines that will be inserted into the file, directly replacing the content matched by `old_str`.
"""
@override
def get_parameters(self) -> list[ToolParameter]:
"""Get the parameters for the str_replace_based_edit_tool."""
return [
ToolParameter(
name="command",
type="string",
description=f"Operation to execute. Supported commands: {', '.join(EditToolSubCommands)}.",
required=True,
enum=EditToolSubCommands,
),
ToolParameter(
name="file_text",
type="string",
description="Required for `create` command. Specifies the textual content for the new file.",
required=False,
),
ToolParameter(
name="insert_line",
type="integer",
description="Required for `insert` command. The line number AFTER which the `new_str` will be inserted.",
required=False,
),
ToolParameter(
name="new_str",
type="string",
description="For `str_replace`: the replacement text (optional, defaults to empty). For `insert`: the text to insert (required).",
required=False,
),
ToolParameter(
name="old_str",
type="string",
description="Required for `str_replace` command. The exact text segment in the file to be replaced.",
required=False,
),
ToolParameter(
name="path",
type="string",
description="Absolute filesystem path to the target file or directory. Example: `/workspace/script.py` or `/workspace`.",
required=True,
),
ToolParameter(
name="view_range",
type="array",
description="Optional for `view` command on files. Defines the line range to display. Examples: `[5, 10]` shows lines 5-10; `[15, -1]` shows from line 15 to EOF. Line numbering starts at 1.",
items={"type": "integer"},
required=False,
),
]
@override
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
"""Execute the str_replace_editor tool."""
command = str(arguments["command"]) if "command" in arguments else None
if command is None:
return ToolExecResult(
error=f"No command provided for the {self.get_name()} tool",
error_code=-1,
)
path = str(arguments["path"]) if "path" in arguments else None
if path is None:
return ToolExecResult(
error=f"No path provided for the {self.get_name()} tool", error_code=-1
)
_path = Path(path)
try:
self.validate_path(command, _path)
match command:
case "view":
return await self._view_handler(arguments, _path)
case "create":
return self._create_handler(arguments, _path)
case "str_replace":
return self._str_replace_handler(arguments, _path)
case "insert":
return self._insert_handler(arguments, _path)
case _:
return ToolExecResult(
error=f"Unrecognized command {command}. The allowed commands for the {self.name} tool are: {', '.join(EditToolSubCommands)}",
error_code=-1,
)
except ToolError as e:
return ToolExecResult(error=str(e), error_code=-1)
def validate_path(self, command: str, path: Path):
"""Validate the path for the str_replace_editor tool."""
if not path.is_absolute():
suggested_path = Path("/") / path
raise ToolError(
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
)
# Check if path exists
if not path.exists() and command != "create":
raise ToolError(
f"The path {path} does not exist. Please provide a valid path."
)
if path.exists() and command == "create":
raise ToolError(
f"File already exists at: {path}. Cannot overwrite files using command `create`."
)
# Check if the path points to a directory
if path.is_dir() and command != "view":
raise ToolError(
f"The path {path} is a directory and only the `view` command can be used on directories"
)
async def _view(
self, path: Path, view_range: list[int] | None = None
) -> ToolExecResult:
"""Implement the view command"""
if path.is_dir():
if view_range:
raise ToolError(
"The `view_range` parameter is not allowed when `path` points to a directory."
)
return_code, stdout, stderr = await run(
rf"find {path} -maxdepth 2 -not -path '*/\.*'"
)
if not stderr:
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
return ToolExecResult(error_code=return_code, output=stdout, error=stderr)
file_content = self.read_file(path)
init_line = 1
if view_range:
if len(view_range) != 2 or not all(
isinstance(i, int) for i in view_range
): # pyright: ignore[reportUnnecessaryIsInstance]
raise ToolError(
"Invalid `view_range`. It should be a list of two integers."
)
file_lines = file_content.split("\n")
n_lines_file = len(file_lines)
init_line, final_line = view_range
if init_line < 1 or init_line > n_lines_file:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
)
if final_line > n_lines_file:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
)
if final_line != -1 and final_line < init_line:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
)
if final_line == -1:
file_content = "\n".join(file_lines[init_line - 1 :])
else:
file_content = "\n".join(file_lines[init_line - 1 : final_line])
return ToolExecResult(
output=self._make_output(file_content, str(path), init_line=init_line)
)
def _view_container(
self, path: Path, view_range: list[int] | None = None
) -> ToolExecResult:
"""Implement the view command"""
if path.is_dir():
raise ToolError("The `path` parameter is not allowed be a directory.")
file_content = self.container_read_file(path)
init_line = 1
make_out_max_lines = None
if view_range:
if len(view_range) != 2 or not all(
isinstance(i, int) for i in view_range
): # pyright: ignore[reportUnnecessaryIsInstance]
raise ToolError(
"Invalid `view_range`. It should be a list of two integers."
)
file_lines = file_content.split("\n")
n_lines_file = len(file_lines)
init_line, final_line = view_range
# Initial line must start from 1, initial line cannot be greater than max line of file
if init_line < 1 or init_line > n_lines_file:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
)
# When the end line takes effect, the end line cannot be less than the start line.
if final_line != -1 and final_line < init_line:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
)
if final_line == -1:
file_content = "\n".join(file_lines[init_line - 1 :])
elif final_line > n_lines_file:
file_content = "\n".join(file_lines[init_line - 1 : n_lines_file])
make_out_max_lines = n_lines_file
pass
else:
file_content = "\n".join(file_lines[init_line - 1 : final_line])
return ToolExecResult(
output=self._make_output(
file_content,
str(path),
init_line=init_line,
max_lines=make_out_max_lines,
)
)
def str_replace(
self, path: Path, old_str: str, new_str: str | None
) -> ToolExecResult:
"""Implement the str_replace command, which replaces old_str with new_str in the file content"""
# Read the file content
file_content = self.read_file(path).expandtabs()
old_str = old_str.expandtabs()
new_str = new_str.expandtabs() if new_str is not None else ""
# Check if old_str is unique in the file
occurrences = file_content.count(old_str)
if occurrences == 0:
raise ToolError(
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
)
elif occurrences > 1:
file_content_lines = file_content.split("\n")
lines = [
idx + 1
for idx, line in enumerate(file_content_lines)
if old_str in line
]
raise ToolError(
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
)
# Replace old_str with new_str
new_file_content = file_content.replace(old_str, new_str)
# Write the new content to the file
self.write_file(path, new_file_content)
# Create a snippet of the edited section
replacement_line = file_content.split(old_str)[0].count("\n")
start_line = max(0, replacement_line - SNIPPET_LINES)
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
# Prepare the success message
success_msg = f"The file {path} has been edited. "
success_msg += self._make_output(
snippet, f"a snippet of {path}", start_line + 1
)
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
return ToolExecResult(
output=success_msg,
)
def _insert(self, path: Path, insert_line: int, new_str: str) -> ToolExecResult:
"""Implement the insert command, which inserts new_str at the specified line in the file content."""
file_text = self.read_file(path).expandtabs()
new_str = new_str.expandtabs()
file_text_lines = file_text.split("\n")
n_lines_file = len(file_text_lines)
if insert_line < 0 or insert_line > n_lines_file:
raise ToolError(
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
)
new_str_lines = new_str.split("\n")
new_file_text_lines = (
file_text_lines[:insert_line]
+ new_str_lines
+ file_text_lines[insert_line:]
)
snippet_lines = (
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
+ new_str_lines
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
)
new_file_text = "\n".join(new_file_text_lines)
snippet = "\n".join(snippet_lines)
self.write_file(path, new_file_text)
success_msg = f"The file {path} has been edited. "
success_msg += self._make_output(
snippet,
"a snippet of the edited file",
max(1, insert_line - SNIPPET_LINES + 1),
)
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
return ToolExecResult(
output=success_msg,
)
# Note: undo_edit method is not implemented in this version as it was removed
def read_file(self, path: Path):
"""Read the content of a file from a given path; raise a ToolError if an error occurs."""
try:
return path.read_text()
except Exception as e:
if self.logger:
self.logger.error(
f"In edit_tool, read_file command error, ran into {e} while trying to read {path}, traceback: {format_exc()}."
)
raise ToolError(f"Ran into {e} while trying to read {path}.") from None
def write_file(self, path: Path, file: str):
"""Write the content of a file to a given path; raise a ToolError if an error occurs."""
try:
_ = path.write_text(file)
except Exception as e:
if self.logger:
self.logger.error(
f"In edit_tool, write_file command error, ran into {e} while trying to write to {path}, traceback: {format_exc()}."
)
raise ToolError(f"Ran into {e} while trying to write to {path}.") from None
def container_read_file(self, path: Path, session_id: str = "0") -> str:
"""Read the content of a file from a container using cat command."""
if not self.executor:
raise ToolError("No executor provided for container operations")
try:
# Check if session is alive and restart if needed
if not self.executor.check_session(session_id):
raise ToolError(
"Container session is not alive and could not be restarted"
)
# Use cat command to read file content
command = f"cat {path}"
# return_code, output = self.executor.execute(session_id, command)
return_code, output = self.executor.execute_once(command)
if return_code != 0:
raise ToolError(
f"Failed to read file {path} from container. Exit code: {return_code}, Output: {output}"
)
# Clean the output by removing only the command echo, preserving file content exactly
# lines = output.split("\n")
# Remove the first line if it contains the command echo
# if lines and f"cat {path}" in lines[0]:
# lines = lines[2:-1]
final = output[:-1] if output.endswith("\n") else output
# return "\n".join(lines)
return final
except Exception as e:
if self.logger:
self.logger.error(
f"In edit_tool, container_read_file command error, ran into {e} while trying to read {path} from container, traceback: {format_exc()}."
)
raise ToolError(
f"Ran into {e} while trying to read {path} from container."
) from None
def container_write_file(
self, path: Path, content: str, session_id: str = "0"
) -> None:
"""Write content to a file in a container using cat with here document."""
if not self.executor:
raise ToolError("No executor provided for container operations")
try:
# Check if session is alive and restart if needed
if not self.executor.check_session():
raise ToolError(
"Container session is not alive and could not be restarted"
)
# 先创建目录
return_code, output = self.executor.execute_once(f"mkdir -p {path.parent}")
if return_code != 0:
raise ToolError(
f"Failed to create dir {path.parent} in container. Exit code: {return_code}, Output: {output}"
)
with tempfile.NamedTemporaryFile(
mode="w+", delete=False, encoding="utf-8"
) as temp_file:
temp_file.write(content)
temp_file_path = temp_file.name
return_code, output = self.executor.cpfile_host_to_container(
temp_file_path, path
)
os.remove(temp_file_path)
if return_code != 0:
raise ToolError(
f"Failed to write to file {path} in container. Exit code: {return_code}, Output: {output}"
)
except Exception as e:
if self.logger:
self.logger.error(
f"In edit_tool, container_write_file command error, ran into {e} while trying to write to {path} in container, traceback: {format_exc()}."
)
raise ToolError(
f"Ran into {e} while trying to write to {path} in container."
) from None
def container_str_replace(
self, path: Path, old_str: str, new_str: str | None, session_id: str = "0"
) -> ToolExecResult:
"""Replace old_str with new_str in a file in a container using sed command."""
if not self.executor:
raise ToolError("No executor provided for container operations")
try:
# Check if session is alive and restart if needed
if not self.executor.check_session():
raise ToolError(
"Container session is not alive and could not be restarted"
)
# First, read the file to check if old_str exists
file_content = self.container_read_file(path, session_id)
# Check if old_str is unique in the file
occurrences = file_content.count(old_str)
if occurrences == 0:
raise ToolError(
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
)
elif occurrences > 1:
# Here the calculation is wrong, old_str could be a multi-line text, cannot calculate all locations through the code below
# file_content_lines = file_content.split("\n")
# lines = [
# idx + 1
# for idx, line in enumerate(file_content_lines)
# if old_str in line
# ]
raise ToolError(
f"No replacement was performed. Multiple occurrences of old_str `{old_str}`. Total occurrences: {occurrences}. Please ensure it is unique"
)
updated_content = file_content.replace(old_str, new_str)
self.container_write_file(path=path, content=updated_content)
# Read the file to show a snippet of the changes
try:
file_content = self.container_read_file(path, session_id)
# Create a simple snippet showing the change
lines = file_content.split("\n")
snippet_lines = lines[: min(10, len(lines))] # Show first 10 lines
snippet = "\n".join(snippet_lines)
success_msg = f"The file {path} has been edited in container. "
success_msg += self._make_output(
snippet, f"a snippet of {path}", init_line=1
)
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
return ToolExecResult(output=success_msg)
except Exception:
# If we can't read the file for snippet, just return success
return ToolExecResult(
output=f"Successfully replaced string in file {path} in container."
)
except Exception as e:
if self.logger:
self.logger.error(
f"In edit_tool, container_str_replace command error, ran into {e} while trying to replace string in {path} in container, traceback: {format_exc()}."
)
raise ToolError(
f"Ran into {e} while trying to replace string in {path} in container."
) from None
def container_insert(
self, path: Path, insert_line: int, new_str: str, session_id: str = "0"
) -> ToolExecResult:
if not self.executor:
raise ToolError("No executor provided for container operations")
try:
if not self.executor.check_session():
raise ToolError(
"Container session is not alive and could not be restarted"
)
file_content = self.container_read_file(path, session_id)
file_text_lines = file_content.split("\n")
n_lines_file = len(file_text_lines)
if insert_line < 0 or insert_line > n_lines_file:
raise ToolError(
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
)
new_str_lines = new_str.split("\n")
new_file_text_lines = (
file_text_lines[:insert_line]
+ new_str_lines
+ file_text_lines[insert_line:]
)
snippet_lines = (
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
+ new_str_lines
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
)
new_file_text = "\n".join(new_file_text_lines)
snippet = "\n".join(snippet_lines)
self.container_write_file(path, new_file_text, session_id)
success_msg = f"The file {path} has been edited in container. "
success_msg += self._make_output(
snippet,
"a snippet of the edited file",
max(1, insert_line - SNIPPET_LINES + 1),
)
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
return ToolExecResult(
output=success_msg,
)
except Exception as e:
if self.logger:
self.logger.error(
f"In edit_tool, container_insert command error, ran into {e} while trying to insert content in {path} in container, traceback: {format_exc()}."
)
raise ToolError(
f"Ran into {e} while trying to insert content in {path} in container."
) from None
def _escape_sed(self, text: str) -> str:
"""Escape special characters in text for use with sed command."""
# Escape sed special characters: / \ &
escaped = text.replace("\\", "\\\\") # Escape backslashes first
escaped = escaped.replace("/", "\\/") # Escape forward slashes
escaped = escaped.replace("&", "\\&") # Escape ampersands
escaped = escaped.replace("\n", "\\n") # Handle newlines
escaped = escaped.replace("\t", "\\t") # Handle tabs
return escaped
def _make_output(
self,
file_content: str,
file_descriptor: str,
init_line: int = 1,
expand_tabs: bool = True,
max_lines: Optional[int] = None, # 文件最大行如果不为None则在返回中提示
):
"""Generate output for the CLI based on the content of a file."""
file_content = maybe_truncate(file_content)
if expand_tabs:
file_content = file_content.expandtabs()
file_content = "\n".join(
[
f"{i + init_line:6}\t{line}"
for i, line in enumerate(file_content.split("\n"))
]
)
if max_lines:
return (
f"Here's the result of running `cat -n` on {file_descriptor}(The file is only {max_lines} lines):\n"
+ file_content
+ "\n"
)
else:
return (
f"Here's the result of running `cat -n` on {file_descriptor}:\n"
+ file_content
+ "\n"
)
async def _view_handler(
self, arguments: ToolCallArguments, _path: Path
) -> ToolExecResult:
view_range = arguments.get("view_range", None)
if view_range is None:
return await self._view(_path, None)
if not (
isinstance(view_range, list) and all(isinstance(i, int) for i in view_range)
):
return ToolExecResult(
error="Parameter `view_range` should be a list of integers.",
error_code=-1,
)
view_range_int: list[int] = [i for i in view_range if isinstance(i, int)]
return await self._view(_path, view_range_int)
def view_handler_container(
self, arguments: ToolCallArguments, path: Path
) -> ToolExecResult:
view_range = arguments.get("view_range", None)
if view_range is None:
return self._view_container(path, None)
if not (
isinstance(view_range, list) and all(isinstance(i, int) for i in view_range)
):
return ToolExecResult(
error="Parameter `view_range` should be a list of integers.",
error_code=-1,
)
view_range_int: list[int] = [i for i in view_range if isinstance(i, int)]
return self._view_container(path, view_range_int)
def _create_handler(
self, arguments: ToolCallArguments, _path: Path
) -> ToolExecResult:
file_text = arguments.get("file_text", None)
if not isinstance(file_text, str):
return ToolExecResult(
error="Parameter `file_text` is required and must be a string for command: create",
error_code=-1,
)
self.write_file(_path, file_text)
return ToolExecResult(output=f"File created successfully at: {_path}")
def _str_replace_handler(
self, arguments: ToolCallArguments, _path: Path
) -> ToolExecResult:
old_str = arguments.get("old_str") if "old_str" in arguments else None
if not isinstance(old_str, str):
return ToolExecResult(
error="Parameter `old_str` is required and should be a string for command: str_replace",
error_code=-1,
)
new_str = arguments.get("new_str") if "new_str" in arguments else None
if not (new_str is None or isinstance(new_str, str)):
return ToolExecResult(
error="Parameter `new_str` should be a string or null for command: str_replace",
error_code=-1,
)
return self.str_replace(_path, old_str, new_str)
def _insert_handler(
self, arguments: ToolCallArguments, _path: Path
) -> ToolExecResult:
insert_line = (
arguments.get("insert_line") if "insert_line" in arguments else None
)
if not isinstance(insert_line, int):
return ToolExecResult(
error="Parameter `insert_line` is required and should be integer for command: insert",
error_code=-1,
)
new_str_to_insert = arguments.get("new_str") if "new_str" in arguments else None
if not isinstance(new_str_to_insert, str):
return ToolExecResult(
error="Parameter `new_str` is required for command: insert",
error_code=-1,
)
return self._insert(_path, insert_line, new_str_to_insert)

198
src/tools/executor.py Normal file
View File

@@ -0,0 +1,198 @@
import subprocess
import uuid
import docker
import pexpect
import re
from docker.errors import DockerException, ImageNotFound, NotFound
from src.managers.log.logger import Logger
class Executor:
def __init__(self, image: str, logger: Logger | None = None):
self.image = image
self.container = None
self.sessions: dict[str, pexpect.spawn] = {}
self.client = docker.from_env()
self.logger = logger
try:
self.client.images.get(self.image)
except ImageNotFound:
raise DockerException(
f"Image '{self.image}' not found. Please build the image first."
)
try:
self.container = self.client.containers.run(
self.image,
command="sleep infinity",
detach=True,
working_dir="/workspace",
)
self.logger.info(f"Created container {self.container.id}")
except DockerException as e:
raise DockerException(
f"Failed to create container with image '{self.image}': {e}"
)
session_id = self.init_session()
if session_id is None:
raise DockerException("Failed to initialize default session")
if session_id in self.sessions:
self.sessions["0"] = self.sessions.pop(session_id)
def init_session(self) -> str:
session_id = str(uuid.uuid4())
command = f"docker exec -it {self.container.id} /bin/bash"
for attempt in range(3): # Retry up to 3 times
try:
shell = pexpect.spawn(command, encoding="utf-8", timeout=120)
shell.expect([r"\$.*", r"#.*"], timeout=120)
# Source conda and activate testbed environment
shell.sendline("source /opt/miniconda3/bin/activate")
shell.expect([r"\$.*", r"#.*"], timeout=30)
shell.sendline("conda activate testbed")
shell.expect([r"\$.*", r"#.*"], timeout=30)
shell.sendline("export NO_COLOR=1 && export PAGER=cat")
shell.expect([r"\$.*", r"#.*"], timeout=30)
# Verify conda environment is alive by checking the full output
# The output should contain (testbed) if the environment is activated
# We can check this by looking at the full output from the conda activate command
output = shell.before
if "(testbed)" not in output:
# Environment not properly activated, retry
if attempt < 2: # Not the last attempt
shell.close(force=True)
continue
else:
shell.close(force=True)
raise DockerException(
"Failed to activate conda environment 'testbed' after 3 attempts"
)
self.sessions[session_id] = shell
return session_id
except pexpect.exceptions.TIMEOUT:
if attempt < 2: # Not the last attempt
if "shell" in locals() and shell.isalive():
shell.close(force=True)
continue
else:
return None
except Exception as e:
if attempt < 2: # Not the last attempt
if "shell" in locals() and shell.isalive():
shell.close(force=True)
continue
else:
raise DockerException(
f"Failed to initialize session after 3 attempts: {e}"
)
return None
def execute(
self, session_id: str, command: str, timeout: int = 300
) -> tuple[int, str]:
shell = self.sessions.get(session_id)
if not shell or not shell.isalive():
return -1, "Session not found or is dead."
full_command = command.strip()
shell.sendline(full_command)
marker = f"---CMD_DONE---"
marker_command = f"echo {marker}$?"
shell.sendline(marker_command)
try:
shell.expect(marker + r"(\d+).*[\n](.*)", timeout=timeout)
except pexpect.exceptions.TIMEOUT:
return (
-1,
f"Error: Command '{command}' timed out after {timeout} seconds. Partial output:\n{shell.before}",
)
exit_code = int(shell.match.group(1))
p = str(shell.match.group(2))
all_lines: str = p + shell.before
# delete all \r
all_lines = re.sub(r"\r", "", all_lines)
# Remove some non-color-related terminal control characters.
# \x1b[?2004h - tell terminal to activate special paste process
# \x1b[?2004l - tell terminal to activate special paste process
all_lines = re.sub(r"\x1B\[\?2004[l|h]", "", all_lines)
# Strip the last line's echo.
all_lines = re.sub(r"\n[^\n]+---CMD_DONE---.*", "", all_lines)
# self.logger.info(f"'{[all_lines]}'")
return exit_code, all_lines
def execute_once(self, command: str, timeout: int = 300) -> tuple[int, str]:
# cmd = ["docker", "exec", self.container.id, "bash", "-c", command]
cmd = ["docker", "exec", "-i", self.container.id, "bash", "-s"]
sub = subprocess.run(
cmd,
encoding="utf-8",
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
input=f"{command}\n",
)
if sub.returncode != 0:
return sub.returncode, sub.stderr
return sub.returncode, sub.stdout
def cpfile_host_to_container(self, source: str, dest: str) -> tuple[int, str]:
cmd = ["docker", "cp", source, f"{self.container.id}:{dest}"]
sub = subprocess.run(
cmd,
encoding="utf-8",
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
self.execute_once(f"chmod 0777 {dest}")
if sub.returncode != 0:
return sub.returncode, sub.stderr
return sub.returncode, sub.stdout
def check_session(self, session_id: str = "0") -> bool:
"""
Check whether the current '0' session is alive and restart it if not.
"""
if session_id in self.sessions:
session = self.sessions[session_id]
if session and session.isalive():
return True
else:
self.sessions.pop(session_id)
new_session_id = self.init_session()
if new_session_id is None:
return False
if new_session_id != session_id:
self.sessions[session_id] = self.sessions.pop(new_session_id)
return True
def close_session(self, session_id: str):
if session_id in self.sessions:
session = self.sessions.pop(session_id)
if session and session.isalive():
session.close(force=True)
# Session not found - this is not an error condition
def shutdown(self):
for session_id in list(self.sessions.keys()):
self.close_session(session_id)
if self.container:
try:
self.container.stop()
self.container.remove()
except DockerException as e:
pass # Silently handle cleanup errors
self.container = None

57
src/tools/run.py Normal file
View File

@@ -0,0 +1,57 @@
# Copyright (c) 2023 Anthropic
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
# Copyright (c) 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
# This file has been modified by Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates. on 27 Oct 2025
#
# Original file was released under MIT License, with the full license text
# available at https://github.com/anthropics/anthropic-quickstarts/blob/main/LICENSE
# and https://github.com/bytedance/trae-agent/blob/main/LICENSE
#
# This modified file is released under the same license.
"""Utility to run shell commands asynchronously with a timeout."""
import asyncio
import contextlib
TRUNCATED_MESSAGE: str = (
"<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
)
MAX_RESPONSE_LEN: int = 16000
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
"""Truncate content and append a notice if content exceeds the specified length."""
return (
content
if not truncate_after or len(content) <= truncate_after
else content[:truncate_after] + TRUNCATED_MESSAGE
)
async def run(
cmd: str,
timeout: float | None = 120.0, # seconds
truncate_after: int | None = MAX_RESPONSE_LEN,
):
"""Run a shell command asynchronously with a timeout."""
process = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
try:
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
return (
process.returncode or 0,
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
)
except asyncio.TimeoutError as exc:
with contextlib.suppress(ProcessLookupError):
process.kill()
raise TimeoutError(
f"Command '{cmd}' timed out after {timeout} seconds"
) from exc

404
src/tools/search_tool.py Normal file
View File

@@ -0,0 +1,404 @@
"""Search tool for finding files based on text content using ripgrep (rg)."""
import asyncio
import json
import re
from pathlib import Path
import shlex
from typing import override
from traceback import format_exc
from src.tools.base import (
Tool,
ToolCallArguments,
ToolError,
ToolExecResult,
ToolParameter,
SEARCH_TOOL_NAME,
)
from src.tools.run import run
from src.tools.executor import Executor
from src.managers.log.logger import Logger
from typing import Dict, Any
class SearchTool(Tool):
"""Tool for searching files based on text content using ripgrep."""
def __init__(
self,
model_provider: str | None = None,
executor: Executor | None = None,
logger: Logger | None = None,
config: Dict[str, Any] | None = None,
) -> None:
super().__init__(model_provider, logger, config)
self._executor = executor
@override
def get_model_provider(self) -> str | None:
return self._model_provider
@override
def get_name(self) -> str:
return SEARCH_TOOL_NAME
@override
def get_description(self) -> str:
return """Search tool for finding files based on text content
* Searches for text patterns in files and directories recursively
* Returns file paths, line numbers, and surrounding context
* Supports regex patterns and various search options
* Provides fast and efficient content searching
Features:
- Pattern matching with full regular expression support
- Line number display for all matches
- Configurable context lines surrounding each match (before and after).
- Filtering by file type.
- Control over case sensitivity
- Option to include hidden files in searches
- Handling of binary files
Example patterns(All patterns must be valid regular expressions):
- Simple text: "function main"
- Regex: "def\\s+\\w+\\s*\\("
"""
@override
def get_parameters(self) -> list[ToolParameter]:
"""Get the parameters for the search tool."""
params = [
ToolParameter(
name="pattern",
type="string",
description=(
"The regular expression pattern to search for within the file content. "
"To match literal characters that are also regex metacharacters (e.g., '.', '*', '+', '?', '(', ')', '[', ']', '{', '}', '|', '^', '$', '\\'), "
"they must be escaped with a backslash. "
"Examples: To find the literal string '(some_value)': '\\(some_value\\)'; To find Python function definitions: 'def\\s+[a-zA-Z_]\\w*\\s*\\('. "
),
required=True,
),
ToolParameter(
name="search_path",
type="string",
description="The directory or file path to search in. Must be an absolute path.",
required=True,
),
ToolParameter(
name="context_lines",
type="integer",
description="Number of context lines to show before and after each match. Default: 2.",
required=False,
),
ToolParameter(
type="boolean",
name="case_insensitive",
description="Whether to perform case-insensitive search. Default: false.",
required=False,
),
ToolParameter(
type="boolean",
name="include_hidden",
description="Whether to include hidden files and directories. Default: false.",
required=False,
),
ToolParameter(
type="boolean",
name="include_binary",
description="Whether to search in binary files. Default: false.",
required=False,
),
ToolParameter(
type="string",
name="file_types",
description="Comma-separated list of file types to search (e.g., 'py,js,md'). Optional.",
required=False,
),
ToolParameter(
type="integer",
name="max_results",
description="Maximum number of results to return per file. Default: 100.",
required=False,
),
]
return params
@override
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
"""Execute the search operation."""
try:
pattern = str(arguments.get("pattern", ""))
if not pattern:
return ToolExecResult(
error="Pattern parameter is required", error_code=-1
)
search_path_str = str(arguments.get("search_path", ""))
if not search_path_str:
return ToolExecResult(
error="search_path parameter is required", error_code=-1
)
search_path = Path(search_path_str)
if not search_path.is_absolute():
return ToolExecResult(
error=f"Search path must be absolute: {search_path}", error_code=-1
)
if not search_path.exists():
return ToolExecResult(
error=f"Search path does not exist: {search_path}", error_code=-1
)
# Parse optional parameters
context_lines = int(arguments.get("context_lines", 2))
case_insensitive = bool(arguments.get("case_insensitive", False))
include_hidden = bool(arguments.get("include_hidden", False))
include_binary = bool(arguments.get("include_binary", False))
file_types = arguments.get("file_types")
max_results = int(arguments.get("max_results", 100))
# Build ripgrep command
cmd_parts = ["rg"]
# Add context lines
if context_lines > 0:
cmd_parts.extend(["-C", str(context_lines)])
# Add case sensitivity
if case_insensitive:
cmd_parts.append("-i")
# Add hidden files
if include_hidden:
cmd_parts.append("--hidden")
# Add binary files
if include_binary:
cmd_parts.append("--binary")
else:
cmd_parts.append("--no-binary")
# Add file types
if file_types and isinstance(file_types, str):
for file_type in file_types.split(","):
file_type = file_type.strip()
if file_type:
cmd_parts.extend(["-g", f'"*.{file_type}"'])
# Add line numbers and filename
cmd_parts.extend(["-n", "-H"])
# Add max results
cmd_parts.extend(["-m", str(max_results)])
# Add pattern and search path (quote pattern to handle spaces)
cmd_parts.extend([f'"{pattern}"', str(search_path)])
# Execute the command
return_code, stdout, stderr = await run(" ".join(cmd_parts))
if return_code == 0:
# Parse and format results
results = self._parse_rg_output(stdout)
formatted_output = self._format_results(results, max_results)
return ToolExecResult(output=formatted_output)
elif return_code == 1:
# No matches found
return ToolExecResult(output=f"No matches found for pattern: {pattern}")
else:
# Error occurred
error_msg = (
stderr if stderr else f"ripgrep exited with code {return_code}"
)
return ToolExecResult(error=error_msg, error_code=return_code)
except Exception as e:
return ToolExecResult(
error=f"Search tool error: {str(e)}",
error_code=-1,
)
def container_search(
self, arguments: ToolCallArguments, session_id: str = "0"
) -> ToolExecResult:
if not self._executor:
return ToolExecResult(
error="No executor provided for container search", error_code=-1
)
try:
pattern = str(arguments.get("pattern", ""))
if not pattern:
return ToolExecResult(
error="Pattern parameter is required", error_code=-1
)
search_path_str = str(arguments.get("search_path", ""))
if not search_path_str:
return ToolExecResult(
error="search_path parameter is required", error_code=-1
)
context_lines = int(arguments.get("context_lines", 2))
case_insensitive = bool(arguments.get("case_insensitive", False))
include_hidden = bool(arguments.get("include_hidden", False))
include_binary = bool(arguments.get("include_binary", False))
file_types = arguments.get("file_types")
max_results = int(arguments.get("max_results", 100))
cmd_parts = ["rg"]
if context_lines > 0:
cmd_parts.extend(["-C", str(context_lines)])
if case_insensitive:
cmd_parts.append("-i")
if include_hidden:
cmd_parts.append("--hidden")
if include_binary:
cmd_parts.append("--binary")
else:
cmd_parts.append("--no-binary")
if file_types and isinstance(file_types, str):
for file_type in file_types.split(","):
file_type = file_type.strip()
if file_type:
cmd_parts.extend(["-g", f'"*.{file_type}"'])
cmd_parts.extend(["-n", "-H"])
cmd_parts.extend(["-m", str(max_results * 2)])
cmd_parts.extend(["--color=never", "-U"])
cmd_parts.extend(["--", shlex.quote(pattern), search_path_str])
command = " ".join(cmd_parts)
return_code, output = self._executor.execute_once(command)
if self.logger:
self.logger.debug(f"search_tool cmd: {command}")
# self.logger.debug(f"DEBUG: SearchTool result - Return code: {return_code}, Output: \n{output}")
if return_code == 0:
results = self._parse_rg_output(output)
# self.logger.debug(f"DEBUG: SearchTool _parse_rg_output results: {results}")
formatted_output = self._format_results(results, max_results)
# self.logger.debug(f"DEBUG: SearchTool _format_results formatted_output: {formatted_output}")
return ToolExecResult(output=formatted_output)
elif return_code == 1:
return ToolExecResult(output=f"No matches found for pattern: {pattern}")
else:
return ToolExecResult(
error=f"ripgrep exited with code {return_code}. Output: {output}",
error_code=return_code,
)
except Exception as e:
return ToolExecResult(
error=f"Container search error: {str(e)}",
error_code=-1,
)
def _parse_rg_output(self, output: str) -> list[dict]:
"""Parse ripgrep output into structured results."""
import re
# Remove ANSI escape codes
ansi_escape = re.compile(r"\x1b\[[0-9;]*m")
clean_output = ansi_escape.sub("", output)
results = []
current_file = None
for line in clean_output.split("\n"):
if not line.strip():
continue
# Check if this is a file path line (no colon, just a path)
if ":" not in line and "/" in line and not line.strip().startswith("-"):
# This is a file path line
current_file = line.strip()
continue
# Parse ripgrep output format: file:line:content or file:line-content
if ":" in line:
# Split by colon to get file, line info, and content
parts = line.split(":", 2)
if len(parts) >= 3:
file_path = parts[0].strip()
line_info = parts[1].strip()
content = parts[2].strip()
# Use current_file if file_path is empty or just a dash
if not file_path or file_path == "-":
file_path = current_file
# Check if line_info is a number (match line) or contains dash (context line)
if line_info.isdigit():
# This is a match line
line_num = int(line_info)
results.append(
{
"file": file_path,
"line": line_num,
"content": content,
"full_line": line,
"is_match": True,
}
)
elif "-" in line_info:
# This is a context line (before/after match)
# Extract line number from context line format like "12-15" or "12-"
try:
line_num = int(line_info.split("-")[0])
results.append(
{
"file": file_path,
"line": line_num,
"content": content,
"full_line": line,
"is_match": False,
}
)
except ValueError:
continue
return results
def _format_results(self, results: list[dict], max_results: int) -> str:
"""Format search results for display."""
if not results:
return "No matches found."
# Filter only match lines for counting
match_results = [r for r in results if r.get("is_match", True)]
limited_results = results[:max_results]
output_lines = [f"Found {len(match_results)} matches:"]
output_lines.append("=" * 50)
current_file = None
for result in limited_results:
file_path = result["file"]
line_num = result["line"]
content = result["content"]
is_match = result.get("is_match", True)
# Add file header if this is a new file
if current_file != file_path:
current_file = file_path
output_lines.append(f"\n📁 {file_path}")
output_lines.append("-" * (len(file_path) + 4))
# Add line with appropriate prefix
prefix = " " if is_match else " " # Match lines get no special prefix
marker = "" if is_match else " " # Mark actual matches
output_lines.append(f"{marker} {line_num:4d}: {content}")
if len(results) > max_results:
output_lines.append(f"\n... and {len(results) - max_results} more lines")
return "\n".join(output_lines)

View File

@@ -0,0 +1,127 @@
"""Search tool for finding files based on text content using ripgrep (rg)."""
import asyncio
import json
from logging import Logger
import re
from pathlib import Path
from typing import Any, Dict, override
from traceback import format_exc
from src.tools.base import (
Tool,
ToolCallArguments,
ToolError,
ToolExecResult,
ToolParameter,
SubmitToolResult,
SUBMIT_RESULT_TOOL_NAME,
)
from src.tools.run import run
from src.tools.executor import Executor
class SubmitResultTool(Tool):
"""Tool for git diff, not for model to invoke"""
def __init__(
self,
model_provider: str | None = None,
executor: Executor | None = None,
logger: Logger | None = None,
config: Dict[str, Any] | None = None,
) -> None:
super().__init__(model_provider, logger, config)
self._executor = executor
@override
def get_model_provider(self) -> str | None:
return self._model_provider
@override
def get_name(self) -> str:
return SUBMIT_RESULT_TOOL_NAME
@override
def get_description(self) -> str:
return """
Submit the final result to complete the task.
This tool should be called when you are confident that the issue has been resolved. Simply indicate that you are ready to submit the result - the system will automatically capture the git diff and generate the final patch.
You don't need to provide the actual patch content manually. Just call this tool to signal completion, and the system will handle the rest.
"""
@override
def get_parameters(self) -> list[ToolParameter]:
params = [
ToolParameter(
name="is_task_done",
type="boolean",
description="Whether the task is done",
required=True,
),
ToolParameter(
name="test_status",
type="string",
description="The status of test execution after applying the patch",
required=True,
enum=["passed", "failed", "skipped", "error"],
),
ToolParameter(
name="reasoning",
type="string",
description="Detailed explanation of the logic behind the patch, including root cause analysis and solution approach",
required=True,
),
]
return params
@override
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
"""Execute the tool locally (not supported for submit_result tool)."""
return ToolExecResult(
error="SubmitResultTool only supports container execution", error_code=-1
)
@override
async def container_execute(self, arguments: ToolCallArguments) -> ToolExecResult:
if not self._executor:
return ToolExecResult(
error="No executor provided for git diff tool", error_code=-1
)
try:
is_task_done = arguments.get("is_task_done", False)
test_status = arguments.get("test_status", "error")
reasoning = arguments.get("reasoning", "")
root_path = self.config.get("builder", {}).get("repo_root_path", "/")
cmd_parts = ["cd", str(root_path), "&&", "git", "--no-pager", "diff"]
command = " ".join(cmd_parts)
self.logger.debug(
f"DEBUG: GitDiffTool executing command: {command}"
) # Debug output
return_code, output = self._executor.execute_once(command)
self.logger.debug(
f"DEBUG: GitDiffTool result - Return code: {return_code}, Output: \n{output}"
) # Debug output
if return_code == 0:
submit_result = SubmitToolResult(
return_code=return_code,
output=output,
is_task_done=is_task_done,
test_status=test_status,
reasoning=reasoning,
)
return ToolExecResult(output=str(submit_result))
else:
return ToolExecResult(
error=f"GitDiffTool exited with code {return_code}. Output: {output}",
error_code=return_code,
)
except Exception as e:
return ToolExecResult(
error=f"Container search error: {str(e)}", error_code=-1
)