Merge pull request #1 from Tokfinity/feat_lzc

feat initial
This commit is contained in:
Pete Wong
2025-10-29 14:21:51 +08:00
committed by GitHub
52 changed files with 12244 additions and 0 deletions

225
.gitignore vendored Normal file
View File

@@ -0,0 +1,225 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
#poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
#pdm.lock
#pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
#pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Cursor
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# workspace
workspace/
# logs
logs/
nohup.out
gold.validate-gold.json
.DS_Store
.idea/
# instance_to_image.json
src/managers/image_builder/instance_to_image.json
/batch_out

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12.9

7
LICENSE Normal file
View File

@@ -0,0 +1,7 @@
Copyright 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

1
__init__.py Normal file
View File

@@ -0,0 +1 @@
# project root directory marker file

275
_batch_run.py Normal file
View File

@@ -0,0 +1,275 @@
import argparse
from dataclasses import dataclass
import logging
from pathlib import Path
import shutil
import signal
import sys
import time
import yaml
from pydantic import BaseModel, Field
from typing import Any, List, Optional
import multiprocessing
from functools import partial
class Config(BaseModel):
issue_list: Optional[List[str]] = None
pass
class Args:
config: str
output: str
parallel: int
name: str
issue_list: str
custom: str
clean: bool
dry_run: bool
pass
class Context:
config: Config
args: Args
def __init__(self, config: Config, args: Args):
self.args = args
self.config = config
def main():
(args, cfg) = parse_args()
ctx = Context(config=cfg, args=args)
print(f"Input args: {args} {args.config}")
out_dir = Path(ctx.args.output).joinpath(format(f"batch-{ctx.args.name}"))
out_dir.mkdir(parents=True, exist_ok=True)
if not out_dir.is_dir():
raise ValueError(f"{out_dir}is not a directory")
if ctx.args.clean:
shutil.rmtree(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
now = time.time()
p = BatchProcessExecutor(ctx, out_dir=out_dir)
p.execute_all_tasks()
duration = int(time.time() - now)
print(
f"DONE Total time cost {int(duration/3600)}h {int(duration/60%60)}m {int(duration%60)}s"
)
@dataclass
class ProcResult:
issue: str
idx: int
duration: int
summary: dict[str, Any]
def worder_func_for_gen_result(ctx: Context, out_dir: Path):
print(f"worder_func_for_gen_result...")
import _run
config_path = Path(__file__).parent / "config" / "config.yaml"
with open(config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
cfg["workspace"]["path"] = str(out_dir)
cfg["result"]["preds"]["result"] = "result"
cfg["runner"]["selector_result_dump_path"] = str(
out_dir.joinpath("selector_result_dump")
)
start_time = time.time()
if not ctx.args.dry_run:
_run.gen_result(cfg)
time.sleep(1)
else:
time.sleep(2)
duration = time.time() - start_time
print(
f"worder_func_for_gen_result DONE time cost:{int(duration/60)}m {int(duration%60)}s"
)
def worker_function(ctx: Context, issue: str, idx: int, out_dir: Path) -> ProcResult:
print(f"worker_function idx:{idx} issue:{issue}")
issue_out_dir = out_dir.joinpath("issues").joinpath(f"{idx:03d}-{issue}")
issue_out_dir.mkdir(parents=True, exist_ok=True)
issue_out_log_dir = issue_out_dir.joinpath("logs")
issue_out_log_dir.mkdir(parents=True, exist_ok=True)
generator_result_dump_path = out_dir.joinpath("generator_result_dump")
generator_result_dump_path.mkdir(parents=True, exist_ok=True)
selector_result_dump_path = out_dir.joinpath("selector_result_dump")
selector_result_dump_path.mkdir(parents=True, exist_ok=True)
original_stdout = sys.stdout
original_stderr = sys.stderr
sys.stdout = open(issue_out_dir.joinpath("stdout.log"), "a")
sys.stderr = open(issue_out_dir.joinpath("stdout.log"), "a")
# signal.signal(signal.SIGINT, signal.SIG_IGN)
config_path = Path(__file__).parent / "config" / "config.yaml"
with open(config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
start_time = time.time()
import _run
config_path = Path(__file__).parent / "config" / "config.yaml"
with open(config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
cfg["log"]["base_path"] = str(issue_out_log_dir)
cfg["runner"]["generator_result_dump_path"] = str(generator_result_dump_path)
cfg["runner"]["selector_result_dump_path"] = str(selector_result_dump_path)
if not ctx.args.dry_run:
summary = _run.run_one_issue(cfg, issue=issue)
time.sleep(1)
else:
time.sleep(2)
summary = {}
summary["success"] = 1
summary["failed"] = 0
summary["total"] = 1
sys.stdout.flush()
sys.stderr.flush()
sys.stdout = original_stdout
sys.stderr = original_stderr
duration = time.time() - start_time
print(
f"worker_function DONE idx:{idx} issue:{issue} 耗时:{int(duration/60)}m {int(duration%60)}s"
)
if summary["success"] > 0:
add_done_set(out_dir, issue)
return ProcResult(duration=duration, issue=issue, idx=idx, summary=summary)
class BatchProcessExecutor:
ctx: Context
out_dir: Path
def __init__(self, ctx: Context, out_dir: Path):
self.ctx = ctx
self.out_dir = out_dir
def execute_all_tasks(self):
done_set = load_done_set(self.out_dir)
parallel = self.ctx.args.parallel
formatted_issues = []
for idx, issue in enumerate(self.ctx.config.issue_list, 1):
if issue in done_set:
print(f"done set: skip {idx}, {issue}")
else:
formatted_issues.append((issue, idx, self.out_dir))
with multiprocessing.Pool(processes=parallel, maxtasksperchild=1) as pool:
try:
worker = partial(worker_function, self.ctx)
results = pool.starmap(worker, formatted_issues)
except KeyboardInterrupt as e:
print(f"ctrl-c received, exit")
pool.terminate()
cum_total = 0
cum_success = 0
cum_failed = 0
for result in results:
total = result.summary["total"]
success = result.summary["success"]
failed = result.summary["failed"]
cum_total += total
cum_success += success
cum_failed += failed
print(f"Total instances: {cum_total}")
print(f"Success: {cum_success}")
print(f"Fail: {cum_failed}")
print(f"Success rate: {(cum_success/cum_total*100):.1f}%" if cum_total > 0 else "0%")
print(f"Start generating final results")
process = multiprocessing.Process(
target=worder_func_for_gen_result, args=(self.ctx, self.out_dir)
)
process.start()
process.join()
def parse_args() -> tuple[Args, Config]:
parser = argparse.ArgumentParser(description="This is a concurrent execution tool.")
parser.add_argument("-c", "--config", help="config file", required=True)
parser.add_argument(
"--name", help="task name,which will be concatenated as a directory name under the output path.", required=True
)
parser.add_argument("-i", "--issue-list", help="question list", type=str)
parser.add_argument(
"-o",
"--output",
help="output directory, ./batch_out/as default",
type=str,
default="./batch_out/",
)
parser.add_argument(
"-p",
"--parallel",
help="Parallelism20 as default",
type=int,
default=20,
)
parser.add_argument(
"--clean", help="Clean up data with the same name in the output directory before starting.", action="store_true"
)
parser.add_argument(
"--dry-run", help="Skip the actual inference task execution, but proceed with all other logic.", action="store_true"
)
args: Args = parser.parse_args()
setattr(args, "custom", "str")
with open(args.config, "r") as f:
config_data = yaml.safe_load(f)
cfg = Config(**config_data)
if args.issue_list:
issues: list[str] = []
with open(args.issue_list, "r", encoding="utf-8") as f:
for line in f:
line_clean = line.strip()
if line_clean:
issues.append(str(line_clean))
cfg.issue_list = issues
print(f"list len = {len(cfg.issue_list)}")
return (args, cfg)
def signal_handler(signum, frame):
print(f"Received signal {signum}, terminating child processes...")
# sys.exit(0)
def load_done_set(out_dir: Path) -> set[str]:
file_name = out_dir.joinpath("done.txt")
if not file_name.exists():
return set()
with open(file_name, "r") as f:
return set(line.strip() for line in f if line.strip())
def add_done_set(out_dir: Path, issue: str):
file_name = out_dir.joinpath("done.txt")
with open(file_name, "a") as f:
f.write(f"{issue}\n")
if __name__ == "__main__":
main()

50
_build_image.py Normal file
View File

@@ -0,0 +1,50 @@
from src.managers.image_builder.build_image import SWEBenchImageBuilder
import yaml
from pathlib import Path
from src.managers.log.logger import ImageBuilderLogger
from typing import List
def load_instance_ids(file_path: str) -> List[str]:
with open(file_path, "r", encoding="utf-8") as f:
return [line.strip() for line in f if line.strip()]
class BuildImage:
def __init__(self, instance_ids: List[str] = None):
self.instance_ids = instance_ids or []
self.config = self._load_config()
self.logger = ImageBuilderLogger(
log_base_path=self.config.get("log", {}).get("image_builder", "workspace/image_logs"),
console_output=True,
)
self.builder = self._init_builder()
def _init_builder(self):
dataset_name = self.config.get("dataset", {}).get("name", "princeton-nlp/SWE-bench_Lite")
dataset_split = self.config.get("dataset", {}).get("split", "dev")
return SWEBenchImageBuilder(
dataset_name=dataset_name,
split=dataset_split,
logger=self.logger,
)
def _load_config(self):
config_path = Path(__file__).parent / "config" / "config.yaml"
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {}
def _build(self):
max_workers = self.config.get("builder", {}).get("max_workers", 3)
max_retries = self.config.get("builder", {}).get("max_retries", 1)
self.builder.build_target_images(instance_ids=self.instance_ids, max_workers=max_workers, max_retries=max_retries)
if __name__ == "__main__":
#instance_ids = ["django__django-10097", "django__django-10880"]
file_path = "" # fill in the list path
instance_ids = load_instance_ids(file_path=file_path)
print(instance_ids)
build_image = BuildImage(instance_ids)
build_image._build()

783
_run.py Normal file
View File

@@ -0,0 +1,783 @@
from typing import List, Dict, Any
from pathlib import Path
import json
import yaml
import asyncio
from datetime import datetime
from traceback import format_exc
from src.tools.executor import Executor
from src.tools import BashTool, TextEditorTool, SearchTool, SubmitResultTool
from src.managers.result_builder.result_builder import ResultBuilder
from src.tools.base import (
ToolExecutor,
BASH_TOOL_NAME,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
SEARCH_TOOL_NAME,
SUBMIT_RESULT_TOOL_NAME,
)
from src.managers.log.logger import create_logger, Logger
from src.managers.llm_api.api_manager import LLMAPIManager
from src.managers.image_builder.build_image import SWEBenchImageBuilder
from src.managers.prompts.prompts_manager import PromptsManager
from src.managers.loop.patch_generator import PatchGenerator
from src.managers.loop.patch_selector import PatchSelector
from src.managers.loop.types import GeneratorResult, SelectorResult
class SelectorLoop:
def __init__(
self,
instance_id: str,
image_name: str,
runner_log_base: Path,
llm_manager: LLMAPIManager | None,
prompts_manager: PromptsManager | None,
instance_data: Dict[str, Any],
config: Dict[str, Any],
):
self.instance_id = instance_id
self.image_name = image_name
self.llm_manager = llm_manager
self.prompts_manager = prompts_manager
self.instance_data = instance_data
self.config = config
self.log_dir = runner_log_base / "run" / instance_id / "selector"
self.log_dir.mkdir(parents=True, exist_ok=True)
self.logger = Logger(
log_base_path=str(self.log_dir.parent),
logger_name=f"selector_{instance_id}",
console_output=True,
instance_id=self.log_dir.name,
)
def _dump_select_result(self, result: SelectorResult) -> None:
try:
runner_cfg = (
self.config.get("runner", {}) if isinstance(self.config, dict) else {}
)
dump_dir_str = runner_cfg.get(
"selector_result_dump_path", "workspace/selector_result_dump"
)
dump_dir = Path(dump_dir_str)
dump_dir.mkdir(parents=True, exist_ok=True)
out_path = dump_dir / f"{self.instance_id}.json"
payload = (
result.to_dict()
if hasattr(result, "to_dict") and callable(getattr(result, "to_dict"))
else {}
)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
except Exception as e:
self.logger.warning(
f"dump selector result fail: {e}, traceback: {format_exc()}"
)
def _load_dumped_generator_results(self) -> List[GeneratorResult]:
try:
runner_cfg = (
self.config.get("runner", {}) if isinstance(self.config, dict) else {}
)
dump_dir_str = runner_cfg.get(
"generator_result_dump_path", "workspace/generator_result_dump"
)
dump_path = Path(dump_dir_str) / f"{self.instance_id}.json"
if not dump_path.exists():
self.logger.warning(f"fail to find dump 文件: {dump_path}")
return []
with open(dump_path, "r", encoding="utf-8") as f:
data = json.load(f) or []
results: List[GeneratorResult] = []
for item in data:
try:
if isinstance(item, dict):
results.append(GeneratorResult.from_dict(item))
except Exception as e:
self.logger.warning(f"parse dump fail: {e}")
continue
return results
except Exception as e:
self.logger.warning(f"load dump fail: {e}, traceback: {format_exc()}")
return []
async def select(self, generator_results: List[GeneratorResult]) -> SelectorResult:
if bool(self.config.get("runner", {}).get("skip_generator", False)):
self.logger.info("jump generatorselector will load generator results from dump files")
generator_results = self._load_dumped_generator_results()
self.logger.info(f"load from dump: {len(generator_results)} candidates")
if not generator_results:
from src.managers.loop.types import (
SelectorResult,
PatchInfo,
LLMUsage,
ToolStats,
)
self.logger.error("No choosable candidates found")
return SelectorResult(
instance_id=self.instance_id,
generator_id=-1,
image="",
success=False,
golden_patch=PatchInfo(patch_content="", test_status="", reasoning=""),
llm_usage=LLMUsage(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
tool_stats=ToolStats(bash=0, edit=0, search=0, submit_result=0),
total_turns=0,
select_reason="",
error="No candidates available",
)
self.logger.info(f"Start choosing best result{len(generator_results)} candidates in total")
executor = Executor(self.image_name, self.logger)
bash_tool = BashTool(
model_provider=None,
executor=executor,
logger=self.logger,
config=self.config,
)
edit_tool = TextEditorTool(
model_provider=None,
executor=executor,
logger=self.logger,
config=self.config,
)
search_tool = SearchTool(
model_provider=None,
executor=executor,
logger=self.logger,
config=self.config,
)
tool_executor = ToolExecutor([bash_tool, edit_tool, search_tool], self.logger)
code, out = executor.execute("0", "echo READY && rg --version || true")
self.logger.info(f"Container Health check: exit={code}, out=\n{out}")
successful_candidates = [r for r in generator_results if r.success]
if not successful_candidates:
self.logger.warning("No successful candidates found, randomly choose one from all candidates")
candidates = generator_results
else:
self.logger.info(f"Find {len(successful_candidates)} successful candidates")
candidates = successful_candidates
patch_selector = PatchSelector(
instance_id=self.instance_id,
instance_data=self.instance_data,
logger=self.logger,
prompts_manager=self.prompts_manager,
llm_manager=self.llm_manager,
tool_executor=tool_executor,
config=self.config,
)
selected = await patch_selector._select_patch(candidates=candidates)
try:
runner_cfg = (
self.config.get("runner", {}) if isinstance(self.config, dict) else {}
)
if bool(runner_cfg.get("selector_result_dump", False)):
self._dump_select_result(selected)
except Exception as e:
self.logger.warning(
f"Error occurred when choosing dump : {e}, traceback: {format_exc()}"
)
# import random
# selected = random.choice(candidates)
self.logger.info(f"Choosing complete: choose#{selected.generator_id}.")
return selected
class GeneratorLoop:
def __init__(
self,
instance_id: str,
image_name: str,
runner_log_base: Path,
llm_manager: LLMAPIManager | None,
prompts_manager: PromptsManager | None,
instance_data: Dict[str, Any],
config: Dict[str, Any],
generator_id: int = 0,
):
self.instance_id = instance_id
self.image_name = image_name
self.generator_id = generator_id
self.llm_manager = llm_manager
self.prompts_manager = prompts_manager
self.instance_data = instance_data
self.config = config
self.log_dir = (
runner_log_base / "run" / instance_id / "generator" / f"{generator_id:03d}"
)
self.log_dir.mkdir(parents=True, exist_ok=True)
self.logger = Logger(
log_base_path=str(self.log_dir.parent),
logger_name=f"generator_{instance_id}_{generator_id:03d}",
console_output=True,
instance_id=self.log_dir.name,
)
async def generate(self) -> GeneratorResult:
executor: Executor | None = None
try:
self.logger.info(
f"Activate instance GeneratorLoop #{self.generator_id:03d}: {self.instance_id} -> {self.image_name}"
)
self.logger.info(f"Use image: {self.image_name}")
executor = Executor(self.image_name, self.logger)
bash_tool = BashTool(
model_provider=None,
executor=executor,
logger=self.logger,
config=self.config,
)
edit_tool = TextEditorTool(
model_provider=None,
executor=executor,
logger=self.logger,
config=self.config,
)
search_tool = SearchTool(
model_provider=None,
executor=executor,
logger=self.logger,
config=self.config,
)
submit_result_tool = SubmitResultTool(
model_provider=None,
executor=executor,
logger=self.logger,
config=self.config,
)
tool_executor = ToolExecutor(
[bash_tool, edit_tool, search_tool, submit_result_tool], self.logger
)
# tool_executor = ToolExecutor([bash_tool, edit_tool])
# optional: do a container health check
code, out = executor.execute("0", "echo READY && rg --version || true")
self.logger.info(f"Container Health Check: exit={code}, out=\n{out}")
patch_generator = PatchGenerator(
instance_id=self.instance_id,
instance_data=self.instance_data,
logger=self.logger,
prompts_manager=self.prompts_manager,
llm_manager=self.llm_manager,
tool_executor=tool_executor,
config=self.config,
)
patch_result = await patch_generator._generate_patch()
if patch_result is None:
result_data = {
"instance_id": self.instance_id,
"generator_id": self.generator_id,
"image": self.image_name,
"success": False,
"golden_patch": [],
"llm_usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
"tool_stats": {
BASH_TOOL_NAME: 0,
STR_REPLACE_BASED_EDIT_TOOL_NAME: 0,
SEARCH_TOOL_NAME: 0,
SUBMIT_RESULT_TOOL_NAME: 0,
},
"total_turns": 0,
}
else:
result_data = {
"instance_id": self.instance_id,
"generator_id": self.generator_id,
"image": self.image_name,
"success": patch_result.get("success", False),
"golden_patch": patch_result.get("golden_patch", []),
"llm_usage": patch_result.get(
"llm_usage",
{
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
),
"tool_stats": patch_result.get(
"tool_stats",
{
BASH_TOOL_NAME: 0,
STR_REPLACE_BASED_EDIT_TOOL_NAME: 0,
SEARCH_TOOL_NAME: 0,
SUBMIT_RESULT_TOOL_NAME: 0,
},
),
"total_turns": patch_result.get("total_turns", 0),
}
self.logger.debug(f"[Generator Loop] result_data: {result_data}")
return GeneratorResult.from_dict(result_data)
except Exception as e:
self.logger.error(
f"Instance {self.instance_id} Generator #{self.generator_id:03d} fail: {e}, traceback: {format_exc()}"
)
error_data = {
"instance_id": self.instance_id,
"generator_id": self.generator_id,
"image": self.image_name,
"success": False,
"error": str(e),
"golden_patch": [],
"llm_usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
"tool_stats": {
BASH_TOOL_NAME: 0,
STR_REPLACE_BASED_EDIT_TOOL_NAME: 0,
SEARCH_TOOL_NAME: 0,
SUBMIT_RESULT_TOOL_NAME: 0,
},
"total_turns": 0,
}
return GeneratorResult.from_dict(error_data)
finally:
if executor:
try:
executor.shutdown()
except Exception:
pass
class Runner:
def __init__(self, cfg: Dict[str, Any], instance_ids: List[str] = None):
self.cfg = cfg
dataset_cfg = cfg.get("dataset", {})
workspace_cfg = cfg.get("workspace", {})
builder_cfg = cfg.get("builder", {})
log_cfg = cfg.get("log", {})
runner_cfg = cfg.get("runner", {})
providers_cfg = cfg.get("providers", {})
self.instance_ids = instance_ids
self.dataset_name = dataset_cfg.get("name", "princeton-nlp/SWE-bench_Lite")
self.dataset_split = dataset_cfg.get("split", "dev")
self.max_workers = int(builder_cfg.get("max_workers", 2))
self.generator_loop_concurrency = int(
runner_cfg.get("generator_concurrency", 2)
)
log_base_path = log_cfg.get("base_path", "workspace/logs")
timestamp = datetime.now().strftime("%Y%m%d%H%M")
self.logs_base = Path(log_base_path) / timestamp
self.logs_base.mkdir(parents=True, exist_ok=True)
self.logger = Logger(
log_base_path=str(self.logs_base.parent),
logger_name="main",
console_output=True,
instance_id=self.logs_base.name,
)
self.builder = self.load_images()
self.logger.debug(f"builder instance_to_image: {self.builder.instance_to_image}.")
self.llm_manager = LLMAPIManager(logger=self.logger, config=cfg)
self.prompts_manager: PromptsManager | None = None
try:
self.prompts_manager = PromptsManager(cfg)
except Exception as e:
self.logger.warning(
f"Failed to initialize PromptsManager: {e}, traceback: {format_exc()}"
)
self.prompts_manager = None
def dump_generator_results(
self, instance_id: str, generator_results: List[GeneratorResult]
) -> None:
"""Dump generator results to disk if enabled in config.
- Reads runner.generator_result_dump (bool) and runner.generator_result_dump_path (str)
- When enabled, writes JSON file named by instance_id under the dump path
"""
try:
runner_cfg: Dict[str, Any] = (
self.cfg.get("runner", {}) if isinstance(self.cfg, dict) else {}
)
enabled: bool = bool(runner_cfg.get("generator_result_dump", False))
if not enabled:
return
dump_dir_str: str = runner_cfg.get(
"generator_result_dump_path", "workspace/generator_result_dump"
)
dump_dir = Path(dump_dir_str)
dump_dir.mkdir(parents=True, exist_ok=True)
out_path = dump_dir / f"{instance_id}.json"
# Convert results to serializable form
serialized: List[Dict[str, Any]] = []
for r in generator_results:
try:
if hasattr(r, "to_dict") and callable(getattr(r, "to_dict")):
serialized.append(r.to_dict())
elif isinstance(r, dict):
serialized.append(r)
else:
# Fallback: best-effort representation
serialized.append({"repr": repr(r)})
except Exception as e: # noqa: BLE001
serialized.append({"error": f"failed to serialize item: {e}"})
with open(out_path, "w", encoding="utf-8") as f:
json.dump(serialized, f, ensure_ascii=False, indent=2)
self.logger.info(
f"Generator results write to: {out_path} ({len(serialized)} items in total)"
)
except Exception as e:
self.logger.warning(
f"dump_generator_results failed: {e}, traceback: {format_exc()}"
)
def load_images(self):
self.logger.info("Initialize SWEBenchImageBuilder and ready to load image...")
builder = SWEBenchImageBuilder(
dataset_name=self.dataset_name,
split=self.dataset_split,
logger=self.logger,
)
builder.load_images(instance_ids=self.instance_ids)
return builder
async def _run_one(
self,
instance_id: str,
image_name: str,
instance_data: Dict[str, Any],
generator_id: int = 0,
) -> GeneratorResult:
loop = GeneratorLoop(
instance_id,
image_name,
self.logs_base,
self.llm_manager,
self.prompts_manager,
instance_data,
self.cfg,
generator_id,
)
return await loop.generate()
async def process_one_instance(self, instance: Dict[str, Any]) -> SelectorResult:
instance_id = instance["instance_id"]
try:
image_name = self.builder.get_image_name(instance_id)
except KeyError:
self.logger.warning(f"Jump instance (image mapping unfind): {instance_id}")
return None
self.logger.info(f"Start processing instance: {instance_id}")
# optional: jump generator, let selector load from dump directly and choose
skip_generator = bool(self.cfg.get("runner", {}).get("skip_generator", False))
if skip_generator:
self.logger.info(
"Jump GeneratorLoopSelector load from dump directly"
)
selector = SelectorLoop(
instance_id=instance_id,
image_name=image_name,
runner_log_base=self.logs_base,
llm_manager=self.llm_manager,
prompts_manager=self.prompts_manager,
instance_data=instance,
config=self.cfg,
)
selected_result = await selector.select(
[]
)
self.logger.info(
f"Instance {instance_id} done(generating jumped)Generator #{selected_result.generator_id:03d} chosen"
)
return selected_result
generator_load_dump = bool(
self.cfg.get("runner", {}).get("generator_load_dump_result", False)
)
valid_results = []
if generator_load_dump:
runner_cfg = self.cfg.get("runner", {})
dump_path = runner_cfg.get(
"generator_result_dump_path", "workspace/generator_result_dump"
)
if self._check_dump_result(dump_path, instance_id):
self.logger.info(f"load generator results from dump: {instance_id}")
valid_results = self._load_generator_dump_result(dump_path, instance_id)
else:
self.logger.info(
f"Fail to find generator dump fileGenerate candidate patches concurrently: {instance_id}"
)
if not valid_results:
generator_tasks = []
for generator_id in range(self.generator_loop_concurrency):
task = asyncio.create_task(
self._run_one(instance_id, image_name, instance, generator_id)
)
generator_tasks.append(task)
generator_results = await asyncio.gather(
*generator_tasks, return_exceptions=True
)
self.logger.debug(
f"In process_one_instance, generator_results len: {len(generator_results)}"
)
for result in generator_results:
if isinstance(result, Exception):
self.logger.error(f"GeneratorLoop exception: {result}")
else:
valid_results.append(result)
self.logger.debug(
f"In process_one_instance, valid_results len: {len(valid_results)}"
)
# optional: Dump the generator results for subsequent selector debugging.
try:
self.dump_generator_results(instance_id, valid_results)
except Exception:
# Dump failure should not block the main process/flow.
pass
if not valid_results:
self.logger.warning(f"Instance {instance_id} has no valid GeneratorLoop results")
return None
selector_load_dump = bool(
self.cfg.get("runner", {}).get("selector_load_dump_result", False)
)
if selector_load_dump:
runner_cfg = self.cfg.get("runner", {})
dump_path = runner_cfg.get(
"selector_result_dump_path", "workspace/selector_result_dump"
)
if self._check_dump_result(dump_path, instance_id):
self.logger.info(f"load selector results from dump file: {instance_id}")
try:
dump_dir = Path(dump_path)
file_path = dump_dir / f"{instance_id}.json"
with open(file_path, "r", encoding="utf-8") as f:
selected_data = json.load(f)
from src.managers.loop.types import SelectorResult
selected_result = SelectorResult.from_dict(selected_data)
self.logger.info(
f"Instance{instance_id} process done (load from dump)Generator #{selected_result.generator_id:03d} chosen"
)
return selected_result
except Exception as e:
self.logger.warning(
f"Fail to load selector dump result: {e}, execute normal choosing procedure"
)
else:
self.logger.info(
f"Fail to load selector dump result, execute normal choosing procedure: {instance_id}"
)
self.logger.info(f"Star choosing best result for instance {instance_id}")
selector = SelectorLoop(
instance_id=instance_id,
image_name=image_name,
runner_log_base=self.logs_base,
llm_manager=self.llm_manager,
prompts_manager=self.prompts_manager,
instance_data=instance,
config=self.cfg,
)
selected_result = await selector.select(valid_results)
self.logger.info(
f"Instance {instance_id} processedGenerator #{selected_result.generator_id:03d} chosen"
)
return selected_result
async def run(self) -> Dict[str, Any]:
assert self.builder is not None
if self.instance_ids:
target_ids = set(self.instance_ids)
instances_to_run = [
inst
for inst in self.builder.full_dataset
if inst.get("instance_id") in target_ids
]
else:
instances_to_run = list(self.builder.full_dataset)
self.logger.info(f"Start to process {len(instances_to_run)} instances")
final_results = []
for i, instance in enumerate(instances_to_run, 1):
self.logger.info(
f"process instance{i}/{len(instances_to_run)}: {instance['instance_id']}"
)
try:
result = await self.process_one_instance(instance)
if result is not None:
final_results.append(result)
except Exception as e:
self.logger.error(
f"Instance {instance['instance_id']} process fail: {e}, traceback: {format_exc()}."
)
final_results.append(
SelectorResult(
instance_id=instance["instance_id"],
generator_id=0,
success=False,
golden_patch=None,
)
)
summary = self._calculate_summary(final_results)
self.logger.info(
f"Done. Total={summary['total']} Success={summary['success']} Fail={summary['failed']}"
)
return summary
def _check_dump_result(self, dump_path: str, instance_id: str) -> bool:
try:
dump_dir = Path(dump_path)
file_path = dump_dir / f"{instance_id}.json"
return file_path.exists()
except Exception as e:
self.logger.warning(f"Fail to check dump file: {e}")
return False
def _load_generator_dump_result(
self, dump_path: str, instance_id: str
) -> List[GeneratorResult]:
try:
dump_dir = Path(dump_path)
file_path = dump_dir / f"{instance_id}.json"
if not file_path.exists():
self.logger.warning(f"Generator dump file does not exist: {file_path}")
return []
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f) or []
results: List[GeneratorResult] = []
for item in data:
try:
if isinstance(item, dict):
results.append(GeneratorResult.from_dict(item))
else:
self.logger.warning(f"Jump non-dict dump: {type(item)}")
except Exception as e:
self.logger.warning(f"parse dump fail: {e}")
continue
self.logger.info(f"load {len(results)} GeneratorResults from dump")
return results
except Exception as e:
self.logger.warning(
f"Load generator dump results failed: {e}, traceback: {format_exc()}"
)
return []
def _calculate_summary(self, results: List[SelectorResult]) -> Dict[str, Any]:
summary: Dict[str, Any] = {
"total": 0,
"success": 0,
"failed": 0,
}
for r in results:
summary["total"] += 1
if r.success:
summary["success"] += 1
else:
summary["failed"] += 1
return summary
def run_one_issue(cfg, issue: str) -> dict[str, Any]:
ids = [issue]
summary = {}
if not cfg.get("runner", {}).get("skip_selector", False):
runner = Runner(cfg, ids)
summary = asyncio.run(runner.run())
print("\n" + "=" * 80)
print("total results")
print("=" * 80)
print(f"Total instances: {summary['total']}")
print(f"Success: {summary['success']}")
print(f"Fail: {summary['failed']}")
print(
f"Success rate: {(summary['success']/summary['total']*100):.1f}%"
if summary["total"] > 0
else "0%"
)
return summary
def gen_result(cfg):
result_builder = ResultBuilder(cfg)
result_path = result_builder.build_preds()
def main() -> None:
config_path = Path(__file__).parent / "config" / "config.yaml"
with open(config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
test_instance_ids = ["astropy__astropy-14309"]
if not cfg.get("runner", {}).get("skip_selector", False):
runner = Runner(cfg, test_instance_ids)
summary = asyncio.run(runner.run())
print("\n" + "=" * 80)
print("Total results")
print("=" * 80)
print(f"Total instances: {summary['total']}")
print(f"Success: {summary['success']}")
print(f"Fail: {summary['failed']}")
print(
f"Success rate: {(summary['success']/summary['total']*100):.1f}%"
if summary["total"] > 0
else "0%"
)
result_builder = ResultBuilder(cfg)
result_builder.build_preds()
if __name__ == "__main__":
main()

1
config/batch.yaml Normal file
View File

@@ -0,0 +1 @@
providers:

61
config/config.yaml Normal file
View File

@@ -0,0 +1,61 @@
providers:
openrouter:
- "anthropic/claude-sonnet-4.5"
dataset:
name: "princeton-nlp/SWE-bench_Verified"
split: "test"
workspace:
path: "workspace"
# Runner Settings
runner:
# generator
skip_generator: false # Whether to skip the generator; for selector testing you can skip and load generator_results directly from dump files
generator_concurrency: 5 # GeneratorLoop patches num
generator_loop: # GeneratorLoop settings
max_turn: 140 # Max turns
temperature: 0.2 # Temperature
generator_result_dump: true # Whether to dump generator results for later selector testing
generator_result_dump_path: "workspace/generator_result_dump" # generator result dump path
generator_load_dump_result: true # Whether to skip examples that already have generated results
# selector
skip_selector: false # Whether to skip the selector; useful to test validation directly
selector_loop: # SelectorLoop settings
max_turn: 200 # Max turns
temperature: 0.2 # Temperature
selector_result_dump: true # Whether to dump selector results
selector_result_dump_path: "workspace/selector_result_dump"
selector_load_dump_result: true # Whether to skip examples that already have selection results
# Image Builder
builder:
max_workers: 16 # Image build concurrency
max_retries: 3 # Image build retry count
repo_root_path: "/testbed" # Repository root path
# Log settings
log:
base_path: "workspace/logs"
image_builder: "workspace/image_logs" # Use temporary directory for image builder logs
disable_image_builder_logs: false # Set to true to prevent saving log files
# Final result paths
result:
preds:
path: "result"
# token_cache settings for claude models
token_cache:
role_turns:
- turn: 0
role: "user"
- turn: 25
role: "assistant"
- turn: 50
role: "assistant"
- turn: 75
role: "assistant"

500
config/list_all.txt Normal file
View File

@@ -0,0 +1,500 @@
astropy__astropy-12907
astropy__astropy-13033
astropy__astropy-13236
astropy__astropy-13398
astropy__astropy-13453
astropy__astropy-13579
astropy__astropy-13977
astropy__astropy-14096
astropy__astropy-14182
astropy__astropy-14309
astropy__astropy-14365
astropy__astropy-14369
astropy__astropy-14508
astropy__astropy-14539
astropy__astropy-14598
astropy__astropy-14995
astropy__astropy-7166
astropy__astropy-7336
astropy__astropy-7606
astropy__astropy-7671
astropy__astropy-8707
astropy__astropy-8872
django__django-10097
django__django-10554
django__django-10880
django__django-10914
django__django-10973
django__django-10999
django__django-11066
django__django-11087
django__django-11095
django__django-11099
django__django-11119
django__django-11133
django__django-11138
django__django-11141
django__django-11149
django__django-11163
django__django-11179
django__django-11206
django__django-11211
django__django-11239
django__django-11265
django__django-11276
django__django-11292
django__django-11299
django__django-11333
django__django-11400
django__django-11433
django__django-11451
django__django-11477
django__django-11490
django__django-11532
django__django-11551
django__django-11555
django__django-11603
django__django-11728
django__django-11734
django__django-11740
django__django-11749
django__django-11790
django__django-11815
django__django-11820
django__django-11848
django__django-11880
django__django-11885
django__django-11951
django__django-11964
django__django-11999
django__django-12039
django__django-12050
django__django-12125
django__django-12143
django__django-12155
django__django-12193
django__django-12209
django__django-12262
django__django-12273
django__django-12276
django__django-12304
django__django-12308
django__django-12325
django__django-12406
django__django-12419
django__django-12663
django__django-12708
django__django-12713
django__django-12741
django__django-12754
django__django-12774
django__django-12858
django__django-12965
django__django-13012
django__django-13023
django__django-13028
django__django-13033
django__django-13089
django__django-13109
django__django-13112
django__django-13121
django__django-13128
django__django-13158
django__django-13195
django__django-13212
django__django-13279
django__django-13297
django__django-13315
django__django-13343
django__django-13344
django__django-13346
django__django-13363
django__django-13401
django__django-13406
django__django-13410
django__django-13417
django__django-13449
django__django-13512
django__django-13513
django__django-13516
django__django-13551
django__django-13568
django__django-13569
django__django-13590
django__django-13658
django__django-13670
django__django-13741
django__django-13786
django__django-13794
django__django-13807
django__django-13809
django__django-13810
django__django-13820
django__django-13821
django__django-13837
django__django-13925
django__django-13933
django__django-13964
django__django-14007
django__django-14011
django__django-14017
django__django-14034
django__django-14053
django__django-14089
django__django-14122
django__django-14140
django__django-14155
django__django-14170
django__django-14238
django__django-14311
django__django-14315
django__django-14349
django__django-14351
django__django-14373
django__django-14376
django__django-14404
django__django-14434
django__django-14493
django__django-14500
django__django-14534
django__django-14539
django__django-14559
django__django-14580
django__django-14608
django__django-14631
django__django-14672
django__django-14725
django__django-14752
django__django-14765
django__django-14771
django__django-14787
django__django-14792
django__django-14855
django__django-14915
django__django-14999
django__django-15022
django__django-15037
django__django-15098
django__django-15103
django__django-15104
django__django-15127
django__django-15128
django__django-15161
django__django-15252
django__django-15268
django__django-15277
django__django-15278
django__django-15280
django__django-15315
django__django-15368
django__django-15375
django__django-15380
django__django-15382
django__django-15467
django__django-15499
django__django-15503
django__django-15525
django__django-15554
django__django-15561
django__django-15563
django__django-15569
django__django-15572
django__django-15629
django__django-15695
django__django-15731
django__django-15732
django__django-15741
django__django-15814
django__django-15851
django__django-15863
django__django-15916
django__django-15930
django__django-15957
django__django-15973
django__django-15987
django__django-16032
django__django-16082
django__django-16100
django__django-16116
django__django-16136
django__django-16139
django__django-16145
django__django-16255
django__django-16256
django__django-16263
django__django-16315
django__django-16333
django__django-16429
django__django-16454
django__django-16485
django__django-16493
django__django-16502
django__django-16527
django__django-16560
django__django-16569
django__django-16595
django__django-16612
django__django-16631
django__django-16642
django__django-16661
django__django-16662
django__django-16667
django__django-16801
django__django-16819
django__django-16877
django__django-16899
django__django-16901
django__django-16938
django__django-16950
django__django-17029
django__django-17084
django__django-17087
django__django-7530
django__django-9296
matplotlib__matplotlib-13989
matplotlib__matplotlib-14623
matplotlib__matplotlib-20488
matplotlib__matplotlib-20676
matplotlib__matplotlib-20826
matplotlib__matplotlib-20859
matplotlib__matplotlib-21568
matplotlib__matplotlib-22719
matplotlib__matplotlib-22865
matplotlib__matplotlib-22871
matplotlib__matplotlib-23299
matplotlib__matplotlib-23314
matplotlib__matplotlib-23412
matplotlib__matplotlib-23476
matplotlib__matplotlib-24026
matplotlib__matplotlib-24149
matplotlib__matplotlib-24177
matplotlib__matplotlib-24570
matplotlib__matplotlib-24627
matplotlib__matplotlib-24637
matplotlib__matplotlib-24870
matplotlib__matplotlib-24970
matplotlib__matplotlib-25122
matplotlib__matplotlib-25287
matplotlib__matplotlib-25311
matplotlib__matplotlib-25332
matplotlib__matplotlib-25479
matplotlib__matplotlib-25775
matplotlib__matplotlib-25960
matplotlib__matplotlib-26113
matplotlib__matplotlib-26208
matplotlib__matplotlib-26291
matplotlib__matplotlib-26342
matplotlib__matplotlib-26466
mwaskom__seaborn-3069
mwaskom__seaborn-3187
pallets__flask-5014
psf__requests-1142
psf__requests-1724
psf__requests-1766
psf__requests-1921
psf__requests-2317
psf__requests-2931
psf__requests-5414
psf__requests-6028
pydata__xarray-2905
pydata__xarray-3095
pydata__xarray-3151
pydata__xarray-3305
pydata__xarray-3677
pydata__xarray-3993
pydata__xarray-4075
pydata__xarray-4094
pydata__xarray-4356
pydata__xarray-4629
pydata__xarray-4687
pydata__xarray-4695
pydata__xarray-4966
pydata__xarray-6461
pydata__xarray-6599
pydata__xarray-6721
pydata__xarray-6744
pydata__xarray-6938
pydata__xarray-6992
pydata__xarray-7229
pydata__xarray-7233
pydata__xarray-7393
pylint-dev__pylint-4551
pylint-dev__pylint-4604
pylint-dev__pylint-4661
pylint-dev__pylint-4970
pylint-dev__pylint-6386
pylint-dev__pylint-6528
pylint-dev__pylint-6903
pylint-dev__pylint-7080
pylint-dev__pylint-7277
pylint-dev__pylint-8898
pytest-dev__pytest-10051
pytest-dev__pytest-10081
pytest-dev__pytest-10356
pytest-dev__pytest-5262
pytest-dev__pytest-5631
pytest-dev__pytest-5787
pytest-dev__pytest-5809
pytest-dev__pytest-5840
pytest-dev__pytest-6197
pytest-dev__pytest-6202
pytest-dev__pytest-7205
pytest-dev__pytest-7236
pytest-dev__pytest-7324
pytest-dev__pytest-7432
pytest-dev__pytest-7490
pytest-dev__pytest-7521
pytest-dev__pytest-7571
pytest-dev__pytest-7982
pytest-dev__pytest-8399
scikit-learn__scikit-learn-10297
scikit-learn__scikit-learn-10844
scikit-learn__scikit-learn-10908
scikit-learn__scikit-learn-11310
scikit-learn__scikit-learn-11578
scikit-learn__scikit-learn-12585
scikit-learn__scikit-learn-12682
scikit-learn__scikit-learn-12973
scikit-learn__scikit-learn-13124
scikit-learn__scikit-learn-13135
scikit-learn__scikit-learn-13142
scikit-learn__scikit-learn-13328
scikit-learn__scikit-learn-13439
scikit-learn__scikit-learn-13496
scikit-learn__scikit-learn-13779
scikit-learn__scikit-learn-14053
scikit-learn__scikit-learn-14087
scikit-learn__scikit-learn-14141
scikit-learn__scikit-learn-14496
scikit-learn__scikit-learn-14629
scikit-learn__scikit-learn-14710
scikit-learn__scikit-learn-14894
scikit-learn__scikit-learn-14983
scikit-learn__scikit-learn-15100
scikit-learn__scikit-learn-25102
scikit-learn__scikit-learn-25232
scikit-learn__scikit-learn-25747
scikit-learn__scikit-learn-25931
scikit-learn__scikit-learn-25973
scikit-learn__scikit-learn-26194
scikit-learn__scikit-learn-26323
scikit-learn__scikit-learn-9288
sphinx-doc__sphinx-10323
sphinx-doc__sphinx-10435
sphinx-doc__sphinx-10449
sphinx-doc__sphinx-10466
sphinx-doc__sphinx-10614
sphinx-doc__sphinx-10673
sphinx-doc__sphinx-11445
sphinx-doc__sphinx-11510
sphinx-doc__sphinx-7440
sphinx-doc__sphinx-7454
sphinx-doc__sphinx-7462
sphinx-doc__sphinx-7590
sphinx-doc__sphinx-7748
sphinx-doc__sphinx-7757
sphinx-doc__sphinx-7889
sphinx-doc__sphinx-7910
sphinx-doc__sphinx-7985
sphinx-doc__sphinx-8035
sphinx-doc__sphinx-8056
sphinx-doc__sphinx-8120
sphinx-doc__sphinx-8265
sphinx-doc__sphinx-8269
sphinx-doc__sphinx-8459
sphinx-doc__sphinx-8475
sphinx-doc__sphinx-8548
sphinx-doc__sphinx-8551
sphinx-doc__sphinx-8593
sphinx-doc__sphinx-8595
sphinx-doc__sphinx-8621
sphinx-doc__sphinx-8638
sphinx-doc__sphinx-8721
sphinx-doc__sphinx-9229
sphinx-doc__sphinx-9230
sphinx-doc__sphinx-9258
sphinx-doc__sphinx-9281
sphinx-doc__sphinx-9320
sphinx-doc__sphinx-9367
sphinx-doc__sphinx-9461
sphinx-doc__sphinx-9591
sphinx-doc__sphinx-9602
sphinx-doc__sphinx-9658
sphinx-doc__sphinx-9673
sphinx-doc__sphinx-9698
sphinx-doc__sphinx-9711
sympy__sympy-11618
sympy__sympy-12096
sympy__sympy-12419
sympy__sympy-12481
sympy__sympy-12489
sympy__sympy-13031
sympy__sympy-13091
sympy__sympy-13372
sympy__sympy-13480
sympy__sympy-13551
sympy__sympy-13615
sympy__sympy-13647
sympy__sympy-13757
sympy__sympy-13798
sympy__sympy-13852
sympy__sympy-13877
sympy__sympy-13878
sympy__sympy-13974
sympy__sympy-14248
sympy__sympy-14531
sympy__sympy-14711
sympy__sympy-14976
sympy__sympy-15017
sympy__sympy-15345
sympy__sympy-15349
sympy__sympy-15599
sympy__sympy-15809
sympy__sympy-15875
sympy__sympy-15976
sympy__sympy-16450
sympy__sympy-16597
sympy__sympy-16766
sympy__sympy-16792
sympy__sympy-16886
sympy__sympy-17139
sympy__sympy-17318
sympy__sympy-17630
sympy__sympy-17655
sympy__sympy-18189
sympy__sympy-18199
sympy__sympy-18211
sympy__sympy-18698
sympy__sympy-18763
sympy__sympy-19040
sympy__sympy-19346
sympy__sympy-19495
sympy__sympy-19637
sympy__sympy-19783
sympy__sympy-19954
sympy__sympy-20154
sympy__sympy-20428
sympy__sympy-20438
sympy__sympy-20590
sympy__sympy-20801
sympy__sympy-20916
sympy__sympy-21379
sympy__sympy-21596
sympy__sympy-21612
sympy__sympy-21847
sympy__sympy-21930
sympy__sympy-22080
sympy__sympy-22456
sympy__sympy-22714
sympy__sympy-22914
sympy__sympy-23262
sympy__sympy-23413
sympy__sympy-23534
sympy__sympy-23824
sympy__sympy-23950
sympy__sympy-24066
sympy__sympy-24213
sympy__sympy-24443
sympy__sympy-24539
sympy__sympy-24562
sympy__sympy-24661

31
env.example Normal file
View File

@@ -0,0 +1,31 @@
# API
## OpenRouter API
OPENROUTER_API_KEY=your-openrouter-api-key-here
## Anthropic Claude API
ANTHROPIC_API_KEY=your-anthropic-api-key-here
## DeepSeek API
DEEPSEEK_API_KEY=your-deepseek-api-key-here
## OpenAI API
OPENAI_API_KEY=your-openai-api-key-here
## Private API
PRIVATE_API_KEY=EMPTY
PRIVATE_MODEL_NAME=your-private-model-name-here
PRIVATE_DEPLOYMENT_TYPE=vllm
PRIVATE_URL=https://127.0.0.1:0/v1
OPENAI_BASE_URL=https://api.openai.com/v1
ANTHROPIC_BASE_URL=https://api.anthropic.com
DEEPSEEK_BASE_URL=https://api.deepseek.com/v1
OPENROUTER_BASE_URL=https://openrouter.ai/api/v1
## OpenRouter
OPENROUTER_APP_NAME=tokfinity-llm-client
OPENROUTER_SITE_URL=https://github.com/your-repo
# config path
CONFIG_PATH=config/config.yaml

63
eval.sh Normal file
View File

@@ -0,0 +1,63 @@
#!/usr/bin/env bash
set -euo pipefail
PROJECT_ROOT="$(cd "$(dirname "$0")" && pwd)"
VENV_DIR="$PROJECT_ROOT/venv"
if [[ ! -d "$VENV_DIR" ]]; then
echo "[ERROR] wrong venv directory: $VENV_DIR" >&2
exit 1
fi
source "$VENV_DIR/bin/activate"
CONFIG_PATH="$PROJECT_ROOT/config/config.yaml"
if [[ ! -f "$CONFIG_PATH" ]]; then
echo "[ERROR] config file not found: $CONFIG_PATH" >&2
exit 1
fi
eval "$(python - <<'PY'
from __future__ import annotations
from pathlib import Path
import yaml
project_root = Path(__file__).resolve().parent
config_path = project_root / 'config' / 'config.yaml'
with open(config_path, 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f) or {}
dataset_name = ((cfg.get('dataset') or {}).get('name')) or 'princeton-nlp/SWE-bench_Lite'
workspace_path = ((cfg.get('workspace') or {}).get('path')) or 'workspace'
preds_dir = (((cfg.get('result') or {}).get('preds') or {}).get('path')) or 'result'
preds_path = (project_root / workspace_path / preds_dir / 'preds.json').resolve()
print(f'DATASET_NAME="{dataset_name}"')
print(f'PREDICTIONS_PATH="{preds_path}"')
PY
)"
if [[ $# -eq 0 ]]; then
echo "[ERROR] run_id is required" >&2
echo "Usage: $0 <run_id>" >&2
exit 1
fi
RUN_ID="$1"
echo "Using dataset: $DATASET_NAME"
echo "Using predictions: $PREDICTIONS_PATH"
echo "Using run_id: $RUN_ID"
# run evaluation
python -m swebench.harness.run_evaluation \
--dataset_name "$DATASET_NAME" \
--predictions_path "$PREDICTIONS_PATH" \
--max_workers 20 \
--cache_level instance \
--run_id "$RUN_ID"

1914
figures/framework.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 236 KiB

41
requirements.txt Normal file
View File

@@ -0,0 +1,41 @@
# HTTP client
httpx>=0.25.0
requests>=2.28.0
# Data Validation and Serialization
pydantic>=2.0.0
# log record
loguru>=0.7.0
# type-check support
typing-extensions>=4.5.0
# env management
python-dotenv>=1.0.0
# YAML config file parse
PyYAML>=6.0
# dataset process
datasets>=2.14.0
# JSONPath parse
jsonpath-ng>=1.6.0
# Docker support
docker>=6.0.0
# cli interact support
pexpect>=4.8.0
# SWE-bench relative requirements
swebench>=1.0.0
# develop requirements (optional)
# pytest>=7.0.0
# pytest-asyncio>=0.21.0
# black>=23.0.0
# flake8>=6.0.0
# mypy>=1.0.0
pydantic>=2.11.9

1
src/__init__.py Normal file
View File

@@ -0,0 +1 @@
# src directory marker file

16
src/managers/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
# Export all modules
from .image_builder import *
from .log import *
from .llm_api import *
from .prompts import *
from .loop import *
from .decorators import *
__all__ = [
"image_builder",
"log",
"llm_api",
"prompts",
"loop",
"decorators"
]

View File

@@ -0,0 +1,13 @@
import threading
def singleton(cls):
instances = {}
lock = threading.Lock()
def get_instance(*args, **kwargs):
if cls not in instances:
with lock:
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance

View File

@@ -0,0 +1,8 @@
# Export all modules
from .build_image import *
from .dockerfiles import *
__all__ = [
"build_image",
"dockerfiles"
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,8 @@
# User image Dockerfile template
# image_key: instance_image_key corresponding to instance_id
# apt_packages: split by ' ', load multi packages needed to be installed in ubuntu
_DOCKERFILE_USER_IMAGE_PY = r"""
FROM {image_key}
RUN apt-get update && apt-get install -y --no-install-recommends {apt_packages}
"""

View File

@@ -0,0 +1,125 @@
"""
User defined logger for swebench
"""
import logging
from typing import Union, Optional
from swebench.harness.docker_utils import remove_image
from src.managers.log.logger import Logger as CustomLogger
from swebench.harness.docker_build import build_instance_image, run_threadpool
def patched_build_instance_images(
client,
dataset: list,
force_rebuild: bool = False,
max_workers: int = 4,
namespace: str = None,
tag: str = None,
env_image_tag: str = None,
custom_logger: Optional[Union[logging.Logger, CustomLogger]] = None,
):
"""
Monkey patched version of build_instance_images that supports custom logger
Args:
client: Docker client
dataset: List of test specs or dataset to build images for
force_rebuild: Whether to force rebuild the images even if they already exist
max_workers: Maximum number of worker threads
namespace: Namespace for images
tag: Tag for images
env_image_tag: Environment image tag
custom_logger: Custom logger to use (instead of creating new ones)
"""
from swebench.harness.docker_build import make_test_spec, build_env_images
test_specs = []
for instance in dataset:
spec = make_test_spec(
instance,
namespace=namespace,
instance_image_tag=tag,
env_image_tag=env_image_tag,
)
test_specs.append(spec)
if force_rebuild:
for spec in test_specs:
remove_image(client, spec.instance_image_key, "quiet")
_, env_failed = build_env_images(client, test_specs, force_rebuild, max_workers)
if len(env_failed) > 0:
# Don't build images for instances that depend on failed-to-build env images
dont_run_specs = [
spec for spec in test_specs if spec.env_image_key in env_failed
]
test_specs = [
spec for spec in test_specs if spec.env_image_key not in env_failed
]
if custom_logger:
custom_logger.info(
f"Skipping {len(dont_run_specs)} instances - due to failed env image builds"
)
else:
print(f"Skipping {len(dont_run_specs)} instances - due to failed env image builds")
if custom_logger:
custom_logger.info(f"Building instance images for {len(test_specs)} instances")
else:
print(f"Building instance images for {len(test_specs)} instances")
successful, failed = [], []
if custom_logger:
payloads = [(spec, client, custom_logger, False) for spec in test_specs]
else:
payloads = [(spec, client, None, False) for spec in test_specs]
successful, failed = run_threadpool(build_instance_image, payloads, max_workers)
if len(failed) == 0:
if custom_logger:
custom_logger.info("All instance images built successfully.")
else:
print("All instance images built successfully.")
else:
if custom_logger:
custom_logger.warning(f"{len(failed)} instance images failed to build.")
else:
print(f"{len(failed)} instance images failed to build.")
return successful, failed
def apply_logger_patch():
import swebench.harness.docker_build as docker_build_module
original_build_instance_images = docker_build_module.build_instance_images
docker_build_module.build_instance_images = patched_build_instance_images
return original_build_instance_images
def restore_logger_patch(original_function):
"""Recover original logger actions"""
import swebench.harness.docker_build as docker_build_module
docker_build_module.build_instance_images = original_function
class LoggerPatch:
"""Context manager"""
def __init__(self, logger: Optional[Union[logging.Logger, CustomLogger]] = None):
self.logger = logger
self.original_function = None
def __enter__(self):
self.original_function = apply_logger_patch()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.original_function:
restore_logger_patch(self.original_function)

View File

@@ -0,0 +1,132 @@
"""
Redirect print output of a third party repo to a specific log
"""
import sys
import logging
import re
from typing import Union, Set, Pattern
from typing import Optional, TextIO
from pathlib import Path
from src.managers.log.logger import Logger as CustomLogger
class StreamWrapper:
"""Simple stream wrapperfor redirecting stdout/stderr"""
def __init__(self, original_stream, write_func):
self.original_stream = original_stream
self.write_func = write_func
self.buffer = ""
def write(self, text):
self.buffer += text
while '\n' in self.buffer:
line, self.buffer = self.buffer.split('\n', 1)
if line.strip():
self.write_func(line + '\n')
return len(text)
def flush(self):
if self.buffer:
if self.buffer.strip():
self.write_func(self.buffer)
self.buffer = ""
if hasattr(self.original_stream, 'flush'):
self.original_stream.flush()
def __getattr__(self, name):
return getattr(self.original_stream, name)
class PrintRedirector:
"""Redirect print output of a third party repo to a specific log"""
TRACEBACK_START_PATTERN = re.compile(r'Traceback\s*\(most recent call last\):', re.IGNORECASE)
EXCEPTION_END_PATTERN = re.compile(
r'\b(Error|Exception|KeyError|ValueError|TypeError|AttributeError|'
r'ImportError|BuildError|DockerError|IndentationError|SyntaxError|'
r'RuntimeError|OSError|FileNotFoundError|PermissionError)\s*:',
re.IGNORECASE
)
ERROR_KEYWORDS = {'error', 'failed', 'exception', 'fatal', 'critical'}
WARNING_KEYWORDS = {'warning', 'warn', 'deprecated'}
SKIP_KEYWORDS = {'skipping', 'skip', 'ignoring', 'ignore'}
def __init__(self, logger: Union[logging.Logger, CustomLogger]):
self.logger = logger
self.original_print = None
self.original_stdout = None
self.original_stderr = None
self.traceback_buffer = []
self.in_traceback = False
def __enter__(self):
self.start_redirect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_redirect()
def start_redirect(self):
if isinstance(__builtins__, dict):
self.original_print = __builtins__['print']
else:
self.original_print = getattr(__builtins__, 'print')
if isinstance(__builtins__, dict):
__builtins__['print'] = self._redirected_print
else:
setattr(__builtins__, 'print', self._redirected_print)
def stop_redirect(self):
if self.original_print:
if isinstance(__builtins__, dict):
__builtins__['print'] = self.original_print
else:
setattr(__builtins__, 'print', self.original_print)
def _redirected_print(self, *args, **kwargs):
if not args:
return
output = ' '.join(str(arg) for arg in args)
if self.TRACEBACK_START_PATTERN.search(output):
self.in_traceback = True
self.traceback_buffer = [output]
return
if self.in_traceback:
self.traceback_buffer.append(output)
if self.EXCEPTION_END_PATTERN.search(output):
self._log_traceback_and_reset()
return
self._log_by_level(output)
def _log_traceback_and_reset(self):
full_traceback = '\n'.join(self.traceback_buffer)
self.logger.error(f"[Third party repo exception stack]\n{full_traceback}")
self.in_traceback = False
self.traceback_buffer.clear()
def _log_by_level(self, output: str):
output_lower = output.lower()
if any(keyword in output_lower for keyword in self.ERROR_KEYWORDS):
self.logger.error(f"[Third party] {output}")
elif any(keyword in output_lower for keyword in self.WARNING_KEYWORDS):
self.logger.warning(f"[Third party] {output}")
elif any(keyword in output_lower for keyword in self.SKIP_KEYWORDS):
self.logger.info(f"[Third party] {output}")
else:
self.logger.info(f"[Third party] {output}")
def redirect_swebench_prints(logger: Union[logging.Logger, CustomLogger]):
return PrintRedirector(logger)

View File

@@ -0,0 +1,98 @@
"""
LLM API Management Module
Provide a standard OpenAI format LLM API interface that supports:
- Unified Chat Completions endpoint
- Synchronous and asynchronous operations
- Streaming responses
- Tool Calling
- Error handling and retry mechanism(s)
Supported providers:
- OpenAI: OpenAI API and compatible services
- Anthropic: Claude models
- DeepSeek: DeepSeek models
- Private: Private modelsvLLM、TGI、Ollama etc.
Usage example::
# 1: use general manager (suggested)
from llm_api import LLMAPIManager
# create manager
manager = LLMAPIManager(
client_name="openai",
model_name="gpt-3.5-turbo",
stream=False
)
# sendf messages
response = manager.chat("Hello world!")
print(response)
# 2: Use client directly
from llm_api import OpenAIClient, ChatMessage, MessageRole
# creat client
client = OpenAIClient(api_key="your-api-key")
# create4 message
messages = [
ChatMessage(role=MessageRole.USER, content="你好,世界!")
]
# send request
request = client.create_request(messages=messages, model="gpt-3.5-turbo")
response = client.chat_completions_create(request)
print(response.choices[0].message.content)
"""
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
MessageRole,
Choice,
Usage,
)
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
from src.managers.llm_api.api_manager import (
LLMAPIManager,
create_manager,
create_common_manager,
COMMON_CONFIGS,
)
__all__ = [
"BaseLLMAPI",
"ChatMessage",
"ChatCompletionRequest",
"ChatCompletionResponse",
"ChatCompletionChunk",
"MessageRole",
"Choice",
"Usage",
"OpenAIClient",
"AnthropicClient",
"DeepSeekClient",
"OpenRouterClient",
"PrivateModelClient",
"LLMAPIManager",
"create_manager",
"create_common_manager",
"COMMON_CONFIGS",
]
__version__ = "1.0.0"
__author__ = "Tokfinity Team"
__description__ = "标准 OpenAI 格式的 LLM API 基类库"

View File

@@ -0,0 +1,477 @@
"""
LLM API manager
"""
import os
import time
from typing import Dict, List, Any, Optional, Union, Generator
from dotenv import load_dotenv
import yaml
from traceback import format_exc
from .base_client import (
BaseLLMAPI,
ChatMessage,
MessageRole,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
EmbeddingRequest,
EmbeddingResponse,
)
from .clients.openai.openai_client import OpenAIClient
from .clients.anthropic.anthropic_client import AnthropicClient
from .clients.deepseek.deepseek_client import DeepSeekClient
from .clients.openrouter.openrouter_client import OpenRouterClient
from .clients.private.private_client import PrivateModelClient
class LLMAPIManager:
SUPPORTED_CLIENTS = {
"openai": OpenAIClient,
"anthropic": AnthropicClient,
"deepseek": DeepSeekClient,
"openrouter": OpenRouterClient,
"private": PrivateModelClient,
}
DEFAULT_CONFIGS = {
"openai": {"base_url": None, "api_key_env": "OPENAI_API_KEY"},
"anthropic": {"base_url": None, "api_key_env": "ANTHROPIC_API_KEY"},
"deepseek": {"base_url": None, "api_key_env": "DEEPSEEK_API_KEY"},
"openrouter": {
"base_url": None,
"api_key_env": "OPENROUTER_API_KEY",
"extra_config": {
"app_name": "tokfinity-llm-client",
"site_url": "https://github.com/your-repo",
},
},
"private": {
"base_url": "http://localhost:8000/v1",
"api_key_env": "PRIVATE_API_KEY",
"extra_config": {"deployment_type": "vllm"},
},
}
def __init__(
self,
client_name: Optional[str] = None,
stream: bool = False,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
auto_load_env: bool = True,
logger: Optional[Any] = None,
config: Optional[Dict[str, Any]] = None,
**kwargs,
):
# if client not provided, get first provider from default providers in config
if client_name is None:
default_client, default_model = (
self._load_default_client_and_model_from_config()
)
self.client_name = default_client
self.default_model = default_model
else:
self.client_name = client_name.lower()
self.default_model = None
self.stream = stream
self.timeout = timeout
self.max_retries = max_retries
self.logger = logger
self.logger.info(f"[LLMAPIManager]: Using client: {self.client_name}, model: {self.default_model}.")
self.config = config
if auto_load_env:
self._load_environment()
if self.client_name not in self.SUPPORTED_CLIENTS:
raise ValueError(
f"Unsupported client: {client_name}"
f"Support client: {list(self.SUPPORTED_CLIENTS.keys())}"
)
self.client = self._create_client(api_key, base_url, logger, **kwargs)
def _load_environment(self) -> None:
"""load environment variables"""
env_paths = [".env", "../.env", "../../.env", "../../../.env"]
for env_path in env_paths:
if os.path.exists(env_path):
load_dotenv(env_path)
break
def _create_client(
self, api_key: Optional[str] = None, base_url: Optional[str] = None, logger: Optional[Any] = None, **kwargs
) -> BaseLLMAPI:
client_class = self.SUPPORTED_CLIENTS[self.client_name]
config = self.DEFAULT_CONFIGS[self.client_name]
if api_key is None:
api_key = os.getenv(config["api_key_env"])
if not api_key:
if self.client_name == "private":
api_key = "EMPTY" # private mode may not need a key
else:
raise ValueError(
f"Fail to find env variable, please set: {config['api_key_env']} "
f"or upload ai_key parameter when initialize"
)
if base_url is None:
env_key = f"{self.client_name.upper()}_BASE_URL"
if self.client_name == "private":
env_key = "PRIVATE_URL"
base_url = os.getenv(env_key)
if base_url is None:
base_url = config.get("base_url")
client_kwargs = {
"api_key": api_key,
"timeout": self.timeout,
"max_retries": self.max_retries,
}
if base_url:
client_kwargs["base_url"] = base_url
if logger is not None:
client_kwargs["logger"] = logger
extra_config = config.get("extra_config", {})
client_kwargs.update(extra_config)
client_kwargs.update(kwargs)
if self.client_name == "openrouter":
client_kwargs.setdefault("app_name", "tokfinity-llm-client")
elif self.client_name == "private":
client_kwargs.setdefault("deployment_type", "vllm")
return client_class(**client_kwargs)
def _load_default_client_and_model_from_config(self) -> (str, Optional[str]):
"""
Get first item from providers from config/config.yaml as the default client
And take the first model as default model
"""
# Parse the root of config file relative to the project root
# This file is in src/managers/llm_api/api_manager.py
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
config_path = os.path.join(base_dir, "config", "config.yaml")
if not os.path.exists(config_path):
raise FileNotFoundError(f"No config file: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
providers = cfg.get("providers")
if not isinstance(providers, dict) or len(providers) == 0:
raise ValueError("config.yaml lack of providers config or format error")
first_provider_name = next(iter(providers.keys()))
models = providers.get(first_provider_name) or []
first_model = (
models[0] if isinstance(models, list) and len(models) > 0 else None
)
client_key = first_provider_name.strip().lower()
if client_key not in self.SUPPORTED_CLIENTS:
raise ValueError(
f"Default provider '{first_provider_name}' in config not registered in SUPPORTED_CLIENTS"
)
return client_key, first_model
def chat(
self,
messages: List[Union[Dict[str, Any], ChatMessage]],
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
timeout: Optional[int] = None,
retry: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
**kwargs,
) -> Union[
ChatCompletionResponse, Generator[ChatCompletionChunk, None, None], None
]:
"""
Send chat message and get response
Args:
model: model name (optional, default self.default_model)
messages: Concatenated message list.
- If list of dict: [{"role": "system|user|assistant|tool", "content": "..."}]
- Or `ChatMessage` list
temperature: Temperature
max_tokens: Max tokens
timeout: Request timeout, use the value from initialization if not specified
retry: Max retries, use the value from initialization if not specified
tools: Tools description list (OpenAI tools format)
tool_choice: Tool choice strategy"auto" | "none" | {"type":..., ...}
**kwargs: Other request args
Returns:
Union[ChatCompletionResponse, Generator[ChatCompletionChunk, None, None], None]:
- Non-streaming: Return the complete ChatCompletionResponse object
- Stream: Return a streaming response generator
- Fail: Return None
"""
actual_timeout = timeout if timeout is not None else self.timeout
actual_retry = retry if retry is not None else self.max_retries
actual_model = model if model is not None else self.default_model
if not actual_model:
raise ValueError("Model unprovided, and cannot get default model from config")
normalized_messages: List[ChatMessage] = []
for msg in messages:
if isinstance(msg, ChatMessage):
if msg.content is None:
msg.content = ""
normalized_messages.append(msg)
else:
role_value = msg.get("role")
content_value = msg.get("content")
if content_value is None:
content_value = ""
role_enum = MessageRole(role_value)
normalized_messages.append(
ChatMessage(
role=role_enum,
content=content_value,
name=msg.get("name"),
tool_calls=msg.get("tool_calls"),
tool_call_id=msg.get("tool_call_id"),
)
)
last_exception = None
for attempt in range(actual_retry + 1):
try:
request = self.client.create_request(
messages=normalized_messages,
model=actual_model,
temperature=temperature,
max_tokens=max_tokens,
stream=self.stream,
tools=tools,
tool_choice=tool_choice,
config=self.config,
**kwargs,
)
response = self.client.chat_completions_create(request)
if self.stream:
# Streaming response: has not processed
return response
else:
# Non-streaming response: return complete ChatCompletionResponse
return response
except Exception as e:
last_exception = e
if attempt < actual_retry:
delay = min(2**attempt, 30)
if self.logger:
self.logger.warning(
f"{attempt + 1}th try failedretry after {delay} seconds: {str(e)}, traceback: {format_exc()}."
)
time.sleep(delay)
else:
if self.logger:
self.logger.error(
f"All {actual_retry + 1} tries failed: {str(e)}"
)
return None
def chat_simple(
self,
model: str,
messages: List[Union[Dict[str, Any], ChatMessage]],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
timeout: Optional[int] = None,
retry: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
**kwargs,
) -> Optional[str]:
# Ban stream
original_stream = self.stream
self.stream = False
try:
response = self.chat(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
timeout=timeout,
retry=retry,
tools=tools,
tool_choice=tool_choice,
**kwargs,
)
if response and hasattr(response, "choices") and response.choices:
return response.choices[0].message.content
return None
finally:
# Recover original stream settings
self.stream = original_stream
def create_embeddings(
self,
input_text: Union[str, List[str]],
model: str,
encoding_format: str = "float",
dimensions: Optional[int] = None,
user: Optional[str] = None,
timeout: Optional[int] = None,
retry: Optional[int] = None,
**kwargs,
) -> Optional[EmbeddingResponse]:
actual_timeout = timeout if timeout is not None else self.timeout
actual_retry = retry if retry is not None else self.max_retries
if isinstance(input_text, str):
text_list = [input_text]
single_input = True
else:
text_list = input_text
single_input = False
if self.logger:
self.logger.debug(
f"Start embedding request - client: {self.client_name}, model: {model}, text num: {len(text_list)}"
)
if not hasattr(self.client, "create_embeddings"):
error_msg = f"Client {self.client_name} not support embedding"
if self.logger:
self.logger.error(error_msg)
return None
last_exception = None
for attempt in range(actual_retry + 1):
try:
if self.logger:
self.logger.debug(f"{attempt + 1}th try to create embedding vector")
response = self.client.create_embeddings(
input_text=text_list,
model=model,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
timeout=actual_timeout,
max_retries=1,
**kwargs,
)
if response:
return response
else:
raise Exception("Client return empty response")
except Exception as e:
last_exception = e
if attempt < actual_retry:
delay = min(2**attempt, 30)
if self.logger:
self.logger.warning(
f"{attempt + 1}th embedding request failedretry after {delay}s: {str(e)}, traceback: {format_exc()}."
)
time.sleep(delay)
else:
if self.logger:
self.logger.error(
f"All {actual_retry + 1} embedding tries failed: {str(e)}, traceback: {format_exc()}."
)
return None
def get_client_info(self) -> Dict[str, Any]:
return {
"client_name": self.client_name,
"stream": self.stream,
"client_info": (
self.client.get_model_info()
if hasattr(self.client, "get_model_info")
else {}
),
}
def get_model_name(self) -> str:
return self.default_model or "unknown"
def close(self) -> None:
if hasattr(self.client, "close"):
self.client.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __str__(self) -> str:
return f"LLMAPIManager(client={self.client_name}, stream={self.stream})"
def __repr__(self) -> str:
return (
f"LLMAPIManager("
f"client_name='{self.client_name}', "
f"stream={self.stream})"
)
def create_manager(
client_name: str, stream: bool = False, logger: Optional[Any] = None, **kwargs
) -> LLMAPIManager:
return LLMAPIManager(
client_name=client_name, stream=stream, logger=logger, **kwargs
)
COMMON_CONFIGS = {
"openai_gpt4": {"client_name": "openai", "model_name": "gpt-4o"},
"openai_gpt35": {"client_name": "openai", "model_name": "gpt-3.5-turbo"},
"claude_sonnet": {
"client_name": "anthropic",
"model_name": "claude-3-5-sonnet-20241022",
},
"claude_haiku": {
"client_name": "anthropic",
"model_name": "claude-3-haiku-20240307",
},
"deepseek_chat": {"client_name": "deepseek", "model_name": "deepseek-chat"},
"deepseek_coder": {"client_name": "deepseek", "model_name": "deepseek-coder"},
}
def create_common_manager(
config_name: str, stream: bool = False, logger: Optional[Any] = None, **kwargs
) -> LLMAPIManager:
if config_name not in COMMON_CONFIGS:
raise ValueError(
f"Unknown config: {config_name}. Available configs: {list(COMMON_CONFIGS.keys())}"
)
config = COMMON_CONFIGS[config_name]
return LLMAPIManager(
client_name=config["client_name"], stream=stream, logger=logger, **kwargs
)

View File

@@ -0,0 +1,594 @@
"""
LLM API base class - standard OpenAI format
Multi LLM providers' universal API supported
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Union, AsyncGenerator, Generator
from dataclasses import dataclass, field
from enum import Enum
import json
import time
import logging
import requests
import asyncio
from traceback import format_exc
from src.managers.log.logger import Logger
class MessageRole(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@dataclass
class ChatMessage:
role: MessageRole
content: Union[str, List[Dict[str, Any]]]
name: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None
@dataclass
class ChatCompletionRequest:
messages: List[ChatMessage]
model: str
temperature: float = 0.7
max_tokens: Optional[int] = None
top_p: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
response_format: Optional[Dict[str, Any]] = None
seed: Optional[int] = None
user: Optional[str] = None
@dataclass
class Usage:
prompt_tokens: int
completion_tokens: int
total_tokens: int
@dataclass
class Choice:
index: int
message: ChatMessage
finish_reason: Optional[str] = None
logprobs: Optional[Dict[str, Any]] = None
@dataclass
class ChatCompletionResponse:
id: str
object: str
created: int
model: str
choices: List[Choice]
usage: Optional[Usage] = None
system_fingerprint: Optional[str] = None
@dataclass
class ChatCompletionChunk:
id: str
object: str
created: int
model: str
choices: List[Dict[str, Any]]
system_fingerprint: Optional[str] = None
@dataclass
class EmbeddingRequest:
input: Union[str, List[str]]
model: str
encoding_format: str = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
@dataclass
class EmbeddingData:
object: str
embedding: List[float]
index: int
@dataclass
class EmbeddingUsage:
prompt_tokens: int
total_tokens: int
@dataclass
class EmbeddingResponse:
object: str
data: List[EmbeddingData]
model: str
usage: EmbeddingUsage
class BaseLLMAPI(ABC):
"""
LLM API base class
Provide standard OpenAI format API, support:
- Synchronous/Asynchronous Chat Completions
- Streaming Responses
- Tool Calling
- Error handling and Retry
- Usage Statics
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
logger: Optional[Logger] = None,
**kwargs,
):
self.api_key = api_key
self.base_url = base_url
self.timeout = timeout
self.max_retries = max_retries
self.retry_delay = retry_delay
self.extra_config = kwargs
self.logger = logger
self.session: Optional[requests.Session] = None
self._initialize_client()
def _create_http_clients(self, headers: Dict[str, str]) -> None:
self.session = requests.Session()
self.session.headers.update(headers)
self.session.timeout = self.timeout
@abstractmethod
def _initialize_client(self) -> None:
pass
@abstractmethod
def _get_chat_endpoint(self) -> str:
pass
@abstractmethod
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
pass
@abstractmethod
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
pass
@abstractmethod
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
pass
def chat_completions_create(
self, request: ChatCompletionRequest
) -> Union[ChatCompletionResponse, Generator[ChatCompletionChunk, None, None]]:
self.validate_request(request)
def _make_request():
payload = self._build_request_payload(request)
endpoint = self._get_chat_endpoint()
if request.stream:
return self._stream_chat_completion(payload, endpoint)
else:
full_url = self.base_url.rstrip("/") + endpoint
headers = {}
if self.session and self.session.headers:
headers.update(self.session.headers)
else:
headers = {
"Content-Type": "application/json",
}
if self.api_key and self.api_key != "EMPTY":
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(
full_url, json=payload, headers=headers, timeout=self.timeout
)
if response.status_code != 200:
error_msg = f"API Error {response.status_code}: {response.text}"
print(error_msg)
raise requests.exceptions.HTTPError(error_msg, response=response)
response.raise_for_status()
return self._parse_response(response.json())
return self._retry_with_backoff(_make_request)
async def achat_completions_create(
self, request: ChatCompletionRequest
) -> Union[ChatCompletionResponse, AsyncGenerator[ChatCompletionChunk, None]]:
self.validate_request(request)
def _make_async_request():
payload = self._build_request_payload(request)
endpoint = self._get_chat_endpoint()
if request.stream:
return self._stream_chat_completion(
payload, endpoint
)
else:
full_url = self.base_url.rstrip("/") + endpoint
headers = {}
if self.session and self.session.headers:
headers.update(self.session.headers)
else:
headers = {
"Content-Type": "application/json",
}
if self.api_key and self.api_key != "EMPTY":
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(
full_url, json=payload, headers=headers, timeout=self.timeout
)
response.raise_for_status()
return self._parse_response(response.json())
import asyncio
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _make_async_request)
def create_message(self, role: MessageRole, content: str, config: Optional[Dict[str, Any]] = None, **kwargs) -> ChatMessage:
return ChatMessage(role=role, content=content, **kwargs)
def _make_token_cache_request(self, messages: List[ChatMessage], model: str, config: Optional[Dict[str, Any]] = None) -> list[ChatMessage]:
"""
Insert cache_control block to Claude/Anthropic model according to config, forward compatibility
- Only activate when model is "Claude/Anthropic"
- Preserve the content as is for other providers to avoid breaking compatibility.
- If token_cache.role_turns is configured, inject it matching the turn and role; otherwise, only inject it into the first system message.
"""
role_turns = self._load_role_turns(config)
#if self.logger:
# self.logger.debug(f"In _make_token_cache_request, got role_turns: {role_turns}.")
if not self._is_claude_like(model):
return messages
turn_to_roles: Dict[int, List[str]] = {}
try:
for rt in role_turns or []:
try:
turn = int(rt.get("turn"))
role = (rt.get("role") or "").strip().lower()
if role:
if turn not in turn_to_roles:
turn_to_roles[turn] = []
turn_to_roles[turn].append(role)
except Exception:
continue
except Exception:
turn_to_roles = {}
current_turn = 0
result_messages: List[ChatMessage] = []
for msg in messages:
msg_blocks = self._to_claude_blocks(msg.content)
# 0th turn rule: If met user, only add cache to its own content and put it into result.
if current_turn == 0 and msg.role == MessageRole.USER:
cached_blocks = self._add_cache_flag(msg_blocks)
result_messages.append(
ChatMessage(
role=msg.role,
content=cached_blocks,
name=msg.name,
tool_calls=msg.tool_calls,
tool_call_id=msg.tool_call_id,
)
)
#if self.logger:
# self.logger.debug(
# f"Applied cache to initial user message at turn {current_turn}."
# )
continue
# Other messages: just add it to result
result_messages.append(
ChatMessage(
role=msg.role,
content=msg_blocks,
name=msg.name,
tool_calls=msg.tool_calls,
tool_call_id=msg.tool_call_id,
)
)
# Hit anchor point: When the current turn's configuration includes the assistant and the current message is from the assistant
# add an extra system turn.
roles_for_turn = turn_to_roles.get(current_turn, [])
if msg.role == MessageRole.ASSISTANT and roles_for_turn and ("assistant" in roles_for_turn):
refresh_text = f"refresh cache tokens."
refresh_blocks = self._to_claude_blocks(refresh_text)
refresh_blocks = self._add_cache_flag(refresh_blocks)
result_messages.append(
ChatMessage(
role=MessageRole.SYSTEM,
content=refresh_blocks,
)
)
#if self.logger:
# self.logger.debug(
# f"Appended system refresh after assistant at turn {current_turn}, refresh_blocks: {refresh_blocks}."
# )
# assistant message make the round +1
if msg.role == MessageRole.ASSISTANT:
current_turn += 1
return result_messages
def _to_claude_blocks(self, content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
if isinstance(content, list):
return content
return [{"type": "text", "text": content}]
def _add_cache_flag(self, blocks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
out: List[Dict[str, Any]] = []
added = False
for blk in blocks:
if not added and isinstance(blk, dict) and blk.get("type") == "text":
out.append({**blk, "cache_control": {"type": "ephemeral"}})
added = True
else:
out.append(blk)
return out
def _is_claude_like(self, model: str) -> bool:
try:
model_lc = (model or "").lower()
return ("claude" in model_lc) or ("anthropic" in model_lc)
except Exception:
return False
def _load_role_turns(self, config: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
"""Load token_cache.role_turnsPriorityInvoker config > default value"""
try:
role_turns = None
# 1) Read from invoker's config
if config and isinstance(config, dict):
token_cache_cfg = config.get("token_cache")
if isinstance(token_cache_cfg, dict):
role_turns = token_cache_cfg.get("role_turns")
# 2) If still lack, use default value in current example config
if not role_turns:
role_turns = [
{"turn": 0, "role": "user"},
{"turn": 25, "role": "assistant"},
{"turn": 50, "role": "assistant"},
{"turn": 75, "role": "assistant"},
]
return role_turns
except Exception as e:
if self.logger:
self.logger.warning(f"Read token_cache.role_turns failuse default value: {e}")
return [
{"turn": 0, "role": "user"},
{"turn": 25, "role": "assistant"},
{"turn": 50, "role": "assistant"},
{"turn": 75, "role": "assistant"},
]
def create_request(
self, messages: List[ChatMessage], model: str, config: Optional[Dict[str, Any]] = None, **kwargs
) -> ChatCompletionRequest:
messages = self._make_token_cache_request(messages, model, config)
return ChatCompletionRequest(messages=messages, model=model, **kwargs)
def _handle_error(self, error: Exception) -> None:
if self.logger:
self.logger.error(f"API request failed: {error}")
else:
print(f"API request failed: {error}")
raise error
def _retry_with_backoff(self, func, *args, **kwargs):
last_exception = None
for attempt in range(self.max_retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < self.max_retries:
delay = self.retry_delay * (2**attempt)
if self.logger:
self.logger.warning(
f"{attempt + 1}th try failedretry after {delay}s: {e}, traceback: {format_exc()}."
)
time.sleep(delay)
else:
if self.logger:
self.logger.error(f"All retries failed: {e}, traceback: {format_exc()}.")
raise last_exception
def get_model_info(self) -> Dict[str, Any]:
return {
"provider": self.__class__.__name__,
"base_url": self.base_url,
"timeout": self.timeout,
"max_retries": self.max_retries,
}
def validate_request(self, request: ChatCompletionRequest) -> bool:
if not request.messages:
raise ValueError("Message list is empty")
if not request.model:
raise ValueError("Model name is empty")
for idx, message in enumerate(request.messages):
if not message.content and not message.tool_calls:
try:
msg_info = {
"index": idx,
"role": getattr(message.role, "value", str(message.role)),
"content_len": (
len(message.content)
if isinstance(message.content, str)
else (len(message.content) if isinstance(message.content, list) else 0)
),
"has_tool_calls": bool(message.tool_calls),
"tool_calls": message.tool_calls,
}
if self.logger:
self.logger.warning(
f"Request validation failed: Invalid message exists (lacking both content and tool_calls): {json.dumps(msg_info, ensure_ascii=False)}"
)
except Exception as e:
if self.logger:
self.logger.warning(
f"Request validation failed: Invalid message exists (lacking both content and tool_calls), index={idx}, error: {e}, traceback: {format_exc()}."
)
raise ValueError("Cannot lacking both content and tool_calls")
return True
def format_messages_for_api(
self, messages: List[ChatMessage]
) -> List[Dict[str, Any]]:
formatted_messages = []
for message in messages:
msg_dict = {"role": message.role.value, "content": message.content}
if message.name:
msg_dict["name"] = message.name
if message.tool_calls:
msg_dict["tool_calls"] = message.tool_calls
if message.tool_call_id:
msg_dict["tool_call_id"] = message.tool_call_id
formatted_messages.append(msg_dict)
return formatted_messages
def _stream_chat_completion(
self, payload: Dict[str, Any], endpoint: str
) -> Generator[ChatCompletionChunk, None, None]:
full_url = self.base_url.rstrip("/") + endpoint
headers = {}
if self.session and self.session.headers:
headers.update(self.session.headers)
else:
headers = {
"Content-Type": "application/json",
}
if self.api_key and self.api_key != "EMPTY":
headers["Authorization"] = f"Bearer {self.api_key}"
response = requests.post(
full_url, json=payload, headers=headers, timeout=self.timeout, stream=True
)
response.raise_for_status()
try:
line_count = 0
for line in response.iter_lines():
if line:
line_count += 1
line_str = line.decode("utf-8")
chunk = self._parse_stream_line(line_str)
if chunk:
yield chunk
else:
print(f"Jump invalid line")
print(f"Streaming request process done, {line_count} processed")
finally:
response.close()
async def _astream_chat_completion(
self, payload: Dict[str, Any], endpoint: str
) -> AsyncGenerator[ChatCompletionChunk, None]:
import asyncio
def _sync_stream():
return list(self._stream_chat_completion(payload, endpoint))
loop = asyncio.get_event_loop()
chunks = await loop.run_in_executor(None, _sync_stream)
for chunk in chunks:
yield chunk
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
# Process standard SSE format
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
return None
try:
chunk_data = json.loads(data)
return self._parse_stream_chunk(chunk_data)
except json.JSONDecodeError:
self.logger.warning(f"Unable to parse streaming data: {data}")
return None
return None
def close(self) -> None:
if self.session:
self.session.close()
async def aclose(self) -> None:
if self.session:
self.session.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
def __str__(self) -> str:
return f"{self.__class__.__name__}(base_url={self.base_url})"
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"base_url={self.base_url}, "
f"timeout={self.timeout}, "
f"max_retries={self.max_retries})"
)

View File

@@ -0,0 +1,23 @@
"""
LLM Client implement module
Diverse LLM provider
- OpenAI: OpenAI API and compatible service
- Anthropic: Claude models
- DeepSeek: DeepSeek models
- Private: Private modelsvLLM、TGI、Ollama etc.
"""
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
__all__ = [
"OpenAIClient",
"AnthropicClient",
"DeepSeekClient",
"OpenRouterClient",
"PrivateModelClient",
]

View File

@@ -0,0 +1,7 @@
"""
Anthropic Claude Client Module
"""
from src.managers.llm_api.clients.anthropic.anthropic_client import AnthropicClient
__all__ = ["AnthropicClient"]

View File

@@ -0,0 +1,288 @@
"""
Anthropic Claude Client Module
"""
import json
import time
from typing import Dict, List, Any, Optional, Union
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
)
class AnthropicClient(BaseLLMAPI):
"""
Anthropic Claude API Client
Anthropic Claude models supported
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
anthropic_version: str = "2023-06-01",
logger: Optional[Logger] = None,
**kwargs
):
"""
Initialize Anthropic client
Args:
api_key: Anthropic API key
base_url: API basic URLdefault Anthropic API
timeout: timeout in seconds
max_retries: max retries
retry_delay: retry delay in seconds
anthropic_version: API version
**kwargs: other args
"""
self.anthropic_version = anthropic_version
if base_url is None:
base_url = "https://api.anthropic.com"
super().__init__(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs
)
def _initialize_client(self) -> None:
"""Initialize HTTP client"""
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
"anthropic-version": self.anthropic_version,
"User-Agent": "tokfinity-llm-client/1.0",
}
self._create_http_clients(headers)
def _get_chat_endpoint(self) -> str:
return "/v1/messages"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
"""Build Anthropic API request payload"""
messages, system_prompt = self._convert_messages_to_anthropic_format(
request.messages
)
payload = {
"model": request.model,
"messages": messages,
"max_tokens": request.max_tokens or 1024,
"temperature": request.temperature,
"top_p": request.top_p,
"stream": request.stream,
}
if system_prompt:
payload["system"] = system_prompt
if request.stop is not None:
payload["stop_sequences"] = (
request.stop if isinstance(request.stop, list) else [request.stop]
)
if request.tools is not None:
payload["tools"] = self._convert_tools_to_anthropic_format(request.tools)
if request.tool_choice is not None:
payload["tool_choice"] = self._convert_tool_choice_to_anthropic_format(
request.tool_choice
)
return payload
def _convert_messages_to_anthropic_format(
self, messages: List[ChatMessage]
) -> tuple[List[Dict[str, Any]], Optional[str]]:
"""Convert messages to anthropic format"""
anthropic_messages = []
system_prompt = None
for message in messages:
if message.role == MessageRole.SYSTEM:
system_prompt = message.content
elif message.role in [MessageRole.USER, MessageRole.ASSISTANT]:
anthropic_messages.append(
{"role": message.role.value, "content": message.content}
)
return anthropic_messages, system_prompt
def _convert_tools_to_anthropic_format(
self, tools: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Convert OpenAI format tools to anthropic format"""
anthropic_tools = []
for tool in tools:
if tool.get("type") == "function":
function_def = tool.get("function", {})
anthropic_tool = {
"name": function_def.get("name", ""),
"description": function_def.get("description", ""),
"input_schema": function_def.get("parameters", {}),
}
anthropic_tools.append(anthropic_tool)
return anthropic_tools
def _convert_tool_choice_to_anthropic_format(
self, tool_choice: Union[str, Dict[str, Any]]
) -> Union[str, Dict[str, Any]]:
"""Convert OpenAI format tool choice to anthropic format"""
if isinstance(tool_choice, str):
if tool_choice == "auto":
return "auto"
elif tool_choice == "none":
return "none"
else:
return {"type": "tool", "name": tool_choice}
elif isinstance(tool_choice, dict):
if tool_choice.get("type") == "function":
return {
"type": "tool",
"name": tool_choice.get("function", {}).get("name", ""),
}
elif tool_choice.get("type") == "tool":
return tool_choice
return "auto"
def _convert_anthropic_tool_calls(
self, content_list: List[Dict[str, Any]]
) -> Optional[List[Dict[str, Any]]]:
"""Convert anthropic tool calls to OpenAI format"""
tool_calls = []
for item in content_list:
if item.get("type") == "tool_use":
tool_call = {
"id": item.get("id", ""),
"type": "function",
"function": {
"name": item.get("name", ""),
"arguments": item.get("input", {}),
},
}
tool_calls.append(tool_call)
return tool_calls if tool_calls else None
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
"""Convert anthropic API response to OpenAI format"""
content = ""
tool_calls = None
if response_data.get("content"):
content_data = (
response_data["content"][0] if response_data["content"] else {}
)
content = content_data.get("text", "")
if content_data.get("type") == "tool_use":
tool_calls = self._convert_anthropic_tool_calls(
response_data.get("content", [])
)
message = ChatMessage(
role=MessageRole.ASSISTANT, content=content, tool_calls=tool_calls
)
choice = Choice(
index=0,
message=message,
finish_reason=self._convert_stop_reason(response_data.get("stop_reason")),
)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("input_tokens", 0),
completion_tokens=usage_data.get("output_tokens", 0),
total_tokens=usage_data.get("input_tokens", 0)
+ usage_data.get("output_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object="chat.completion",
created=int(time.time()),
model=response_data.get("model", ""),
choices=[choice],
usage=usage,
)
def _convert_stop_reason(self, stop_reason: Optional[str]) -> Optional[str]:
"""Convert anthropic stop reason to OpenAI format"""
if stop_reason == "end_turn":
return "stop"
elif stop_reason == "max_tokens":
return "length"
elif stop_reason == "stop_sequence":
return "stop"
else:
return stop_reason
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
"""Convert anthropic steam response to OpenAI format"""
event_type = chunk_data.get("type")
if event_type == "content_block_delta":
delta = chunk_data.get("delta", {})
text = delta.get("text", "")
if text:
choices = [
{"index": 0, "delta": {"content": text}, "finish_reason": None}
]
return ChatCompletionChunk(
id=chunk_data.get("message", {}).get("id", ""),
object="chat.completion.chunk",
created=int(time.time()),
model=chunk_data.get("message", {}).get("model", ""),
choices=choices,
)
elif event_type == "message_stop":
choices = [{"index": 0, "delta": {}, "finish_reason": "stop"}]
return ChatCompletionChunk(
id=chunk_data.get("message", {}).get("id", ""),
object="chat.completion.chunk",
created=int(time.time()),
model=chunk_data.get("message", {}).get("model", ""),
choices=choices,
)
return None
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
try:
chunk_data = json.loads(line)
return self._parse_stream_chunk(chunk_data)
except json.JSONDecodeError:
# if not json, try SSE
return super()._parse_stream_line(line)

View File

@@ -0,0 +1,7 @@
"""
DeepSeek Client Module
"""
from src.managers.llm_api.clients.deepseek.deepseek_client import DeepSeekClient
__all__ = ["DeepSeekClient"]

View File

@@ -0,0 +1,164 @@
"""
DeepSeek Client Module
"""
import time
from typing import Dict, List, Any, Optional
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
)
class DeepSeekClient(BaseLLMAPI):
"""
DeepSeek API Client
DeepSeek models supported, including DeepSeek-Coder、DeepSeek-Chat etc.
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
logger: Optional[Logger] = None,
**kwargs,
):
if base_url is None:
base_url = "https://api.deepseek.com/v1"
super().__init__(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs,
)
def _initialize_client(self) -> None:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"User-Agent": "tokfinity-llm-client/1.0",
}
self._create_http_clients(headers)
def _get_chat_endpoint(self) -> str:
return "/chat/completions"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"temperature": request.temperature,
"top_p": request.top_p,
"stream": request.stream,
}
# optional
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.stop is not None:
payload["stop"] = request.stop
if request.frequency_penalty != 0.0:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty != 0.0:
payload["presence_penalty"] = request.presence_penalty
if request.tools is not None:
payload["tools"] = request.tools
if request.tool_choice is not None:
payload["tool_choice"] = request.tool_choice
if request.response_format is not None:
payload["response_format"] = request.response_format
if request.seed is not None:
payload["seed"] = request.seed
if request.user is not None:
payload["user"] = request.user
return payload
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
choices = []
for choice_data in response_data.get("choices", []):
message_data = choice_data.get("message", {})
tool_calls = message_data.get("tool_calls")
message = ChatMessage(
role=MessageRole(message_data.get("role", "assistant")),
content=message_data.get("content", ""),
tool_calls=tool_calls,
)
choice = Choice(
index=choice_data.get("index", 0),
message=message,
finish_reason=choice_data.get("finish_reason"),
logprobs=choice_data.get("logprobs"),
)
choices.append(choice)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object=response_data.get("object", "chat.completion"),
created=response_data.get("created", int(time.time())),
model=response_data.get("model", ""),
choices=choices,
usage=usage,
system_fingerprint=response_data.get("system_fingerprint"),
)
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
"""Parse stream chunk"""
return ChatCompletionChunk(
id=chunk_data.get("id", ""),
object=chunk_data.get("object", "chat.completion.chunk"),
created=chunk_data.get("created", int(time.time())),
model=chunk_data.get("model", ""),
choices=chunk_data.get("choices", []),
system_fingerprint=chunk_data.get("system_fingerprint"),
)
def list_models(self) -> Dict[str, Any]:
"""Get available models"""
response = self.client.get("/models")
response.raise_for_status()
return response.json()
async def alist_models(self) -> Dict[str, Any]:
response = await self.async_client.get("/models")
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,7 @@
"""
OpenAI Client Module
"""
from src.managers.llm_api.clients.openai.openai_client import OpenAIClient
__all__ = ["OpenAIClient"]

View File

@@ -0,0 +1,279 @@
"""
OpenAI Client
"""
import time
from typing import Dict, List, Any, Optional, Union
from traceback import format_exc
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingData,
EmbeddingUsage,
)
class OpenAIClient(BaseLLMAPI):
"""
OpenAI API Client
OpenAI API supported, and other API services compatible with the OpenAI format
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
organization: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
logger: Optional[Logger] = None,
**kwargs,
):
self.organization = organization
if base_url is None:
base_url = "https://api.openai.com/v1"
super().__init__(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs,
)
def _initialize_client(self) -> None:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"User-Agent": "tokfinity-llm-client/1.0",
}
if self.organization:
headers["OpenAI-Organization"] = self.organization
self._create_http_clients(headers)
return "/chat/completions"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"temperature": request.temperature,
"top_p": request.top_p,
"frequency_penalty": request.frequency_penalty,
"presence_penalty": request.presence_penalty,
"stream": request.stream,
}
# Optional
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.stop is not None:
payload["stop"] = request.stop
if request.tools is not None:
payload["tools"] = request.tools
if request.tool_choice is not None:
payload["tool_choice"] = request.tool_choice
if request.response_format is not None:
payload["response_format"] = request.response_format
if request.seed is not None:
payload["seed"] = request.seed
if request.user is not None:
payload["user"] = request.user
return payload
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
choices = []
for choice_data in response_data.get("choices", []):
message_data = choice_data.get("message", {})
tool_calls = message_data.get("tool_calls")
message = ChatMessage(
role=MessageRole(message_data.get("role", "assistant")),
content=message_data.get("content", ""),
tool_calls=tool_calls,
)
choice = Choice(
index=choice_data.get("index", 0),
message=message,
finish_reason=choice_data.get("finish_reason"),
logprobs=choice_data.get("logprobs"),
)
choices.append(choice)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object=response_data.get("object", "chat.completion"),
created=response_data.get("created", int(time.time())),
model=response_data.get("model", ""),
choices=choices,
usage=usage,
system_fingerprint=response_data.get("system_fingerprint"),
)
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
return ChatCompletionChunk(
id=chunk_data.get("id", ""),
object=chunk_data.get("object", "chat.completion.chunk"),
created=chunk_data.get("created", int(time.time())),
model=chunk_data.get("model", ""),
choices=chunk_data.get("choices", []),
system_fingerprint=chunk_data.get("system_fingerprint"),
)
def list_models(self) -> Dict[str, Any]:
response = self.client.get("/models")
response.raise_for_status()
return response.json()
async def alist_models(self) -> Dict[str, Any]:
response = await self.async_client.get("/models")
response.raise_for_status()
return response.json()
def create_embeddings(
self,
input_text: Union[str, List[str]],
model: str,
encoding_format: str = "float",
dimensions: Optional[int] = None,
user: Optional[str] = None,
timeout: Optional[int] = None,
max_retries: Optional[int] = None,
retry_delay: Optional[float] = None,
) -> EmbeddingResponse:
actual_timeout = timeout if timeout is not None else self.timeout
actual_max_retries = (
max_retries if max_retries is not None else self.max_retries
)
actual_retry_delay = (
retry_delay if retry_delay is not None else self.retry_delay
)
request = EmbeddingRequest(
input=input_text,
model=model,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
payload = self._build_embedding_request_payload(request)
for attempt in range(actual_max_retries + 1):
try:
print(f"Debug: Payload: {payload}")
response = self.session.post(
f"{self.base_url}/embeddings", json=payload, timeout=actual_timeout
)
if response.status_code == 200:
response_data = response.json()
return self._parse_embedding_response(response_data)
else:
error_msg = f"embedding failed (try {attempt + 1}): HTTP {response.status_code}"
if hasattr(response, "text"):
error_msg += f" - {response.text}"
print(f"Debug: {error_msg}")
if attempt < actual_max_retries:
print(f"Debug: wait {actual_retry_delay} and retry...")
time.sleep(actual_retry_delay)
continue
else:
raise Exception(f"All retries failed: {error_msg}")
except Exception as e:
error_msg = f"Embedding request failed (try {attempt + 1}): {str(e)}, traceback: {format_exc()}."
print(f"Debug: {error_msg}")
if attempt < actual_max_retries:
print(f"Debug: wait {actual_retry_delay} and retry...")
time.sleep(actual_retry_delay)
continue
else:
raise Exception(f"All retries failed: {str(e)}, traceback: {format_exc()}.")
raise Exception("Unknown error")
def _build_embedding_request_payload(
self, request: EmbeddingRequest
) -> Dict[str, Any]:
payload = {
"input": request.input,
"model": request.model,
"encoding_format": request.encoding_format,
}
if request.dimensions is not None:
payload["dimensions"] = request.dimensions
if request.user is not None:
payload["user"] = request.user
return payload
def _parse_embedding_response(
self, response_data: Dict[str, Any]
) -> EmbeddingResponse:
embedding_data_list = []
for data_item in response_data.get("data", []):
embedding_data = EmbeddingData(
object=data_item.get("object", "embedding"),
embedding=data_item.get("embedding", []),
index=data_item.get("index", 0),
)
embedding_data_list.append(embedding_data)
usage_data = response_data.get("usage", {})
usage = EmbeddingUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return EmbeddingResponse(
object=response_data.get("object", "list"),
data=embedding_data_list,
model=response_data.get("model", ""),
usage=usage,
)

View File

@@ -0,0 +1,8 @@
"""
OpenRouter Client
"""
from src.managers.llm_api.clients.openrouter.openrouter_client import OpenRouterClient
__all__ = ["OpenRouterClient"]

View File

@@ -0,0 +1,329 @@
"""
OpenRouter Client
"""
import time
from typing import Dict, List, Any, Optional
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
)
class OpenRouterClient(BaseLLMAPI):
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
app_name: Optional[str] = None,
site_url: Optional[str] = None,
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
logger: Optional[Logger] = None,
**kwargs,
):
"""
Initialize OpenRouter Client
Args:
api_key: OpenRouter API key
base_url: API base URL, default to OpenRouter official API
app_name: Application name (optional, for statistics)
site_url: Website URL (optional, for statistics)
timeout: Request timeout
max_retries: Maximum number of retries
retry_delay: Retry delay
**kwargs: Other configuration parameters
"""
self.app_name = app_name or "tokfinity-llm-client"
self.site_url = site_url
if base_url is None:
base_url = "https://openrouter.ai/api/v1"
super().__init__(
api_key=api_key,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs,
)
def _initialize_client(self) -> None:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"User-Agent": "tokfinity-llm-client/1.0",
"X-Title": self.app_name,
}
if self.site_url:
headers["HTTP-Referer"] = self.site_url
self._create_http_clients(headers)
def _get_chat_endpoint(self) -> str:
return "/chat/completions"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"temperature": request.temperature,
"top_p": request.top_p,
"stream": request.stream,
}
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.stop is not None:
payload["stop"] = request.stop
if request.frequency_penalty != 0.0:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty != 0.0:
payload["presence_penalty"] = request.presence_penalty
if request.tools is not None:
payload["tools"] = request.tools
if request.tool_choice is not None:
payload["tool_choice"] = request.tool_choice
if request.response_format is not None:
payload["response_format"] = request.response_format
if request.seed is not None:
payload["seed"] = request.seed
if request.user is not None:
payload["user"] = request.user
return payload
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
"""
Parse OpenRouter API response
Args:
response_data: API response data
Returns:
ChatCompletionResponse: Parsed response object
"""
choices = []
for choice_data in response_data.get("choices", []):
message_data = choice_data.get("message", {})
tool_calls = message_data.get("tool_calls")
message = ChatMessage(
role=MessageRole(message_data.get("role", "assistant")),
content=message_data.get("content", ""),
tool_calls=tool_calls,
)
choice = Choice(
index=choice_data.get("index", 0),
message=message,
finish_reason=choice_data.get("finish_reason"),
logprobs=choice_data.get("logprobs"),
)
choices.append(choice)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object=response_data.get("object", "chat.completion"),
created=response_data.get("created", int(time.time())),
model=response_data.get("model", ""),
choices=choices,
usage=usage,
system_fingerprint=response_data.get("system_fingerprint"),
)
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
"""
Parse stream chunk
Args:
chunk_data: Raw chunk data
Returns:
Optional[ChatCompletionChunk]: Parsed chunk
"""
return ChatCompletionChunk(
id=chunk_data.get("id", ""),
object=chunk_data.get("object", "chat.completion.chunk"),
created=chunk_data.get("created", int(time.time())),
model=chunk_data.get("model", ""),
choices=chunk_data.get("choices", []),
system_fingerprint=chunk_data.get("system_fingerprint"),
)
def list_models(self) -> Dict[str, Any]:
"""
Get available models
Returns:
Dict[str, Any]: Model list response
"""
response = self.client.get("/models")
response.raise_for_status()
return response.json()
async def alist_models(self) -> Dict[str, Any]:
"""
Async get available models
Returns:
Dict[str, Any]: Model list response
"""
response = await self.async_client.get("/models")
response.raise_for_status()
return response.json()
def get_generation_info(self, generation_id: str) -> Dict[str, Any]:
"""
Get specific generation request details
Args:
generation_id: Generation request ID
Returns:
Dict[str, Any]: Generation information
"""
response = self.client.get(f"/generation?id={generation_id}")
response.raise_for_status()
return response.json()
async def aget_generation_info(self, generation_id: str) -> Dict[str, Any]:
"""
Async get specific generation request details
Args:
generation_id: Generation request ID
Returns:
Dict[str, Any]: Generation information
"""
response = await self.async_client.get(f"/generation?id={generation_id}")
response.raise_for_status()
return response.json()
def get_account_credits(self) -> Dict[str, Any]:
"""
Get account credits
Returns:
Dict[str, Any]: Account credits information
"""
response = self.client.get("/auth/key")
response.raise_for_status()
return response.json()
async def aget_account_credits(self) -> Dict[str, Any]:
"""
Async get account credits
Returns:
Dict[str, Any]: Account credits information
"""
response = await self.async_client.get("/auth/key")
response.raise_for_status()
return response.json()
@staticmethod
def get_popular_models() -> List[str]:
"""
Get popular models
Returns:
List[str]: Popular model names
"""
return [
# OpenAI
"openai/gpt-4-turbo-preview",
"openai/gpt-4",
"openai/gpt-3.5-turbo",
# Anthropic
"anthropic/claude-3-opus",
"anthropic/claude-3-sonnet",
"anthropic/claude-3-haiku",
# Google
"google/gemini-pro",
"google/gemini-pro-vision",
# Meta
"meta-llama/llama-2-70b-chat",
"meta-llama/llama-2-13b-chat",
# Mistral
"mistralai/mixtral-8x7b-instruct",
"mistralai/mistral-7b-instruct",
# Open Source
"microsoft/wizardlm-2-8x22b",
"databricks/dbrx-instruct",
"cohere/command-r-plus",
]
def get_model_info(self, model_name: str) -> Dict[str, Any]:
"""
Get specific model details
Args:
model_name: Model name
Returns:
Dict[str, Any]: Model information
"""
models_response = self.list_models()
models = models_response.get("data", [])
for model in models:
if model.get("id") == model_name:
return model
raise ValueError(f"Model {model_name} not found")
async def aget_model_info(self, model_name: str) -> Dict[str, Any]:
"""
Async get specific model details
Args:
model_name: Model name
Returns:
Dict[str, Any]: Model information
"""
models_response = await self.alist_models()
models = models_response.get("data", [])
for model in models:
if model.get("id") == model_name:
return model
raise ValueError(f"Model {model_name} not found")

