package agent import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "time" "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" "go.uber.org/zap" ) // Agent AI代理 type Agent struct { openAIClient *http.Client config *config.OpenAIConfig mcpServer *mcp.Server logger *zap.Logger } // NewAgent 创建新的Agent func NewAgent(cfg *config.OpenAIConfig, mcpServer *mcp.Server, logger *zap.Logger) *Agent { return &Agent{ openAIClient: &http.Client{Timeout: 5 * time.Minute}, config: cfg, mcpServer: mcpServer, logger: logger, } } // ChatMessage 聊天消息 type ChatMessage struct { Role string `json:"role"` Content string `json:"content,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` } // MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串 func (cm ChatMessage) MarshalJSON() ([]byte, error) { // 构建序列化结构 aux := map[string]interface{}{ "role": cm.Role, } // 添加content(如果存在) if cm.Content != "" { aux["content"] = cm.Content } // 添加tool_call_id(如果存在) if cm.ToolCallID != "" { aux["tool_call_id"] = cm.ToolCallID } // 转换tool_calls,将arguments转换为JSON字符串 if len(cm.ToolCalls) > 0 { toolCallsJSON := make([]map[string]interface{}, len(cm.ToolCalls)) for i, tc := range cm.ToolCalls { // 将arguments转换为JSON字符串 argsJSON := "" if tc.Function.Arguments != nil { argsBytes, err := json.Marshal(tc.Function.Arguments) if err != nil { return nil, err } argsJSON = string(argsBytes) } toolCallsJSON[i] = map[string]interface{}{ "id": tc.ID, "type": tc.Type, "function": map[string]interface{}{ "name": tc.Function.Name, "arguments": argsJSON, }, } } aux["tool_calls"] = toolCallsJSON } return json.Marshal(aux) } // OpenAIRequest OpenAI API请求 type OpenAIRequest struct { Model string `json:"model"` Messages []ChatMessage `json:"messages"` Tools []Tool `json:"tools,omitempty"` } // OpenAIResponse OpenAI API响应 type OpenAIResponse struct { ID string `json:"id"` Choices []Choice `json:"choices"` Error *Error `json:"error,omitempty"` } // Choice 选择 type Choice struct { Message MessageWithTools `json:"message"` FinishReason string `json:"finish_reason"` } // MessageWithTools 带工具调用的消息 type MessageWithTools struct { Role string `json:"role"` Content string `json:"content"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` } // Tool OpenAI工具定义 type Tool struct { Type string `json:"type"` Function FunctionDefinition `json:"function"` } // FunctionDefinition 函数定义 type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description"` Parameters map[string]interface{} `json:"parameters"` } // Error OpenAI错误 type Error struct { Message string `json:"message"` Type string `json:"type"` } // ToolCall 工具调用 type ToolCall struct { ID string `json:"id"` Type string `json:"type"` Function FunctionCall `json:"function"` } // FunctionCall 函数调用 type FunctionCall struct { Name string `json:"name"` Arguments map[string]interface{} `json:"arguments"` } // UnmarshalJSON 自定义JSON解析,处理arguments可能是字符串或对象的情况 func (fc *FunctionCall) UnmarshalJSON(data []byte) error { type Alias FunctionCall aux := &struct { Name string `json:"name"` Arguments interface{} `json:"arguments"` *Alias }{ Alias: (*Alias)(fc), } if err := json.Unmarshal(data, &aux); err != nil { return err } fc.Name = aux.Name // 处理arguments可能是字符串或对象的情况 switch v := aux.Arguments.(type) { case map[string]interface{}: fc.Arguments = v case string: // 如果是字符串,尝试解析为JSON if err := json.Unmarshal([]byte(v), &fc.Arguments); err != nil { // 如果解析失败,创建一个包含原始字符串的map fc.Arguments = map[string]interface{}{ "raw": v, } } case nil: fc.Arguments = make(map[string]interface{}) default: // 其他类型,尝试转换为map fc.Arguments = map[string]interface{}{ "value": v, } } return nil } // AgentLoopResult Agent Loop执行结果 type AgentLoopResult struct { Response string MCPExecutionIDs []string } // ProgressCallback 进度回调函数类型 type ProgressCallback func(eventType, message string, data interface{}) // AgentLoop 执行Agent循环 func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) { return a.AgentLoopWithProgress(ctx, userInput, historyMessages, nil) } // AgentLoopWithProgress 执行Agent循环(带进度回调) func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, callback ProgressCallback) (*AgentLoopResult, error) { // 发送进度更新 sendProgress := func(eventType, message string, data interface{}) { if callback != nil { callback(eventType, message, data) } } // 系统提示词,指导AI如何处理工具错误 systemPrompt := `你是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。 重要:当工具调用失败时,请遵循以下原则: 1. 仔细分析错误信息,理解失败的具体原因 2. 如果工具不存在或未启用,尝试使用其他替代工具完成相同目标 3. 如果参数错误,根据错误提示修正参数后重试 4. 如果工具执行失败但输出了有用信息,可以基于这些信息继续分析 5. 如果确实无法使用某个工具,向用户说明问题,并建议替代方案或手动操作 6. 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务 当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。` messages := []ChatMessage{ { Role: "system", Content: systemPrompt, }, } // 添加历史消息(数据库只保存user和assistant消息) a.logger.Info("处理历史消息", zap.Int("count", len(historyMessages)), ) addedCount := 0 for i, msg := range historyMessages { // 只添加有内容的消息 if msg.Content != "" { messages = append(messages, ChatMessage{ Role: msg.Role, Content: msg.Content, }) addedCount++ contentPreview := msg.Content if len(contentPreview) > 50 { contentPreview = contentPreview[:50] + "..." } a.logger.Info("添加历史消息到上下文", zap.Int("index", i), zap.String("role", msg.Role), zap.String("content", contentPreview), ) } } a.logger.Info("构建消息数组", zap.Int("historyMessages", len(historyMessages)), zap.Int("addedMessages", addedCount), zap.Int("totalMessages", len(messages)), ) // 添加当前用户消息 messages = append(messages, ChatMessage{ Role: "user", Content: userInput, }) result := &AgentLoopResult{ MCPExecutionIDs: make([]string, 0), } maxIterations := 10 for i := 0; i < maxIterations; i++ { // 获取可用工具 tools := a.getAvailableTools() // 发送进度更新 if i == 0 { sendProgress("progress", "正在分析请求并制定测试策略...", nil) } else { sendProgress("progress", fmt.Sprintf("正在继续分析(第 %d 轮迭代)...", i+1), nil) } // 记录每次调用OpenAI if i == 0 { a.logger.Info("调用OpenAI", zap.Int("iteration", i+1), zap.Int("messagesCount", len(messages)), ) // 记录前几条消息的内容(用于调试) for j, msg := range messages { if j >= 5 { // 只记录前5条 break } contentPreview := msg.Content if len(contentPreview) > 100 { contentPreview = contentPreview[:100] + "..." } a.logger.Debug("消息内容", zap.Int("index", j), zap.String("role", msg.Role), zap.String("content", contentPreview), ) } } else { a.logger.Info("调用OpenAI", zap.Int("iteration", i+1), zap.Int("messagesCount", len(messages)), ) } // 调用OpenAI sendProgress("progress", "正在调用AI模型...", nil) response, err := a.callOpenAI(ctx, messages, tools) if err != nil { result.Response = "" return result, fmt.Errorf("调用OpenAI失败: %w", err) } if response.Error != nil { result.Response = "" return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message) } if len(response.Choices) == 0 { result.Response = "" return result, fmt.Errorf("没有收到响应") } choice := response.Choices[0] // 检查是否有工具调用 if len(choice.Message.ToolCalls) > 0 { // 添加assistant消息(包含工具调用) messages = append(messages, ChatMessage{ Role: "assistant", Content: choice.Message.Content, ToolCalls: choice.Message.ToolCalls, }) // 发送工具调用进度 sendProgress("progress", fmt.Sprintf("检测到 %d 个工具调用,开始执行...", len(choice.Message.ToolCalls)), nil) // 执行所有工具调用 for idx, toolCall := range choice.Message.ToolCalls { // 发送工具调用开始事件 toolArgsJSON, _ := json.Marshal(toolCall.Function.Arguments) sendProgress("tool_call", fmt.Sprintf("正在调用工具: %s", toolCall.Function.Name), map[string]interface{}{ "toolName": toolCall.Function.Name, "arguments": string(toolArgsJSON), "index": idx + 1, "total": len(choice.Message.ToolCalls), }) // 执行工具 execResult, err := a.executeToolViaMCP(ctx, toolCall.Function.Name, toolCall.Function.Arguments) if err != nil { // 构建详细的错误信息,帮助AI理解问题并做出决策 errorMsg := a.formatToolError(toolCall.Function.Name, toolCall.Function.Arguments, err) messages = append(messages, ChatMessage{ Role: "tool", ToolCallID: toolCall.ID, Content: errorMsg, }) // 发送工具执行失败事件 sendProgress("tool_result", fmt.Sprintf("工具 %s 执行失败", toolCall.Function.Name), map[string]interface{}{ "toolName": toolCall.Function.Name, "success": false, "error": err.Error(), "index": idx + 1, "total": len(choice.Message.ToolCalls), }) a.logger.Warn("工具执行失败,已返回详细错误信息", zap.String("tool", toolCall.Function.Name), zap.Error(err), ) } else { // 即使工具返回了错误结果(IsError=true),也继续处理,让AI决定下一步 messages = append(messages, ChatMessage{ Role: "tool", ToolCallID: toolCall.ID, Content: execResult.Result, }) // 收集执行ID if execResult.ExecutionID != "" { result.MCPExecutionIDs = append(result.MCPExecutionIDs, execResult.ExecutionID) } // 发送工具执行成功事件 resultPreview := execResult.Result if len(resultPreview) > 200 { resultPreview = resultPreview[:200] + "..." } sendProgress("tool_result", fmt.Sprintf("工具 %s 执行完成", toolCall.Function.Name), map[string]interface{}{ "toolName": toolCall.Function.Name, "success": !execResult.IsError, "isError": execResult.IsError, "result": resultPreview, "executionId": execResult.ExecutionID, "index": idx + 1, "total": len(choice.Message.ToolCalls), }) // 如果工具返回了错误,记录日志但不中断流程 if execResult.IsError { a.logger.Warn("工具返回错误结果,但继续处理", zap.String("tool", toolCall.Function.Name), zap.String("result", execResult.Result), ) } } } continue } // 添加assistant响应 messages = append(messages, ChatMessage{ Role: "assistant", Content: choice.Message.Content, }) // 如果完成,返回结果 if choice.FinishReason == "stop" { sendProgress("progress", "正在生成最终回复...", nil) result.Response = choice.Message.Content return result, nil } } result.Response = "达到最大迭代次数" return result, nil } // getAvailableTools 获取可用工具 // 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗 func (a *Agent) getAvailableTools() []Tool { // 从MCP服务器获取所有已注册的工具 mcpTools := a.mcpServer.GetAllTools() // 转换为OpenAI格式的工具定义 tools := make([]Tool, 0, len(mcpTools)) for _, mcpTool := range mcpTools { // 使用简短描述(如果存在),否则使用详细描述 description := mcpTool.ShortDescription if description == "" { description = mcpTool.Description } // 转换schema中的类型为OpenAI标准类型 convertedSchema := a.convertSchemaTypes(mcpTool.InputSchema) tools = append(tools, Tool{ Type: "function", Function: FunctionDefinition{ Name: mcpTool.Name, Description: description, // 使用简短描述减少token消耗 Parameters: convertedSchema, }, }) } a.logger.Debug("获取可用工具列表", zap.Int("count", len(tools)), ) return tools } // convertSchemaTypes 递归转换schema中的类型为OpenAI标准类型 func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]interface{} { if schema == nil { return schema } // 创建新的schema副本 converted := make(map[string]interface{}) for k, v := range schema { converted[k] = v } // 转换properties中的类型 if properties, ok := converted["properties"].(map[string]interface{}); ok { convertedProperties := make(map[string]interface{}) for propName, propValue := range properties { if prop, ok := propValue.(map[string]interface{}); ok { convertedProp := make(map[string]interface{}) for pk, pv := range prop { if pk == "type" { // 转换类型 if typeStr, ok := pv.(string); ok { convertedProp[pk] = a.convertToOpenAIType(typeStr) } else { convertedProp[pk] = pv } } else { convertedProp[pk] = pv } } convertedProperties[propName] = convertedProp } else { convertedProperties[propName] = propValue } } converted["properties"] = convertedProperties } return converted } // convertToOpenAIType 将配置中的类型转换为OpenAI/JSON Schema标准类型 func (a *Agent) convertToOpenAIType(configType string) string { switch configType { case "bool": return "boolean" case "int", "integer": return "number" case "float", "double": return "number" case "string", "array", "object": return configType default: // 默认返回原类型 return configType } } // callOpenAI 调用OpenAI API func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) { reqBody := OpenAIRequest{ Model: a.config.Model, Messages: messages, } if len(tools) > 0 { reqBody.Tools = tools } jsonData, err := json.Marshal(reqBody) if err != nil { return nil, err } req, err := http.NewRequestWithContext(ctx, "POST", a.config.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+a.config.APIKey) resp, err := a.openAIClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } // 记录响应内容(用于调试) if resp.StatusCode != http.StatusOK { a.logger.Warn("OpenAI API返回非200状态码", zap.Int("status", resp.StatusCode), zap.String("body", string(body)), ) } var response OpenAIResponse if err := json.Unmarshal(body, &response); err != nil { a.logger.Error("解析OpenAI响应失败", zap.Error(err), zap.String("body", string(body)), ) return nil, fmt.Errorf("解析响应失败: %w, 响应内容: %s", err, string(body)) } return &response, nil } // parseToolCall 解析工具调用 func (a *Agent) parseToolCall(content string) (map[string]interface{}, error) { // 简单解析,实际应该更复杂 // 格式: [TOOL_CALL]tool_name:arg1=value1,arg2=value2 if !strings.HasPrefix(content, "[TOOL_CALL]") { return nil, fmt.Errorf("不是有效的工具调用格式") } parts := strings.Split(content[len("[TOOL_CALL]"):], ":") if len(parts) < 2 { return nil, fmt.Errorf("工具调用格式错误") } toolName := strings.TrimSpace(parts[0]) argsStr := strings.TrimSpace(parts[1]) args := make(map[string]interface{}) argPairs := strings.Split(argsStr, ",") for _, pair := range argPairs { kv := strings.Split(pair, "=") if len(kv) == 2 { args[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1]) } } args["_tool_name"] = toolName return args, nil } // ToolExecutionResult 工具执行结果 type ToolExecutionResult struct { Result string ExecutionID string IsError bool // 标记是否为错误结果 } // executeToolViaMCP 通过MCP执行工具 // 即使工具执行失败,也返回结果而不是错误,让AI能够处理错误情况 func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map[string]interface{}) (*ToolExecutionResult, error) { a.logger.Info("通过MCP执行工具", zap.String("tool", toolName), zap.Any("args", args), ) // 通过MCP服务器调用工具 result, executionID, err := a.mcpServer.CallTool(ctx, toolName, args) // 如果调用失败(如工具不存在),返回友好的错误信息而不是抛出异常 if err != nil { errorMsg := fmt.Sprintf(`工具调用失败 工具名称: %s 错误类型: 系统错误 错误详情: %v 可能的原因: - 工具 "%s" 不存在或未启用 - 系统配置问题 - 网络或权限问题 建议: - 检查工具名称是否正确 - 尝试使用其他替代工具 - 如果这是必需的工具,请向用户说明情况`, toolName, err, toolName) return &ToolExecutionResult{ Result: errorMsg, ExecutionID: executionID, IsError: true, }, nil // 返回 nil 错误,让调用者处理结果 } // 格式化结果 var resultText strings.Builder for _, content := range result.Content { resultText.WriteString(content.Text) resultText.WriteString("\n") } return &ToolExecutionResult{ Result: resultText.String(), ExecutionID: executionID, IsError: result != nil && result.IsError, }, nil } // formatToolError 格式化工具错误信息,提供更友好的错误描述 func (a *Agent) formatToolError(toolName string, args map[string]interface{}, err error) string { errorMsg := fmt.Sprintf(`工具执行失败 工具名称: %s 调用参数: %v 错误信息: %v 请分析错误原因并采取以下行动之一: 1. 如果参数错误,请修正参数后重试 2. 如果工具不可用,请尝试使用替代工具 3. 如果这是系统问题,请向用户说明情况并提供建议 4. 如果错误信息中包含有用信息,可以基于这些信息继续分析`, toolName, args, err) return errorMsg }