From b8dfb9556a8db6a27a198646f0ff103463302288 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:25:00 +0800 Subject: [PATCH] Add files via upload --- internal/agent/agent.go | 1038 +---------------- .../agent/default_single_system_prompt.go | 6 +- internal/agent/token_counter.go | 54 + internal/database/database.go | 6 +- internal/multiagent/eino_single_runner.go | 2 +- internal/multiagent/eino_summarize.go | 4 +- 6 files changed, 66 insertions(+), 1044 deletions(-) create mode 100644 internal/agent/token_counter.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 026ecd70..accb6be4 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -17,9 +17,7 @@ import ( "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp/builtin" - "cyberstrike-ai/internal/project" "cyberstrike-ai/internal/openai" - "cyberstrike-ai/internal/security" "cyberstrike-ai/internal/storage" "go.uber.org/zap" @@ -30,7 +28,6 @@ type Agent struct { openAIClient *openai.Client config *config.OpenAIConfig agentConfig *config.AgentConfig - memoryCompressor *MemoryCompressor mcpServer *mcp.Server externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器 logger *zap.Logger @@ -56,8 +53,6 @@ type ResultStorage interface { DeleteResult(executionID string) error } -type toolCallInterceptorCtxKey struct{} - type agentConversationIDKey struct{} func withAgentConversationID(ctx context.Context, id string) context.Context { @@ -81,17 +76,6 @@ func ConversationIDFromContext(ctx context.Context) string { return agentConversationIDFromContext(ctx) } -// ToolCallInterceptor allows caller to gate or rewrite tool arguments just before execution. -// Returning a non-nil error means the tool call is rejected and execution is skipped. -type ToolCallInterceptor func(ctx context.Context, toolName string, args map[string]interface{}, toolCallID string) (map[string]interface{}, error) - -func WithToolCallInterceptor(ctx context.Context, fn ToolCallInterceptor) context.Context { - if fn == nil { - return ctx - } - return context.WithValue(ctx, toolCallInterceptorCtxKey{}, fn) -} - // NewAgent 创建新的Agent func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer *mcp.Server, externalMCPMgr *mcp.ExternalMCPManager, logger *zap.Logger, maxIterations int) *Agent { // 如果 maxIterations 为 0 或负数,使用默认值 30 @@ -141,28 +125,10 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer } llmClient := openai.NewClient(cfg, httpClient, logger) - var memoryCompressor *MemoryCompressor - if cfg != nil { - mc, err := NewMemoryCompressor(MemoryCompressorConfig{ - MaxTotalTokens: cfg.MaxTotalTokens, - OpenAIConfig: cfg, - HTTPClient: httpClient, - Logger: logger, - }) - if err != nil { - logger.Warn("初始化MemoryCompressor失败,将跳过上下文压缩", zap.Error(err)) - } else { - memoryCompressor = mc - } - } else { - logger.Warn("OpenAI配置为空,无法初始化MemoryCompressor") - } - return &Agent{ openAIClient: llmClient, config: cfg, agentConfig: agentCfg, - memoryCompressor: memoryCompressor, mcpServer: mcpServer, externalMCPMgr: externalMCPMgr, logger: logger, @@ -353,28 +319,10 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error { return nil } -// AgentLoopResult Agent Loop执行结果 -type AgentLoopResult struct { - Response string - MCPExecutionIDs []string - LastAgentTraceInput string // 最后一轮代理消息轨迹(压缩后的 messages,JSON;与 multiagent.RunResult 字段对齐) - LastAgentTraceOutput string // 最终助手输出文本 -} - // ProgressCallback 进度回调函数类型 type ProgressCallback func(eventType, message string, data interface{}) -// AgentLoop 执行Agent循环 -func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) { - return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil, "") -} - -// AgentLoopWithConversationID 执行Agent循环(带对话ID) -func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) { - return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil, "") -} - -// EinoSingleAgentSystemInstruction 供 Eino adk.ChatModelAgent.Instruction 使用,与 AgentLoopWithProgress 首条 system 对齐(含 system_prompt_path)。 +// EinoSingleAgentSystemInstruction 供 Eino adk.ChatModelAgent.Instruction 使用(含 system_prompt_path)。 func (a *Agent) EinoSingleAgentSystemInstruction() string { systemPrompt := DefaultSingleAgentSystemPrompt() if a.agentConfig != nil { @@ -396,576 +344,6 @@ func (a *Agent) EinoSingleAgentSystemInstruction() string { return systemPrompt } -// AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID) -func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string, systemPromptExtra string) (*AgentLoopResult, error) { - ctx = withAgentConversationID(ctx, conversationID) - // 设置当前对话ID(兼容未走 context 的旧路径;并发会话应以 context 为准) - a.mu.Lock() - a.currentConversationID = conversationID - a.mu.Unlock() - // 发送进度更新 - sendProgress := func(eventType, message string, data interface{}) { - if callback != nil { - callback(eventType, message, data) - } - } - - systemPrompt := DefaultSingleAgentSystemPrompt() - if a.agentConfig != nil { - if p := strings.TrimSpace(a.agentConfig.SystemPromptPath); p != "" { - path := p - a.mu.RLock() - base := a.promptBaseDir - a.mu.RUnlock() - if !filepath.IsAbs(path) && base != "" { - path = filepath.Join(base, path) - } - if b, err := os.ReadFile(path); err != nil { - a.logger.Warn("读取单代理 system_prompt_path 失败,使用内置提示", zap.String("path", path), zap.Error(err)) - } else if s := strings.TrimSpace(string(b)); s != "" { - systemPrompt = s - } - } - } - systemPrompt = project.AppendSystemPromptBlock(systemPrompt, systemPromptExtra) - - messages := []ChatMessage{ - { - Role: "system", - Content: systemPrompt, - }, - } - - // 添加历史消息(保留所有字段,包括ToolCalls和ToolCallID) - a.logger.Info("处理历史消息", - zap.Int("count", len(historyMessages)), - ) - addedCount := 0 - for i, msg := range historyMessages { - // 对于tool消息,即使content为空也要添加(因为tool消息可能只有ToolCallID) - // 对于其他消息,只添加有内容的消息 - if msg.Role == "tool" || msg.Content != "" { - messages = append(messages, ChatMessage{ - Role: msg.Role, - Content: msg.Content, - ToolCalls: msg.ToolCalls, - ToolCallID: msg.ToolCallID, - ToolName: msg.ToolName, - }) - addedCount++ - contentPreview := msg.Content - if len(contentPreview) > 50 { - contentPreview = contentPreview[:50] + "..." - } - a.logger.Info("添加历史消息到上下文", - zap.Int("index", i), - zap.String("role", msg.Role), - zap.String("content", contentPreview), - zap.Int("toolCalls", len(msg.ToolCalls)), - zap.String("toolCallID", msg.ToolCallID), - ) - } - } - - a.logger.Info("构建消息数组", - zap.Int("historyMessages", len(historyMessages)), - zap.Int("addedMessages", addedCount), - zap.Int("totalMessages", len(messages)), - ) - - // 在添加当前用户消息之前,先修复可能存在的失配tool消息 - // 这可以防止在继续对话时出现"messages with role 'tool' must be a response to a preceeding message with 'tool_calls'"错误 - if len(messages) > 0 { - if fixed := a.repairOrphanToolMessages(&messages); fixed { - a.logger.Info("修复了历史消息中的失配tool消息") - } - } - - // 添加当前用户消息 - messages = append(messages, ChatMessage{ - Role: "user", - Content: userInput, - }) - - result := &AgentLoopResult{ - MCPExecutionIDs: make([]string, 0), - } - - // 用于保存当前的messages,以便在异常情况下也能保存ReAct输入 - var currentAgentTraceInput string - - maxIterations := a.maxIterations - thinkingStreamSeq := 0 - for i := 0; i < maxIterations; i++ { - // 先获取本轮可用工具并统计 tools token,再压缩,以便压缩时预留 tools 占用的空间 - tools := a.getAvailableTools(roleTools) - toolsTokens := a.countToolsTokens(tools) - messages = a.applyMemoryCompression(ctx, messages, toolsTokens) - - // 检查是否是最后一次迭代 - isLastIteration := (i == maxIterations-1) - - // 每次迭代都保存压缩后的messages,以便在异常中断(取消、错误等)时也能保存最新的ReAct输入 - // 保存压缩后的数据,这样后续使用时就不需要再考虑压缩了 - messagesJSON, err := json.Marshal(messages) - if err != nil { - a.logger.Warn("序列化ReAct输入失败", zap.Error(err)) - } else { - currentAgentTraceInput = string(messagesJSON) - // 更新result中的值,确保始终保存最新的ReAct输入(压缩后的) - result.LastAgentTraceInput = currentAgentTraceInput - } - - // 检查上下文是否已取消 - select { - case <-ctx.Done(): - // 上下文被取消(可能是用户主动暂停或其他原因) - a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err())) - result.LastAgentTraceInput = currentAgentTraceInput - if ctx.Err() == context.Canceled { - result.Response = "任务已被取消。" - } else { - result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err()) - } - result.LastAgentTraceOutput = result.Response - return result, ctx.Err() - default: - } - - // 记录当前上下文的 Token 用量(messages + tools),展示压缩器运行状态 - if a.memoryCompressor != nil { - messagesTokens, systemCount, regularCount := a.memoryCompressor.totalTokensFor(messages) - totalTokens := messagesTokens + toolsTokens - a.logger.Info("memory compressor context stats", - zap.Int("iteration", i+1), - zap.Int("messagesCount", len(messages)), - zap.Int("systemMessages", systemCount), - zap.Int("regularMessages", regularCount), - zap.Int("messagesTokens", messagesTokens), - zap.Int("toolsTokens", toolsTokens), - zap.Int("totalTokens", totalTokens), - zap.Int("maxTotalTokens", a.memoryCompressor.maxTotalTokens), - ) - } - - // 发送迭代开始事件 - if i == 0 { - sendProgress("iteration", "开始分析请求并制定测试策略", map[string]interface{}{ - "iteration": i + 1, - "total": maxIterations, - }) - } else if isLastIteration { - sendProgress("iteration", fmt.Sprintf("第 %d 轮迭代(最后一次)", i+1), map[string]interface{}{ - "iteration": i + 1, - "total": maxIterations, - "isLast": true, - }) - } else { - sendProgress("iteration", fmt.Sprintf("第 %d 轮迭代", i+1), map[string]interface{}{ - "iteration": i + 1, - "total": maxIterations, - }) - } - - // 记录每次调用OpenAI - if i == 0 { - a.logger.Info("调用OpenAI", - zap.Int("iteration", i+1), - zap.Int("messagesCount", len(messages)), - ) - // 记录前几条消息的内容(用于调试) - for j, msg := range messages { - if j >= 5 { // 只记录前5条 - break - } - contentPreview := msg.Content - if len(contentPreview) > 100 { - contentPreview = contentPreview[:100] + "..." - } - a.logger.Debug("消息内容", - zap.Int("index", j), - zap.String("role", msg.Role), - zap.String("content", contentPreview), - ) - } - } else { - a.logger.Info("调用OpenAI", - zap.Int("iteration", i+1), - zap.Int("messagesCount", len(messages)), - ) - } - - // 调用OpenAI - sendProgress("progress", "正在调用AI模型...", nil) - thinkingStreamSeq++ - thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq) - thinkingStreamStarted := false - var thinkingWire string - - response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error { - if delta == "" { - return nil - } - var deltaOut string - thinkingWire, deltaOut = openai.NormalizeStreamingDelta(thinkingWire, delta) - if deltaOut == "" { - return nil - } - if !thinkingStreamStarted { - thinkingStreamStarted = true - sendProgress("thinking_stream_start", " ", map[string]interface{}{ - "streamId": thinkingStreamId, - "iteration": i + 1, - "toolStream": false, - }) - } - sendProgress("thinking_stream_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{ - "streamId": thinkingStreamId, - "iteration": i + 1, - }, thinkingWire)) - return nil - }) - if err != nil { - // API调用失败,保存当前的ReAct输入和错误信息作为输出 - result.LastAgentTraceInput = currentAgentTraceInput - errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err) - result.Response = errorMsg - result.LastAgentTraceOutput = errorMsg - a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err)) - return result, fmt.Errorf("调用OpenAI失败: %w", err) - } - - if response.Error != nil { - if handled, toolName := a.handleMissingToolError(response.Error.Message, &messages); handled { - sendProgress("warning", fmt.Sprintf("模型尝试调用不存在的工具:%s,已提示其改用可用工具。", toolName), map[string]interface{}{ - "toolName": toolName, - }) - a.logger.Warn("模型调用了不存在的工具,将重试", - zap.String("tool", toolName), - zap.String("error", response.Error.Message), - ) - continue - } - if a.handleToolRoleError(response.Error.Message, &messages) { - sendProgress("warning", "检测到未配对的工具结果,已自动修复上下文并重试。", map[string]interface{}{ - "error": response.Error.Message, - }) - a.logger.Warn("检测到未配对的工具消息,已修复并重试", - zap.String("error", response.Error.Message), - ) - continue - } - // OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出 - result.LastAgentTraceInput = currentAgentTraceInput - errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message) - result.Response = errorMsg - result.LastAgentTraceOutput = errorMsg - return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message) - } - - if len(response.Choices) == 0 { - // 没有收到响应,保存当前的ReAct输入和错误信息作为输出 - result.LastAgentTraceInput = currentAgentTraceInput - errorMsg := "没有收到响应" - result.Response = errorMsg - result.LastAgentTraceOutput = errorMsg - return result, fmt.Errorf("没有收到响应") - } - - choice := response.Choices[0] - - // 检查是否有工具调用 - if len(choice.Message.ToolCalls) > 0 { - // ReAct 助手正文流式增量(thinking_stream_*)在 UI 上归为「思考」;若与 streamId 重复则前端会去重。 - // 该条 thinking 用于刷新后持久化展示(与流式聚合一致)。 - if choice.Message.Content != "" { - sendProgress("thinking", choice.Message.Content, map[string]interface{}{ - "iteration": i + 1, - "streamId": thinkingStreamId, - }) - } - - // 添加assistant消息(包含工具调用) - messages = append(messages, ChatMessage{ - Role: "assistant", - Content: choice.Message.Content, - ToolCalls: choice.Message.ToolCalls, - }) - - // 发送工具调用进度 - sendProgress("tool_calls_detected", fmt.Sprintf("检测到 %d 个工具调用", len(choice.Message.ToolCalls)), map[string]interface{}{ - "count": len(choice.Message.ToolCalls), - "iteration": i + 1, - }) - - // 执行所有工具调用 - for idx, toolCall := range choice.Message.ToolCalls { - // 发送工具调用开始事件 - toolArgsJSON, _ := json.Marshal(toolCall.Function.Arguments) - sendProgress("tool_call", fmt.Sprintf("正在调用工具: %s", toolCall.Function.Name), map[string]interface{}{ - "toolName": toolCall.Function.Name, - "arguments": string(toolArgsJSON), - "argumentsObj": toolCall.Function.Arguments, - "toolCallId": toolCall.ID, - "index": idx + 1, - "total": len(choice.Message.ToolCalls), - "iteration": i + 1, - }) - - execArgs := toolCall.Function.Arguments - if interceptor, ok := ctx.Value(toolCallInterceptorCtxKey{}).(ToolCallInterceptor); ok && interceptor != nil { - newArgs, interceptErr := interceptor(ctx, toolCall.Function.Name, execArgs, toolCall.ID) - if interceptErr != nil { - errorMsg := fmt.Sprintf("工具调用被人工拒绝: %v", interceptErr) - messages = append(messages, ChatMessage{ - Role: "tool", - ToolCallID: toolCall.ID, - Content: errorMsg, - }) - sendProgress("tool_result", fmt.Sprintf("工具 %s 执行失败", toolCall.Function.Name), map[string]interface{}{ - "toolName": toolCall.Function.Name, - "success": false, - "isError": true, - "error": errorMsg, - "toolCallId": toolCall.ID, - "index": idx + 1, - "total": len(choice.Message.ToolCalls), - "iteration": i + 1, - }) - continue - } - if newArgs != nil { - execArgs = newArgs - } - } - - // 执行工具 - toolCtx := context.WithValue(ctx, security.ToolOutputCallbackCtxKey, security.ToolOutputCallback(func(chunk string) { - if strings.TrimSpace(chunk) == "" { - return - } - sendProgress("tool_result_delta", chunk, map[string]interface{}{ - "toolName": toolCall.Function.Name, - "toolCallId": toolCall.ID, - "index": idx + 1, - "total": len(choice.Message.ToolCalls), - "iteration": i + 1, - // success 在最终 tool_result 事件里会以 success/isError 标记为准 - }) - })) - - execResult, err := a.executeToolViaMCP(toolCtx, toolCall.Function.Name, execArgs) - if err != nil { - // 构建详细的错误信息,帮助AI理解问题并做出决策 - errorMsg := a.formatToolError(toolCall.Function.Name, toolCall.Function.Arguments, err) - messages = append(messages, ChatMessage{ - Role: "tool", - ToolCallID: toolCall.ID, - Content: errorMsg, - }) - - // 发送工具执行失败事件 - sendProgress("tool_result", fmt.Sprintf("工具 %s 执行失败", toolCall.Function.Name), map[string]interface{}{ - "toolName": toolCall.Function.Name, - "success": false, - "isError": true, - "error": err.Error(), - "toolCallId": toolCall.ID, - "index": idx + 1, - "total": len(choice.Message.ToolCalls), - "iteration": i + 1, - }) - - a.logger.Warn("工具执行失败,已返回详细错误信息", - zap.String("tool", toolCall.Function.Name), - zap.Error(err), - ) - } else { - // 即使工具返回了错误结果(IsError=true),也继续处理,让AI决定下一步 - messages = append(messages, ChatMessage{ - Role: "tool", - ToolCallID: toolCall.ID, - Content: execResult.Result, - }) - // 收集执行ID - if execResult.ExecutionID != "" { - result.MCPExecutionIDs = append(result.MCPExecutionIDs, execResult.ExecutionID) - } - - // 发送工具执行成功事件 - resultPreview := execResult.Result - if len(resultPreview) > 200 { - resultPreview = resultPreview[:200] + "..." - } - sendProgress("tool_result", fmt.Sprintf("工具 %s 执行完成", toolCall.Function.Name), map[string]interface{}{ - "toolName": toolCall.Function.Name, - "success": !execResult.IsError, - "isError": execResult.IsError, - "result": execResult.Result, // 完整结果 - "resultPreview": resultPreview, // 预览结果 - "executionId": execResult.ExecutionID, - "toolCallId": toolCall.ID, - "index": idx + 1, - "total": len(choice.Message.ToolCalls), - "iteration": i + 1, - }) - - // 如果工具返回了错误,记录日志但不中断流程 - if execResult.IsError { - a.logger.Warn("工具返回错误结果,但继续处理", - zap.String("tool", toolCall.Function.Name), - zap.String("result", execResult.Result), - ) - } - } - } - - // 如果是最后一次迭代,执行完工具后要求AI进行总结 - if isLastIteration { - sendProgress("progress", "最后一次迭代:正在生成总结和下一步计划...", nil) - // 添加用户消息,要求AI进行总结 - messages = append(messages, ChatMessage{ - Role: "user", - Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。", - }) - messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留 - // 流式调用OpenAI获取总结(不提供工具,强制AI直接回复) - sendProgress("response_start", "", map[string]interface{}{ - "conversationId": conversationID, - "mcpExecutionIds": result.MCPExecutionIDs, - "messageGeneratedBy": "summary", - }) - var summaryWire string - streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { - var deltaOut string - summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta) - if deltaOut == "" { - return nil - } - sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{ - "conversationId": conversationID, - }, summaryWire)) - return nil - }) - if strings.TrimSpace(streamText) != "" { - result.Response = streamText - result.LastAgentTraceOutput = result.Response - sendProgress("progress", "总结生成完成", nil) - return result, nil - } - // 如果获取总结失败,跳出循环,让后续逻辑处理 - break - } - - continue - } - - // 添加assistant响应 - messages = append(messages, ChatMessage{ - Role: "assistant", - Content: choice.Message.Content, - }) - - // 发送AI思考内容(如果没有工具调用) - if choice.Message.Content != "" && !thinkingStreamStarted { - sendProgress("thinking", choice.Message.Content, map[string]interface{}{ - "iteration": i + 1, - }) - } - - // 如果是最后一次迭代,无论finish_reason是什么,都要求AI进行总结 - if isLastIteration { - sendProgress("progress", "最后一次迭代:正在生成总结和下一步计划...", nil) - // 添加用户消息,要求AI进行总结 - messages = append(messages, ChatMessage{ - Role: "user", - Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。", - }) - messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留 - // 流式调用OpenAI获取总结(不提供工具,强制AI直接回复) - sendProgress("response_start", "", map[string]interface{}{ - "conversationId": conversationID, - "mcpExecutionIds": result.MCPExecutionIDs, - "messageGeneratedBy": "summary", - }) - var summaryWire string - streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { - var deltaOut string - summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta) - if deltaOut == "" { - return nil - } - sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{ - "conversationId": conversationID, - }, summaryWire)) - return nil - }) - if strings.TrimSpace(streamText) != "" { - result.Response = streamText - result.LastAgentTraceOutput = result.Response - sendProgress("progress", "总结生成完成", nil) - return result, nil - } - // 如果获取总结失败,使用当前回复作为结果 - if choice.Message.Content != "" { - result.Response = choice.Message.Content - result.LastAgentTraceOutput = result.Response - return result, nil - } - // 如果都没有内容,跳出循环,让后续逻辑处理 - break - } - - // 如果完成,返回结果 - if choice.FinishReason == "stop" { - sendProgress("progress", "正在生成最终回复...", nil) - result.Response = choice.Message.Content - result.LastAgentTraceOutput = result.Response - return result, nil - } - } - - // 如果循环结束仍未返回,说明达到了最大迭代次数 - // 尝试最后一次调用AI获取总结 - sendProgress("progress", "达到最大迭代次数,正在生成总结...", nil) - finalSummaryPrompt := ChatMessage{ - Role: "user", - Content: fmt.Sprintf("已达到最大迭代次数(%d轮)。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。", a.maxIterations), - } - messages = append(messages, finalSummaryPrompt) - messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留 - - // 流式调用OpenAI获取总结(不提供工具,强制AI直接回复) - sendProgress("response_start", "", map[string]interface{}{ - "conversationId": conversationID, - "mcpExecutionIds": result.MCPExecutionIDs, - "messageGeneratedBy": "max_iter_summary", - }) - var summaryWire string - streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { - var deltaOut string - summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta) - if deltaOut == "" { - return nil - } - sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{ - "conversationId": conversationID, - }, summaryWire)) - return nil - }) - if strings.TrimSpace(streamText) != "" { - result.Response = streamText - result.LastAgentTraceOutput = result.Response - sendProgress("progress", "总结生成完成", nil) - return result, nil - } - - // 如果无法生成总结,返回友好的提示 - result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations) - result.LastAgentTraceOutput = result.Response - return result, nil -} - // getAvailableTools 获取可用工具 // 从MCP服务器动态获取工具列表,描述模式由 tool_description_mode 控制 // roleTools: 角色配置的工具列表(toolKey格式),如果为空或nil,则使用所有工具(默认角色) @@ -1171,319 +549,11 @@ func (a *Agent) convertToOpenAIType(configType string) string { } } -// isRetryableError 判断错误是否可重试 -func (a *Agent) isRetryableError(err error) bool { - if err == nil { - return false - } - errStr := err.Error() - // 网络相关错误,可以重试 - retryableErrors := []string{ - "connection reset", - "connection reset by peer", - "connection refused", - "timeout", - "i/o timeout", - "context deadline exceeded", - "no such host", - "network is unreachable", - "broken pipe", - "EOF", - "read tcp", - "write tcp", - "dial tcp", - } - for _, retryable := range retryableErrors { - if strings.Contains(strings.ToLower(errStr), retryable) { - return true - } - } - return false -} - -// callOpenAI 调用OpenAI API(带重试机制) -func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) { - maxRetries := 3 - var lastErr error - - for attempt := 0; attempt < maxRetries; attempt++ { - response, err := a.callOpenAISingle(ctx, messages, tools) - if err == nil { - if attempt > 0 { - a.logger.Info("OpenAI API调用重试成功", - zap.Int("attempt", attempt+1), - zap.Int("maxRetries", maxRetries), - ) - } - return response, nil - } - - lastErr = err - - // 如果不是可重试的错误,直接返回 - if !a.isRetryableError(err) { - return nil, err - } - - // 如果不是最后一次重试,等待后重试 - if attempt < maxRetries-1 { - // 指数退避:2s, 4s, 8s... - backoff := time.Duration(1< 30*time.Second { - backoff = 30 * time.Second // 最大30秒 - } - a.logger.Warn("OpenAI API调用失败,准备重试", - zap.Error(err), - zap.Int("attempt", attempt+1), - zap.Int("maxRetries", maxRetries), - zap.Duration("backoff", backoff), - ) - - // 检查上下文是否已取消 - select { - case <-ctx.Done(): - return nil, fmt.Errorf("上下文已取消: %w", ctx.Err()) - case <-time.After(backoff): - // 继续重试 - } - } - } - - return nil, fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr) -} - -// callOpenAISingle 单次调用OpenAI API(不包含重试逻辑) -func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) { - reqBody := OpenAIRequest{ - Model: a.config.Model, - Messages: messages, - } - - if len(tools) > 0 { - reqBody.Tools = tools - } - - a.logger.Debug("准备发送OpenAI请求", - zap.Int("messagesCount", len(messages)), - zap.Int("toolsCount", len(tools)), - ) - - var response OpenAIResponse - if a.openAIClient == nil { - return nil, fmt.Errorf("OpenAI客户端未初始化") - } - if err := a.openAIClient.ChatCompletion(ctx, reqBody, &response); err != nil { - return nil, err - } - - return &response, nil -} - -// callOpenAISingleStreamText 单次调用OpenAI的流式模式,只用于“不会调用工具”的纯文本输出(tools 为空时最佳)。 -// onDelta 每收到一段 content delta,就回调一次;如果 callback 返回错误,会终止读取并返回错误。 -func (a *Agent) callOpenAISingleStreamText(ctx context.Context, messages []ChatMessage, tools []Tool, onDelta func(delta string) error) (string, error) { - reqBody := OpenAIRequest{ - Model: a.config.Model, - Messages: messages, - Stream: true, - } - if len(tools) > 0 { - reqBody.Tools = tools - } - - if a.openAIClient == nil { - return "", fmt.Errorf("OpenAI客户端未初始化") - } - - return a.openAIClient.ChatCompletionStream(ctx, reqBody, onDelta) -} - -// callOpenAIStreamText 调用OpenAI流式模式(带重试),仅在“未输出任何 delta”时才允许重试,避免重复发送已下发的内容。 -func (a *Agent) callOpenAIStreamText(ctx context.Context, messages []ChatMessage, tools []Tool, onDelta func(delta string) error) (string, error) { - maxRetries := 3 - var lastErr error - - for attempt := 0; attempt < maxRetries; attempt++ { - var deltasSent bool - full, err := a.callOpenAISingleStreamText(ctx, messages, tools, func(delta string) error { - deltasSent = true - return onDelta(delta) - }) - if err == nil { - if attempt > 0 { - a.logger.Info("OpenAI stream 调用重试成功", - zap.Int("attempt", attempt+1), - zap.Int("maxRetries", maxRetries), - ) - } - return full, nil - } - - lastErr = err - // 已经开始输出了 delta,避免重复内容:直接失败让上层处理。 - if deltasSent { - return "", err - } - - if !a.isRetryableError(err) { - return "", err - } - - if attempt < maxRetries-1 { - backoff := time.Duration(1< 30*time.Second { - backoff = 30 * time.Second - } - a.logger.Warn("OpenAI stream 调用失败,准备重试", - zap.Error(err), - zap.Int("attempt", attempt+1), - zap.Int("maxRetries", maxRetries), - zap.Duration("backoff", backoff), - ) - - select { - case <-ctx.Done(): - return "", fmt.Errorf("上下文已取消: %w", ctx.Err()) - case <-time.After(backoff): - } - } - } - - return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr) -} - -// callOpenAISingleStreamWithToolCalls 单次调用OpenAI流式模式(带工具调用解析),不包含重试逻辑。 -func (a *Agent) callOpenAISingleStreamWithToolCalls( - ctx context.Context, - messages []ChatMessage, - tools []Tool, - onContentDelta func(delta string) error, -) (*OpenAIResponse, error) { - reqBody := OpenAIRequest{ - Model: a.config.Model, - Messages: messages, - Stream: true, - } - if len(tools) > 0 { - reqBody.Tools = tools - } - if a.openAIClient == nil { - return nil, fmt.Errorf("OpenAI客户端未初始化") - } - - content, streamToolCalls, finishReason, err := a.openAIClient.ChatCompletionStreamWithToolCalls(ctx, reqBody, onContentDelta) - if err != nil { - return nil, err - } - - toolCalls := make([]ToolCall, 0, len(streamToolCalls)) - for _, stc := range streamToolCalls { - fnArgsStr := stc.FunctionArgsStr - args := make(map[string]interface{}) - if strings.TrimSpace(fnArgsStr) != "" { - if err := json.Unmarshal([]byte(fnArgsStr), &args); err != nil { - // 兼容:arguments 不一定是严格 JSON - args = map[string]interface{}{"raw": fnArgsStr} - } - } - - typ := stc.Type - if strings.TrimSpace(typ) == "" { - typ = "function" - } - - toolCalls = append(toolCalls, ToolCall{ - ID: stc.ID, - Type: typ, - Function: FunctionCall{ - Name: stc.FunctionName, - Arguments: args, - }, - }) - } - - response := &OpenAIResponse{ - ID: "", - Choices: []Choice{ - { - Message: MessageWithTools{ - Role: "assistant", - Content: content, - ToolCalls: toolCalls, - }, - FinishReason: finishReason, - }, - }, - } - return response, nil -} - -// callOpenAIStreamWithToolCalls 调用OpenAI流式模式(带重试),仅当还没有输出任何 content delta 时才允许重试。 -func (a *Agent) callOpenAIStreamWithToolCalls( - ctx context.Context, - messages []ChatMessage, - tools []Tool, - onContentDelta func(delta string) error, -) (*OpenAIResponse, error) { - maxRetries := 3 - var lastErr error - - for attempt := 0; attempt < maxRetries; attempt++ { - deltasSent := false - resp, err := a.callOpenAISingleStreamWithToolCalls(ctx, messages, tools, func(delta string) error { - deltasSent = true - if onContentDelta != nil { - return onContentDelta(delta) - } - return nil - }) - if err == nil { - if attempt > 0 { - a.logger.Info("OpenAI stream 调用重试成功", - zap.Int("attempt", attempt+1), - zap.Int("maxRetries", maxRetries), - ) - } - return resp, nil - } - - lastErr = err - if deltasSent { - // 已经开始输出了 delta:避免重复发送 - return nil, err - } - - if !a.isRetryableError(err) { - return nil, err - } - if attempt < maxRetries-1 { - backoff := time.Duration(1< 30*time.Second { - backoff = 30 * time.Second - } - a.logger.Warn("OpenAI stream 调用失败,准备重试", - zap.Error(err), - zap.Int("attempt", attempt+1), - zap.Int("maxRetries", maxRetries), - zap.Duration("backoff", backoff), - ) - - select { - case <-ctx.Done(): - return nil, fmt.Errorf("上下文已取消: %w", ctx.Err()) - case <-time.After(backoff): - } - } - } - - return nil, fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr) -} - -// ToolExecutionResult 工具执行结果 +// ToolExecutionResult MCP 工具执行结果(供 Eino 桥与监控落库使用)。 type ToolExecutionResult struct { Result string ExecutionID string - IsError bool // 标记是否为错误结果 + IsError bool } // executeToolViaMCP 通过MCP执行工具 @@ -1698,11 +768,6 @@ func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) { defer a.mu.Unlock() a.config = cfg - // 同时更新MemoryCompressor的配置(如果存在) - if a.memoryCompressor != nil { - a.memoryCompressor.UpdateConfig(cfg) - } - a.logger.Info("Agent配置已更新", zap.String("base_url", cfg.BaseURL), zap.String("model", cfg.Model), @@ -1731,103 +796,6 @@ func (a *Agent) UpdateToolDescriptionMode(mode string) { a.logger.Info("Agent工具描述模式已更新", zap.String("tool_description_mode", mode)) } -// formatToolError 格式化工具错误信息,提供更友好的错误描述 -func (a *Agent) formatToolError(toolName string, args map[string]interface{}, err error) string { - errorMsg := fmt.Sprintf(`工具执行失败 - -工具名称: %s -调用参数: %v -错误信息: %v - -请分析错误原因并采取以下行动之一: -1. 如果参数错误,请修正参数后重试 -2. 如果工具不可用,请尝试使用替代工具 -3. 如果这是系统问题,请向用户说明情况并提供建议 -4. 如果错误信息中包含有用信息,可以基于这些信息继续分析`, toolName, args, err) - - return errorMsg -} - -// applyMemoryCompression 在调用LLM前对消息进行压缩,避免超过 token 限制。reservedTokens 为预留给 tools 的 token 数,传 0 表示不预留。 -func (a *Agent) applyMemoryCompression(ctx context.Context, messages []ChatMessage, reservedTokens int) []ChatMessage { - if a.memoryCompressor == nil { - return messages - } - - compressed, changed, err := a.memoryCompressor.CompressHistory(ctx, messages, reservedTokens) - if err != nil { - a.logger.Warn("上下文压缩失败,将使用原始消息继续", zap.Error(err)) - return messages - } - if changed { - a.logger.Info("历史上下文已压缩", - zap.Int("originalMessages", len(messages)), - zap.Int("compressedMessages", len(compressed)), - ) - return compressed - } - - return messages -} - -// countToolsTokens 统计 tools 序列化后的 token 数,用于日志与压缩时预留空间。mc 为 nil 时返回 0。 -func (a *Agent) countToolsTokens(tools []Tool) int { - if len(tools) == 0 || a.memoryCompressor == nil { - return 0 - } - data, err := json.Marshal(tools) - if err != nil { - return 0 - } - return a.memoryCompressor.CountTextTokens(string(data)) -} - -// handleMissingToolError 当LLM调用不存在的工具时,向其追加提示消息并允许继续迭代 -func (a *Agent) handleMissingToolError(errMsg string, messages *[]ChatMessage) (bool, string) { - lowerMsg := strings.ToLower(errMsg) - if !(strings.Contains(lowerMsg, "non-exist tool") || strings.Contains(lowerMsg, "non exist tool")) { - return false, "" - } - - toolName := extractQuotedToolName(errMsg) - if toolName == "" { - toolName = "unknown_tool" - } - - notice := fmt.Sprintf("System notice: the previous call failed with error: %s. Please verify tool availability and proceed using existing tools or pure reasoning.", errMsg) - *messages = append(*messages, ChatMessage{ - Role: "user", - Content: notice, - }) - - return true, toolName -} - -// handleToolRoleError 自动修复因缺失tool_calls导致的OpenAI错误 -func (a *Agent) handleToolRoleError(errMsg string, messages *[]ChatMessage) bool { - if messages == nil { - return false - } - - lowerMsg := strings.ToLower(errMsg) - if !(strings.Contains(lowerMsg, "role 'tool'") && strings.Contains(lowerMsg, "tool_calls")) { - return false - } - - fixed := a.repairOrphanToolMessages(messages) - if !fixed { - return false - } - - notice := "System notice: the previous call failed because some tool outputs lost their corresponding assistant tool_calls context. The history has been repaired. Please continue." - *messages = append(*messages, ChatMessage{ - Role: "user", - Content: notice, - }) - - return true -} - // RepairOrphanToolMessages 清理失去配对的tool消息和未完成的tool_calls,避免OpenAI报错 // 同时确保历史消息中的tool_calls只作为上下文记忆,不会触发重新执行 // 这是一个公开方法,可以在恢复历史消息时调用 diff --git a/internal/agent/default_single_system_prompt.go b/internal/agent/default_single_system_prompt.go index 0ee8468e..0ccdd352 100644 --- a/internal/agent/default_single_system_prompt.go +++ b/internal/agent/default_single_system_prompt.go @@ -4,7 +4,7 @@ import ( "cyberstrike-ai/internal/project" ) -// DefaultSingleAgentSystemPrompt 单代理(ReAct / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。 +// DefaultSingleAgentSystemPrompt 单代理(Eino ADK / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。 func DefaultSingleAgentSystemPrompt() string { return `你是CyberStrikeAI,是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。 @@ -112,6 +112,6 @@ func DefaultSingleAgentSystemPrompt() string { ## 技能库(Skills)与知识库 - 技能包位于服务器 skills/ 目录(各子目录 SKILL.md,遵循 agentskills.io);知识库用于向量检索片段,Skills 为可执行工作流指令。 -- 单代理本会话通过 MCP 使用知识库与漏洞记录等;Skills 的渐进式加载在「多代理 / Eino DeepAgent」中由内置 skill 工具完成(需在配置中启用 multi_agent.eino_skills)。 -- 若当前无 skill 工具,需要完整 Skill 工作流时请使用多代理模式或切换为 Eino 编排会话(亦可选 Eino ADK 单代理路径 /api/eino-agent)。` +- 本会话通过 MCP 使用知识库与漏洞记录等。Skills 由 Eino ADK skill 工具按需加载(配置 multi_agent.eino_skills;单代理与多代理均可,未启用时无 skill 工具)。 +- 需要完整 Skill 工作流但当前无 skill 工具时,请确认已启用 multi_agent.eino_skills,或改用 Deep / Supervisor 等多代理编排(/api/multi-agent/stream)。` } diff --git a/internal/agent/token_counter.go b/internal/agent/token_counter.go new file mode 100644 index 00000000..8795461b --- /dev/null +++ b/internal/agent/token_counter.go @@ -0,0 +1,54 @@ +package agent + +import ( + "sync" + + "github.com/pkoukk/tiktoken-go" +) + +// TokenCounter 估算文本 token 数(tiktoken;模型未知时回退 cl100k_base)。 +type TokenCounter interface { + Count(model, text string) (int, error) +} + +type tikTokenCounter struct { + mu sync.Mutex + cache map[string]*tiktoken.Tiktoken +} + +// NewTikTokenCounter 创建基于 tiktoken 的 TokenCounter。 +func NewTikTokenCounter() TokenCounter { + return &tikTokenCounter{cache: make(map[string]*tiktoken.Tiktoken)} +} + +func (c *tikTokenCounter) encoding(model string) (*tiktoken.Tiktoken, error) { + key := model + if key == "" { + key = "cl100k_base" + } + c.mu.Lock() + defer c.mu.Unlock() + if enc, ok := c.cache[key]; ok { + return enc, nil + } + enc, err := tiktoken.EncodingForModel(key) + if err != nil { + enc, err = tiktoken.GetEncoding("cl100k_base") + } + if err != nil { + return nil, err + } + c.cache[key] = enc + return enc, nil +} + +func (c *tikTokenCounter) Count(model, text string) (int, error) { + if text == "" { + return 0, nil + } + enc, err := c.encoding(model) + if err != nil { + return 0, err + } + return len(enc.Encode(text, nil, nil)), nil +} diff --git a/internal/database/database.go b/internal/database/database.go index 7b39b52e..06dc35f3 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -388,7 +388,7 @@ func (db *DB) initTables() error { id TEXT PRIMARY KEY, title TEXT, role TEXT, - agent_mode TEXT NOT NULL DEFAULT 'single', + agent_mode TEXT NOT NULL DEFAULT 'eino_single', schedule_mode TEXT NOT NULL DEFAULT 'manual', cron_expr TEXT, next_run_at DATETIME, @@ -984,14 +984,14 @@ func (db *DB) migrateBatchTaskQueuesTable() error { var agentModeCount int err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='agent_mode'").Scan(&agentModeCount) if err != nil { - if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); addErr != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'eino_single'"); addErr != nil { errMsg := strings.ToLower(addErr.Error()) if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { db.logger.Warn("添加agent_mode字段失败", zap.Error(addErr)) } } } else if agentModeCount == 0 { - if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); err != nil { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'eino_single'"); err != nil { db.logger.Warn("添加agent_mode字段失败", zap.Error(err)) } } diff --git a/internal/multiagent/eino_single_runner.go b/internal/multiagent/eino_single_runner.go index 980e118b..06ff1fd0 100644 --- a/internal/multiagent/eino_single_runner.go +++ b/internal/multiagent/eino_single_runner.go @@ -26,7 +26,7 @@ import ( const einoSingleAgentName = "cyberstrike-eino-single" // RunEinoSingleChatModelAgent 使用 Eino adk.NewChatModelAgent + adk.NewRunner.Run(官方 Quick Start 的 Query 同属 Runner API;此处用历史 + 用户消息切片等价于多轮 Query)。 -// 不替代既有原生 ReAct;与 RunDeepAgent 共享 runEinoADKAgentLoop 的 SSE 映射与 MCP 桥。 +// 与 RunDeepAgent 共享 runEinoADKAgentLoop 的 SSE 映射与 MCP 桥。 func RunEinoSingleChatModelAgent( ctx context.Context, appCfg *config.Config, diff --git a/internal/multiagent/eino_summarize.go b/internal/multiagent/eino_summarize.go index b0e418a5..d1ab90b2 100644 --- a/internal/multiagent/eino_summarize.go +++ b/internal/multiagent/eino_summarize.go @@ -18,7 +18,7 @@ import ( "go.uber.org/zap" ) -// einoSummarizeUserInstruction 与单 Agent MemoryCompressor 目标一致:压缩时保留渗透关键信息。 +// einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。 const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。 必须保留:已确认漏洞与攻击路径、工具输出中的核心发现、凭证与认证细节、架构与薄弱点、当前进度、失败尝试与死路、策略决策。 @@ -29,7 +29,7 @@ const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完 输出须使后续代理能无缝继续同一授权测试任务。` // newEinoSummarizationMiddleware 使用 Eino ADK Summarization 中间件(见 https://www.cloudwego.io/zh/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_summarization/)。 -// 触发阈值与单 Agent MemoryCompressor 一致:当估算 token 超过 openai.max_total_tokens 的 90% 时摘要。 +// 触发阈值:估算 token 超过 openai.max_total_tokens * summarization_trigger_ratio(默认 0.8)时摘要。 func newEinoSummarizationMiddleware( ctx context.Context, summaryModel model.BaseChatModel,