View File

@@ -0,0 +1,7 @@
"""
Private model client
"""
from src.managers.llm_api.clients.private.private_client import PrivateModelClient
__all__ = ["PrivateModelClient"]

View File

@@ -0,0 +1,321 @@
"""
Private model client
Implement of private model API based on the BaseLLMAPI
vLLM, Text Generation Inference, Ollama supported
"""
import json
import time
from typing import Dict, List, Any, Optional, Union, AsyncGenerator, Generator
from src.managers.log.logger import Logger
from src.managers.llm_api.base_client import (
BaseLLMAPI,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
ChatMessage,
MessageRole,
Choice,
Usage,
)
class PrivateModelClient(BaseLLMAPI):
def __init__(
self,
api_key: Optional[str] = None,
base_url: str = "http://localhost:8000/v1",
timeout: int = 60,
max_retries: int = 3,
retry_delay: float = 1.0,
deployment_type: str = "vllm",
custom_headers: Optional[Dict[str, str]] = None,
supports_tools: bool = True,
logger: Optional[Logger] = None,
**kwargs,
):
self.deployment_type = deployment_type.lower()
self.custom_headers = custom_headers or {}
self.supports_tools = supports_tools
super().__init__(
api_key=api_key or "EMPTY",
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
logger=logger,
**kwargs,
)
def _initialize_client(self) -> None:
headers = {
"Content-Type": "application/json",
"User-Agent": "tokfinity-llm-client/1.0",
}
if self.api_key and self.api_key != "EMPTY":
headers["Authorization"] = f"Bearer {self.api_key}"
headers.update(self.custom_headers)
if self.deployment_type == "ollama":
pass
elif self.deployment_type == "tgi":
pass
self._create_http_clients(headers)
def _get_chat_endpoint(self) -> str:
if self.deployment_type == "ollama":
return "/api/chat"
elif self.deployment_type == "tgi":
return "/generate"
else:
return "/chat/completions"
def _build_request_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
if self.deployment_type == "ollama":
return self._build_ollama_payload(request)
elif self.deployment_type == "tgi":
return self._build_tgi_payload(request)
else:
return self._build_openai_payload(request)
def _build_openai_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"temperature": request.temperature,
"top_p": request.top_p,
"stream": request.stream,
}
# Optional
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.stop is not None:
payload["stop"] = request.stop
if request.frequency_penalty != 0.0:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty != 0.0:
payload["presence_penalty"] = request.presence_penalty
if self.supports_tools and request.tools is not None:
payload["tools"] = request.tools
if self.supports_tools and request.tool_choice is not None:
payload["tool_choice"] = request.tool_choice
if request.response_format is not None:
payload["response_format"] = request.response_format
return payload
def _build_ollama_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
payload = {
"model": request.model,
"messages": self.format_messages_for_api(request.messages),
"stream": request.stream,
"options": {
"temperature": request.temperature,
"top_p": request.top_p,
},
}
if request.max_tokens is not None:
payload["options"]["num_predict"] = request.max_tokens
if request.stop is not None:
payload["options"]["stop"] = (
request.stop if isinstance(request.stop, list) else [request.stop]
)
return payload
def _build_tgi_payload(self, request: ChatCompletionRequest) -> Dict[str, Any]:
prompt = self._messages_to_prompt(request.messages)
payload = {
"inputs": prompt,
"parameters": {
"temperature": request.temperature,
"top_p": request.top_p,
"do_sample": True,
"stream": request.stream,
},
}
if request.max_tokens is not None:
payload["parameters"]["max_new_tokens"] = request.max_tokens
if request.stop is not None:
payload["parameters"]["stop_sequences"] = (
request.stop if isinstance(request.stop, list) else [request.stop]
)
return payload
def _messages_to_prompt(self, messages: List[ChatMessage]) -> str:
prompt_parts = []
for message in messages:
if message.role == MessageRole.SYSTEM:
prompt_parts.append(f"System: {message.content}")
elif message.role == MessageRole.USER:
prompt_parts.append(f"User: {message.content}")
elif message.role == MessageRole.ASSISTANT:
prompt_parts.append(f"Assistant: {message.content}")
prompt_parts.append("Assistant:")
return "\n\n".join(prompt_parts)
def _parse_response(self, response_data: Dict[str, Any]) -> ChatCompletionResponse:
if self.deployment_type == "ollama":
return self._parse_ollama_response(response_data)
elif self.deployment_type == "tgi":
return self._parse_tgi_response(response_data)
else:
return self._parse_openai_response(response_data)
def _parse_openai_response(
self, response_data: Dict[str, Any]
) -> ChatCompletionResponse:
choices = []
for choice_data in response_data.get("choices", []):
message_data = choice_data.get("message", {})
message = ChatMessage(
role=MessageRole(message_data.get("role", "assistant")),
content=message_data.get("content", ""),
tool_calls=message_data.get("tool_calls"),
)
choice = Choice(
index=choice_data.get("index", 0),
message=message,
finish_reason=choice_data.get("finish_reason"),
)
choices.append(choice)
usage_data = response_data.get("usage", {})
usage = None
if usage_data:
usage = Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_data.get("id", ""),
object=response_data.get("object", "chat.completion"),
created=response_data.get("created", int(time.time())),
model=response_data.get("model", ""),
choices=choices,
usage=usage,
)
def _parse_ollama_response(
self, response_data: Dict[str, Any]
) -> ChatCompletionResponse:
content = response_data.get("message", {}).get("content", "")
message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
choice = Choice(
index=0,
message=message,
finish_reason="stop" if response_data.get("done", False) else None,
)
return ChatCompletionResponse(
id=f"ollama-{int(time.time())}",
object="chat.completion",
created=int(time.time()),
model=response_data.get("model", ""),
choices=[choice],
)
def _parse_tgi_response(
self, response_data: Dict[str, Any]
) -> ChatCompletionResponse:
if isinstance(response_data, list) and len(response_data) > 0:
response_data = response_data[0]
content = response_data.get("generated_text", "")
message = ChatMessage(role=MessageRole.ASSISTANT, content=content)
choice = Choice(
index=0,
message=message,
finish_reason=response_data.get("finish_reason", "stop"),
)
return ChatCompletionResponse(
id=f"tgi-{int(time.time())}",
object="chat.completion",
created=int(time.time()),
model="tgi-model",
choices=[choice],
)
def _parse_stream_line(self, line: str) -> Optional[ChatCompletionChunk]:
if self.deployment_type == "ollama":
try:
chunk_data = json.loads(line)
return self._parse_ollama_chunk(chunk_data)
except json.JSONDecodeError:
return None
else:
# Standard SSE format
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
return None
try:
chunk_data = json.loads(data)
return self._parse_stream_chunk(chunk_data)
except json.JSONDecodeError:
return None
return None
def _parse_ollama_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
content = chunk_data.get("message", {}).get("content", "")
done = chunk_data.get("done", False)
choices = [
{
"index": 0,
"delta": {"content": content} if content else {},
"finish_reason": "stop" if done else None,
}
]
return ChatCompletionChunk(
id=f"ollama-{int(time.time())}",
object="chat.completion.chunk",
created=int(time.time()),
model=chunk_data.get("model", ""),
choices=choices,
)
def _parse_stream_chunk(
self, chunk_data: Dict[str, Any]
) -> Optional[ChatCompletionChunk]:
return ChatCompletionChunk(
id=chunk_data.get("id", ""),
object=chunk_data.get("object", "chat.completion.chunk"),
created=chunk_data.get("created", int(time.time())),
model=chunk_data.get("model", ""),
choices=chunk_data.get("choices", []),
)

