From 5fe5f5b71f3dfaba4c1c19aa731c4cb32073559b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Fri, 20 Mar 2026 01:03:40 +0800 Subject: [PATCH] Add files via upload --- internal/agent/agent.go | 328 ++++++++++++++++++++++++++++---- internal/handler/agent.go | 57 +++++- internal/openai/openai.go | 340 ++++++++++++++++++++++++++++++++++ internal/security/executor.go | 105 ++++++++++- web/static/js/monitor.js | 336 ++++++++++++++++++++++++++++++--- web/static/js/webshell.js | 17 +- 6 files changed, 1123 insertions(+), 60 deletions(-) diff --git a/internal/agent/agent.go b/internal/agent/agent.go index e9bcee6b..3ae5576d 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -15,6 +15,7 @@ import ( "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/openai" + "cyberstrike-ai/internal/security" "cyberstrike-ai/internal/storage" "go.uber.org/zap" @@ -196,6 +197,7 @@ type OpenAIRequest struct { Model string `json:"model"` Messages []ChatMessage `json:"messages"` Tools []Tool `json:"tools,omitempty"` + Stream bool `json:"stream,omitempty"` } // OpenAIResponse OpenAI API响应 @@ -529,6 +531,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his var currentReActInput string maxIterations := a.maxIterations + thinkingStreamSeq := 0 for i := 0; i < maxIterations; i++ { // 先获取本轮可用工具并统计 tools token,再压缩,以便压缩时预留 tools 占用的空间 tools := a.getAvailableTools(roleTools) @@ -630,7 +633,28 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // 调用OpenAI sendProgress("progress", "正在调用AI模型...", nil) - response, err := a.callOpenAI(ctx, messages, tools) + thinkingStreamSeq++ + thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq) + thinkingStreamStarted := false + + response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error { + if delta == "" { + return nil + } + if !thinkingStreamStarted { + thinkingStreamStarted = true + sendProgress("thinking_stream_start", " ", map[string]interface{}{ + "streamId": thinkingStreamId, + "iteration": i + 1, + "toolStream": false, + }) + } + sendProgress("thinking_stream_delta", delta, map[string]interface{}{ + "streamId": thinkingStreamId, + "iteration": i + 1, + }) + return nil + }) if err != nil { // API调用失败,保存当前的ReAct输入和错误信息作为输出 result.LastReActInput = currentReActInput @@ -682,10 +706,12 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his // 检查是否有工具调用 if len(choice.Message.ToolCalls) > 0 { - // 如果有思考内容,先发送思考事件 + // 思考内容:如果本轮启用了思考流式增量(thinking_stream_*),前端会去重; + // 同时也需要在该“思考阶段结束”时补一条可落库的 thinking(用于刷新后持久化展示)。 if choice.Message.Content != "" { sendProgress("thinking", choice.Message.Content, map[string]interface{}{ "iteration": i + 1, + "streamId": thinkingStreamId, }) } @@ -717,7 +743,21 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his }) // 执行工具 - execResult, err := a.executeToolViaMCP(ctx, toolCall.Function.Name, toolCall.Function.Arguments) + toolCtx := context.WithValue(ctx, security.ToolOutputCallbackCtxKey, security.ToolOutputCallback(func(chunk string) { + if strings.TrimSpace(chunk) == "" { + return + } + sendProgress("tool_result_delta", chunk, map[string]interface{}{ + "toolName": toolCall.Function.Name, + "toolCallId": toolCall.ID, + "index": idx + 1, + "total": len(choice.Message.ToolCalls), + "iteration": i + 1, + // success 在最终 tool_result 事件里会以 success/isError 标记为准 + }) + })) + + execResult, err := a.executeToolViaMCP(toolCtx, toolCall.Function.Name, toolCall.Function.Arguments) if err != nil { // 构建详细的错误信息,帮助AI理解问题并做出决策 errorMsg := a.formatToolError(toolCall.Function.Name, toolCall.Function.Arguments, err) @@ -792,16 +832,23 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。", }) messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留 - // 立即调用OpenAI获取总结 - summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复 - if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 { - summaryChoice := summaryResponse.Choices[0] - if summaryChoice.Message.Content != "" { - result.Response = summaryChoice.Message.Content - result.LastReActOutput = result.Response - sendProgress("progress", "总结生成完成", nil) - return result, nil - } + // 流式调用OpenAI获取总结(不提供工具,强制AI直接回复) + sendProgress("response_start", "", map[string]interface{}{ + "conversationId": conversationID, + "mcpExecutionIds": result.MCPExecutionIDs, + "messageGeneratedBy": "summary", + }) + streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { + sendProgress("response_delta", delta, map[string]interface{}{ + "conversationId": conversationID, + }) + return nil + }) + if strings.TrimSpace(streamText) != "" { + result.Response = streamText + result.LastReActOutput = result.Response + sendProgress("progress", "总结生成完成", nil) + return result, nil } // 如果获取总结失败,跳出循环,让后续逻辑处理 break @@ -817,7 +864,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his }) // 发送AI思考内容(如果没有工具调用) - if choice.Message.Content != "" { + if choice.Message.Content != "" && !thinkingStreamStarted { sendProgress("thinking", choice.Message.Content, map[string]interface{}{ "iteration": i + 1, }) @@ -832,16 +879,23 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his Content: "这是最后一次迭代。请总结到目前为止的所有测试结果、发现的问题和已完成的工作。如果需要继续测试,请提供详细的下一步执行计划。请直接回复,不要调用工具。", }) messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留 - // 立即调用OpenAI获取总结 - summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复 - if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 { - summaryChoice := summaryResponse.Choices[0] - if summaryChoice.Message.Content != "" { - result.Response = summaryChoice.Message.Content - result.LastReActOutput = result.Response - sendProgress("progress", "总结生成完成", nil) - return result, nil - } + // 流式调用OpenAI获取总结(不提供工具,强制AI直接回复) + sendProgress("response_start", "", map[string]interface{}{ + "conversationId": conversationID, + "mcpExecutionIds": result.MCPExecutionIDs, + "messageGeneratedBy": "summary", + }) + streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { + sendProgress("response_delta", delta, map[string]interface{}{ + "conversationId": conversationID, + }) + return nil + }) + if strings.TrimSpace(streamText) != "" { + result.Response = streamText + result.LastReActOutput = result.Response + sendProgress("progress", "总结生成完成", nil) + return result, nil } // 如果获取总结失败,使用当前回复作为结果 if choice.Message.Content != "" { @@ -872,15 +926,23 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his messages = append(messages, finalSummaryPrompt) messages = a.applyMemoryCompression(ctx, messages, 0) // 总结时不带 tools,不预留 - summaryResponse, err := a.callOpenAI(ctx, messages, []Tool{}) // 不提供工具,强制AI直接回复 - if err == nil && summaryResponse != nil && len(summaryResponse.Choices) > 0 { - summaryChoice := summaryResponse.Choices[0] - if summaryChoice.Message.Content != "" { - result.Response = summaryChoice.Message.Content - result.LastReActOutput = result.Response - sendProgress("progress", "总结生成完成", nil) - return result, nil - } + // 流式调用OpenAI获取总结(不提供工具,强制AI直接回复) + sendProgress("response_start", "", map[string]interface{}{ + "conversationId": conversationID, + "mcpExecutionIds": result.MCPExecutionIDs, + "messageGeneratedBy": "max_iter_summary", + }) + streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { + sendProgress("response_delta", delta, map[string]interface{}{ + "conversationId": conversationID, + }) + return nil + }) + if strings.TrimSpace(streamText) != "" { + result.Response = streamText + result.LastReActOutput = result.Response + sendProgress("progress", "总结生成完成", nil) + return result, nil } // 如果无法生成总结,返回友好的提示 @@ -1200,6 +1262,206 @@ func (a *Agent) callOpenAISingle(ctx context.Context, messages []ChatMessage, to return &response, nil } +// callOpenAISingleStreamText 单次调用OpenAI的流式模式,只用于“不会调用工具”的纯文本输出(tools 为空时最佳)。 +// onDelta 每收到一段 content delta,就回调一次;如果 callback 返回错误,会终止读取并返回错误。 +func (a *Agent) callOpenAISingleStreamText(ctx context.Context, messages []ChatMessage, tools []Tool, onDelta func(delta string) error) (string, error) { + reqBody := OpenAIRequest{ + Model: a.config.Model, + Messages: messages, + Stream: true, + } + if len(tools) > 0 { + reqBody.Tools = tools + } + + if a.openAIClient == nil { + return "", fmt.Errorf("OpenAI客户端未初始化") + } + + return a.openAIClient.ChatCompletionStream(ctx, reqBody, onDelta) +} + +// callOpenAIStreamText 调用OpenAI流式模式(带重试),仅在“未输出任何 delta”时才允许重试,避免重复发送已下发的内容。 +func (a *Agent) callOpenAIStreamText(ctx context.Context, messages []ChatMessage, tools []Tool, onDelta func(delta string) error) (string, error) { + maxRetries := 3 + var lastErr error + + for attempt := 0; attempt < maxRetries; attempt++ { + var deltasSent bool + full, err := a.callOpenAISingleStreamText(ctx, messages, tools, func(delta string) error { + deltasSent = true + return onDelta(delta) + }) + if err == nil { + if attempt > 0 { + a.logger.Info("OpenAI stream 调用重试成功", + zap.Int("attempt", attempt+1), + zap.Int("maxRetries", maxRetries), + ) + } + return full, nil + } + + lastErr = err + // 已经开始输出了 delta,避免重复内容:直接失败让上层处理。 + if deltasSent { + return "", err + } + + if !a.isRetryableError(err) { + return "", err + } + + if attempt < maxRetries-1 { + backoff := time.Duration(1< 30*time.Second { + backoff = 30 * time.Second + } + a.logger.Warn("OpenAI stream 调用失败,准备重试", + zap.Error(err), + zap.Int("attempt", attempt+1), + zap.Int("maxRetries", maxRetries), + zap.Duration("backoff", backoff), + ) + + select { + case <-ctx.Done(): + return "", fmt.Errorf("上下文已取消: %w", ctx.Err()) + case <-time.After(backoff): + } + } + } + + return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr) +} + +// callOpenAISingleStreamWithToolCalls 单次调用OpenAI流式模式(带工具调用解析),不包含重试逻辑。 +func (a *Agent) callOpenAISingleStreamWithToolCalls( + ctx context.Context, + messages []ChatMessage, + tools []Tool, + onContentDelta func(delta string) error, +) (*OpenAIResponse, error) { + reqBody := OpenAIRequest{ + Model: a.config.Model, + Messages: messages, + Stream: true, + } + if len(tools) > 0 { + reqBody.Tools = tools + } + if a.openAIClient == nil { + return nil, fmt.Errorf("OpenAI客户端未初始化") + } + + content, streamToolCalls, finishReason, err := a.openAIClient.ChatCompletionStreamWithToolCalls(ctx, reqBody, onContentDelta) + if err != nil { + return nil, err + } + + toolCalls := make([]ToolCall, 0, len(streamToolCalls)) + for _, stc := range streamToolCalls { + fnArgsStr := stc.FunctionArgsStr + args := make(map[string]interface{}) + if strings.TrimSpace(fnArgsStr) != "" { + if err := json.Unmarshal([]byte(fnArgsStr), &args); err != nil { + // 兼容:arguments 不一定是严格 JSON + args = map[string]interface{}{"raw": fnArgsStr} + } + } + + typ := stc.Type + if strings.TrimSpace(typ) == "" { + typ = "function" + } + + toolCalls = append(toolCalls, ToolCall{ + ID: stc.ID, + Type: typ, + Function: FunctionCall{ + Name: stc.FunctionName, + Arguments: args, + }, + }) + } + + response := &OpenAIResponse{ + ID: "", + Choices: []Choice{ + { + Message: MessageWithTools{ + Role: "assistant", + Content: content, + ToolCalls: toolCalls, + }, + FinishReason: finishReason, + }, + }, + } + return response, nil +} + +// callOpenAIStreamWithToolCalls 调用OpenAI流式模式(带重试),仅当还没有输出任何 content delta 时才允许重试。 +func (a *Agent) callOpenAIStreamWithToolCalls( + ctx context.Context, + messages []ChatMessage, + tools []Tool, + onContentDelta func(delta string) error, +) (*OpenAIResponse, error) { + maxRetries := 3 + var lastErr error + + for attempt := 0; attempt < maxRetries; attempt++ { + deltasSent := false + resp, err := a.callOpenAISingleStreamWithToolCalls(ctx, messages, tools, func(delta string) error { + deltasSent = true + if onContentDelta != nil { + return onContentDelta(delta) + } + return nil + }) + if err == nil { + if attempt > 0 { + a.logger.Info("OpenAI stream 调用重试成功", + zap.Int("attempt", attempt+1), + zap.Int("maxRetries", maxRetries), + ) + } + return resp, nil + } + + lastErr = err + if deltasSent { + // 已经开始输出了 delta:避免重复发送 + return nil, err + } + + if !a.isRetryableError(err) { + return nil, err + } + if attempt < maxRetries-1 { + backoff := time.Duration(1< 30*time.Second { + backoff = 30 * time.Second + } + a.logger.Warn("OpenAI stream 调用失败,准备重试", + zap.Error(err), + zap.Int("attempt", attempt+1), + zap.Int("maxRetries", maxRetries), + zap.Duration("backoff", backoff), + ) + + select { + case <-ctx.Done(): + return nil, fmt.Errorf("上下文已取消: %w", ctx.Err()) + case <-time.After(backoff): + } + } + } + + return nil, fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr) +} + // ToolExecutionResult 工具执行结果 type ToolExecutionResult struct { Result string diff --git a/internal/handler/agent.go b/internal/handler/agent.go index e1582a4f..989b43ea 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -662,8 +662,16 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID } } - // 保存过程详情到数据库(排除response和done事件,它们会在后面单独处理) - if assistantMessageID != "" && eventType != "response" && eventType != "done" { + // 保存过程详情到数据库(排除response/done事件,它们会在后面单独处理) + // 另外:response_start/response_delta 是模型流式增量,保存会导致过程详情膨胀,因此不落库。 + if assistantMessageID != "" && + eventType != "response" && + eventType != "done" && + eventType != "response_start" && + eventType != "response_delta" && + eventType != "tool_result_delta" && + eventType != "thinking_stream_start" && + eventType != "thinking_stream_delta" { if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil { h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", eventType)) } @@ -703,8 +711,53 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { // 发送初始事件 // 用于跟踪客户端是否已断开连接 clientDisconnected := false + // 用于快速确认模型是否真的产生了流式 delta + var responseDeltaCount int + var responseStartLogged bool sendEvent := func(eventType, message string, data interface{}) { + if eventType == "response_start" { + responseDeltaCount = 0 + responseStartLogged = true + h.logger.Info("SSE: response_start", + zap.Int("conversationIdPresent", func() int { + if m, ok := data.(map[string]interface{}); ok { + if v, ok2 := m["conversationId"]; ok2 && v != nil && fmt.Sprint(v) != "" { + return 1 + } + } + return 0 + }()), + zap.String("messageGeneratedBy", func() string { + if m, ok := data.(map[string]interface{}); ok { + if v, ok2 := m["messageGeneratedBy"]; ok2 { + if s, ok3 := v.(string); ok3 { + return s + } + return fmt.Sprint(v) + } + } + return "" + }()), + ) + } else if eventType == "response_delta" { + responseDeltaCount++ + // 只打前几条,避免刷屏 + if responseStartLogged && responseDeltaCount <= 3 { + h.logger.Info("SSE: response_delta", + zap.Int("index", responseDeltaCount), + zap.Int("deltaLen", len(message)), + zap.String("deltaPreview", func() string { + p := strings.ReplaceAll(message, "\n", "\\n") + if len(p) > 80 { + return p[:80] + "..." + } + return p + }()), + ) + } + } + // 如果客户端已断开,不再发送事件 if clientDisconnected { return diff --git a/internal/openai/openai.go b/internal/openai/openai.go index e07f0d61..637ee48b 100644 --- a/internal/openai/openai.go +++ b/internal/openai/openai.go @@ -1,6 +1,7 @@ package openai import ( + "bufio" "bytes" "context" "encoding/json" @@ -142,3 +143,342 @@ func (c *Client) ChatCompletion(ctx context.Context, payload interface{}, out in return nil } + +// ChatCompletionStream 调用 /chat/completions 的流式模式(stream=true),并在每个 delta 到达时回调 onDelta。 +// 返回最终拼接的 content(只拼 content delta;工具调用 delta 未做处理)。 +func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{}, onDelta func(delta string) error) (string, error) { + if c == nil { + return "", fmt.Errorf("openai client is not initialized") + } + if c.config == nil { + return "", fmt.Errorf("openai config is nil") + } + if strings.TrimSpace(c.config.APIKey) == "" { + return "", fmt.Errorf("openai api key is empty") + } + + baseURL := strings.TrimSuffix(c.config.BaseURL, "/") + if baseURL == "" { + baseURL = "https://api.openai.com/v1" + } + + body, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("marshal openai payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("build openai request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.config.APIKey) + + requestStart := time.Now() + resp, err := c.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("call openai api: %w", err) + } + defer resp.Body.Close() + + // 非200:读完 body 返回 + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return "", &APIError{ + StatusCode: resp.StatusCode, + Body: string(respBody), + } + } + + type streamDelta struct { + // OpenAI 兼容流式通常使用 content;但部分兼容实现可能用 text。 + Content string `json:"content,omitempty"` + Text string `json:"text,omitempty"` + } + type streamChoice struct { + Delta streamDelta `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` + } + type streamResponse struct { + ID string `json:"id,omitempty"` + Choices []streamChoice `json:"choices"` + Error *struct { + Message string `json:"message"` + Type string `json:"type"` + } `json:"error,omitempty"` + } + + reader := bufio.NewReader(resp.Body) + var full strings.Builder + + // 典型 SSE 结构: + // data: {...}\n\n + // data: [DONE]\n\n + for { + line, readErr := reader.ReadString('\n') + if readErr != nil { + if readErr == io.EOF { + break + } + return full.String(), fmt.Errorf("read openai stream: %w", readErr) + } + trimmed := strings.TrimSpace(line) + if trimmed == "" { + continue + } + if !strings.HasPrefix(trimmed, "data:") { + continue + } + dataStr := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + if dataStr == "[DONE]" { + break + } + + var chunk streamResponse + if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil { + // 解析失败跳过(兼容各种兼容层的差异) + continue + } + if chunk.Error != nil && strings.TrimSpace(chunk.Error.Message) != "" { + return full.String(), fmt.Errorf("openai stream error: %s", chunk.Error.Message) + } + if len(chunk.Choices) == 0 { + continue + } + + delta := chunk.Choices[0].Delta.Content + if delta == "" { + delta = chunk.Choices[0].Delta.Text + } + if delta == "" { + continue + } + + full.WriteString(delta) + if onDelta != nil { + if err := onDelta(delta); err != nil { + return full.String(), err + } + } + } + + c.logger.Debug("received OpenAI stream completion", + zap.Duration("duration", time.Since(requestStart)), + zap.Int("contentLen", full.Len()), + ) + + return full.String(), nil +} + +// StreamToolCall 流式工具调用的累积结果(arguments 以字符串形式拼接,留给上层再解析为 JSON)。 +type StreamToolCall struct { + Index int + ID string + Type string + FunctionName string + FunctionArgsStr string +} + +// ChatCompletionStreamWithToolCalls 流式模式:同时把 content delta 实时回调,并在结束后返回 tool_calls 和 finish_reason。 +func (c *Client) ChatCompletionStreamWithToolCalls( + ctx context.Context, + payload interface{}, + onContentDelta func(delta string) error, +) (string, []StreamToolCall, string, error) { + if c == nil { + return "", nil, "", fmt.Errorf("openai client is not initialized") + } + if c.config == nil { + return "", nil, "", fmt.Errorf("openai config is nil") + } + if strings.TrimSpace(c.config.APIKey) == "" { + return "", nil, "", fmt.Errorf("openai api key is empty") + } + + baseURL := strings.TrimSuffix(c.config.BaseURL, "/") + if baseURL == "" { + baseURL = "https://api.openai.com/v1" + } + + body, err := json.Marshal(payload) + if err != nil { + return "", nil, "", fmt.Errorf("marshal openai payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body)) + if err != nil { + return "", nil, "", fmt.Errorf("build openai request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.config.APIKey) + + requestStart := time.Now() + resp, err := c.httpClient.Do(req) + if err != nil { + return "", nil, "", fmt.Errorf("call openai api: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return "", nil, "", &APIError{ + StatusCode: resp.StatusCode, + Body: string(respBody), + } + } + + // delta tool_calls 的增量结构 + type toolCallFunctionDelta struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + } + type toolCallDelta struct { + Index int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Function toolCallFunctionDelta `json:"function,omitempty"` + } + type streamDelta2 struct { + Content string `json:"content,omitempty"` + Text string `json:"text,omitempty"` + ToolCalls []toolCallDelta `json:"tool_calls,omitempty"` + } + type streamChoice2 struct { + Delta streamDelta2 `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` + } + type streamResponse2 struct { + Choices []streamChoice2 `json:"choices"` + Error *struct { + Message string `json:"message"` + Type string `json:"type"` + } `json:"error,omitempty"` + } + + type toolCallAccum struct { + id string + typ string + name string + args strings.Builder + } + toolCallAccums := make(map[int]*toolCallAccum) + + reader := bufio.NewReader(resp.Body) + var full strings.Builder + finishReason := "" + + for { + line, readErr := reader.ReadString('\n') + if readErr != nil { + if readErr == io.EOF { + break + } + return full.String(), nil, finishReason, fmt.Errorf("read openai stream: %w", readErr) + } + trimmed := strings.TrimSpace(line) + if trimmed == "" { + continue + } + if !strings.HasPrefix(trimmed, "data:") { + continue + } + dataStr := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + if dataStr == "[DONE]" { + break + } + + var chunk streamResponse2 + if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil { + // 兼容:解析失败跳过 + continue + } + if chunk.Error != nil && strings.TrimSpace(chunk.Error.Message) != "" { + return full.String(), nil, finishReason, fmt.Errorf("openai stream error: %s", chunk.Error.Message) + } + if len(chunk.Choices) == 0 { + continue + } + + choice := chunk.Choices[0] + if choice.FinishReason != nil && strings.TrimSpace(*choice.FinishReason) != "" { + finishReason = strings.TrimSpace(*choice.FinishReason) + } + + delta := choice.Delta + + content := delta.Content + if content == "" { + content = delta.Text + } + if content != "" { + full.WriteString(content) + if onContentDelta != nil { + if err := onContentDelta(content); err != nil { + return full.String(), nil, finishReason, err + } + } + } + + if len(delta.ToolCalls) > 0 { + for _, tc := range delta.ToolCalls { + acc, ok := toolCallAccums[tc.Index] + if !ok { + acc = &toolCallAccum{} + toolCallAccums[tc.Index] = acc + } + if tc.ID != "" { + acc.id = tc.ID + } + if tc.Type != "" { + acc.typ = tc.Type + } + if tc.Function.Name != "" { + acc.name = tc.Function.Name + } + if tc.Function.Arguments != "" { + acc.args.WriteString(tc.Function.Arguments) + } + } + } + } + + // 组装 tool calls + indices := make([]int, 0, len(toolCallAccums)) + for idx := range toolCallAccums { + indices = append(indices, idx) + } + // 手写简单排序(避免额外 import) + for i := 0; i < len(indices); i++ { + for j := i + 1; j < len(indices); j++ { + if indices[j] < indices[i] { + indices[i], indices[j] = indices[j], indices[i] + } + } + } + + toolCalls := make([]StreamToolCall, 0, len(indices)) + for _, idx := range indices { + acc := toolCallAccums[idx] + tc := StreamToolCall{ + Index: idx, + ID: acc.id, + Type: acc.typ, + FunctionName: acc.name, + FunctionArgsStr: acc.args.String(), + } + toolCalls = append(toolCalls, tc) + } + + c.logger.Debug("received OpenAI stream completion (tool_calls)", + zap.Duration("duration", time.Since(requestStart)), + zap.Int("contentLen", full.Len()), + zap.Int("toolCalls", len(toolCalls)), + zap.String("finishReason", finishReason), + ) + + if strings.TrimSpace(finishReason) == "" { + finishReason = "stop" + } + + return full.String(), toolCalls, finishReason, nil +} diff --git a/internal/security/executor.go b/internal/security/executor.go index 393650fd..d56ac46d 100644 --- a/internal/security/executor.go +++ b/internal/security/executor.go @@ -9,6 +9,8 @@ import ( "os/exec" "strconv" "strings" + "sync" + "time" "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/mcp" @@ -17,6 +19,15 @@ import ( "go.uber.org/zap" ) +// ToolOutputCallback 用于在工具执行过程中把 stdout/stderr 增量推给上层(SSE)。 +// 通过 context 传递,避免修改 MCP ToolHandler 签名导致的“写死工具”问题。 +type ToolOutputCallback func(chunk string) + +type toolOutputCallbackCtxKey struct{} + +// ToolOutputCallbackCtxKey 是 context 中的 key,供 Agent 写入回调,Executor 读取并流式回调。 +var ToolOutputCallbackCtxKey = toolOutputCallbackCtxKey{} + // Executor 安全工具执行器 type Executor struct { config *config.SecurityConfig @@ -144,7 +155,16 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st zap.Strings("args", cmdArgs), ) - output, err := cmd.CombinedOutput() + var output string + var err error + // 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。 + if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil { + output, err = streamCommandOutput(cmd, cb) + } else { + outputBytes, err2 := cmd.CombinedOutput() + output = string(outputBytes) + err = err2 + } if err != nil { // 检查退出码是否在允许列表中 exitCode := getExitCode(err) @@ -931,7 +951,16 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int } // 非后台命令:等待输出 - output, err := cmd.CombinedOutput() + var output string + var err error + // 若上层提供工具输出增量回调,则边执行边流式读取。 + if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil { + output, err = streamCommandOutput(cmd, cb) + } else { + outputBytes, err2 := cmd.CombinedOutput() + output = string(outputBytes) + err = err2 + } if err != nil { e.logger.Error("系统命令执行失败", zap.String("command", command), @@ -965,6 +994,78 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int }, nil } +// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。 +// 保持输出内容完整拼接返回,并用 cb(chunk) 向上层持续推送。 +func streamCommandOutput(cmd *exec.Cmd, cb ToolOutputCallback) (string, error) { + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return "", err + } + stderrPipe, err := cmd.StderrPipe() + if err != nil { + _ = stdoutPipe.Close() + return "", err + } + if err := cmd.Start(); err != nil { + _ = stdoutPipe.Close() + _ = stderrPipe.Close() + return "", err + } + + chunks := make(chan string, 64) + var wg sync.WaitGroup + readFn := func(r io.Reader) { + defer wg.Done() + br := bufio.NewReader(r) + for { + s, readErr := br.ReadString('\n') + if s != "" { + chunks <- s + } + if readErr != nil { + // EOF 正常结束 + return + } + } + } + + wg.Add(2) + go readFn(stdoutPipe) + go readFn(stderrPipe) + + go func() { + wg.Wait() + close(chunks) + }() + + var outBuilder strings.Builder + var deltaBuilder strings.Builder + lastFlush := time.Now() + + flush := func() { + if deltaBuilder.Len() == 0 { + return + } + cb(deltaBuilder.String()) + deltaBuilder.Reset() + lastFlush = time.Now() + } + + for chunk := range chunks { + outBuilder.WriteString(chunk) + deltaBuilder.WriteString(chunk) + // 简单节流:buffer 大于 2KB 或 200ms 就刷新一次 + if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond { + flush() + } + } + flush() + + // 等待命令结束,返回最终退出状态 + waitErr := cmd.Wait() + return outBuilder.String(), waitErr +} + // executeInternalTool 执行内部工具(不执行外部命令) func (e *Executor) executeInternalTool(ctx context.Context, toolName string, command string, args map[string]interface{}) (*mcp.ToolResult, error) { // 提取内部工具类型(去掉 "internal:" 前缀) diff --git a/web/static/js/monitor.js b/web/static/js/monitor.js index 9a04cb88..ef55face 100644 --- a/web/static/js/monitor.js +++ b/web/static/js/monitor.js @@ -67,6 +67,75 @@ if (typeof window !== 'undefined') { // 存储工具调用ID到DOM元素的映射,用于更新执行状态 const toolCallStatusMap = new Map(); +// 模型流式输出缓存:progressId -> { assistantId, buffer } +const responseStreamStateByProgressId = new Map(); + +// AI 思考流式输出:progressId -> Map(streamId -> { itemId, buffer }) +const thinkingStreamStateByProgressId = new Map(); + +// 工具输出流式增量:progressId::toolCallId -> { itemId, buffer } +const toolResultStreamStateByKey = new Map(); +function toolResultStreamKey(progressId, toolCallId) { + return String(progressId) + '::' + String(toolCallId); +} + +// markdown 渲染(用于最终合并渲染;流式增量阶段用纯转义避免部分语法不稳定) +const assistantMarkdownSanitizeConfig = { + ALLOWED_TAGS: ['p', 'br', 'strong', 'em', 'u', 's', 'code', 'pre', 'blockquote', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'ul', 'ol', 'li', 'a', 'img', 'table', 'thead', 'tbody', 'tr', 'th', 'td', 'hr'], + ALLOWED_ATTR: ['href', 'title', 'alt', 'src', 'class'], + ALLOW_DATA_ATTR: false, +}; + +function escapeHtmlLocal(text) { + if (!text) return ''; + const div = document.createElement('div'); + div.textContent = String(text); + return div.innerHTML; +} + +function formatAssistantMarkdownContent(text) { + const raw = text == null ? '' : String(text); + if (typeof marked !== 'undefined') { + try { + marked.setOptions({ breaks: true, gfm: true }); + const parsed = marked.parse(raw); + if (typeof DOMPurify !== 'undefined') { + return DOMPurify.sanitize(parsed, assistantMarkdownSanitizeConfig); + } + return parsed; + } catch (e) { + return escapeHtmlLocal(raw).replace(/\n/g, '
'); + } + } + return escapeHtmlLocal(raw).replace(/\n/g, '
'); +} + +function updateAssistantBubbleContent(assistantMessageId, content, renderMarkdown) { + const assistantElement = document.getElementById(assistantMessageId); + if (!assistantElement) return; + const bubble = assistantElement.querySelector('.message-bubble'); + if (!bubble) return; + + // 保留复制按钮:addMessage 会把按钮 append 在 message-bubble 里 + const copyBtn = bubble.querySelector('.message-copy-btn'); + if (copyBtn) copyBtn.remove(); + + const newContent = content == null ? '' : String(content); + const html = renderMarkdown + ? formatAssistantMarkdownContent(newContent) + : escapeHtmlLocal(newContent).replace(/\n/g, '
'); + + bubble.innerHTML = html; + + // 更新原始内容(给复制功能用) + assistantElement.dataset.originalContent = newContent; + + if (typeof wrapTablesInBubble === 'function') { + wrapTablesInBubble(bubble); + } + if (copyBtn) bubble.appendChild(copyBtn); +} + const conversationExecutionTracker = { activeConversations: new Set(), update(tasks = []) { @@ -543,7 +612,77 @@ function handleStreamEvent(event, progressElement, progressId, }); break; + case 'thinking_stream_start': { + const d = event.data || {}; + const streamId = d.streamId || null; + if (!streamId) break; + + let state = thinkingStreamStateByProgressId.get(progressId); + if (!state) { + state = new Map(); + thinkingStreamStateByProgressId.set(progressId, state); + } + // 若已存在,重置 buffer + const title = '🤔 ' + (typeof window.t === 'function' ? window.t('chat.aiThinking') : 'AI思考'); + const itemId = addTimelineItem(timeline, 'thinking', { + title: title, + message: ' ', + data: d + }); + state.set(streamId, { itemId, buffer: '' }); + break; + } + + case 'thinking_stream_delta': { + const d = event.data || {}; + const streamId = d.streamId || null; + if (!streamId) break; + + const state = thinkingStreamStateByProgressId.get(progressId); + if (!state || !state.has(streamId)) break; + const s = state.get(streamId); + + const delta = event.message || ''; + s.buffer += delta; + + const item = document.getElementById(s.itemId); + if (item) { + const contentEl = item.querySelector('.timeline-item-content'); + if (contentEl) { + if (typeof formatMarkdown === 'function') { + contentEl.innerHTML = formatMarkdown(s.buffer); + } else { + contentEl.textContent = s.buffer; + } + } + } + break; + } + case 'thinking': + // 如果本 thinking 是由 thinking_stream_* 聚合出来的(带 streamId),避免重复创建 timeline item + if (event.data && event.data.streamId) { + const streamId = event.data.streamId; + const state = thinkingStreamStateByProgressId.get(progressId); + if (state && state.has(streamId)) { + const s = state.get(streamId); + s.buffer = event.message || ''; + const item = document.getElementById(s.itemId); + if (item) { + const contentEl = item.querySelector('.timeline-item-content'); + if (contentEl) { + // contentEl.innerHTML 用于兼容 Markdown 展示 + if (typeof formatMarkdown === 'function') { + contentEl.innerHTML = formatMarkdown(s.buffer); + } else { + contentEl.textContent = s.buffer; + } + } + } + break; + } + } + addTimelineItem(timeline, 'thinking', { title: '🤔 ' + (typeof window.t === 'function' ? window.t('chat.aiThinking') : 'AI思考'), message: event.message, @@ -584,6 +723,55 @@ function handleStreamEvent(event, progressElement, progressId, updateToolCallStatus(toolCallId, 'running'); } break; + + case 'tool_result_delta': { + const deltaInfo = event.data || {}; + const toolCallId = deltaInfo.toolCallId || null; + if (!toolCallId) break; + + const key = toolResultStreamKey(progressId, toolCallId); + let state = toolResultStreamStateByKey.get(key); + const toolNameDelta = deltaInfo.toolName || (typeof window.t === 'function' ? window.t('chat.unknownTool') : '未知工具'); + const deltaText = event.message || ''; + if (!deltaText) break; + + if (!state) { + // 首次增量:创建一个 tool_result 占位条目,后续不断更新 pre 内容 + const runningLabel = typeof window.t === 'function' ? window.t('timeline.running') : '执行中...'; + const title = '⏳ ' + (typeof window.t === 'function' + ? window.t('timeline.running') + : runningLabel) + ' ' + (typeof window.t === 'function' ? window.t('chat.callTool', { name: escapeHtmlLocal(toolNameDelta), index: deltaInfo.index || 0, total: deltaInfo.total || 0 }) : toolNameDelta); + + const itemId = addTimelineItem(timeline, 'tool_result', { + title: title, + message: '', + data: { + toolName: toolNameDelta, + success: true, + isError: false, + result: deltaText, + toolCallId: toolCallId, + index: deltaInfo.index, + total: deltaInfo.total, + iteration: deltaInfo.iteration + }, + expanded: false + }); + + state = { itemId, buffer: '' }; + toolResultStreamStateByKey.set(key, state); + } + + state.buffer += deltaText; + const item = document.getElementById(state.itemId); + if (item) { + const pre = item.querySelector('pre.tool-result'); + if (pre) { + pre.textContent = state.buffer; + } + } + break; + } case 'tool_result': const resultInfo = event.data || {}; @@ -592,6 +780,39 @@ function handleStreamEvent(event, progressElement, progressId, const statusIcon = success ? '✅' : '❌'; const resultToolCallId = resultInfo.toolCallId || null; const resultExecText = success ? (typeof window.t === 'function' ? window.t('chat.toolExecComplete', { name: escapeHtml(resultToolName) }) : '工具 ' + escapeHtml(resultToolName) + ' 执行完成') : (typeof window.t === 'function' ? window.t('chat.toolExecFailed', { name: escapeHtml(resultToolName) }) : '工具 ' + escapeHtml(resultToolName) + ' 执行失败'); + + // 若此 tool 已经流式推送过增量,则复用占位条目并更新最终结果,避免重复添加一条 + if (resultToolCallId) { + const key = toolResultStreamKey(progressId, resultToolCallId); + const state = toolResultStreamStateByKey.get(key); + if (state && state.itemId) { + const item = document.getElementById(state.itemId); + if (item) { + const pre = item.querySelector('pre.tool-result'); + const resultVal = resultInfo.result || resultInfo.error || ''; + if (pre) pre.textContent = typeof resultVal === 'string' ? resultVal : JSON.stringify(resultVal); + + const section = item.querySelector('.tool-result-section'); + if (section) { + section.className = 'tool-result-section ' + (success ? 'success' : 'error'); + } + + const titleEl = item.querySelector('.timeline-item-title'); + if (titleEl) { + titleEl.textContent = statusIcon + ' ' + resultExecText; + } + } + toolResultStreamStateByKey.delete(key); + + // 同时更新 tool_call 的状态 + if (resultToolCallId && toolCallStatusMap.has(resultToolCallId)) { + updateToolCallStatus(resultToolCallId, success ? 'completed' : 'failed'); + toolCallStatusMap.delete(resultToolCallId); + } + break; + } + } + if (resultToolCallId && toolCallStatusMap.has(resultToolCallId)) { updateToolCallStatus(resultToolCallId, success ? 'completed' : 'failed'); toolCallStatusMap.delete(resultToolCallId); @@ -683,47 +904,108 @@ function handleStreamEvent(event, progressElement, progressId, loadActiveTasks(); break; - case 'response': - // 在更新之前,先获取任务对应的原始对话ID + case 'response_start': { const responseTaskState = progressTaskState.get(progressId); const responseOriginalConversationId = responseTaskState?.conversationId; - - // 先添加助手回复 + const responseData = event.data || {}; const mcpIds = responseData.mcpExecutionIds || []; setMcpIds(mcpIds); - - // 更新对话ID + if (responseData.conversationId) { - // 如果用户已经开始了新对话(currentConversationId 为 null), - // 且这个 response 事件来自旧对话,就不更新 currentConversationId 也不添加消息 + // 如果用户已经开始了新对话(currentConversationId 为 null),且这个事件来自旧对话,则忽略 if (currentConversationId === null && responseOriginalConversationId !== null) { - // 用户已经开始了新对话,忽略旧对话的 response 事件 - // 但仍然更新任务状态,以便正确显示任务信息 updateProgressConversation(progressId, responseData.conversationId); break; } - currentConversationId = responseData.conversationId; updateActiveConversation(); addAttackChainButton(currentConversationId); updateProgressConversation(progressId, responseData.conversationId); loadActiveTasks(); } - - // 添加助手回复,并传入进度ID以便集成详情 - const assistantId = addMessage('assistant', event.message, mcpIds, progressId); + + // 已存在则复用;否则创建空助手消息占位,用于增量追加 + const existing = responseStreamStateByProgressId.get(progressId); + if (existing && existing.assistantId) break; + + const assistantId = addMessage('assistant', '', mcpIds, progressId); setAssistantId(assistantId); - - // 将进度详情集成到工具调用区域 - integrateProgressToMCPSection(progressId, assistantId); - - // 延迟自动折叠详情(3秒后) + responseStreamStateByProgressId.set(progressId, { assistantId, buffer: '' }); + break; + } + + case 'response_delta': { + const responseData = event.data || {}; + const responseTaskState = progressTaskState.get(progressId); + const responseOriginalConversationId = responseTaskState?.conversationId; + + if (responseData.conversationId) { + if (currentConversationId === null && responseOriginalConversationId !== null) { + updateProgressConversation(progressId, responseData.conversationId); + break; + } + } + + let state = responseStreamStateByProgressId.get(progressId); + if (!state || !state.assistantId) { + const mcpIds = responseData.mcpExecutionIds || []; + const assistantId = addMessage('assistant', '', mcpIds, progressId); + setAssistantId(assistantId); + state = { assistantId, buffer: '' }; + responseStreamStateByProgressId.set(progressId, state); + } + + state.buffer += (event.message || ''); + updateAssistantBubbleContent(state.assistantId, state.buffer, false); + break; + } + + case 'response': + // 在更新之前,先获取任务对应的原始对话ID + const responseTaskState = progressTaskState.get(progressId); + const responseOriginalConversationId = responseTaskState?.conversationId; + + // 先更新 mcp ids + const responseData = event.data || {}; + const mcpIds = responseData.mcpExecutionIds || []; + setMcpIds(mcpIds); + + // 更新对话ID + if (responseData.conversationId) { + if (currentConversationId === null && responseOriginalConversationId !== null) { + updateProgressConversation(progressId, responseData.conversationId); + break; + } + + currentConversationId = responseData.conversationId; + updateActiveConversation(); + addAttackChainButton(currentConversationId); + updateProgressConversation(progressId, responseData.conversationId); + loadActiveTasks(); + } + + // 如果之前已经在 response_start/response_delta 阶段创建过占位,则复用该消息更新最终内容 + const streamState = responseStreamStateByProgressId.get(progressId); + const existingAssistantId = streamState?.assistantId || getAssistantId(); + let assistantIdFinal = existingAssistantId; + + if (!assistantIdFinal) { + assistantIdFinal = addMessage('assistant', event.message, mcpIds, progressId); + setAssistantId(assistantIdFinal); + } else { + setAssistantId(assistantIdFinal); + updateAssistantBubbleContent(assistantIdFinal, event.message, true); + } + + // 将进度详情集成到工具调用区域(放在最终 response 之后,保证时间线已完整) + integrateProgressToMCPSection(progressId, assistantIdFinal); + responseStreamStateByProgressId.delete(progressId); + setTimeout(() => { - collapseAllProgressDetails(assistantId, progressId); + collapseAllProgressDetails(assistantIdFinal, progressId); }, 3000); - - // 延迟刷新对话列表,确保助手消息已保存,updated_at已更新 + setTimeout(() => { loadConversations(); }, 200); @@ -802,6 +1084,16 @@ function handleStreamEvent(event, progressElement, progressId, break; case 'done': + // 清理流式输出状态 + responseStreamStateByProgressId.delete(progressId); + thinkingStreamStateByProgressId.delete(progressId); + // 清理工具流式输出占位 + const prefix = String(progressId) + '::'; + for (const key of Array.from(toolResultStreamStateByKey.keys())) { + if (String(key).startsWith(prefix)) { + toolResultStreamStateByKey.delete(key); + } + } // 完成,更新进度标题(如果进度消息还存在) const doneTitle = document.querySelector(`#${progressId} .progress-title`); if (doneTitle) { diff --git a/web/static/js/webshell.js b/web/static/js/webshell.js index fe615c1c..9f055898 100644 --- a/web/static/js/webshell.js +++ b/web/static/js/webshell.js @@ -797,10 +797,25 @@ function runWebshellAiSend(conn, inputEl, sendBtn, messagesContainer) { el.classList.toggle('active', el.dataset.convId === convId); }); }); + } else if (eventData.type === 'response_start') { + streamingTarget = ''; + webshellStreamingTypingId += 1; + streamingTypingId = webshellStreamingTypingId; + assistantDiv.textContent = '…'; + messagesContainer.scrollTop = messagesContainer.scrollHeight; + } else if (eventData.type === 'response_delta') { + var deltaText = (eventData.message != null && eventData.message !== '') ? String(eventData.message) : ''; + if (deltaText) { + streamingTarget += deltaText; + webshellStreamingTypingId += 1; + streamingTypingId = webshellStreamingTypingId; + runWebshellAiStreamingTyping(assistantDiv, streamingTarget, streamingTypingId, messagesContainer); + } } else if (eventData.type === 'response') { var text = (eventData.message != null && eventData.message !== '') ? eventData.message : (eventData.data && typeof eventData.data === 'string' ? eventData.data : ''); if (text) { - streamingTarget += text; + // response 为最终完整内容:避免与增量重复拼接 + streamingTarget = String(text); webshellStreamingTypingId += 1; streamingTypingId = webshellStreamingTypingId; runWebshellAiStreamingTyping(assistantDiv, streamingTarget, streamingTypingId, messagesContainer);