View File

@@ -0,0 +1,43 @@
"""
Log manage Module
Provide log management functionality categorized by timestamp and level
Usage example:
from src.managers.log import Logger, create_logger
# Basic Usage
logger = Logger("logs", "my_app")
logger.info("This is an info")
logger.error("This is an error")
logger.close()
# Use the context manager
with Logger("logs", "my_app") as logger:
logger.info("Automatically manage resources")
# Convenient function
logger = create_logger("logs", "my_app")
logger.info("Convenient create")
logger.close()
"""
from .logger import (
Logger,
create_logger,
init_global_logger,
get_global_logger,
set_global_logger,
)
__all__ = [
"Logger",
"create_logger",
"init_global_logger",
"get_global_logger",
"set_global_logger",
]
__version__ = "1.0.0"
__author__ = "Tokfinity Team"
__description__ = "Timestamp:Directory and Level-Based Log Manager "

357
src/managers/log/logger.py Normal file
View File

@@ -0,0 +1,357 @@
"""
Log Manager
Create dir according to timestamp, store level-based log files
"""
import os
import logging
from datetime import datetime
from typing import Optional
from pathlib import Path
"""
Module-level Self-defined NOTICE level log (between INFO and WARNING)
"""
NOTICE_LEVEL = 25
if not hasattr(logging, "NOTICE"):
logging.addLevelName(NOTICE_LEVEL, "NOTICE")
def notice(self, message, *args, **kwargs):
if self.isEnabledFor(NOTICE_LEVEL):
self._log(NOTICE_LEVEL, message, args, **kwargs)
logging.Logger.notice = notice # type: ignore[attr-defined]
class ImageBuilderLogger:
"""
ImageBuilder log manager
Function:
- Create sub dir based on timestamp
- Create debug.log, info.log and error.log
- Provide standard log format, able to print logs two levels up and the code location
- Console and file outputs
"""
def __init__(self, log_base_path: str, console_output: bool = True):
self.log_base_path = Path(log_base_path)
self.console_output = console_output
self.log_dir = self._create_log_dir()
self.file_handlers = {}
self.logger = self._setup_logger()
def _create_log_dir(self) -> Path:
timestamp = datetime.now().strftime("%Y%m%d%H%M")
log_dir = self.log_base_path / f"build_images/{timestamp}"
log_dir.mkdir(parents=True, exist_ok=True)
return log_dir
def _setup_logger(self) -> logging.Logger:
logger = logging.getLogger("image_builder_logger")
logger.setLevel(logging.DEBUG)
logger.handlers.clear()
log_format = (
"%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - "
"%(pathname)s:%(lineno)d - %(funcName)s - %(message)s"
)
formatter = logging.Formatter(log_format)
if self.console_output:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
self._add_file_handlers(logger, formatter)
return logger
def _add_file_handlers(self, logger: logging.Logger, formatter: logging.Formatter):
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
}
for level_name, level_value in log_levels.items():
log_file = self.log_dir / f"{level_name}.log"
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
if level_name == "debug":
file_handler.addFilter(lambda record: record.levelno == logging.DEBUG)
elif level_name == "info":
file_handler.addFilter(lambda record: record.levelno == logging.INFO)
elif level_name == "warning":
file_handler.addFilter(lambda record: record.levelno == logging.WARNING)
elif level_name == "error":
file_handler.addFilter(lambda record: record.levelno == logging.ERROR)
logger.addHandler(file_handler)
self.file_handlers[level_name] = file_handler
all_log_file = self.log_dir / "all.log"
all_file_handler = logging.FileHandler(all_log_file, encoding="utf-8")
all_file_handler.setFormatter(formatter)
all_file_handler.setLevel(logging.DEBUG)
logger.addHandler(all_file_handler)
self.file_handlers["all"] = all_file_handler
def debug(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.debug(message, *args, **kwargs)
def info(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.info(message, *args, **kwargs)
def warning(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.warning(message, *args, **kwargs)
def error(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.error(message, *args, **kwargs)
@property
def log_file(self) -> str:
return str(self.log_dir / "all.log")
def get_log_dir(self) -> str:
return str(self.log_dir.absolute())
def get_log_files(self) -> dict:
return {
"debug": str(self.log_dir / "debug.log"),
"info": str(self.log_dir / "info.log"),
"warning": str(self.log_dir / "warning.log"),
"error": str(self.log_dir / "error.log"),
"all": str(self.log_dir / "all.log"),
}
def close(self):
for handler in self.file_handlers.values():
handler.close()
for handler in self.logger.handlers[:]:
self.logger.removeHandler(handler)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __str__(self) -> str:
return f"ImageBuilderLogger(dir={self.log_dir})"
def __repr__(self) -> str:
return (
f"ImageBuilderLogger("
f"base_path='{self.log_base_path}', "
f"log_dir='{self.log_dir}', "
f"console_output={self.console_output})"
)
class Logger:
"""
Log manager
Function:
- Create sub dir based on timestamp
- Create debug.log, info.log, notice.log, warning.log and error.log
- Provide standard log format
- Console and file outputs
"""
def __init__(
self,
log_base_path: str,
logger_name: str = "tokfinity_logger",
console_output: bool = True,
log_format: Optional[str] = None,
instance_id: Optional[str] = None,
):
self.log_base_path = Path(log_base_path)
self.logger_name = logger_name
self.console_output = console_output
self.instance_id = instance_id
self.log_format = log_format or (
"%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - "
"%(pathname)s:%(lineno)d - %(funcName)s - %(message)s"
)
self.log_dir = self._create_log_dir()
self.file_handlers = {}
self.logger = self._setup_logger()
def _create_log_dir(self) -> Path:
if self.instance_id:
log_dir = self.log_base_path / self.instance_id
else:
timestamp = datetime.now().strftime("%Y%m%d%H%M")
log_dir = self.log_base_path / timestamp
log_dir.mkdir(parents=True, exist_ok=True)
return log_dir
def _setup_logger(self) -> logging.Logger:
logger = logging.getLogger(self.logger_name)
logger.setLevel(logging.DEBUG)
logger.handlers.clear()
formatter = logging.Formatter(self.log_format)
if self.console_output:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
self._add_file_handlers(logger, formatter)
return logger
def _add_file_handlers(self, logger: logging.Logger, formatter: logging.Formatter):
log_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
"notice": NOTICE_LEVEL,
"warning": logging.WARNING,
"error": logging.ERROR,
}
for level_name, level_value in log_levels.items():
log_file = self.log_dir / f"{level_name}.log"
file_handler = logging.FileHandler(log_file, encoding="utf-8")
file_handler.setFormatter(formatter)
if level_name == "debug":
file_handler.addFilter(lambda record: record.levelno == logging.DEBUG)
elif level_name == "info":
file_handler.addFilter(lambda record: record.levelno == logging.INFO)
elif level_name == "notice":
file_handler.addFilter(lambda record: record.levelno == NOTICE_LEVEL)
elif level_name == "warning":
file_handler.addFilter(lambda record: record.levelno == logging.WARNING)
elif level_name == "error":
file_handler.addFilter(lambda record: record.levelno == logging.ERROR)
logger.addHandler(file_handler)
self.file_handlers[level_name] = file_handler
all_log_file = self.log_dir / "all.log"
all_file_handler = logging.FileHandler(all_log_file, encoding="utf-8")
all_file_handler.setFormatter(formatter)
all_file_handler.setLevel(logging.DEBUG)
logger.addHandler(all_file_handler)
self.file_handlers["all"] = all_file_handler
def debug(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.debug(message, *args, **kwargs)
def info(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.info(message, *args, **kwargs)
def notice(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.log(NOTICE_LEVEL, message, *args, **kwargs)
def warning(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.warning(message, *args, **kwargs)
def error(self, message: str, *args, **kwargs):
kwargs.setdefault("stacklevel", 2)
self.logger.error(message, *args, **kwargs)
def get_log_dir(self) -> str:
return str(self.log_dir.absolute())
def get_log_files(self) -> dict:
return {
"debug": str(self.log_dir / "debug.log"),
"info": str(self.log_dir / "info.log"),
"notice": str(self.log_dir / "notice.log"),
"warning": str(self.log_dir / "warning.log"),
"error": str(self.log_dir / "error.log"),
"all": str(self.log_dir / "all.log"),
}
@property
def log_file(self) -> str:
return str(self.log_dir / "all.log")
def close(self):
for handler in self.file_handlers.values():
handler.close()
for handler in self.logger.handlers[:]:
self.logger.removeHandler(handler)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __str__(self) -> str:
return f"Logger(name={self.logger_name}, dir={self.log_dir})"
def __repr__(self) -> str:
return (
f"Logger("
f"name='{self.logger_name}', "
f"base_path='{self.log_base_path}', "
f"log_dir='{self.log_dir}', "
f"console_output={self.console_output})"
)
def create_logger(
log_base_path: str,
logger_name: str = "tokfinity_logger",
console_output: bool = True,
) -> Logger:
return Logger(
log_base_path=log_base_path,
logger_name=logger_name,
console_output=console_output,
)
# Global log manager instance (optional)
_global_logger: Optional[Logger] = None
def get_global_logger() -> Optional[Logger]:
return _global_logger
def set_global_logger(logger: Logger):
global _global_logger
_global_logger = logger
def init_global_logger(
log_base_path: str, logger_name: str = "global_logger"
) -> Logger:
global _global_logger
_global_logger = Logger(log_base_path, logger_name)
return _global_logger

275
src/managers/loop/base.py Normal file
View File

@@ -0,0 +1,275 @@
from typing import Any, Dict, List
import json
from traceback import format_exc
from src.managers.log.logger import Logger
from src.managers.llm_api.api_manager import LLMAPIManager
from src.managers.prompts.prompts_manager import PromptsManager
from src.tools.base import (
ToolExecutor,
BASH_TOOL_NAME,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
SEARCH_TOOL_NAME,
SUBMIT_RESULT_TOOL_NAME,
)
from src.managers.loop.types import ToolStats, LLMUsage
class BaseLoop:
def __init__(self, instance_id: str, instance_data: Dict[str, Any], logger: Logger, prompts_manager: PromptsManager | None, llm_manager: LLMAPIManager | None, tool_executor: ToolExecutor, config: Dict[str, Any] | None = None):
self.instance_id = instance_id
self.instance_data = instance_data
self.logger = logger
self.prompts_manager = prompts_manager
self.llm_manager = llm_manager
self.tool_executor = tool_executor
self.config = config or {}
self.component_name = self.__class__.__name__
def _make_assistant(
self, content: str | None, tool_calls: Any, messages: List[Dict[str, Any]]
) -> bool:
"""
Construct an assistant message based on the current content and tool calls, and append it to the messages.
"""
safe_content = content or ""
if not safe_content and not tool_calls:
self.logger.warning(
f"[{self.component_name}] Assistant returned an empty message with no tool calls; skipping this message and prompting to continue"
)
messages.append(
{"role": "user", "content": "请继续分析问题并使用工具来解决问题。"}
)
return False
assistant_message: Dict[str, Any] = {"role": "assistant"}
if tool_calls and not safe_content:
assistant_message["content"] = ""
elif safe_content:
assistant_message["content"] = safe_content
if tool_calls:
assistant_message["tool_calls"] = tool_calls
messages.append(assistant_message)
return True
def _make_tool_response(
self, tool_results: List[Any], messages: List[Dict[str, Any]]
) -> None:
"""Convert tool execution results into standard tool messages (role=tool) and append them to the messages.
- Generate content per result: use prompts_manager.tool_response_prompts([{...}]) to produce the content
- Set tool_call_id: prefer ToolResult.id; fallback to ToolResult.call_id
"""
if not tool_results:
return
for result in tool_results:
single_dict = [
{
"name": getattr(result, "name", "unknown"),
"success": getattr(result, "success", False),
"result": getattr(result, "result", None) or "",
"error": getattr(result, "error", None) or "",
}
]
content_text = (
self.prompts_manager.tool_response_prompts(single_dict)
if self.prompts_manager
else ""
)
tool_call_id = getattr(result, "id", None) or getattr(
result, "call_id", None
)
messages.append(
{
"role": "tool",
"content": content_text,
"tool_call_id": tool_call_id,
}
)
def _response_log(
self, response: Any, first_content: str, first_tool_calls: Any, total_turns: int
) -> None:
"""notice log for the current turn's LLM output"""
try:
response_log: Dict[str, Any] = {}
if hasattr(response, "usage") and response.usage:
response_log["usage"] = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", None),
"completion_tokens": getattr(response.usage, "completion_tokens", None),
"total_tokens": getattr(response.usage, "total_tokens", None),
}
if hasattr(response, "choices") and response.choices:
response_log["choice"] = {
"message": {
"content": first_content,
"tool_calls": first_tool_calls,
}
}
if response_log:
self.logger.notice(
f"[{self.component_name}] The {total_turns}th turn output: {json.dumps(response_log, ensure_ascii=False)}"
)
else:
self.logger.notice(
f"[{self.component_name}] The {total_turns}th turn output: {str(response)}"
)
except Exception:
self.logger.notice(
f"[{self.component_name}] 第 {total_turns} 轮: LLM 输出序列化失败,使用字符串表示: {str(response)}, traceback: {format_exc()}."
)
def _debug_messages(
self, turn: int, messages: List[Dict[str, Any]], prefix_len: int = 300
) -> None:
"""debug log for the messages to be sent to the model"""
try:
self.logger.debug(f"[{self.component_name}] msg:")
recent_messages = messages[-2:] if len(messages) > 2 else messages
base_index = len(messages) - len(recent_messages)
for offset, msg in enumerate[Dict[str, Any]](recent_messages):
idx = base_index + offset
role = msg.get("role")
content = msg.get("content")
content_str = content if isinstance(content, str) else ""
preview = content_str[:prefix_len]
content_len = len(content_str)
extra = ""
if role == "assistant":
tool_calls = msg.get("tool_calls")
has_tool = tool_calls is not None and tool_calls != []
try:
tool_calls_json = json.dumps(tool_calls, ensure_ascii=False)
except Exception:
self.logger.warning(
f"[{self.component_name}] In debug_messages function, fail: {format_exc()}, tool calls: {tool_calls}."
)
tool_calls_json = str(tool_calls)
extra = f", has_tool_calls={has_tool}, tool_calls={tool_calls_json}"
elif role == "tool":
tool_call_id = msg.get("tool_call_id")
extra = f", tool_call_id={tool_call_id}"
self.logger.debug(
f"[{self.component_name}] {turn+1}th, msg#{idx}: role={role}, content_len={content_len}, content_preview={json.dumps(preview, ensure_ascii=False)}{extra}"
)
except Exception:
self.logger.warning(
f"[{self.component_name}] In debug_messages function, fail msg: {format_exc()}."
)
def _debug_last_message(
self, turn: int, messages: List[Dict[str, Any]], prefix_len: int = 300
) -> None:
"""debug last turn msg"""
try:
if not messages:
return
last_assistant_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") == "assistant":
last_assistant_idx = i
break
if last_assistant_idx is None:
return
msg = messages[last_assistant_idx]
content = msg.get("content")
content_str = content if isinstance(content, str) else ""
preview = content_str[:prefix_len]
content_len = len(content_str)
tool_calls = msg.get("tool_calls")
has_tool = tool_calls is not None and tool_calls != []
try:
tool_calls_json = json.dumps(tool_calls, ensure_ascii=False)
except Exception:
self.logger.warning(
f"[{self.component_name}] In debug_last_message function, fail: {format_exc()}, tool calls: {tool_calls}."
)
tool_calls_json = str(tool_calls)
self.logger.debug(
f"[{self.component_name}] {turn+1}th turn, output_preview: role=assistant, content_len={content_len}, content_preview={json.dumps(preview, ensure_ascii=False)}, has_tool_calls={has_tool}, tool_calls={tool_calls_json}"
)
except Exception:
self.logger.warning(
f"[{self.component_name}] In debug_last_message function, last turn fail: {format_exc()}."
)
def _debug_tools(self, tools: List[Dict[str, Any]]) -> None:
"""debug tools msg"""
try:
self.logger.debug(f"[{self.component_name}] tools num: {len(tools)}")
for i, tool in enumerate(tools):
try:
tool_json = json.dumps(tool, ensure_ascii=False)
self.logger.debug(f"[{self.component_name}] tool #{i+1}: {tool_json}")
except Exception:
self.logger.debug(
f"[{self.component_name}] tool #{i+1} fail: {format_exc()}, string: {str(tool)}."
)
except Exception:
try:
self.logger.warning(
f"[{self.component_name}] fail; traceback: {format_exc()}."
)
self.logger.warning(f"[{self.component_name}] tools string: {str(tools)}")
except Exception:
pass
def _get_tools(self) -> List[Dict[str, Any]]:
pass
def _is_bash_tool(self, tool_name: str) -> bool:
return BASH_TOOL_NAME in tool_name
def _is_edit_tool(self, tool_name: str) -> bool:
return "edit" in tool_name or "str_replace" in tool_name or STR_REPLACE_BASED_EDIT_TOOL_NAME in tool_name
def _is_search_tool(self, tool_name: str) -> bool:
return SEARCH_TOOL_NAME in tool_name or "search" in tool_name
def _is_submit_result_tool(self, tool_name: str) -> bool:
return SUBMIT_RESULT_TOOL_NAME in tool_name
def _update_usage(self, response: Any, usage_stats: LLMUsage) -> None:
if hasattr(response, "usage") and response.usage:
usage_stats.prompt_tokens += int(getattr(response.usage, "prompt_tokens", 0) or 0)
usage_stats.completion_tokens += int(
getattr(response.usage, "completion_tokens", 0) or 0
)
usage_stats.total_tokens += int(getattr(response.usage, "total_tokens", 0) or 0)
def _init_usage_stats(self) -> LLMUsage:
return LLMUsage()
def _init_tools_stats(self) -> ToolStats:
return ToolStats()
def _update_tool_call_statistic(
self, tool_results: List[Any], tool_stats: ToolStats
) -> None:
for result in tool_results:
try:
tool_name = getattr(result, "name", "")
tool_name = tool_name.lower() if isinstance(tool_name, str) else ""
success = bool(getattr(result, "success", False))
if self._is_bash_tool(tool_name):
tool_stats.bash["count"] += 1
if not success:
tool_stats.bash["failed"] += 1
elif self._is_edit_tool(tool_name):
tool_stats.edit["count"] += 1
if not success:
tool_stats.edit["failed"] += 1
elif self._is_search_tool(tool_name):
tool_stats.search["count"] += 1
if not success:
tool_stats.search["failed"] += 1
elif self._is_submit_result_tool(tool_name):
tool_stats.submit_result["count"] += 1
if not success:
tool_stats.submit_result["failed"] += 1
except Exception:
continue

View File

@@ -0,0 +1,339 @@
from typing import Any, Dict, List
import json
from traceback import format_exc
from src.managers.log.logger import Logger
from src.managers.llm_api.api_manager import LLMAPIManager
from src.managers.prompts.prompts_manager import PromptsManager
from src.managers.loop.base import BaseLoop
from src.tools.base import (
ToolExecutor,
ToolResult,
SubmitToolResult,
BASH_TOOL_NAME,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
SEARCH_TOOL_NAME,
SUBMIT_RESULT_TOOL_NAME,
)
class PatchGenerator(BaseLoop):
def __init__(
self,
instance_id: str,
instance_data: Dict[str, Any],
logger: Logger,
prompts_manager: PromptsManager | None,
llm_manager: LLMAPIManager | None,
tool_executor: ToolExecutor,
config: Dict[str, Any] | None = None,
) -> None:
super().__init__(instance_id, instance_data, logger, prompts_manager, llm_manager, tool_executor, config)
async def _submit_all_tool_calls(
self, other_tool_calls: List[Dict[str, Any]]
) -> List[Any]:
"""execute tool calls, return tool execution results list"""
if not other_tool_calls:
return []
from src.tools.base import ToolCall
tool_call_objects = []
for tool_call_dict in other_tool_calls:
raw_args = tool_call_dict.get("function", {}).get("arguments", {})
parsed_args = raw_args
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except Exception as e:
self.logger.warning(f"[{self.component_name}] In _submit_all_tool_calls function, fail: {e}, traceback: {format_exc()}, args: {raw_args}.")
parsed_args = {}
tool_call_obj = ToolCall(
name=tool_call_dict.get("function", {}).get("name", ""),
call_id=tool_call_dict.get("id", ""),
arguments=parsed_args,
id=tool_call_dict.get("id", ""),
)
tool_call_objects.append(tool_call_obj)
return await self.tool_executor.container_sequential_tool_call(
tool_call_objects
)
def _process_submit_result_tool_result(
self,
submit_result: ToolResult,
golden_patch: List[Dict[str, Any]],
) -> None:
"""process submit_result tool call, fill golden_patch and log"""
if not submit_result.success or not submit_result.result:
self.logger.warning(f"[{self.component_name}] submit_result failed and no result.")
return
try:
submit_tool_result = SubmitToolResult.from_string(submit_result.result)
if submit_tool_result.output:
patch_info = {
"patch_content": submit_tool_result.output,
"test_status": submit_tool_result.test_status,
"reasoning": submit_tool_result.reasoning,
}
golden_patch.clear()
golden_patch.append(patch_info)
self.logger.info(
f"[{self.component_name}] patch len: {len(submit_tool_result.output)}."
)
self.logger.info(
f"[{self.component_name}] test status: {submit_tool_result.test_status}."
)
self.logger.info(
f"[{self.component_name}] reasoning: {submit_tool_result.reasoning[:100]}..."
)
else:
self.logger.warning(
f"[{self.component_name}] submit_result success but no patch content."
)
except Exception as e:
self.logger.error(f"[{self.component_name}] parse submit_result result fail: {e}, traceback: {format_exc()}.")
def _get_tools(self) -> List[Dict[str, Any]]:
tools = []
#use_openai_format = self._should_use_openai_format()
use_openai_format = True
for tool in self.tool_executor.tools.values():
if use_openai_format:
tool_def = tool._definition_for_openai_fmt()
else:
tool_def = tool._definition_for_claude_fmt()
tools.append(tool_def)
return tools
def _should_use_openai_format(self) -> bool:
if not self.llm_manager or not hasattr(self.llm_manager, "get_model_name"):
return True # openAI format by default
model_name = self.llm_manager.get_model_name().lower()
return "claude" not in model_name
def _get_issue_prompt(self) -> str:
"""generate issue prompt based on instance data"""
if not self.prompts_manager:
self.logger.warning("PromptsManager not initialized, cannot generate issue prompt.")
return ""
#instance_id = self.instance_data.get("instance_id", "")
#repo = self.instance_data.get("repo", "")
created_at = self.instance_data.get("created_at", "")
base_commit = self.instance_data.get("base_commit", "")
environment_setup_commit = self.instance_data.get(
"environment_setup_commit", ""
)
version = self.instance_data.get("version", "")
problem_statement = self.instance_data.get("problem_statement", "")
difficulty = self.instance_data.get("difficulty", "")
return self.prompts_manager.format_issue_prompt(
created_at=created_at,
base_commit=base_commit,
environment_setup_commit=environment_setup_commit,
version=version,
problem_statement=problem_statement,
difficulty=difficulty,
)
async def _generate_patch(self) -> Dict[str, Any] | None:
"""main loop logic for generating candidate patch"""
usage_stats = self._init_usage_stats()
tool_stats = self._init_tools_stats()
if not self.llm_manager or not self.prompts_manager:
self.logger.error(f"[{self.component_name}] LLM manager or prompts manager not initialized.")
return {
"success": False,
"golden_patch": [],
"llm_usage": usage_stats.to_dict(),
"tool_stats": tool_stats.to_dict(),
"total_turns": 0,
}
tools = self._get_tools()
self._debug_tools(tools)
root_path = self.config.get("builder", {}).get("repo_root_path", "")
max_turn = (
self.config.get("runner", {}).get("generator_loop", {}).get("max_turn", 10)
)
temperature = (
self.config.get("runner", {})
.get("generator_loop", {})
.get("temperature", 0.2)
)
issue_prompt = self._get_issue_prompt()
user_prompt = self.prompts_manager.get_generator_user(root_path, issue_prompt)
system_prompt = self.prompts_manager.get_generator_system(root_path)
total_turns = 0
golden_patch = []
try:
self.logger.info(
f"[{self.component_name}] {self.instance_id}: start generating candidate patch, max turn: {max_turn}"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
self.logger.notice(
f"[{self.component_name}]: {json.dumps(messages[0], ensure_ascii=False)}"
)
self.logger.notice(
f"[{self.component_name}]: {json.dumps(messages[1], ensure_ascii=False)}"
)
for turn in range(max_turn):
total_turns = turn + 1
self.logger.info(f"[{self.component_name}] The {total_turns}th turn started.")
try:
current_input_msg = messages[-1] if messages else None
if current_input_msg is not None:
self.logger.notice(
f"[{self.component_name}] The {total_turns}th turn input: {json.dumps(current_input_msg, ensure_ascii=False)}"
)
except Exception as e:
self.logger.warning(
f"[{self.component_name}] {total_turns}th turn: LLM input fail: {messages[-1] if messages else None}, error: {e}, traceback: {format_exc()}."
)
self._debug_messages(turn, messages)
response = self.llm_manager.chat(
messages=messages,
tools=tools,
tool_choice="auto",
temperature=temperature,
)
first_content: str = ""
first_tool_calls: Any = None
if hasattr(response, "choices") and response.choices:
ch0 = response.choices[0]
first_content = (
getattr(getattr(ch0, "message", None), "content", None) or ""
)
first_tool_calls = getattr(
getattr(ch0, "message", None), "tool_calls", None
)
self._response_log(
response, first_content, first_tool_calls, total_turns
)
self._update_usage(response, usage_stats)
if hasattr(response, "choices") and response.choices:
content = first_content
tool_calls = first_tool_calls
if not self._make_assistant(content, tool_calls, messages):
continue
if tool_calls:
self.logger.info(
f"[{self.component_name}] {total_turns}th turn: call {len(tool_calls)} tools."
)
tool_results = await self._submit_all_tool_calls(tool_calls)
self._update_tool_call_statistic(tool_results, tool_stats)
if tool_results:
submit_result = None
other_tool_results = []
for tool_result in tool_results:
tool_name = getattr(tool_result, "name", "")
if tool_name == SUBMIT_RESULT_TOOL_NAME:
submit_result = tool_result
else:
other_tool_results.append(tool_result)
if submit_result:
self.logger.debug(
f"[{self.component_name}] {total_turns}th turn: got submit_result tool call."
)
self.logger.debug(f"[{self.component_name}] {total_turns}th turn: submit_result result: {submit_result}")
self._process_submit_result_tool_result(
submit_result, golden_patch
)
self._debug_last_message(turn, messages)
break
if other_tool_results:
self._make_tool_response(other_tool_results, messages)
else:
messages.append(
{
"role": "user",
"content": "请继续分析问题并使用工具来解决问题。",
}
)
self.logger.debug(f"[{self.component_name}] final golden_patch: {golden_patch}")
success = (
len(golden_patch) > 0 and golden_patch[0].get("patch_content", "") != ""
)
self.logger.info(
f"[{self.component_name}] status={success}, total_turns={total_turns}, tools_stats={tool_stats}"
)
result_payload = {
"success": success,
"golden_patch": golden_patch,
"llm_usage": usage_stats.to_dict(),
"tool_stats": tool_stats.to_dict(),
"total_turns": total_turns,
}
try:
self.logger.notice(
f"[{self.component_name}] final output: {json.dumps(result_payload, ensure_ascii=False)}"
)
except Exception as e:
self.logger.warning(
f"[{self.component_name}] output: {str(result_payload)}, error: {e}, traceback: {format_exc()}."
)
return result_payload
except Exception as e:
self.logger.error(f"[{self.component_name}] fail: {e}, traceback: {format_exc()}.")
result_payload = {
"success": False,
"golden_patch": [],
"llm_usage": usage_stats.to_dict(),
"tool_stats": tool_stats.to_dict(),
"total_turns": total_turns,
}
try:
self.logger.notice(
f"[{self.component_name}] 最终返回数据(失败): {json.dumps(result_payload, ensure_ascii=False)}"
)
except Exception as e:
self.logger.notice(
f"[{self.component_name}] 最终返回数据(失败, 字符串回退): {str(result_payload)}, error: {e}, traceback: {format_exc()}."
)
return result_payload

View File

@@ -0,0 +1,338 @@
from typing import Any, Dict, List
import json
from traceback import format_exc
from src.managers.log.logger import Logger
from src.managers.llm_api.api_manager import LLMAPIManager
from src.managers.prompts.prompts_manager import PromptsManager
from src.managers.loop.types import GeneratorResult, SelectorResult, LLMUsage, ToolStats, PatchInfo
from src.tools.base import ToolExecutor, ToolCall, ToolResult
from src.managers.loop.base import BaseLoop
SELECTOR_SUBMIT_TOOL_NAME = "submit_result"
class PatchSelector(BaseLoop):
def __init__(
self,
instance_id: str,
instance_data: Dict[str, Any],
logger: Logger,
prompts_manager: PromptsManager | None,
llm_manager: LLMAPIManager | None,
tool_executor: ToolExecutor,
config: Dict[str, Any] | None = None,
) -> None:
super().__init__(instance_id, instance_data, logger, prompts_manager, llm_manager, tool_executor, config)
def _get_submit_result_tool_name(self):
return SELECTOR_SUBMIT_TOOL_NAME
def _definition_for_submit_tool(self, use_openai_format: bool) -> Dict[str, Any]:
"""submit_result tool"""
if use_openai_format:
return {
"type": "function",
"function": {
"name": self._get_submit_result_tool_name(),
"description": "Submit the final selected patch index and reasoning.",
"parameters": {
"type": "object",
"properties": {
"index": {
"type": "integer",
"description": "The chosen patch index (0-based).",
},
"reason": {
"type": "string",
"description": "Detailed reasoning for the selection.",
},
},
"required": ["index", "reason"],
},
},
}
return {
"type": "function",
"function": {
"name": self._get_submit_result_tool_name(),
"description": "Submit the final selected patch index and reasoning.",
"parameters": {
"type": "object",
"properties": {
"index": {
"type": "integer",
"description": "The chosen patch index (0-based).",
},
"reason": {
"type": "string",
"description": "Detailed reasoning for the selection.",
},
},
"required": ["index", "reason"],
},
},
}
def _build_user_prompt(self, candidates: List[GeneratorResult], root_path: str) -> str:
if not self.prompts_manager:
return ""
return self.prompts_manager.get_selector_user(self.instance_data, candidates, root_path)
def _get_system_prompt(self, patches_count: int, root_path: str) -> str:
if not self.prompts_manager:
return ""
return self.prompts_manager.get_selector_system(patches_count, root_path)
def _get_tools(self) -> List[Dict[str, Any]]:
tool_defs: List[Dict[str, Any]] = []
try:
for tool in self.tool_executor.tools.values():
try:
tool_defs.append(tool._definition_for_openai_fmt())
except Exception:
continue
except Exception:
pass
tool_defs.append(self._definition_for_submit_tool(True))
return tool_defs
def _extract_submit_choice(self, tool_call: Dict[str, Any]) -> Dict[str, Any] | None:
if not tool_call:
return None
fn = tool_call.get("function", {})
if fn.get("name") != self._get_submit_result_tool_name():
return None
raw_args = fn.get("arguments", {})
try:
args = json.loads(raw_args) if isinstance(raw_args, str) else raw_args
except Exception:
args = {}
index = args.get("index")
reason = args.get("reason")
if isinstance(index, int) and index >= 0:
return {"index": index, "reason": reason or ""}
return None
async def _submit_other_tool_calls(
self, tool_calls: List[Dict[str, Any]]
) -> List[ToolResult]:
if not tool_calls:
return []
to_run: List[ToolCall] = []
for tool_call_dict in tool_calls:
fn = tool_call_dict.get("function", {})
name = fn.get("name", "")
if name == SELECTOR_SUBMIT_TOOL_NAME:
continue
raw_args = fn.get("arguments", {})
parsed_args = raw_args
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
except Exception:
parsed_args = {}
to_run.append(
ToolCall(
name=name,
call_id=tool_call_dict.get("id", ""),
arguments=parsed_args,
id=tool_call_dict.get("id", ""),
)
)
if not to_run:
return []
results: List[ToolResult] = await self.tool_executor.container_sequential_tool_call(to_run)
return results
async def _select_patch(self, candidates: List[GeneratorResult]) -> SelectorResult:
if not candidates:
raise ValueError("No candidates provided")
if not self.llm_manager:
raise ValueError("LLM manager is not initialized")
tools = self._get_tools()
self._debug_tools(tools)
root_path = self.config.get("builder", {}).get("repo_root_path", "")
system_prompt = self._get_system_prompt(len(candidates), root_path)
user_prompt = self._build_user_prompt(candidates, root_path)
messages: List[Dict[str, Any]] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
try:
self.logger.notice(
f"[{self.component_name}]: {json.dumps(messages[0], ensure_ascii=False)}"
)
self.logger.notice(
f"[{self.component_name}]: {json.dumps(messages[1], ensure_ascii=False)}"
)
except Exception:
self.logger.warning(
f"[{self.component_name}] Initial fail in selector loop: SP={str(messages[0])}, UP={str(messages[1])}, traceback: {format_exc()}."
)
max_turn = int(
self.config.get("runner", {})
.get("selector_loop", {})
.get("max_turn", 200)
)
temperature = (
self.config.get("runner", {})
.get("selector_loop", {})
.get("temperature", 0.2)
)
usage_stats = self._init_usage_stats()
tool_stats = self._init_tools_stats()
total_turns = 0
chosen_index: int | None = None
select_reason: str = ""
for turn in range(max_turn):
try:
try:
current_input_msg = messages[-1] if messages else None
if current_input_msg is not None:
self.logger.notice(
f"[{self.component_name}] The {turn+1}th turn input: {json.dumps(current_input_msg, ensure_ascii=False)}"
)
except Exception:
self.logger.warning(
f"[{self.component_name}] {turn+1}th turn fail: {messages[-1] if messages else None}, traceback: {format_exc()}."
)
self._debug_messages(turn, messages)
response = self.llm_manager.chat(
messages=messages,
tools=tools,
tool_choice="auto",
temperature=temperature,
)
first_tool_calls = None
if hasattr(response, "choices") and response.choices:
ch0 = response.choices[0]
first_tool_calls = getattr(getattr(ch0, "message", None), "tool_calls", None)
first_content = getattr(getattr(ch0, "message", None), "content", None) or ""
else:
first_content = ""
total_turns = turn + 1
self._response_log(response, first_content, first_tool_calls, turn + 1)
self._update_usage(response, usage_stats)
if first_tool_calls:
if not self._make_assistant(first_content, first_tool_calls, messages):
messages.append(
{
"role": "user",
"content": "请完成分析并调用 submit_result 工具给出最终选择与理由。",
}
)
continue
submit_found = False
for tc in first_tool_calls:
choice = self._extract_submit_choice(tc)
if choice is not None:
chosen_index = choice["index"]
reason = choice.get("reason", "")
self.logger.info(
f"[{self.component_name}] choose: index={chosen_index}, reason={reason}"
)
select_reason = reason or ""
submit_found = True
self._debug_last_message(turn, messages)
break
if not submit_found:
results = await self._submit_other_tool_calls(first_tool_calls)
self._make_tool_response(results, messages)
self._update_tool_call_statistic(results, tool_stats)
else:
messages.append(
{
"role": "user",
"content": "请完成分析并调用 submit_result 工具给出最终选择与理由。",
}
)
if chosen_index is not None:
break
except Exception as e:
self.logger.warning(
f"[{self.component_name}] fail: {e}, traceback: {format_exc()}"
)
break
if chosen_index is None:
# If the model provides no choice, fallback: pick the first successful one; otherwise the first
for i, r in enumerate(candidates):
try:
if r.success:
chosen_index = i
break
except Exception:
continue
if chosen_index is None:
chosen_index = 0
if not (0 <= chosen_index < len(candidates)):
chosen_index = 0
selected = candidates[chosen_index]
try:
gp = selected.golden_patch[0] if selected.golden_patch else None
if gp is None:
patch_info = PatchInfo(patch_content="", test_status="", reasoning="")
else:
patch_info = PatchInfo(
patch_content=gp.patch_content,
test_status=gp.test_status,
reasoning=gp.reasoning,
)
except Exception:
patch_info = PatchInfo(patch_content="", test_status="", reasoning="")
selector_result = SelectorResult(
instance_id=selected.instance_id,
generator_id=selected.generator_id,
image=selected.image,
success=True,
golden_patch=patch_info,
llm_usage=usage_stats,
tool_stats=tool_stats,
total_turns=total_turns,
select_reason=select_reason,
error=None,
)
return selector_result

254
src/managers/loop/types.py Normal file
View File

@@ -0,0 +1,254 @@
"""
This module defines the GeneratorResult data structure for patch generation results.
"""
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional
from src.tools.base import (
BASH_TOOL_NAME,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
SEARCH_TOOL_NAME,
SUBMIT_RESULT_TOOL_NAME,
)
@dataclass
class LLMUsage:
"""LLM usage statistics."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
def to_dict(self) -> Dict[str, int]:
"""Serialize LLMUsage to a plain dictionary."""
return {
"prompt_tokens": int(self.prompt_tokens),
"completion_tokens": int(self.completion_tokens),
"total_tokens": int(self.total_tokens),
}
@dataclass
class ToolStats:
"""Tool usage statistics per tool.
Each tool is represented by a small map with two fields:
- count: total invocation count
- failed: failed invocation count
"""
bash: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
edit: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
search: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
submit_result: Dict[str, int] = field(default_factory=lambda: {"count": 0, "failed": 0})
def to_dict(self) -> Dict[str, Dict[str, int]]:
"""Serialize ToolStats to a plain dictionary."""
return {
BASH_TOOL_NAME: {"count": int(self.bash.get("count", 0)), "failed": int(self.bash.get("failed", 0))},
STR_REPLACE_BASED_EDIT_TOOL_NAME: {"count": int(self.edit.get("count", 0)), "failed": int(self.edit.get("failed", 0))},
SEARCH_TOOL_NAME: {"count": int(self.search.get("count", 0)), "failed": int(self.search.get("failed", 0))},
SUBMIT_RESULT_TOOL_NAME: {"count": int(self.submit_result.get("count", 0)), "failed": int(self.submit_result.get("failed", 0))},
}
@dataclass
class PatchInfo:
"""Information about a generated patch."""
patch_content: str
test_status: str
reasoning: str
@dataclass
class GeneratorResult:
"""Result from a patch generator."""
instance_id: str
generator_id: int
image: str
success: bool
golden_patch: List[
PatchInfo
]
llm_usage: LLMUsage
tool_stats: ToolStats
total_turns: int
error: Optional[str] = None
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GeneratorResult":
"""Create GeneratorResult from dictionary."""
# Handle golden_patch conversion
golden_patch = []
if data.get("golden_patch"):
for patch_data in data["golden_patch"]:
if isinstance(patch_data, dict):
golden_patch.append(
PatchInfo(
patch_content=patch_data.get("patch_content", ""),
test_status=patch_data.get("test_status", ""),
reasoning=patch_data.get("reasoning", ""),
)
)
else:
# Legacy format: just patch content string
golden_patch.append(
PatchInfo(
patch_content=str(patch_data), test_status="", reasoning=""
)
)
# Handle LLM usage
llm_usage_data = data.get("llm_usage", {})
llm_usage = LLMUsage(
prompt_tokens=llm_usage_data.get("prompt_tokens", 0),
completion_tokens=llm_usage_data.get("completion_tokens", 0),
total_tokens=llm_usage_data.get("total_tokens", 0),
)
# Handle tool stats
tool_stats_data = data.get("tool_stats", {})
tool_stats = ToolStats(
bash=tool_stats_data.get(BASH_TOOL_NAME, 0),
edit=tool_stats_data.get(STR_REPLACE_BASED_EDIT_TOOL_NAME, 0),
search=tool_stats_data.get(SEARCH_TOOL_NAME, 0),
submit_result=tool_stats_data.get(SUBMIT_RESULT_TOOL_NAME, 0),
)
return cls(
instance_id=data.get("instance_id", ""),
generator_id=data.get("generator_id", 0),
image=data.get("image", ""),
success=data.get("success", False),
golden_patch=golden_patch,
llm_usage=llm_usage,
tool_stats=tool_stats,
total_turns=data.get("total_turns", 0),
error=data.get("error"),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert GeneratorResult to dictionary."""
return {
"instance_id": self.instance_id,
"generator_id": self.generator_id,
"image": self.image,
"success": self.success,
"golden_patch": [
{
"patch_content": patch.patch_content,
"test_status": patch.test_status,
"reasoning": patch.reasoning,
}
for patch in self.golden_patch
],
"llm_usage": {
"prompt_tokens": self.llm_usage.prompt_tokens,
"completion_tokens": self.llm_usage.completion_tokens,
"total_tokens": self.llm_usage.total_tokens,
},
"tool_stats": {
BASH_TOOL_NAME: self.tool_stats.bash,
STR_REPLACE_BASED_EDIT_TOOL_NAME: self.tool_stats.edit,
SEARCH_TOOL_NAME: self.tool_stats.search,
SUBMIT_RESULT_TOOL_NAME: self.tool_stats.submit_result,
},
"total_turns": self.total_turns,
"error": self.error,
}
@dataclass
class SelectorResult:
"""Result from a patch selector.
"""
instance_id: str
generator_id: int
image: str
success: bool
golden_patch: PatchInfo
llm_usage: LLMUsage
tool_stats: ToolStats
total_turns: int
select_reason: str
error: Optional[str] = None
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SelectorResult":
"""Create SelectorResult from dictionary."""
gp_data = data.get("golden_patch", {})
if isinstance(gp_data, dict):
golden_patch = PatchInfo(
patch_content=gp_data.get("patch_content", ""),
test_status=gp_data.get("test_status", ""),
reasoning=gp_data.get("reasoning", ""),
)
else:
golden_patch = PatchInfo(
patch_content=str(gp_data) if gp_data is not None else "",
test_status="",
reasoning="",
)
# LLM usage
llm_usage_data = data.get("llm_usage", {})
llm_usage = LLMUsage(
prompt_tokens=llm_usage_data.get("prompt_tokens", 0),
completion_tokens=llm_usage_data.get("completion_tokens", 0),
total_tokens=llm_usage_data.get("total_tokens", 0),
)
# Tool stats
tool_stats_data = data.get("tool_stats", {})
tool_stats = ToolStats(
bash=tool_stats_data.get(BASH_TOOL_NAME, 0),
edit=tool_stats_data.get(STR_REPLACE_BASED_EDIT_TOOL_NAME, 0),
search=tool_stats_data.get(SEARCH_TOOL_NAME, 0),
submit_result=tool_stats_data.get(SUBMIT_RESULT_TOOL_NAME, 0),
)
return cls(
instance_id=data.get("instance_id", ""),
generator_id=data.get("generator_id", 0),
image=data.get("image", ""),
success=data.get("success", False),
golden_patch=golden_patch,
llm_usage=llm_usage,
tool_stats=tool_stats,
total_turns=data.get("total_turns", 0),
select_reason=data.get("select_reason", ""),
error=data.get("error"),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert SelectorResult to dictionary."""
return {
"instance_id": self.instance_id,
"generator_id": self.generator_id,
"image": self.image,
"success": self.success,
"golden_patch": {
"patch_content": self.golden_patch.patch_content,
"test_status": self.golden_patch.test_status,
"reasoning": self.golden_patch.reasoning,
},
"llm_usage": {
"prompt_tokens": self.llm_usage.prompt_tokens,
"completion_tokens": self.llm_usage.completion_tokens,
"total_tokens": self.llm_usage.total_tokens,
},
"tool_stats": {
BASH_TOOL_NAME: self.tool_stats.bash,
STR_REPLACE_BASED_EDIT_TOOL_NAME: self.tool_stats.edit,
SEARCH_TOOL_NAME: self.tool_stats.search,
SUBMIT_RESULT_TOOL_NAME: self.tool_stats.submit_result,
},
"total_turns": self.total_turns,
"select_reason": self.select_reason,
"error": self.error,
}

View File

@@ -0,0 +1,268 @@
from typing import Any, List, Dict
class PromptsManager:
def __init__(self, config):
self.candidate_length = config.get("runner", {}).get("generator_concurrency", 5)
def get_generator_system(self, root_path: str | None = None):
return f"""
# You are a highly skilled expert in software engineering focused on resolving complex GitHub issues by effectively analyzing codebases, implementing fixes, and ensuring code reliability through rigorous testing.
## Skills
1. Code Analysis and Debugging
- Issue Exploration: Ability to explore and comprehend codebases within repositories.
- Workflow Tracing: Skilled in using debugging techniques to trace issues through code.
- Root Cause Identification: Proficient in pinpointing the underlying causes of software issues.
- Test Creation: Expertise in establishing tests that replicate and validate issues.
2. Solution Implementation and Testing
- Fix Implementation: Experience in crafting precise and minimal code patches.
- Comprehensive Testing: Skilled in running and analyzing both existing and newly created tests.
- Regression Prevention: Ensures all changes maintain overall code stability.
- Continuous Improvement: Iterates on solutions based on test results to achieve optimal functionality.
## Task
Your task is resolve the given GitHub issue by understanding the repository and the issue, implementing a fix, and checking your changes against existing tests and your own test(s).
Write ABSOLUTE PATHS as arguments for tools that take a `file_path`. Combine the project root path `{root_path or "/testbed"}` with the file's path inside the project.
For example, pass `/root/testbed/_run.py` as `file_path` if you need to edit `_run.py` given the root path `/root/test_bed`.
Here's the project root path: `{root_path or "/testbed"}`. The target repository has already been cloned, and I activated the virtual environment for you. You can start analyzing the issue, searching and reading relevant files, and performing necessary fixes directly.
Follow these steps:
1. Problem Analysis:
- Read the issue description carefully to fully grasp the issue and explore the repository (source code, tests, examples) to understand expected behavior of relevant components.
- Identify the full scope. Does the issue mention multiple components, backends, or functions? Your solution must address all of them.
2. Reproduce the issue (IMPORTANT):
- Create a test that reproduces the issue as a baseline for verification.
- Check that the output of your test matches your understanding of the issue in step 1.
3. Identify the root cause:
- Go through relavant files, create debugging scripts with print statements or use other methods if necessary,to trace the workflow and exact cause of the issue.
- Trace the problem to its root cause.** Do not just patch the symptom where the error appears. Trace the data and execution flow upstream to find where the problem originates.
4. Implement a Fix:
- Once you have identified the root cause, develop a precise and targeted fix and then apply it as a minimal patch using the `str_replace_based_edit_tool` tools.
5. Test comprehensively:
- Verify the Fix: Run your initial reproduction script to confirm that the bug is resolved.
- Prevent Regressions:
--Identify the right tests: Once you have verified your fix, identify the most relevant tests within the project's existing test suite that correspond to your code changes.
--Run the tests: Then you **must** run these tests to ensure that your fix does not introduce any new bugs.
--Analyze failures carefully:
---If tests fail, do not immediately assume your fix is wrong. Critically analyze the failure.
---Is it a **regression**? Did your change break existing, valid functionality? If so, you must refine your fix.
---Is it an **unrelated failure**? It could be an environmental issue (e.g., missing dependency, network error) or a pre-existing flaky test. If you suspect this, try to run a more focused test and note the issue in your final reasoning.
---Is the **test now obsolete**? If your fix improves behavior in a way that makes an old test's assertions incorrect, you should **update the test** to match the new, correct behavior and explain why in your reasoning.
- Write New Tests: Create new, specific test cases (e.g., using `pytest`) that cover the original bug scenario.
- Consider Edge Cases: Think about and test potential edge cases related to your changes.
6. Revisit step 1 through 5 if unexpected behavior occurs, then call `submit_result` to submit the reliable and verified solution patch after successful testing and validation.
**Mandatory Workflow** As a senior engineer, ensure solution correctness and safety. Upon successful verification, immediately conclude the task by calling `submit_result`.
"""
def format_issue_prompt(
self,
created_at: str,
base_commit: str,
environment_setup_commit: str,
version: str,
problem_statement: str,
difficulty: str,
) -> str:
template = f"""
[📝 Issue Description]
**Created at**: {created_at}
**Base commit**: {base_commit}
---
### 📌 Problem Statement
{problem_statement}
---
### ⚙️ Difficulty Level
{difficulty}
---
"""
return template.strip()
def get_generator_user(self, root_path: str, issue_text: str):
return (
f"""
[Project root path]:
{root_path}
[Issue Information]:
{issue_text}
"""
+ self.get_generator_notice()
)
def get_generator_notice(self):
return """
[notice]
1. Use the available tools to locate the root cause.
2. Prioritize using the `search_tool` to retrieve and locate the precise location of key information in the project.
3. Collect supporting evidence: stack traces, logs, configs, recent changes, related modules.
"""
def get_selector_system(self, patches_count: int, root_path: str):
return f"""
# ROLE:
*You are a highly proficient software engineer tasked with evaluating and selecting optimal code patches to resolve specific issues within a given project.
*You colleagus worked on {patches_count} potential patches for an github issue. Select ONE correct patch to solve the issue.
*Here's the project root path: `{root_path or "/testbed"}`. The target repository has already been cloned, and the virtual environment has been activated for you. You can start analyzing the issue, searching and reading relevant files, and performing necessary fixes directly.
*Write ABSOLUTE PATHS as arguments for tools that take a `file_path`. Combine the project root path `{root_path or "/testbed"}` with the file's path inside the project. For instance, pass `/root/testbed/_run.py` as `file_path` if you need to edit `_run.py` given the root path `/root/test_bed`.
# WORKFLOWS:
*Follow these steps without any skipping:
1.Problem Analysis:
- Read the issue description and the current code that needs to be fixed. Explore the repository (source code, tests, examples) to understand expected behavior of relevant components, and gather comprehensive information about the problem area
2.Conduct a thorough review of each patch:
- Scrutinize all code modifications.
- Decipher the core logic and problem-solving methodology.
- Evaluate potential edge cases and unintended consequences.
- Validate that each patch fully addresses the initial issue specifications.
3.Verify Your Analysis
- Use available tools to verify your analysis works of this issue.
- Test your conclusions against relevant code sections.
- Ensure full contextual understanding.
4.Proceed with Your Decision
- Upon completion of the preceding three steps, utilize the `submit_result` tool with your detailed reasoning.
#RULES:
1.It is MANDATORY to utilize both available tools prior to finalizing any selectio:
-- Start with `bash` to explore the codebase structure;
-- Employ the str_replace_based_edit_tool to inspect the current code;
-- Use `search_tool` to search related code and file;
2.You MUST first explore the codebase before using the `submit_result` tool.
3.Substantiate your reasoning with evidence from your analysis.
4.Only selections made after employing the tools will be accepted.
#FINAL DECISION:
Upon completion of your tool-based analysis, finalize the process by submitting your choice via the `submit_result` tool.
#NOTICE:
1. Tool usage is MANDATORY - do not skip this step.
2. Without making a decision after completing analysis is not permitted.
3. Never generate new patches by your own, just make the selection.
4. Always provide detailed reasoning for the selection based on your tool-based investigation
"""
def get_selector_user(
self, instance_data: Dict[str, Any] | None = None, candidates: List[Any] | None = None, root_path: str | None = None
) -> str:
"""
Generate user prompt of selector, including issue information and the first golden patch of each candidate.
- instance_data: Current instance metadata (issue description etc.)
- candidates: Candidates list (.to_dict() supported), only get golden_patch[0].patch_content
"""
if not instance_data or not candidates:
return ""
created_at = instance_data.get("created_at", "")
base_commit = instance_data.get("base_commit", "")
environment_setup_commit = instance_data.get("environment_setup_commit", "")
version = instance_data.get("version", "")
problem_statement = instance_data.get("problem_statement", "")
difficulty = instance_data.get("difficulty", "")
issue_block = self.format_issue_prompt(
created_at=created_at,
base_commit=base_commit,
environment_setup_commit=environment_setup_commit,
version=version,
problem_statement=problem_statement,
difficulty=difficulty,
)
root_path_block = f"""
[Project root path]:
{root_path or "/testbed"}
"""
parts: List[str] = [root_path_block, issue_block, "\n[🔎 Candidates]\n"]
for idx, r in enumerate(candidates):
try:
data = r.to_dict() if hasattr(r, "to_dict") else {}
except Exception:
data = {}
golden_patch = data.get("golden_patch", [])
patch_content = golden_patch[0].get("patch_content", "") if golden_patch else ""
test_status = golden_patch[0].get("test_status", "") if golden_patch else ""
reasoning = golden_patch[0].get("reasoning", "") if golden_patch else ""
parts.append(self.format_selector_candidate(idx, patch_content, test_status, reasoning))
parts.append(
"\nPlease analyze the candidates, then call the submit_result tool with the final index and reasoning."
)
return "\n".join(parts)
def get_terminal_response(self, exit_code: int, output: str, timeout_status: bool):
if timeout_status == True:
return f"""[Terminal response]
Exit code: {exit_code}
Output: {output}"""
else:
return f"""[Terminal response]
Terminal time out."""
def tool_response_prompts(self, tool_results: list) -> str:
if not tool_results:
return ""
response_parts = ["[tool_response]"]
for i, result in enumerate(tool_results, 1):
tool_name = result.get("name", "unknown")
success = result.get("success", False)
output = result.get("result", "")
error = result.get("error", "")
response_parts.append(f"Tool {i}: {tool_name}")
response_parts.append(f"Success: {success}")
if success and output:
response_parts.append(f"Output:\n{output}")
elif error:
response_parts.append(f"Error: {error}")
else:
response_parts.append("No output")
response_parts.append("")
return "\n".join(response_parts)
def format_selector_candidate(self, index: int, patch_content: str, test_status: str, reasoning: str) -> str:
"""
Generate description of selector candidate items, including key information of the first golden_patch
- index: Candidate index(0-based)
- patch_content: golden_patch[0].patch_content
- test_status: golden_patch[0].test_status test status in generating stage
- reasoning: golden_patch[0].reasoning model reasoning in generating stage
"""
header = f"- Candidate #{index}:"
patch_block = patch_content or ""
status_block = test_status or ""
reasoning_block = reasoning or ""
return (
f"--{header}\n"
f"--Patch content (the proposed fix):\n{patch_block}\n\n"
f"--Test status during generation: {status_block}\n\n"
f"--Reasoning during generation (model's logic):\n{reasoning_block}"
)

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict
import json
class ResultBuilder:
"""
Build preds.json
Iterate through JSON files named by instance_id in the directory specified by `runner.selector_result_dump_path` in the configuration.
- Parse the `golden_patch` field from each JSON file and extract the patch text as `model_patch`.
- Read the first top-level field from providers and the first model under it, and concatenate them to form `model_name_or_path`.
- The output location is `{workspace.path}/{result.preds.path}/preds.json`.
"""
def __init__(self, config: Dict[str, Any]):
self.config = config or {}
def _get_selector_dump_dir(self) -> Path:
runner_cfg = self.config.get("runner", {}) if isinstance(self.config, dict) else {}
dump_dir_str = runner_cfg.get(
"selector_result_dump_path", "workspace/selector_result_dump"
)
return Path(dump_dir_str)
def _get_preds_output_dir(self) -> Path:
workspace_cfg = self.config.get("workspace", {}) if isinstance(self.config, dict) else {}
result_cfg = self.config.get("result", {}) if isinstance(self.config, dict) else {}
preds_cfg = result_cfg.get("preds", {}) if isinstance(result_cfg, dict) else {}
workspace_path = workspace_cfg.get("path", "workspace")
preds_path = preds_cfg.get("path", "result")
return Path(workspace_path) / preds_path
def _get_model_name_or_path(self) -> str:
providers = self.config.get("providers", {}) if isinstance(self.config, dict) else {}
if not isinstance(providers, dict) or not providers:
return ""
first_provider_name = next(iter(providers.keys()))
first_models = providers.get(first_provider_name, [])
if isinstance(first_models, list) and first_models:
first_model = first_models[0]
else:
first_model = ""
return f"{first_provider_name}/{first_model}" if first_provider_name and first_model else ""
@staticmethod
def _extract_model_patch(golden_patch: Any) -> str:
"""
Extract patch content from golden_patch
Forms supported:
- dict: prioritize extract 'patch_content', then attempt `model_patch`
- string: Directly return
- other: return empty string
"""
if isinstance(golden_patch, dict):
if "patch_content" in golden_patch and isinstance(golden_patch["patch_content"], str):
return golden_patch["patch_content"]
if "model_patch" in golden_patch and isinstance(golden_patch["model_patch"], str):
return golden_patch["model_patch"]
return ""
if isinstance(golden_patch, str):
return golden_patch
return ""
def build_preds(self) -> Path:
dump_dir = self._get_selector_dump_dir()
output_dir = self._get_preds_output_dir()
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / "preds.json"
model_name_or_path = self._get_model_name_or_path()
# SWE-bench evaluation expects: list[dict], each element includes instance_id / model_patch / model
predictions: list[dict[str, str]] = []
if dump_dir.exists() and dump_dir.is_dir():
for path in sorted(dump_dir.glob("*.json")):
try:
instance_id = path.stem
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
golden_patch = data.get("golden_patch", {}) if isinstance(data, dict) else {}
model_patch = self._extract_model_patch(golden_patch)
predictions.append(
{
"instance_id": instance_id,
"model_patch": model_patch,
"model_name_or_path": model_name_or_path,
}
)
except Exception:
continue
with open(output_file, "w", encoding="utf-8") as f:
json.dump(predictions, f, ensure_ascii=False, indent=2)
return output_file

35
src/tools/__init__.py Normal file
View File

@@ -0,0 +1,35 @@
"""Tools module for Code Agent."""
from src.tools.base import (
Tool,
ToolCall,
ToolExecutor,
ToolResult,
BASH_TOOL_NAME,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
SEARCH_TOOL_NAME,
SUBMIT_RESULT_TOOL_NAME,
)
from src.tools.bash_tool import BashTool
from src.tools.edit_tool import TextEditorTool
from src.tools.search_tool import SearchTool
from src.tools.submit_result_tool import SubmitResultTool
__all__ = [
"Tool",
"ToolResult",
"ToolCall",
"ToolExecutor",
"BashTool",
"TextEditorTool",
"JSONEditTool",
"SearchTool",
"SubmitResultTool",
]
tools_registry: dict[str, type[Tool]] = {
BASH_TOOL_NAME: BashTool,
STR_REPLACE_BASED_EDIT_TOOL_NAME: TextEditorTool,
SEARCH_TOOL_NAME: SearchTool,
SUBMIT_RESULT_TOOL_NAME: SubmitResultTool,
}

523
src/tools/base.py Normal file
View File

@@ -0,0 +1,523 @@
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
# Copyright (c) 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
# This file has been modified by Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates. on 27 Oct 2025
#
# Original file was released under MIT License, with the full license text
# available at https://github.com/bytedance/trae-agent/blob/main/LICENSE
#
# This modified file is released under the same license.
"""Base classes for tools and tool calling."""
import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import cached_property
from typing import override
from src.managers.log.logger import Logger
from typing import Dict, Any
from traceback import format_exc
from pathlib import Path
ParamSchemaValue = str | list[str] | bool | dict[str, object]
Property = dict[str, ParamSchemaValue]
BASH_TOOL_NAME = "bash"
STR_REPLACE_BASED_EDIT_TOOL_NAME = "str_replace_based_edit_tool"
SEARCH_TOOL_NAME = "search_tool"
SUBMIT_RESULT_TOOL_NAME = "submit_result"
class ToolError(Exception):
"""Base class for tool errors."""
def __init__(self, message: str):
super().__init__(message)
self.message: str = message
@dataclass
class ToolExecResult:
"""Intermediate result of a tool execution."""
output: str | None = None
error: str | None = None
error_code: int = 0
@dataclass
class ToolResult:
"""Result of a tool execution."""
call_id: str
name: str # Gemini specific field
success: bool
result: str | None = None
error: str | None = None
id: str | None = None # OpenAI-specific field
@dataclass
class SubmitToolResult:
"""Structured result for submit_result tool."""
return_code: int
output: str
is_task_done: bool
test_status: str
reasoning: str
def __str__(self) -> str:
"""Convert to JSON string for output."""
import json
return json.dumps(
{
"return_code": self.return_code,
"output": self.output,
"is_task_done": self.is_task_done,
"test_status": self.test_status,
"reasoning": self.reasoning,
}
)
@classmethod
def from_string(cls, json_str: str) -> "SubmitToolResult":
"""Create SubmitToolResult from JSON string."""
import json
data = json.loads(json_str)
return cls(
return_code=data.get("return_code", 0),
output=data.get("output", ""),
is_task_done=data.get("is_task_done", False),
test_status=data.get("test_status", "error"),
reasoning=data.get("reasoning", ""),
)
ToolCallArguments = dict[
str, str | int | float | dict[str, object] | list[object] | None
]
@dataclass
class ToolCall:
"""Represents a parsed tool call."""
name: str
call_id: str
arguments: ToolCallArguments = field(default_factory=dict)
id: str | None = None
@override
def __str__(self) -> str:
return f"ToolCall(name={self.name}, arguments={self.arguments}, call_id={self.call_id}, id={self.id})"
@dataclass
class ToolParameter:
"""Tool parameter definition."""
name: str
type: str | list[str]
description: str
enum: list[str] | None = None
items: dict[str, object] | None = None
required: bool = True
class Tool(ABC):
"""Base class for all tools."""
def __init__(
self,
model_provider: str | None = None,
logger: Logger | None = None,
config: Dict[str, Any] | None = None,
):
self._model_provider = model_provider
self.logger = logger
self.config = config
@cached_property
def model_provider(self) -> str | None:
return self.get_model_provider()
@cached_property
def name(self) -> str:
return self.get_name()
@cached_property
def description(self) -> str:
return self.get_description()
@cached_property
def parameters(self) -> list[ToolParameter]:
return self.get_parameters()
def get_model_provider(self) -> str | None:
"""Get the model provider."""
return self._model_provider
@abstractmethod
def get_name(self) -> str:
"""Get the tool name."""
pass
@abstractmethod
def get_description(self) -> str:
"""Get the tool description."""
pass
@abstractmethod
def get_parameters(self) -> list[ToolParameter]:
"""Get the tool parameters."""
pass
@abstractmethod
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
"""Execute the tool with given parameters."""
pass
# Optional container execution hooks (to be overridden by tools that support containers)
async def container_execute(self, arguments: ToolCallArguments) -> ToolExecResult:
"""Execute the tool inside a container shell (optional)."""
raise ToolError(
f"Tool '{self.get_name()}' does not support container execution"
)
def container_search(
self, arguments: ToolCallArguments, session_id: str = "0"
) -> ToolExecResult:
"""Execute a search-like operation inside container (optional)."""
raise ToolError(f"Tool '{self.get_name()}' does not support container search")
# Optional container file editing hooks used by edit tools
def container_read_file(self, path) -> str:
"""Read a file inside container (optional)."""
raise ToolError(
f"Tool '{self.get_name()}' does not support container_read_file"
)
def container_write_file(self, path, content: str) -> None:
"""Write a file inside container (optional)."""
raise ToolError(
f"Tool '{self.get_name()}' does not support container_write_file"
)
def container_str_replace(
self, path, old_str: str, new_str: str | None
) -> ToolExecResult:
"""String replace inside a file in container (optional)."""
raise ToolError(
f"Tool '{self.get_name()}' does not support container_str_replace"
)
def container_insert(self, path, insert_line: int, new_str: str) -> ToolExecResult:
"""Insert text into a file in container (optional)."""
raise ToolError(f"Tool '{self.get_name()}' does not support container_insert")
def view_handler_container(
self, arguments: ToolCallArguments, path: Path
) -> ToolExecResult:
"""View handler in container (optional)."""
raise ToolError(f"Tool '{self.get_name()}' does not support view_handler_container")
def json_definition(self) -> dict[str, object]:
"""Default return Claude format (backward compatibility)"""
return self._definition_for_claude_fmt()
def _definition_for_claude_fmt(self) -> dict[str, object]:
"""Return Claude format tool definition (Anthropic Messages API)"""
return {
"name": self.name,
"description": self.description,
"input_schema": self.get_input_schema(),
}
def _definition_for_openai_fmt(self) -> dict[str, object]:
"""Return OpenAI format tool definition"""
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.get_input_schema(),
},
}
def get_input_schema(self) -> dict[str, object]:
"""Get the input schema for the tool."""
schema: dict[str, object] = {
"type": "object",
}
properties: dict[str, Property] = {}
required: list[str] = []
for param in self.parameters:
param_schema: Property = {
"type": param.type,
"description": param.description,
}
# For OpenAI strict mode, all params must be in 'required'.
# Optional params are made "nullable" to be compliant.
if self.model_provider == "openai":
required.append(param.name)
if not param.required:
current_type = param_schema["type"]
if isinstance(current_type, str):
param_schema["type"] = [current_type, "null"]
elif isinstance(current_type, list) and "null" not in current_type:
param_schema["type"] = list(current_type) + ["null"]
elif param.required:
required.append(param.name)
if param.enum:
param_schema["enum"] = param.enum
if param.items:
param_schema["items"] = param.items
# For OpenAI, nested objects also need additionalProperties: false
if self.model_provider == "openai" and param.type == "object":
param_schema["additionalProperties"] = False
properties[param.name] = param_schema
schema["properties"] = properties
if len(required) > 0:
schema["required"] = required
# For OpenAI, the top-level schema needs additionalProperties: false
if self.model_provider == "openai":
schema["additionalProperties"] = False
return schema
async def close(self):
"""Ensure proper tool resource deallocation before task completion."""
return None # Using "pass" will trigger a Ruff check error: B027
class ToolExecutor:
"""Tool executor that manages tool execution."""
def __init__(self, tools: list[Tool], logger: Logger | None = None):
self._tools = tools
self._tool_map: dict[str, Tool] | None = None
self.logger = logger
async def close_tools(self):
"""Ensure all tool resources are properly released."""
tasks = [tool.close() for tool in self._tools if hasattr(tool, "close")]
res = await asyncio.gather(*tasks)
return res
def _normalize_name(self, name: str) -> str:
"""Normalize tool name by making it lowercase and removing underscores."""
return name.lower().replace("_", "")
@property
def tools(self) -> dict[str, Tool]:
if self._tool_map is None:
self._tool_map = {
self._normalize_name(tool.name): tool for tool in self._tools
}
return self._tool_map
async def execute_tool_call(self, tool_call: ToolCall) -> ToolResult:
"""Execute a tool call locally."""
normalized_name = self._normalize_name(tool_call.name)
if normalized_name not in self.tools:
return ToolResult(
name=tool_call.name,
success=False,
error=f"Tool '{tool_call.name}' not found. Available tools: {[tool.name for tool in self._tools]}",
call_id=tool_call.call_id,
id=tool_call.id,
)
tool = self.tools[normalized_name]
try:
tool_exec_result = await tool.execute(tool_call.arguments)
return ToolResult(
name=tool_call.name,
success=tool_exec_result.error_code == 0,
result=tool_exec_result.output,
error=tool_exec_result.error,
call_id=tool_call.call_id,
id=tool_call.id,
)
except Exception as e:
return ToolResult(
name=tool_call.name,
success=False,
error=f"Error executing tool '{tool_call.name}': {str(e)}, traceback: {format_exc()}.",
call_id=tool_call.call_id,
id=tool_call.id,
)
async def container_execute_tool_call(self, tool_call: ToolCall) -> ToolResult:
"""Execute a tool call in container."""
normalized_name = self._normalize_name(tool_call.name)
if normalized_name not in self.tools:
self.logger.warning(
f"[ToolExecutor] '{tool_call.name}' not found. Available tools: {[tool.name for tool in self._tools]}"
)
return ToolResult(
name=tool_call.name,
success=False,
error=f"Tool '{tool_call.name}' not found. Available tools: {[tool.name for tool in self._tools]}",
call_id=tool_call.call_id,
id=tool_call.id,
)
tool = self.tools[normalized_name]
try:
tool_exec_result = await self._container_execute_tool_by_name(
tool, tool_call
)
return ToolResult(
name=tool_call.name,
success=tool_exec_result.error_code == 0,
result=tool_exec_result.output,
error=tool_exec_result.error,
call_id=tool_call.call_id,
id=tool_call.id,
)
except Exception as e:
return ToolResult(
name=tool_call.name,
success=False,
error=f"Error executing tool '{tool_call.name}': {str(e)}, traceback: {format_exc()}.",
call_id=tool_call.call_id,
id=tool_call.id,
)
async def _container_execute_tool_by_name(
self, tool: Tool, tool_call: ToolCall
) -> ToolExecResult:
tool_name = tool.get_name()
if tool_name == BASH_TOOL_NAME:
# BashTool: execute through container
if hasattr(tool, "container_execute"):
return await tool.container_execute(tool_call.arguments)
else:
raise ToolError(
f"Tool '{tool_name}' does not support container execution"
)
elif tool_name == STR_REPLACE_BASED_EDIT_TOOL_NAME:
# TextEditorTool: execute through container
if hasattr(tool, "container_read_file"):
return await self._execute_edit_tool_in_container(
tool, tool_call.arguments
)
else:
raise ToolError(
f"Tool '{tool_name}' does not support container execution"
)
elif tool_name == SEARCH_TOOL_NAME:
# SearchTool: execute through container
if hasattr(tool, "container_search"):
return tool.container_search(tool_call.arguments)
else:
raise ToolError(
f"Tool '{tool_name}' does not support container execution"
)
elif tool_name == SUBMIT_RESULT_TOOL_NAME:
# SubmitResultTool: execute through container
if hasattr(tool, "container_execute"):
return await tool.container_execute(tool_call.arguments)
else:
raise ToolError(
f"Tool '{tool_name}' does not support container execution"
)
else:
# Other toolscontainer execution not supported
raise ToolError(f"Tool '{tool_name}' does not support container execution")
async def _execute_edit_tool_in_container(
self, tool: Tool, arguments: ToolCallArguments
) -> ToolExecResult:
command = str(arguments.get("command", ""))
path_str = str(arguments.get("path", ""))
if not path_str:
return ToolExecResult(
error="No path provided for the edit tool", error_code=-1
)
from pathlib import Path
path = Path(path_str)
try:
if command == "view":
return tool.view_handler_container(arguments, path)
#return ToolExecResult(output=tool._make_output(content, str(path)))
elif command == "create":
file_text = str(arguments.get("file_text", ""))
tool.container_write_file(path, file_text)
return ToolExecResult(output=f"File created successfully at: {path}")
elif command == "str_replace":
old_str = str(arguments.get("old_str", ""))
new_str = arguments.get("new_str")
if new_str is not None:
new_str = str(new_str)
return tool.container_str_replace(path, old_str, new_str)
elif command == "insert":
insert_line = int(arguments.get("insert_line", 0))
new_str = str(arguments.get("new_str", ""))
return tool.container_insert(path, insert_line, new_str)
else:
return ToolExecResult(
error=f"Unsupported command '{command}' for container execution",
error_code=-1,
)
except Exception as e:
return ToolExecResult(
error=f"Container edit tool error: {str(e)}.", error_code=-1
)
async def parallel_tool_call(self, tool_calls: list[ToolCall]) -> list[ToolResult]:
"""Execute tool calls in parallel locally"""
return await asyncio.gather(
*[self.execute_tool_call(call) for call in tool_calls]
)
async def sequential_tool_call(
self, tool_calls: list[ToolCall]
) -> list[ToolResult]:
"""Execute tool calls in sequential locally"""
return [await self.execute_tool_call(call) for call in tool_calls]
async def container_parallel_tool_call(
self, tool_calls: list[ToolCall]
) -> list[ToolResult]:
"""Execute tool calls in parallel in container"""
return await asyncio.gather(
*[self.container_execute_tool_call(call) for call in tool_calls]
)
async def container_sequential_tool_call(
self, tool_calls: list[ToolCall]
) -> list[ToolResult]:
"""Execute tool calls in sequential in container"""
return [await self.container_execute_tool_call(call) for call in tool_calls]

314
src/tools/bash_tool.py Normal file
View File

@@ -0,0 +1,314 @@
# Copyright (c) 2023 Anthropic
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
# Copyright (c) 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
# This file has been modified by Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates. on 27 Oct 2025
#
# Original file was released under MIT License, with the full license text
# available at https://github.com/anthropics/anthropic-quickstarts/blob/main/LICENSE
# and https://github.com/bytedance/trae-agent/blob/main/LICENSE
#
# This modified file is released under the same license.
import asyncio
import os
from typing import override
from src.tools.base import (
Tool,
ToolCallArguments,
ToolError,
ToolExecResult,
ToolParameter,
BASH_TOOL_NAME,
)
from src.tools.executor import Executor
from src.managers.log.logger import Logger
from typing import Dict, Any
from traceback import format_exc
class _BashSession:
"""A session of a bash shell."""
_started: bool
_timed_out: bool
command: str = "/bin/bash"
_output_delay: float = 0.2 # seconds
_timeout: float = 120.0 # seconds
_sentinel: str = (
",,,,bash-command-exit-__ERROR_CODE__-banner,,,," # `__ERROR_CODE__` will be replaced by `$?` or `!errorlevel!` later
)
def __init__(self) -> None:
self._started = False
self._timed_out = False
self._process: asyncio.subprocess.Process | None = None
async def start(self) -> None:
if self._started:
return
# Windows compatibility: os.setsid not available
if os.name != "nt": # Unix-like systems
self._process = await asyncio.create_subprocess_shell(
self.command,
shell=True,
bufsize=0,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
preexec_fn=os.setsid,
)
else:
self._process = await asyncio.create_subprocess_shell(
"cmd.exe /v:on", # enable delayed expansion to allow `echo !errorlevel!`
shell=True,
bufsize=0,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
self._started = True
async def stop(self) -> None:
"""Terminate the bash shell."""
if not self._started:
raise ToolError("Session has not started.")
if self._process is None:
return
if self._process.returncode is not None:
return
self._process.terminate()
# Wait until the process has truly terminated.
stdout, stderr = await self._process.communicate()
async def run(self, command: str) -> ToolExecResult:
"""Execute a command in the bash shell."""
if not self._started or self._process is None:
raise ToolError("Session has not started.")
if self._process.returncode is not None:
return ToolExecResult(
error=f"bash has exited with returncode {self._process.returncode}. tool must be restarted.",
error_code=-1,
)
if self._timed_out:
raise ToolError(
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
)
# we know these are not None because we created the process with PIPEs
assert self._process.stdin
assert self._process.stdout
assert self._process.stderr
error_code = 0
sentinel_before, pivot, sentinel_after = self._sentinel.partition(
"__ERROR_CODE__"
)
assert pivot == "__ERROR_CODE__"
errcode_retriever = "!errorlevel!" if os.name == "nt" else "$?"
command_sep = "&" if os.name == "nt" else ";"
# send command to the process
self._process.stdin.write(
b"(\n"
+ command.encode()
+ f"\n){command_sep} echo {self._sentinel.replace('__ERROR_CODE__', errcode_retriever)}\n".encode()
)
await self._process.stdin.drain()
# read output from the process, until the sentinel is found
try:
async with asyncio.timeout(self._timeout):
while True:
await asyncio.sleep(self._output_delay)
# if we read directly from stdout/stderr, it will wait forever for
# EOF. use the StreamReader buffer directly instead.
output: str = self._process.stdout._buffer.decode() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType, reportUnknownVariableType]
if sentinel_before in output:
# strip the sentinel from output
output, pivot, exit_banner = output.rpartition(sentinel_before)
assert pivot
# get error code inside banner
error_code_str, pivot, _ = exit_banner.partition(sentinel_after)
if not pivot or not error_code_str.isdecimal():
continue
error_code = int(error_code_str)
break
except asyncio.TimeoutError:
self._timed_out = True
raise ToolError(
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
) from None
if output.endswith("\n"): # pyright: ignore[reportUnknownMemberType]
output = output[:-1] # pyright: ignore[reportUnknownVariableType]
error: str = self._process.stderr._buffer.decode() # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType, reportAttributeAccessIssue]
if error.endswith("\n"): # pyright: ignore[reportUnknownMemberType]
error = error[:-1] # pyright: ignore[reportUnknownVariableType]
# clear the buffers so that the next output can be read correctly
self._process.stdout._buffer.clear() # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
self._process.stderr._buffer.clear() # type: ignore[attr-defined] # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
return ToolExecResult(
output=output, error=error, error_code=error_code
) # pyright: ignore[reportUnknownArgumentType]
class BashTool(Tool):
"""
A tool that allows the agent to run bash commands.
The tool parameters are defined by Anthropic and are not editable.
"""
def __init__(
self,
model_provider: str | None = None,
executor: Executor | None = None,
logger: Logger | None = None,
config: Dict[str, Any] | None = None,
):
super().__init__(model_provider, logger, config)
self._session: _BashSession | None = None
self.executor = executor
@override
def get_model_provider(self) -> str | None:
return self._model_provider
@override
def get_name(self) -> str:
return BASH_TOOL_NAME
@override
def get_description(self) -> str:
return """Execute commands within a bash shell environment, either on the local system or inside a container.
* When providing the "command" parameter, its contents must be provided as-is without any XML escaping.
* You have access to a mirrored repository of common Linux (via apt) and Python (via pip) packages for installation.
* State is persisted across all command executions and throughout our conversation session.
* Avoid executing commands that are likely to generate excessively large outputs.
* Avoid executing interactive commands that require user input (e.g., password prompts, confirmation messages).
* For Git commands, always prefer non-interactive forms. For example, use git --no-pager diff instead of git diff to prevent opening a pager.
* To inspect a specific range of lines in a file (e.g., lines 5-10), you can use a command like: sed -n '5,10p' /path/to/file
"""
@override
def get_parameters(self) -> list[ToolParameter]:
# For OpenAI models, all parameters must be required=True
# For other providers, optional parameters can have required=False
restart_required = self.model_provider == "openai"
return [
ToolParameter(
name="command",
type="string",
description="The exact bash command string to be executed.",
required=True,
),
ToolParameter(
name="restart",
type="boolean",
description="If true, terminates the current shell session and starts a new one before executing the command. This clears the session state.",
required=restart_required,
),
]
@override
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
if arguments.get("restart"):
if self._session:
await self._session.stop()
self._session = _BashSession()
await self._session.start()
return ToolExecResult(output="tool has been restarted.")
if self._session is None:
try:
self._session = _BashSession()
await self._session.start()
except Exception as e:
return ToolExecResult(
error=f"Error starting bash session: {e}",
error_code=-1,
)
command = str(arguments["command"]) if "command" in arguments else None
if command is None:
return ToolExecResult(
error=f"No command provided for the {self.get_name()} tool",
error_code=-1,
)
try:
return await self._session.run(command)
except Exception as e:
return ToolExecResult(
error=f"Error running bash command: {e}",
error_code=-1,
)
async def container_execute(
self, arguments: ToolCallArguments, session_id: str = "0"
) -> ToolExecResult:
"""Execute a command in a container bash shell."""
if not self.executor:
return ToolExecResult(
error="Container execution requires an executor to be provided during tool initialization",
error_code=-1,
)
if arguments.get("restart"):
# Close the existing session if it exists
self.executor.close_session("0")
# The executor will automatically recreate session '0' when needed
return ToolExecResult(output="Container session has been restarted.")
command = str(arguments["command"]) if "command" in arguments else None
if command is None:
return ToolExecResult(
error=f"No command provided for container execution",
error_code=-1,
)
# command_with_init = f"source /opt/miniconda3/bin/activate && conda activate testbed && {command}"
# Check if the session is alive before executing the command
if not self.executor.check_session():
return ToolExecResult(
error="Container session is not alive and could not be restarted",
error_code=-1,
)
try:
return_code, output = self.executor.execute(session_id, command)
# return_code, output = self.executor.execute_once(command_with_init)
# The executor returns (return_code, output) tuple
# We'll treat any non-zero return code as an error
error = None
if return_code != 0:
error = f"Command failed with exit code {return_code}, output: {output}"
return ToolExecResult(output=output, error=error, error_code=return_code)
except Exception as e:
return ToolExecResult(
error=f"Error running container bash command: {e}", error_code=-1
)
@override
async def close(self):
"""Properly close self._process."""
if self._session:
await self._session.stop()
self._session = None

735
src/tools/edit_tool.py Normal file
View File

@@ -0,0 +1,735 @@
# Copyright (c) 2023 Anthropic
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
# Copyright (c) 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
# This file has been modified by Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates. on 27 Oct 2025
#
# Original file was released under MIT License, with the full license text
# available at https://github.com/anthropics/anthropic-quickstarts/blob/main/LICENSE
# and https://github.com/bytedance/trae-agent/blob/main/LICENSE
#
# This modified file is released under the same license.
import os
from pathlib import Path
import tempfile
from typing import Optional, override
import shlex
from src.tools.base import (
Tool,
ToolCallArguments,
ToolError,
ToolExecResult,
ToolParameter,
STR_REPLACE_BASED_EDIT_TOOL_NAME,
)
from src.tools.run import maybe_truncate, run
from src.tools.executor import Executor
from src.managers.log.logger import Logger
from typing import Dict, Any
from traceback import format_exc
EditToolSubCommands = [
"view",
"create",
"str_replace",
"insert",
]
SNIPPET_LINES: int = 4
class TextEditorTool(Tool):
"""Tool to replace a string in a file."""
def __init__(
self,
model_provider: str | None = None,
executor: Executor | None = None,
logger: Logger | None = None,
config: Dict[str, Any] | None = None,
) -> None:
super().__init__(model_provider, logger, config)
self.executor = executor
@override
def get_model_provider(self) -> str | None:
return self._model_provider
@override
def get_name(self) -> str:
return STR_REPLACE_BASED_EDIT_TOOL_NAME
@override
def get_description(self) -> str:
return """This tool provides capabilities for viewing, creating and editing files
* This tool is stateless. No context is retained between individual command invocations.
* Content Examination `view`:
** For a file: Executing `view` on a file path will output the file's full content with sequential line numbers prefixed (using cat -n).
** For a directory: Executing view on a directory path will recursively list all non-hidden items, displaying contents up to two directory levels deep.
* File Creation `create`:
** The `create` operation is strictly prohibited if a file already exists at the specified `path`.
** Mandatory Pre-action: You must explicitly remove any existing file at the target `path` before proceeding with the creation of a new file.
* Output Handling:
** Should the output generated by a `command` exceed a certain length threshold, it will be automatically shortened and clearly marked with the indicator: <response clipped>.
* String Replacement `str_replace` Operational Rules:
** Precision Targeting: The `old_str` parameter must be an exact, character-for-character match of one or more complete lines from the source file. Special attention must be paid to invisible characters like spaces and tabs.
** Match Uniqueness: The replacement will be canceled if the specified `old_str` pattern is not absolutely unique within the file. To ensure a single match, expand the `old_str` scope to include sufficient preceding or following context lines.
** Content Insertion: The `new_str` parameter defines the complete set of lines that will be inserted into the file, directly replacing the content matched by `old_str`.
"""
@override
def get_parameters(self) -> list[ToolParameter]:
"""Get the parameters for the str_replace_based_edit_tool."""
return [
ToolParameter(
name="command",
type="string",
description=f"Operation to execute. Supported commands: {', '.join(EditToolSubCommands)}.",
required=True,
enum=EditToolSubCommands,
),
ToolParameter(
name="file_text",
type="string",
description="Required for `create` command. Specifies the textual content for the new file.",
required=False,
),
ToolParameter(
name="insert_line",
type="integer",
description="Required for `insert` command. The line number AFTER which the `new_str` will be inserted.",
required=False,
),
ToolParameter(
name="new_str",
type="string",
description="For `str_replace`: the replacement text (optional, defaults to empty). For `insert`: the text to insert (required).",
required=False,
),
ToolParameter(
name="old_str",
type="string",
description="Required for `str_replace` command. The exact text segment in the file to be replaced.",
required=False,
),
ToolParameter(
name="path",
type="string",
description="Absolute filesystem path to the target file or directory. Example: `/workspace/script.py` or `/workspace`.",
required=True,
),
ToolParameter(
name="view_range",
type="array",
description="Optional for `view` command on files. Defines the line range to display. Examples: `[5, 10]` shows lines 5-10; `[15, -1]` shows from line 15 to EOF. Line numbering starts at 1.",
items={"type": "integer"},
required=False,
),
]
@override
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
"""Execute the str_replace_editor tool."""
command = str(arguments["command"]) if "command" in arguments else None
if command is None:
return ToolExecResult(
error=f"No command provided for the {self.get_name()} tool",
error_code=-1,
)
path = str(arguments["path"]) if "path" in arguments else None
if path is None:
return ToolExecResult(
error=f"No path provided for the {self.get_name()} tool", error_code=-1
)
_path = Path(path)
try:
self.validate_path(command, _path)
match command:
case "view":
return await self._view_handler(arguments, _path)
case "create":
return self._create_handler(arguments, _path)
case "str_replace":
return self._str_replace_handler(arguments, _path)
case "insert":
return self._insert_handler(arguments, _path)
case _:
return ToolExecResult(
error=f"Unrecognized command {command}. The allowed commands for the {self.name} tool are: {', '.join(EditToolSubCommands)}",
error_code=-1,
)
except ToolError as e:
return ToolExecResult(error=str(e), error_code=-1)
def validate_path(self, command: str, path: Path):
"""Validate the path for the str_replace_editor tool."""
if not path.is_absolute():
suggested_path = Path("/") / path
raise ToolError(
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
)
# Check if path exists
if not path.exists() and command != "create":
raise ToolError(
f"The path {path} does not exist. Please provide a valid path."
)
if path.exists() and command == "create":
raise ToolError(
f"File already exists at: {path}. Cannot overwrite files using command `create`."
)
# Check if the path points to a directory
if path.is_dir() and command != "view":
raise ToolError(
f"The path {path} is a directory and only the `view` command can be used on directories"
)
async def _view(
self, path: Path, view_range: list[int] | None = None
) -> ToolExecResult:
"""Implement the view command"""
if path.is_dir():
if view_range:
raise ToolError(
"The `view_range` parameter is not allowed when `path` points to a directory."
)
return_code, stdout, stderr = await run(
rf"find {path} -maxdepth 2 -not -path '*/\.*'"
)
if not stderr:
stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
return ToolExecResult(error_code=return_code, output=stdout, error=stderr)
file_content = self.read_file(path)
init_line = 1
if view_range:
if len(view_range) != 2 or not all(
isinstance(i, int) for i in view_range
): # pyright: ignore[reportUnnecessaryIsInstance]
raise ToolError(
"Invalid `view_range`. It should be a list of two integers."
)
file_lines = file_content.split("\n")
n_lines_file = len(file_lines)
init_line, final_line = view_range
if init_line < 1 or init_line > n_lines_file:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
)
if final_line > n_lines_file:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
)
if final_line != -1 and final_line < init_line:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
)
if final_line == -1:
file_content = "\n".join(file_lines[init_line - 1 :])
else:
file_content = "\n".join(file_lines[init_line - 1 : final_line])
return ToolExecResult(
output=self._make_output(file_content, str(path), init_line=init_line)
)
def _view_container(
self, path: Path, view_range: list[int] | None = None
) -> ToolExecResult:
"""Implement the view command"""
if path.is_dir():
raise ToolError("The `path` parameter is not allowed be a directory.")
file_content = self.container_read_file(path)
init_line = 1
make_out_max_lines = None
if view_range:
if len(view_range) != 2 or not all(
isinstance(i, int) for i in view_range
): # pyright: ignore[reportUnnecessaryIsInstance]
raise ToolError(
"Invalid `view_range`. It should be a list of two integers."
)
file_lines = file_content.split("\n")
n_lines_file = len(file_lines)
init_line, final_line = view_range
# Initial line must start from 1, initial line cannot be greater than max line of file
if init_line < 1 or init_line > n_lines_file:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
)
# When the end line takes effect, the end line cannot be less than the start line.
if final_line != -1 and final_line < init_line:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
)
if final_line == -1:
file_content = "\n".join(file_lines[init_line - 1 :])
elif final_line > n_lines_file:
file_content = "\n".join(file_lines[init_line - 1 : n_lines_file])
make_out_max_lines = n_lines_file
pass
else:
file_content = "\n".join(file_lines[init_line - 1 : final_line])
return ToolExecResult(
output=self._make_output(
file_content,
str(path),
init_line=init_line,
max_lines=make_out_max_lines,
)
)
def str_replace(
self, path: Path, old_str: str, new_str: str | None
) -> ToolExecResult:
"""Implement the str_replace command, which replaces old_str with new_str in the file content"""
# Read the file content
file_content = self.read_file(path).expandtabs()
old_str = old_str.expandtabs()
new_str = new_str.expandtabs() if new_str is not None else ""
# Check if old_str is unique in the file
occurrences = file_content.count(old_str)
if occurrences == 0:
raise ToolError(
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
)
elif occurrences > 1:
file_content_lines = file_content.split("\n")
lines = [
idx + 1
for idx, line in enumerate(file_content_lines)
if old_str in line
]
raise ToolError(
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
)
# Replace old_str with new_str
new_file_content = file_content.replace(old_str, new_str)
# Write the new content to the file
self.write_file(path, new_file_content)
# Create a snippet of the edited section
replacement_line = file_content.split(old_str)[0].count("\n")
start_line = max(0, replacement_line - SNIPPET_LINES)
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
# Prepare the success message
success_msg = f"The file {path} has been edited. "
success_msg += self._make_output(
snippet, f"a snippet of {path}", start_line + 1
)
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
return ToolExecResult(
output=success_msg,
)
def _insert(self, path: Path, insert_line: int, new_str: str) -> ToolExecResult:
"""Implement the insert command, which inserts new_str at the specified line in the file content."""
file_text = self.read_file(path).expandtabs()
new_str = new_str.expandtabs()
file_text_lines = file_text.split("\n")
n_lines_file = len(file_text_lines)
if insert_line < 0 or insert_line > n_lines_file:
raise ToolError(
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
)
new_str_lines = new_str.split("\n")
new_file_text_lines = (
file_text_lines[:insert_line]
+ new_str_lines
+ file_text_lines[insert_line:]
)
snippet_lines = (
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
+ new_str_lines
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
)
new_file_text = "\n".join(new_file_text_lines)
snippet = "\n".join(snippet_lines)
self.write_file(path, new_file_text)
success_msg = f"The file {path} has been edited. "
success_msg += self._make_output(
snippet,
"a snippet of the edited file",
max(1, insert_line - SNIPPET_LINES + 1),
)
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
return ToolExecResult(
output=success_msg,
)
# Note: undo_edit method is not implemented in this version as it was removed
def read_file(self, path: Path):
"""Read the content of a file from a given path; raise a ToolError if an error occurs."""
try:
return path.read_text()
except Exception as e:
if self.logger:
self.logger.error(
f"In edit_tool, read_file command error, ran into {e} while trying to read {path}, traceback: {format_exc()}."
)
raise ToolError(f"Ran into {e} while trying to read {path}.") from None
def write_file(self, path: Path, file: str):
"""Write the content of a file to a given path; raise a ToolError if an error occurs."""
try:
_ = path.write_text(file)
except Exception as e:
if self.logger:
self.logger.error(
f"In edit_tool, write_file command error, ran into {e} while trying to write to {path}, traceback: {format_exc()}."
)
raise ToolError(f"Ran into {e} while trying to write to {path}.") from None
def container_read_file(self, path: Path, session_id: str = "0") -> str:
"""Read the content of a file from a container using cat command."""
if not self.executor:
raise ToolError("No executor provided for container operations")
try:
# Check if session is alive and restart if needed
if not self.executor.check_session(session_id):
raise ToolError(
"Container session is not alive and could not be restarted"
)
# Use cat command to read file content
command = f"cat {path}"
# return_code, output = self.executor.execute(session_id, command)
return_code, output = self.executor.execute_once(command)
if return_code != 0:
raise ToolError(
f"Failed to read file {path} from container. Exit code: {return_code}, Output: {output}"
)
# Clean the output by removing only the command echo, preserving file content exactly
# lines = output.split("\n")
# Remove the first line if it contains the command echo
# if lines and f"cat {path}" in lines[0]:
# lines = lines[2:-1]
final = output[:-1] if output.endswith("\n") else output
# return "\n".join(lines)
return final
except Exception as e:
if self.logger:
self.logger.error(
f"In edit_tool, container_read_file command error, ran into {e} while trying to read {path} from container, traceback: {format_exc()}."
)
raise ToolError(
f"Ran into {e} while trying to read {path} from container."
) from None
def container_write_file(
self, path: Path, content: str, session_id: str = "0"
) -> None:
"""Write content to a file in a container using cat with here document."""
if not self.executor:
raise ToolError("No executor provided for container operations")
try:
# Check if session is alive and restart if needed
if not self.executor.check_session():
raise ToolError(
"Container session is not alive and could not be restarted"
)
# 先创建目录
return_code, output = self.executor.execute_once(f"mkdir -p {path.parent}")
if return_code != 0:
raise ToolError(
f"Failed to create dir {path.parent} in container. Exit code: {return_code}, Output: {output}"
)
with tempfile.NamedTemporaryFile(
mode="w+", delete=False, encoding="utf-8"
) as temp_file:
temp_file.write(content)
temp_file_path = temp_file.name
return_code, output = self.executor.cpfile_host_to_container(
temp_file_path, path
)
os.remove(temp_file_path)
if return_code != 0:
raise ToolError(
f"Failed to write to file {path} in container. Exit code: {return_code}, Output: {output}"
)
except Exception as e:
if self.logger:
self.logger.error(
f"In edit_tool, container_write_file command error, ran into {e} while trying to write to {path} in container, traceback: {format_exc()}."
)
raise ToolError(
f"Ran into {e} while trying to write to {path} in container."
) from None
def container_str_replace(
self, path: Path, old_str: str, new_str: str | None, session_id: str = "0"
) -> ToolExecResult:
"""Replace old_str with new_str in a file in a container using sed command."""
if not self.executor:
raise ToolError("No executor provided for container operations")
try:
# Check if session is alive and restart if needed
if not self.executor.check_session():
raise ToolError(
"Container session is not alive and could not be restarted"
)
# First, read the file to check if old_str exists
file_content = self.container_read_file(path, session_id)
# Check if old_str is unique in the file
occurrences = file_content.count(old_str)
if occurrences == 0:
raise ToolError(
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
)
elif occurrences > 1:
# Here the calculation is wrong, old_str could be a multi-line text, cannot calculate all locations through the code below
# file_content_lines = file_content.split("\n")
# lines = [
# idx + 1
# for idx, line in enumerate(file_content_lines)
# if old_str in line
# ]
raise ToolError(
f"No replacement was performed. Multiple occurrences of old_str `{old_str}`. Total occurrences: {occurrences}. Please ensure it is unique"
)
updated_content = file_content.replace(old_str, new_str)
self.container_write_file(path=path, content=updated_content)
# Read the file to show a snippet of the changes
try:
file_content = self.container_read_file(path, session_id)
# Create a simple snippet showing the change
lines = file_content.split("\n")
snippet_lines = lines[: min(10, len(lines))] # Show first 10 lines
snippet = "\n".join(snippet_lines)
success_msg = f"The file {path} has been edited in container. "
success_msg += self._make_output(
snippet, f"a snippet of {path}", init_line=1
)
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
return ToolExecResult(output=success_msg)
except Exception:
# If we can't read the file for snippet, just return success
return ToolExecResult(
output=f"Successfully replaced string in file {path} in container."
)
except Exception as e:
if self.logger:
self.logger.error(
f"In edit_tool, container_str_replace command error, ran into {e} while trying to replace string in {path} in container, traceback: {format_exc()}."
)
raise ToolError(
f"Ran into {e} while trying to replace string in {path} in container."
) from None
def container_insert(
self, path: Path, insert_line: int, new_str: str, session_id: str = "0"
) -> ToolExecResult:
if not self.executor:
raise ToolError("No executor provided for container operations")
try:
if not self.executor.check_session():
raise ToolError(
"Container session is not alive and could not be restarted"
)
file_content = self.container_read_file(path, session_id)
file_text_lines = file_content.split("\n")
n_lines_file = len(file_text_lines)
if insert_line < 0 or insert_line > n_lines_file:
raise ToolError(
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
)
new_str_lines = new_str.split("\n")
new_file_text_lines = (
file_text_lines[:insert_line]
+ new_str_lines
+ file_text_lines[insert_line:]
)
snippet_lines = (
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
+ new_str_lines
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
)
new_file_text = "\n".join(new_file_text_lines)
snippet = "\n".join(snippet_lines)
self.container_write_file(path, new_file_text, session_id)
success_msg = f"The file {path} has been edited in container. "
success_msg += self._make_output(
snippet,
"a snippet of the edited file",
max(1, insert_line - SNIPPET_LINES + 1),
)
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
return ToolExecResult(
output=success_msg,
)
except Exception as e:
if self.logger:
self.logger.error(
f"In edit_tool, container_insert command error, ran into {e} while trying to insert content in {path} in container, traceback: {format_exc()}."
)
raise ToolError(
f"Ran into {e} while trying to insert content in {path} in container."
) from None
def _escape_sed(self, text: str) -> str:
"""Escape special characters in text for use with sed command."""
# Escape sed special characters: / \ &
escaped = text.replace("\\", "\\\\") # Escape backslashes first
escaped = escaped.replace("/", "\\/") # Escape forward slashes
escaped = escaped.replace("&", "\\&") # Escape ampersands
escaped = escaped.replace("\n", "\\n") # Handle newlines
escaped = escaped.replace("\t", "\\t") # Handle tabs
return escaped
def _make_output(
self,
file_content: str,
file_descriptor: str,
init_line: int = 1,
expand_tabs: bool = True,
max_lines: Optional[int] = None,
):
"""Generate output for the CLI based on the content of a file."""
file_content = maybe_truncate(file_content)
if expand_tabs:
file_content = file_content.expandtabs()
file_content = "\n".join(
[
f"{i + init_line:6}\t{line}"
for i, line in enumerate(file_content.split("\n"))
]
)
if max_lines:
return (
f"Here's the result of running `cat -n` on {file_descriptor}(The file is only {max_lines} lines):\n"
+ file_content
+ "\n"
)
else:
return (
f"Here's the result of running `cat -n` on {file_descriptor}:\n"
+ file_content
+ "\n"
)
async def _view_handler(
self, arguments: ToolCallArguments, _path: Path
) -> ToolExecResult:
view_range = arguments.get("view_range", None)
if view_range is None:
return await self._view(_path, None)
if not (
isinstance(view_range, list) and all(isinstance(i, int) for i in view_range)
):
return ToolExecResult(
error="Parameter `view_range` should be a list of integers.",
error_code=-1,
)
view_range_int: list[int] = [i for i in view_range if isinstance(i, int)]
return await self._view(_path, view_range_int)
def view_handler_container(
self, arguments: ToolCallArguments, path: Path
) -> ToolExecResult:
view_range = arguments.get("view_range", None)
if view_range is None:
return self._view_container(path, None)
if not (
isinstance(view_range, list) and all(isinstance(i, int) for i in view_range)
):
return ToolExecResult(
error="Parameter `view_range` should be a list of integers.",
error_code=-1,
)
view_range_int: list[int] = [i for i in view_range if isinstance(i, int)]
return self._view_container(path, view_range_int)
def _create_handler(
self, arguments: ToolCallArguments, _path: Path
) -> ToolExecResult:
file_text = arguments.get("file_text", None)
if not isinstance(file_text, str):
return ToolExecResult(
error="Parameter `file_text` is required and must be a string for command: create",
error_code=-1,
)
self.write_file(_path, file_text)
return ToolExecResult(output=f"File created successfully at: {_path}")
def _str_replace_handler(
self, arguments: ToolCallArguments, _path: Path
) -> ToolExecResult:
old_str = arguments.get("old_str") if "old_str" in arguments else None
if not isinstance(old_str, str):
return ToolExecResult(
error="Parameter `old_str` is required and should be a string for command: str_replace",
error_code=-1,
)
new_str = arguments.get("new_str") if "new_str" in arguments else None
if not (new_str is None or isinstance(new_str, str)):
return ToolExecResult(
error="Parameter `new_str` should be a string or null for command: str_replace",
error_code=-1,
)
return self.str_replace(_path, old_str, new_str)
def _insert_handler(
self, arguments: ToolCallArguments, _path: Path
) -> ToolExecResult:
insert_line = (
arguments.get("insert_line") if "insert_line" in arguments else None
)
if not isinstance(insert_line, int):
return ToolExecResult(
error="Parameter `insert_line` is required and should be integer for command: insert",
error_code=-1,
)
new_str_to_insert = arguments.get("new_str") if "new_str" in arguments else None
if not isinstance(new_str_to_insert, str):
return ToolExecResult(
error="Parameter `new_str` is required for command: insert",
error_code=-1,
)
return self._insert(_path, insert_line, new_str_to_insert)

198
src/tools/executor.py Normal file
View File

@@ -0,0 +1,198 @@
import subprocess
import uuid
import docker
import pexpect
import re
from docker.errors import DockerException, ImageNotFound, NotFound
from src.managers.log.logger import Logger
class Executor:
def __init__(self, image: str, logger: Logger | None = None):
self.image = image
self.container = None
self.sessions: dict[str, pexpect.spawn] = {}
self.client = docker.from_env()
self.logger = logger
try:
self.client.images.get(self.image)
except ImageNotFound:
raise DockerException(
f"Image '{self.image}' not found. Please build the image first."
)
try:
self.container = self.client.containers.run(
self.image,
command="sleep infinity",
detach=True,
working_dir="/workspace",
)
self.logger.info(f"Created container {self.container.id}")
except DockerException as e:
raise DockerException(
f"Failed to create container with image '{self.image}': {e}"
)
session_id = self.init_session()
if session_id is None:
raise DockerException("Failed to initialize default session")
if session_id in self.sessions:
self.sessions["0"] = self.sessions.pop(session_id)
def init_session(self) -> str:
session_id = str(uuid.uuid4())
command = f"docker exec -it {self.container.id} /bin/bash"
for attempt in range(3): # Retry up to 3 times
try:
shell = pexpect.spawn(command, encoding="utf-8", timeout=120)
shell.expect([r"\$.*", r"#.*"], timeout=120)
# Source conda and activate testbed environment
shell.sendline("source /opt/miniconda3/bin/activate")
shell.expect([r"\$.*", r"#.*"], timeout=30)
shell.sendline("conda activate testbed")
shell.expect([r"\$.*", r"#.*"], timeout=30)
shell.sendline("export NO_COLOR=1 && export PAGER=cat")
shell.expect([r"\$.*", r"#.*"], timeout=30)
# Verify conda environment is alive by checking the full output
# The output should contain (testbed) if the environment is activated
# We can check this by looking at the full output from the conda activate command
output = shell.before
if "(testbed)" not in output:
# Environment not properly activated, retry
if attempt < 2: # Not the last attempt
shell.close(force=True)
continue
else:
shell.close(force=True)
raise DockerException(
"Failed to activate conda environment 'testbed' after 3 attempts"
)
self.sessions[session_id] = shell
return session_id
except pexpect.exceptions.TIMEOUT:
if attempt < 2: # Not the last attempt
if "shell" in locals() and shell.isalive():
shell.close(force=True)
continue
else:
return None
except Exception as e:
if attempt < 2: # Not the last attempt
if "shell" in locals() and shell.isalive():
shell.close(force=True)
continue
else:
raise DockerException(
f"Failed to initialize session after 3 attempts: {e}"
)
return None
def execute(
self, session_id: str, command: str, timeout: int = 300
) -> tuple[int, str]:
shell = self.sessions.get(session_id)
if not shell or not shell.isalive():
return -1, "Session not found or is dead."
full_command = command.strip()
shell.sendline(full_command)
marker = f"---CMD_DONE---"
marker_command = f"echo {marker}$?"
shell.sendline(marker_command)
try:
shell.expect(marker + r"(\d+).*[\n](.*)", timeout=timeout)
except pexpect.exceptions.TIMEOUT:
return (
-1,
f"Error: Command '{command}' timed out after {timeout} seconds. Partial output:\n{shell.before}",
)
exit_code = int(shell.match.group(1))
p = str(shell.match.group(2))
all_lines: str = p + shell.before
# delete all \r
all_lines = re.sub(r"\r", "", all_lines)
# Remove some non-color-related terminal control characters.
# \x1b[?2004h - tell terminal to activate special paste process
# \x1b[?2004l - tell terminal to activate special paste process
all_lines = re.sub(r"\x1B\[\?2004[l|h]", "", all_lines)
# Strip the last line's echo.
all_lines = re.sub(r"\n[^\n]+---CMD_DONE---.*", "", all_lines)
# self.logger.info(f"'{[all_lines]}'")
return exit_code, all_lines
def execute_once(self, command: str, timeout: int = 300) -> tuple[int, str]:
# cmd = ["docker", "exec", self.container.id, "bash", "-c", command]
cmd = ["docker", "exec", "-i", self.container.id, "bash", "-s"]
sub = subprocess.run(
cmd,
encoding="utf-8",
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
input=f"{command}\n",
)
if sub.returncode != 0:
return sub.returncode, sub.stderr
return sub.returncode, sub.stdout
def cpfile_host_to_container(self, source: str, dest: str) -> tuple[int, str]:
cmd = ["docker", "cp", source, f"{self.container.id}:{dest}"]
sub = subprocess.run(
cmd,
encoding="utf-8",
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
self.execute_once(f"chmod 0777 {dest}")
if sub.returncode != 0:
return sub.returncode, sub.stderr
return sub.returncode, sub.stdout
def check_session(self, session_id: str = "0") -> bool:
"""
Check whether the current '0' session is alive and restart it if not.
"""
if session_id in self.sessions:
session = self.sessions[session_id]
if session and session.isalive():
return True
else:
self.sessions.pop(session_id)
new_session_id = self.init_session()
if new_session_id is None:
return False
if new_session_id != session_id:
self.sessions[session_id] = self.sessions.pop(new_session_id)
return True
def close_session(self, session_id: str):
if session_id in self.sessions:
session = self.sessions.pop(session_id)
if session and session.isalive():
session.close(force=True)
# Session not found - this is not an error condition
def shutdown(self):
for session_id in list(self.sessions.keys()):
self.close_session(session_id)
if self.container:
try:
self.container.stop()
self.container.remove()
except DockerException as e:
pass # Silently handle cleanup errors
self.container = None

57
src/tools/run.py Normal file
View File

@@ -0,0 +1,57 @@
# Copyright (c) 2023 Anthropic
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
# Copyright (c) 2025 Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
# This file has been modified by Beijing Tokens Infinity Technology Co., Ltd. and/or its affiliates. on 27 Oct 2025
#
# Original file was released under MIT License, with the full license text
# available at https://github.com/anthropics/anthropic-quickstarts/blob/main/LICENSE
# and https://github.com/bytedance/trae-agent/blob/main/LICENSE
#
# This modified file is released under the same license.
"""Utility to run shell commands asynchronously with a timeout."""
import asyncio
import contextlib
TRUNCATED_MESSAGE: str = (
"<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
)
MAX_RESPONSE_LEN: int = 16000
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
"""Truncate content and append a notice if content exceeds the specified length."""
return (
content
if not truncate_after or len(content) <= truncate_after
else content[:truncate_after] + TRUNCATED_MESSAGE
)
async def run(
cmd: str,
timeout: float | None = 120.0, # seconds
truncate_after: int | None = MAX_RESPONSE_LEN,
):
"""Run a shell command asynchronously with a timeout."""
process = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
try:
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
return (
process.returncode or 0,
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
)
except asyncio.TimeoutError as exc:
with contextlib.suppress(ProcessLookupError):
process.kill()
raise TimeoutError(
f"Command '{cmd}' timed out after {timeout} seconds"
) from exc

404
src/tools/search_tool.py Normal file
View File

@@ -0,0 +1,404 @@
"""Search tool for finding files based on text content using ripgrep (rg)."""
import asyncio
import json
import re
from pathlib import Path
import shlex
from typing import override
from traceback import format_exc
from src.tools.base import (
Tool,
ToolCallArguments,
ToolError,
ToolExecResult,
ToolParameter,
SEARCH_TOOL_NAME,
)
from src.tools.run import run
from src.tools.executor import Executor
from src.managers.log.logger import Logger
from typing import Dict, Any
class SearchTool(Tool):
"""Tool for searching files based on text content using ripgrep."""
def __init__(
self,
model_provider: str | None = None,
executor: Executor | None = None,
logger: Logger | None = None,
config: Dict[str, Any] | None = None,
) -> None:
super().__init__(model_provider, logger, config)
self._executor = executor
@override
def get_model_provider(self) -> str | None:
return self._model_provider
@override
def get_name(self) -> str:
return SEARCH_TOOL_NAME
@override
def get_description(self) -> str:
return """Search tool for finding files based on text content
* Searches for text patterns in files and directories recursively
* Returns file paths, line numbers, and surrounding context
* Supports regex patterns and various search options
* Provides fast and efficient content searching
Features:
- Pattern matching with full regular expression support
- Line number display for all matches
- Configurable context lines surrounding each match (before and after).
- Filtering by file type.
- Control over case sensitivity
- Option to include hidden files in searches
- Handling of binary files
Example patterns(All patterns must be valid regular expressions):
- Simple text: "function main"
- Regex: "def\\s+\\w+\\s*\\("
"""
@override
def get_parameters(self) -> list[ToolParameter]:
"""Get the parameters for the search tool."""
params = [
ToolParameter(
name="pattern",
type="string",
description=(
"The regular expression pattern to search for within the file content. "
"To match literal characters that are also regex metacharacters (e.g., '.', '*', '+', '?', '(', ')', '[', ']', '{', '}', '|', '^', '$', '\\'), "
"they must be escaped with a backslash. "
"Examples: To find the literal string '(some_value)': '\\(some_value\\)'; To find Python function definitions: 'def\\s+[a-zA-Z_]\\w*\\s*\\('. "
),
required=True,
),
ToolParameter(
name="search_path",
type="string",
description="The directory or file path to search in. Must be an absolute path.",
required=True,
),
ToolParameter(
name="context_lines",
type="integer",
description="Number of context lines to show before and after each match. Default: 2.",
required=False,
),
ToolParameter(
type="boolean",
name="case_insensitive",
description="Whether to perform case-insensitive search. Default: false.",
required=False,
),
ToolParameter(
type="boolean",
name="include_hidden",
description="Whether to include hidden files and directories. Default: false.",
required=False,
),
ToolParameter(
type="boolean",
name="include_binary",
description="Whether to search in binary files. Default: false.",
required=False,
),
ToolParameter(
type="string",
name="file_types",
description="Comma-separated list of file types to search (e.g., 'py,js,md'). Optional.",
required=False,
),
ToolParameter(
type="integer",
name="max_results",
description="Maximum number of results to return per file. Default: 100.",
required=False,
),
]
return params
@override
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
"""Execute the search operation."""
try:
pattern = str(arguments.get("pattern", ""))
if not pattern:
return ToolExecResult(
error="Pattern parameter is required", error_code=-1
)
search_path_str = str(arguments.get("search_path", ""))
if not search_path_str:
return ToolExecResult(
error="search_path parameter is required", error_code=-1
)
search_path = Path(search_path_str)
if not search_path.is_absolute():
return ToolExecResult(
error=f"Search path must be absolute: {search_path}", error_code=-1
)
if not search_path.exists():
return ToolExecResult(
error=f"Search path does not exist: {search_path}", error_code=-1
)
# Parse optional parameters
context_lines = int(arguments.get("context_lines", 2))
case_insensitive = bool(arguments.get("case_insensitive", False))
include_hidden = bool(arguments.get("include_hidden", False))
include_binary = bool(arguments.get("include_binary", False))
file_types = arguments.get("file_types")
max_results = int(arguments.get("max_results", 100))
# Build ripgrep command
cmd_parts = ["rg"]
# Add context lines
if context_lines > 0:
cmd_parts.extend(["-C", str(context_lines)])
# Add case sensitivity
if case_insensitive:
cmd_parts.append("-i")
# Add hidden files
if include_hidden:
cmd_parts.append("--hidden")
# Add binary files
if include_binary:
cmd_parts.append("--binary")
else:
cmd_parts.append("--no-binary")
# Add file types
if file_types and isinstance(file_types, str):
for file_type in file_types.split(","):
file_type = file_type.strip()
if file_type:
cmd_parts.extend(["-g", f'"*.{file_type}"'])
# Add line numbers and filename
cmd_parts.extend(["-n", "-H"])
# Add max results
cmd_parts.extend(["-m", str(max_results)])
# Add pattern and search path (quote pattern to handle spaces)
cmd_parts.extend([f'"{pattern}"', str(search_path)])
# Execute the command
return_code, stdout, stderr = await run(" ".join(cmd_parts))
if return_code == 0:
# Parse and format results
results = self._parse_rg_output(stdout)
formatted_output = self._format_results(results, max_results)
return ToolExecResult(output=formatted_output)
elif return_code == 1:
# No matches found
return ToolExecResult(output=f"No matches found for pattern: {pattern}")
else:
# Error occurred
error_msg = (
stderr if stderr else f"ripgrep exited with code {return_code}"
)
return ToolExecResult(error=error_msg, error_code=return_code)
except Exception as e:
return ToolExecResult(
error=f"Search tool error: {str(e)}",
error_code=-1,
)
def container_search(
self, arguments: ToolCallArguments, session_id: str = "0"
) -> ToolExecResult:
if not self._executor:
return ToolExecResult(
error="No executor provided for container search", error_code=-1
)
try:
pattern = str(arguments.get("pattern", ""))
if not pattern:
return ToolExecResult(
error="Pattern parameter is required", error_code=-1
)
search_path_str = str(arguments.get("search_path", ""))
if not search_path_str:
return ToolExecResult(
error="search_path parameter is required", error_code=-1
)
context_lines = int(arguments.get("context_lines", 2))
case_insensitive = bool(arguments.get("case_insensitive", False))
include_hidden = bool(arguments.get("include_hidden", False))
include_binary = bool(arguments.get("include_binary", False))
file_types = arguments.get("file_types")
max_results = int(arguments.get("max_results", 100))
cmd_parts = ["rg"]
if context_lines > 0:
cmd_parts.extend(["-C", str(context_lines)])
if case_insensitive:
cmd_parts.append("-i")
if include_hidden:
cmd_parts.append("--hidden")
if include_binary:
cmd_parts.append("--binary")
else:
cmd_parts.append("--no-binary")
if file_types and isinstance(file_types, str):
for file_type in file_types.split(","):
file_type = file_type.strip()
if file_type:
cmd_parts.extend(["-g", f'"*.{file_type}"'])
cmd_parts.extend(["-n", "-H"])
cmd_parts.extend(["-m", str(max_results * 2)])
cmd_parts.extend(["--color=never", "-U"])
cmd_parts.extend(["--", shlex.quote(pattern), search_path_str])
command = " ".join(cmd_parts)
return_code, output = self._executor.execute_once(command)
if self.logger:
self.logger.debug(f"search_tool cmd: {command}")
# self.logger.debug(f"DEBUG: SearchTool result - Return code: {return_code}, Output: \n{output}")
if return_code == 0:
results = self._parse_rg_output(output)
# self.logger.debug(f"DEBUG: SearchTool _parse_rg_output results: {results}")
formatted_output = self._format_results(results, max_results)
# self.logger.debug(f"DEBUG: SearchTool _format_results formatted_output: {formatted_output}")
return ToolExecResult(output=formatted_output)
elif return_code == 1:
return ToolExecResult(output=f"No matches found for pattern: {pattern}")
else:
return ToolExecResult(
error=f"ripgrep exited with code {return_code}. Output: {output}",
error_code=return_code,
)
except Exception as e:
return ToolExecResult(
error=f"Container search error: {str(e)}",
error_code=-1,
)
def _parse_rg_output(self, output: str) -> list[dict]:
"""Parse ripgrep output into structured results."""
import re
# Remove ANSI escape codes
ansi_escape = re.compile(r"\x1b\[[0-9;]*m")
clean_output = ansi_escape.sub("", output)
results = []
current_file = None
for line in clean_output.split("\n"):
if not line.strip():
continue
# Check if this is a file path line (no colon, just a path)
if ":" not in line and "/" in line and not line.strip().startswith("-"):
# This is a file path line
current_file = line.strip()
continue
# Parse ripgrep output format: file:line:content or file:line-content
if ":" in line:
# Split by colon to get file, line info, and content
parts = line.split(":", 2)
if len(parts) >= 3:
file_path = parts[0].strip()
line_info = parts[1].strip()
content = parts[2].strip()
# Use current_file if file_path is empty or just a dash
if not file_path or file_path == "-":
file_path = current_file
# Check if line_info is a number (match line) or contains dash (context line)
if line_info.isdigit():
# This is a match line
line_num = int(line_info)
results.append(
{
"file": file_path,
"line": line_num,
"content": content,
"full_line": line,
"is_match": True,
}
)
elif "-" in line_info:
# This is a context line (before/after match)
# Extract line number from context line format like "12-15" or "12-"
try:
line_num = int(line_info.split("-")[0])
results.append(
{
"file": file_path,
"line": line_num,
"content": content,
"full_line": line,
"is_match": False,
}
)
except ValueError:
continue
return results
def _format_results(self, results: list[dict], max_results: int) -> str:
"""Format search results for display."""
if not results:
return "No matches found."
# Filter only match lines for counting
match_results = [r for r in results if r.get("is_match", True)]
limited_results = results[:max_results]
output_lines = [f"Found {len(match_results)} matches:"]
output_lines.append("=" * 50)
current_file = None
for result in limited_results:
file_path = result["file"]
line_num = result["line"]
content = result["content"]
is_match = result.get("is_match", True)
# Add file header if this is a new file
if current_file != file_path:
current_file = file_path
output_lines.append(f"\n📁 {file_path}")
output_lines.append("-" * (len(file_path) + 4))
# Add line with appropriate prefix
prefix = " " if is_match else " " # Match lines get no special prefix
marker = "" if is_match else " " # Mark actual matches
output_lines.append(f"{marker} {line_num:4d}: {content}")
if len(results) > max_results:
output_lines.append(f"\n... and {len(results) - max_results} more lines")
return "\n".join(output_lines)

View File

@@ -0,0 +1,127 @@
"""Search tool for finding files based on text content using ripgrep (rg)."""
import asyncio
import json
from logging import Logger
import re
from pathlib import Path
from typing import Any, Dict, override
from traceback import format_exc
from src.tools.base import (
Tool,
ToolCallArguments,
ToolError,
ToolExecResult,
ToolParameter,
SubmitToolResult,
SUBMIT_RESULT_TOOL_NAME,
)
from src.tools.run import run
from src.tools.executor import Executor
class SubmitResultTool(Tool):
"""Tool for git diff, not for model to invoke"""
def __init__(
self,
model_provider: str | None = None,
executor: Executor | None = None,
logger: Logger | None = None,
config: Dict[str, Any] | None = None,
) -> None:
super().__init__(model_provider, logger, config)
self._executor = executor
@override
def get_model_provider(self) -> str | None:
return self._model_provider
@override
def get_name(self) -> str:
return SUBMIT_RESULT_TOOL_NAME
@override
def get_description(self) -> str:
return """
Submit the final result to complete the task.
This tool should be called when you are confident that the issue has been resolved. Simply indicate that you are ready to submit the result - the system will automatically capture the git diff and generate the final patch.
You don't need to provide the actual patch content manually. Just call this tool to signal completion, and the system will handle the rest.
"""
@override
def get_parameters(self) -> list[ToolParameter]:
params = [
ToolParameter(
name="is_task_done",
type="boolean",
description="Whether the task is done",
required=True,
),
ToolParameter(
name="test_status",
type="string",
description="The status of test execution after applying the patch",
required=True,
enum=["passed", "failed", "skipped", "error"],
),
ToolParameter(
name="reasoning",
type="string",
description="Detailed explanation of the logic behind the patch, including root cause analysis and solution approach",
required=True,
),
]
return params
@override
async def execute(self, arguments: ToolCallArguments) -> ToolExecResult:
"""Execute the tool locally (not supported for submit_result tool)."""
return ToolExecResult(
error="SubmitResultTool only supports container execution", error_code=-1
)
@override
async def container_execute(self, arguments: ToolCallArguments) -> ToolExecResult:
if not self._executor:
return ToolExecResult(
error="No executor provided for git diff tool", error_code=-1
)
try:
is_task_done = arguments.get("is_task_done", False)
test_status = arguments.get("test_status", "error")
reasoning = arguments.get("reasoning", "")
root_path = self.config.get("builder", {}).get("repo_root_path", "/")
cmd_parts = ["cd", str(root_path), "&&", "git", "--no-pager", "diff"]
command = " ".join(cmd_parts)
self.logger.debug(
f"DEBUG: GitDiffTool executing command: {command}"
) # Debug output
return_code, output = self._executor.execute_once(command)
self.logger.debug(
f"DEBUG: GitDiffTool result - Return code: {return_code}, Output: \n{output}"
) # Debug output
if return_code == 0:
submit_result = SubmitToolResult(
return_code=return_code,
output=output,
is_task_done=is_task_done,
test_status=test_status,
reasoning=reasoning,
)
return ToolExecResult(output=str(submit_result))
else:
return ToolExecResult(
error=f"GitDiffTool exited with code {return_code}. Output: {output}",
error_code=return_code,
)
except Exception as e:
return ToolExecResult(
error=f"Container search error: {str(e)}", error_code=-1
)