mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-17 21:44:43 +02:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bb9e3f9477 | |||
| a57720fb29 | |||
| 9e34b480e7 | |||
| cd30953a84 | |||
| a273d6d7ba | |||
| 87d9e50781 | |||
| 54b9e2e2fa | |||
| 946d347dc9 | |||
| ed8c0b15dd | |||
| f658cc6e93 | |||
| 7bf0697526 | |||
| 7e8cc3e2b8 |
+1
-1
@@ -10,7 +10,7 @@
|
||||
# ============================================
|
||||
|
||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||
version: "v1.4.10"
|
||||
version: "v1.4.13"
|
||||
# 服务器配置
|
||||
server:
|
||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||
|
||||
@@ -675,6 +675,7 @@ func setupRoutes(
|
||||
protected.DELETE("/groups/:id", groupHandler.DeleteGroup)
|
||||
protected.PUT("/groups/:id/pinned", groupHandler.UpdateGroupPinned)
|
||||
protected.GET("/groups/:id/conversations", groupHandler.GetGroupConversations)
|
||||
protected.GET("/groups/mappings", groupHandler.GetAllMappings)
|
||||
protected.POST("/groups/conversations", groupHandler.AddConversationToGroup)
|
||||
protected.DELETE("/groups/:id/conversations/:conversationId", groupHandler.RemoveConversationFromGroup)
|
||||
protected.PUT("/groups/:id/conversations/:conversationId/pinned", groupHandler.UpdateConversationPinnedInGroup)
|
||||
@@ -682,6 +683,7 @@ func setupRoutes(
|
||||
// 监控
|
||||
protected.GET("/monitor", monitorHandler.Monitor)
|
||||
protected.GET("/monitor/execution/:id", monitorHandler.GetExecution)
|
||||
protected.POST("/monitor/executions/names", monitorHandler.BatchGetToolNames)
|
||||
protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution)
|
||||
protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions)
|
||||
protected.GET("/monitor/stats", monitorHandler.GetStats)
|
||||
@@ -691,6 +693,7 @@ func setupRoutes(
|
||||
protected.GET("/config/tools", configHandler.GetTools)
|
||||
protected.PUT("/config", configHandler.UpdateConfig)
|
||||
protected.POST("/config/apply", configHandler.ApplyConfig)
|
||||
protected.POST("/config/test-openai", configHandler.TestOpenAI)
|
||||
|
||||
// 系统设置 - 终端(执行命令,提高运维效率)
|
||||
protected.POST("/terminal/run", terminalHandler.RunCommand)
|
||||
|
||||
@@ -310,15 +310,14 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
|
||||
var err error
|
||||
|
||||
if search != "" {
|
||||
// 使用LIKE进行模糊搜索,搜索标题和消息内容
|
||||
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
||||
searchPattern := "%" + search + "%"
|
||||
// 使用DISTINCT避免重复,因为一个对话可能有多条消息匹配
|
||||
rows, err = db.Query(
|
||||
`SELECT DISTINCT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at
|
||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at
|
||||
FROM conversations c
|
||||
LEFT JOIN messages m ON c.id = m.conversation_id
|
||||
WHERE c.title LIKE ? OR m.content LIKE ?
|
||||
ORDER BY c.updated_at DESC
|
||||
WHERE c.title LIKE ?
|
||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
|
||||
ORDER BY c.updated_at DESC
|
||||
LIMIT ? OFFSET ?`,
|
||||
searchPattern, searchPattern, limit, offset,
|
||||
)
|
||||
|
||||
@@ -403,6 +403,35 @@ func (db *DB) UpdateGroupPinned(id string, pinned bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GroupMapping 分组映射关系
|
||||
type GroupMapping struct {
|
||||
ConversationID string `json:"conversationId"`
|
||||
GroupID string `json:"groupId"`
|
||||
}
|
||||
|
||||
// GetAllGroupMappings 批量获取所有分组映射(消除 N+1 查询)
|
||||
func (db *DB) GetAllGroupMappings() ([]GroupMapping, error) {
|
||||
rows, err := db.Query("SELECT conversation_id, group_id FROM conversation_group_mappings")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询分组映射失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var mappings []GroupMapping
|
||||
for rows.Next() {
|
||||
var m GroupMapping
|
||||
if err := rows.Scan(&m.ConversationID, &m.GroupID); err != nil {
|
||||
return nil, fmt.Errorf("扫描分组映射失败: %w", err)
|
||||
}
|
||||
mappings = append(mappings, m)
|
||||
}
|
||||
|
||||
if mappings == nil {
|
||||
mappings = []GroupMapping{}
|
||||
}
|
||||
return mappings, nil
|
||||
}
|
||||
|
||||
// UpdateConversationPinnedInGroup 更新对话在分组中的置顶状态
|
||||
func (db *DB) UpdateConversationPinnedInGroup(conversationID, groupID string, pinned bool) error {
|
||||
pinnedValue := 0
|
||||
|
||||
@@ -160,13 +160,17 @@ func runMCPToolInvocation(
|
||||
}
|
||||
|
||||
// UnknownToolReminderHandler 供 compose.ToolsNodeConfig.UnknownToolsHandler 使用:
|
||||
// 模型请求了未注册的工具名时,仅返回说明性文本,error 恒为 nil,以便 ReAct 继续迭代而不中断图执行。
|
||||
// 模型请求了未注册的工具名时,返回一个「可恢复」的错误,让上层 runner 触发重试与纠错提示,
|
||||
// 同时避免 UI 永远停留在“执行中”(runner 会在 recoverable 分支 flush 掉 pending 的 tool_call)。
|
||||
// 不进行名称猜测或映射,避免误执行。
|
||||
func UnknownToolReminderHandler() func(ctx context.Context, name, input string) (string, error) {
|
||||
return func(ctx context.Context, name, input string) (string, error) {
|
||||
_ = ctx
|
||||
_ = input
|
||||
return unknownToolReminderText(strings.TrimSpace(name)), nil
|
||||
requested := strings.TrimSpace(name)
|
||||
// Return a recoverable error that still carries a friendly, bilingual hint.
|
||||
// This will be caught by multiagent runner as "tool not found" and trigger a retry.
|
||||
return "", fmt.Errorf("tool %q not found: %s", requested, unknownToolReminderText(requested))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -990,6 +990,24 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
|
||||
return
|
||||
}
|
||||
|
||||
// 当 Agent 同时发送 thinking_stream_* 和 thinking(带同一 streamId)时,
|
||||
// thinking_stream_* 已经会在 flushThinkingStreams() 聚合落库;
|
||||
// 这里跳过同 streamId 的 thinking,避免 processDetails 双份展示。
|
||||
if eventType == "thinking" {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||
if tb, exists := thinkingStreams[sid]; exists && tb != nil {
|
||||
if strings.TrimSpace(tb.b.String()) != "" {
|
||||
return
|
||||
}
|
||||
}
|
||||
if flushedThinking[sid] {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存过程详情到数据库(排除 response/done;response 正文已在 messages 表)
|
||||
// response_start/response_delta 已聚合为 planning,不落逐条。
|
||||
if assistantMessageID != "" &&
|
||||
|
||||
@@ -3,7 +3,9 @@ package handler
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -754,6 +756,137 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||
}
|
||||
|
||||
// TestOpenAIRequest 测试OpenAI连接请求
|
||||
type TestOpenAIRequest struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// TestOpenAI 测试OpenAI API连接是否可用
|
||||
func (h *ConfigHandler) TestOpenAI(c *gin.Context) {
|
||||
var req TestOpenAIRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.APIKey) == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "API Key 不能为空"})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Model) == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "模型不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(strings.TrimSpace(req.BaseURL), "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
|
||||
// 构造一个最小的 chat completion 请求
|
||||
payload := map[string]interface{}{
|
||||
"model": req.Model,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "Hi"},
|
||||
},
|
||||
"max_tokens": 5,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "构造请求失败"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "构造HTTP请求失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(req.APIKey))
|
||||
|
||||
start := time.Now()
|
||||
resp, err := http.DefaultClient.Do(httpReq)
|
||||
latency := time.Since(start)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"error": "连接失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// 尝试提取错误信息
|
||||
var errResp struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
errMsg := string(respBody)
|
||||
if json.Unmarshal(respBody, &errResp) == nil && errResp.Error.Message != "" {
|
||||
errMsg = errResp.Error.Message
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("API 返回错误 (HTTP %d): %s", resp.StatusCode, errMsg),
|
||||
"status_code": resp.StatusCode,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析响应并严格验证是否为有效的 chat completion 响应
|
||||
var chatResp struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &chatResp); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"error": "API 响应不是有效的 JSON,请检查 Base URL 是否正确",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 严格校验:必须包含 choices 且有 assistant 回复
|
||||
if len(chatResp.Choices) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"error": "API 响应缺少 choices 字段,请检查 Base URL 路径是否正确(通常以 /v1 结尾)",
|
||||
})
|
||||
return
|
||||
}
|
||||
if chatResp.ID == "" && chatResp.Model == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"error": "API 响应格式不符合 OpenAI 规范,请检查 Base URL 是否正确",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"model": chatResp.Model,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
})
|
||||
}
|
||||
|
||||
// ApplyConfig 应用配置(重新加载并重启相关服务)
|
||||
func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
// 先检查是否需要动态初始化知识库(在锁外执行,避免阻塞其他请求)
|
||||
|
||||
@@ -234,6 +234,18 @@ func (h *GroupHandler) GetGroupConversations(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, groupConvs)
|
||||
}
|
||||
|
||||
// GetAllMappings 批量获取所有分组映射(消除前端 N+1 请求)
|
||||
func (h *GroupHandler) GetAllMappings(c *gin.Context) {
|
||||
mappings, err := h.db.GetAllGroupMappings()
|
||||
if err != nil {
|
||||
h.logger.Error("获取分组映射失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, mappings)
|
||||
}
|
||||
|
||||
// UpdateConversationPinnedRequest 更新对话置顶状态请求
|
||||
type UpdateConversationPinnedRequest struct {
|
||||
Pinned bool `json:"pinned"`
|
||||
|
||||
@@ -246,6 +246,41 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
|
||||
}
|
||||
|
||||
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
|
||||
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
|
||||
var req struct {
|
||||
IDs []string `json:"ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
result := make(map[string]string, len(req.IDs))
|
||||
for _, id := range req.IDs {
|
||||
// 先从内部MCP服务器查找
|
||||
if exec, exists := h.mcpServer.GetExecution(id); exists {
|
||||
result[id] = exec.ToolName
|
||||
continue
|
||||
}
|
||||
// 再从外部MCP管理器查找
|
||||
if h.externalMCPMgr != nil {
|
||||
if exec, exists := h.externalMCPMgr.GetExecution(id); exists {
|
||||
result[id] = exec.ToolName
|
||||
continue
|
||||
}
|
||||
}
|
||||
// 最后从数据库查找
|
||||
if h.db != nil {
|
||||
if exec, err := h.db.GetToolExecution(id); err == nil && exec != nil {
|
||||
result[id] = exec.ToolName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// GetStats 获取统计信息
|
||||
func (h *MonitorHandler) GetStats(c *gin.Context) {
|
||||
stats := h.loadStats()
|
||||
|
||||
@@ -36,6 +36,16 @@ type RunResult struct {
|
||||
LastReActOutput string
|
||||
}
|
||||
|
||||
// toolCallPendingInfo tracks a tool_call emitted to the UI so we can later
|
||||
// correlate tool_result events (even when the framework omits ToolCallID) and
|
||||
// avoid leaving the UI stuck in "running" state on recoverable errors.
|
||||
type toolCallPendingInfo struct {
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
EinoAgent string
|
||||
EinoRole string
|
||||
}
|
||||
|
||||
// RunDeepAgent 使用 Eino DeepAgent 执行一轮对话(流式事件通过 progress 回调输出)。
|
||||
func RunDeepAgent(
|
||||
ctx context.Context,
|
||||
@@ -223,6 +233,9 @@ func RunDeepAgent(
|
||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||
Tools: subTools,
|
||||
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
||||
ToolCallMiddlewares: []compose.ToolMiddleware{
|
||||
{Invokable: softRecoveryToolCallMiddleware()},
|
||||
},
|
||||
},
|
||||
EmitInternalEvents: true,
|
||||
},
|
||||
@@ -278,6 +291,9 @@ func RunDeepAgent(
|
||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||
Tools: mainTools,
|
||||
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
||||
ToolCallMiddlewares: []compose.ToolMiddleware{
|
||||
{Invokable: softRecoveryToolCallMiddleware()},
|
||||
},
|
||||
},
|
||||
EmitInternalEvents: true,
|
||||
},
|
||||
@@ -326,6 +342,69 @@ attemptLoop:
|
||||
var einoMainRound int
|
||||
var einoLastAgent string
|
||||
subAgentToolStep := make(map[string]int)
|
||||
// Track tool calls emitted in this attempt so we can:
|
||||
// - attach toolCallId to tool_result when framework omits it
|
||||
// - flush running tool calls as failed when a recoverable tool execution error happens
|
||||
pendingByID := make(map[string]toolCallPendingInfo)
|
||||
pendingQueueByAgent := make(map[string][]string)
|
||||
markPending := func(tc toolCallPendingInfo) {
|
||||
if tc.ToolCallID == "" {
|
||||
return
|
||||
}
|
||||
pendingByID[tc.ToolCallID] = tc
|
||||
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
|
||||
}
|
||||
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
|
||||
q := pendingQueueByAgent[agentName]
|
||||
for len(q) > 0 {
|
||||
id := q[0]
|
||||
q = q[1:]
|
||||
pendingQueueByAgent[agentName] = q
|
||||
if tc, ok := pendingByID[id]; ok {
|
||||
delete(pendingByID, id)
|
||||
return tc, true
|
||||
}
|
||||
}
|
||||
return toolCallPendingInfo{}, false
|
||||
}
|
||||
removePendingByID := func(toolCallID string) {
|
||||
if toolCallID == "" {
|
||||
return
|
||||
}
|
||||
delete(pendingByID, toolCallID)
|
||||
// queue cleanup is lazy in popNextPendingForAgent
|
||||
}
|
||||
flushAllPendingAsFailed := func(err error) {
|
||||
if progress == nil {
|
||||
pendingByID = make(map[string]toolCallPendingInfo)
|
||||
pendingQueueByAgent = make(map[string][]string)
|
||||
return
|
||||
}
|
||||
msg := ""
|
||||
if err != nil {
|
||||
msg = err.Error()
|
||||
}
|
||||
for _, tc := range pendingByID {
|
||||
toolName := tc.ToolName
|
||||
if strings.TrimSpace(toolName) == "" {
|
||||
toolName = "unknown"
|
||||
}
|
||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), map[string]interface{}{
|
||||
"toolName": toolName,
|
||||
"success": false,
|
||||
"isError": true,
|
||||
"result": msg,
|
||||
"resultPreview": msg,
|
||||
"toolCallId": tc.ToolCallID,
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": tc.EinoAgent,
|
||||
"einoRole": tc.EinoRole,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
pendingByID = make(map[string]toolCallPendingInfo)
|
||||
pendingQueueByAgent = make(map[string][]string)
|
||||
}
|
||||
|
||||
runner := adk.NewRunner(ctx, adk.RunnerConfig{
|
||||
Agent: da,
|
||||
@@ -370,6 +449,9 @@ attemptLoop:
|
||||
logger.Warn("eino: recoverable tool execution error, will retry with corrective hint",
|
||||
zap.Error(ev.Err), zap.Int("attempt", attempt))
|
||||
}
|
||||
// Ensure UI/tool timeline doesn't get stuck at "running" for tool calls that
|
||||
// will never receive a proper tool_result due to the recoverable error.
|
||||
flushAllPendingAsFailed(ev.Err)
|
||||
retryHints = append(retryHints, toolExecutionRetryHint())
|
||||
if progress != nil {
|
||||
progress("eino_recovery", toolExecutionRecoveryTimelineMessage(attempt), map[string]interface{}{
|
||||
@@ -385,6 +467,7 @@ attemptLoop:
|
||||
}
|
||||
|
||||
// Non-recoverable error.
|
||||
flushAllPendingAsFailed(ev.Err)
|
||||
if progress != nil {
|
||||
progress("error", ev.Err.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
@@ -535,7 +618,7 @@ attemptLoop:
|
||||
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
||||
lastToolChunk = &schema.Message{ToolCalls: merged}
|
||||
}
|
||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep)
|
||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -543,7 +626,7 @@ attemptLoop:
|
||||
if gerr != nil || msg == nil {
|
||||
continue
|
||||
}
|
||||
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep)
|
||||
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
||||
|
||||
if mv.Role == schema.Assistant {
|
||||
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
||||
@@ -611,8 +694,31 @@ attemptLoop:
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"source": "eino",
|
||||
}
|
||||
if msg.ToolCallID != "" {
|
||||
data["toolCallId"] = msg.ToolCallID
|
||||
toolCallID := strings.TrimSpace(msg.ToolCallID)
|
||||
// Some framework paths (e.g. UnknownToolsHandler) may omit ToolCallID on tool messages.
|
||||
// Infer from the tool_call emission order for this agent to keep UI state consistent.
|
||||
if toolCallID == "" {
|
||||
// In some internal tool execution paths, ev.AgentName may be empty for tool-role
|
||||
// messages. Try several fallbacks to avoid leaving UI tool_call status stuck.
|
||||
if inferred, ok := popNextPendingForAgent(ev.AgentName); ok {
|
||||
toolCallID = inferred.ToolCallID
|
||||
} else if inferred, ok := popNextPendingForAgent(orchestratorName); ok {
|
||||
toolCallID = inferred.ToolCallID
|
||||
} else if inferred, ok := popNextPendingForAgent(""); ok {
|
||||
toolCallID = inferred.ToolCallID
|
||||
} else {
|
||||
// last resort: pick any pending toolCallID
|
||||
for id := range pendingByID {
|
||||
toolCallID = id
|
||||
delete(pendingByID, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
removePendingByID(toolCallID)
|
||||
}
|
||||
if toolCallID != "" {
|
||||
data["toolCallId"] = toolCallID
|
||||
}
|
||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
||||
}
|
||||
@@ -755,7 +861,14 @@ func toolCallsRichSignature(msg *schema.Message) string {
|
||||
return base + "|" + strings.Join(parts, ";")
|
||||
}
|
||||
|
||||
func tryEmitToolCallsOnce(msg *schema.Message, agentName, orchestratorName, conversationID string, progress func(string, string, interface{}), seen map[string]struct{}, subAgentToolStep map[string]int) {
|
||||
func tryEmitToolCallsOnce(
|
||||
msg *schema.Message,
|
||||
agentName, orchestratorName, conversationID string,
|
||||
progress func(string, string, interface{}),
|
||||
seen map[string]struct{},
|
||||
subAgentToolStep map[string]int,
|
||||
markPending func(toolCallPendingInfo),
|
||||
) {
|
||||
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil || seen == nil {
|
||||
return
|
||||
}
|
||||
@@ -767,10 +880,16 @@ func tryEmitToolCallsOnce(msg *schema.Message, agentName, orchestratorName, conv
|
||||
return
|
||||
}
|
||||
seen[sig] = struct{}{}
|
||||
emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, progress, subAgentToolStep)
|
||||
emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, progress, subAgentToolStep, markPending)
|
||||
}
|
||||
|
||||
func emitToolCallsFromMessage(msg *schema.Message, agentName, orchestratorName, conversationID string, progress func(string, string, interface{}), subAgentToolStep map[string]int) {
|
||||
func emitToolCallsFromMessage(
|
||||
msg *schema.Message,
|
||||
agentName, orchestratorName, conversationID string,
|
||||
progress func(string, string, interface{}),
|
||||
subAgentToolStep map[string]int,
|
||||
markPending func(toolCallPendingInfo),
|
||||
) {
|
||||
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil {
|
||||
return
|
||||
}
|
||||
@@ -819,6 +938,16 @@ func emitToolCallsFromMessage(msg *schema.Message, agentName, orchestratorName,
|
||||
if toolCallID == "" && tc.Index != nil {
|
||||
toolCallID = fmt.Sprintf("eino-stream-%d", *tc.Index)
|
||||
}
|
||||
// Record pending tool calls for later tool_result correlation / recovery flushing.
|
||||
// We intentionally record even for unknown tools to avoid "running" badge getting stuck.
|
||||
if markPending != nil && toolCallID != "" {
|
||||
markPending(toolCallPendingInfo{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: display,
|
||||
EinoAgent: agentName,
|
||||
EinoRole: role,
|
||||
})
|
||||
}
|
||||
progress("tool_call", fmt.Sprintf("正在调用工具: %s", display), map[string]interface{}{
|
||||
"toolName": display,
|
||||
"arguments": argStr,
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
// maxToolCallRecoveryAttempts 含首次运行:首次 + 自动重试次数。
|
||||
// 例如为 3 表示最多共 3 次完整 DeepAgent 运行(2 次失败后各追加一条纠错提示)。
|
||||
// 该常量同时用于 JSON 参数错误和工具执行错误(如子代理名称不存在)的恢复重试。
|
||||
const maxToolCallRecoveryAttempts = 3
|
||||
const maxToolCallRecoveryAttempts = 5
|
||||
|
||||
// toolCallArgumentsJSONRetryHint 追加在用户消息后,提示模型输出合法 JSON 工具参数(部分云厂商会在流式阶段校验 arguments)。
|
||||
func toolCallArgumentsJSONRetryHint() *schema.Message {
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
// softRecoveryToolCallMiddleware returns an InvokableToolMiddleware that catches
|
||||
// specific recoverable errors from tool execution (JSON parse errors, tool-not-found,
|
||||
// etc.) and converts them into soft errors: nil error + descriptive error content
|
||||
// returned to the LLM. This allows the model to self-correct within the same
|
||||
// iteration rather than crashing the entire graph and requiring a full replay.
|
||||
//
|
||||
// Without this middleware, a JSON parse failure in any tool's InvokableRun propagates
|
||||
// as a hard error through the Eino ToolsNode → [NodeRunError] → ev.Err, which
|
||||
// either triggers the full-replay retry loop (expensive) or terminates the run
|
||||
// entirely once retries are exhausted. With it, the LLM simply sees an error message
|
||||
// in the tool result and can adjust its next tool call accordingly.
|
||||
func softRecoveryToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
return func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint {
|
||||
return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
output, err := next(ctx, input)
|
||||
if err == nil {
|
||||
return output, nil
|
||||
}
|
||||
if !isSoftRecoverableToolError(err) {
|
||||
return output, err
|
||||
}
|
||||
// Convert the hard error into a soft error: the LLM will see this
|
||||
// message as the tool's output and can self-correct.
|
||||
msg := buildSoftRecoveryMessage(input.Name, input.Arguments, err)
|
||||
return &compose.ToolOutput{Result: msg}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isSoftRecoverableToolError determines whether a tool execution error should be
|
||||
// silently converted to a tool-result message rather than crashing the graph.
|
||||
func isSoftRecoverableToolError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s := strings.ToLower(err.Error())
|
||||
|
||||
// JSON unmarshal/parse failures — the model generated truncated or malformed arguments.
|
||||
if isJSONRelatedError(s) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Sub-agent type not found (from deep/task_tool.go)
|
||||
if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Tool not found in ToolsNode indexes
|
||||
if strings.Contains(s, "tool") && strings.Contains(s, "not found") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isJSONRelatedError checks whether an error string indicates a JSON parsing problem.
|
||||
func isJSONRelatedError(lower string) bool {
|
||||
if !strings.Contains(lower, "json") {
|
||||
return false
|
||||
}
|
||||
jsonIndicators := []string{
|
||||
"unexpected end of json",
|
||||
"unmarshal",
|
||||
"invalid character",
|
||||
"cannot unmarshal",
|
||||
"invalid tool arguments",
|
||||
"failed to unmarshal",
|
||||
"must be in json format",
|
||||
"unexpected eof",
|
||||
}
|
||||
for _, ind := range jsonIndicators {
|
||||
if strings.Contains(lower, ind) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// buildSoftRecoveryMessage creates a bilingual error message that the LLM can act on.
|
||||
func buildSoftRecoveryMessage(toolName, arguments string, err error) string {
|
||||
// Truncate arguments preview to avoid flooding the context.
|
||||
argPreview := arguments
|
||||
if len(argPreview) > 300 {
|
||||
argPreview = argPreview[:300] + "... (truncated)"
|
||||
}
|
||||
|
||||
// Try to determine if it's specifically a JSON parse error for a friendlier message.
|
||||
errStr := err.Error()
|
||||
var jsonErr *json.SyntaxError
|
||||
isJSONErr := strings.Contains(strings.ToLower(errStr), "json") ||
|
||||
strings.Contains(strings.ToLower(errStr), "unmarshal")
|
||||
_ = jsonErr // suppress unused
|
||||
|
||||
if isJSONErr {
|
||||
return fmt.Sprintf(
|
||||
"[Tool Error] The arguments for tool '%s' are not valid JSON and could not be parsed.\n"+
|
||||
"Error: %s\n"+
|
||||
"Arguments received: %s\n\n"+
|
||||
"Please fix the JSON (ensure double-quoted keys, matched braces/brackets, no trailing commas, "+
|
||||
"no truncation) and call the tool again.\n\n"+
|
||||
"[工具错误] 工具 '%s' 的参数不是合法 JSON,无法解析。\n"+
|
||||
"错误:%s\n"+
|
||||
"收到的参数:%s\n\n"+
|
||||
"请修正 JSON(确保双引号键名、括号配对、无尾部逗号、无截断),然后重新调用工具。",
|
||||
toolName, errStr, argPreview,
|
||||
toolName, errStr, argPreview,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"[Tool Error] Tool '%s' execution failed: %s\n"+
|
||||
"Arguments: %s\n\n"+
|
||||
"Please review the available tools and their expected arguments, then retry.\n\n"+
|
||||
"[工具错误] 工具 '%s' 执行失败:%s\n"+
|
||||
"参数:%s\n\n"+
|
||||
"请检查可用工具及其参数要求,然后重试。",
|
||||
toolName, errStr, argPreview,
|
||||
toolName, errStr, argPreview,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
func TestIsSoftRecoverableToolError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "unexpected end of JSON input",
|
||||
err: errors.New("unexpected end of JSON input"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "failed to unmarshal task tool input json",
|
||||
err: errors.New("failed to unmarshal task tool input json: unexpected end of JSON input"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tool arguments JSON",
|
||||
err: errors.New("invalid tool arguments JSON: unexpected end of JSON input"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "json invalid character",
|
||||
err: errors.New(`invalid character '}' looking for beginning of value in JSON`),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "subagent type not found",
|
||||
err: errors.New("subagent type recon_agent not found"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "tool not found",
|
||||
err: errors.New("tool nmap_scan not found in toolsNode indexes"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "unrelated network error",
|
||||
err: errors.New("connection refused"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "context cancelled",
|
||||
err: context.Canceled,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "real json unmarshal error",
|
||||
err: func() error {
|
||||
var v map[string]interface{}
|
||||
return json.Unmarshal([]byte(`{"key": `), &v)
|
||||
}(),
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isSoftRecoverableToolError(tt.err)
|
||||
if got != tt.expected {
|
||||
t.Errorf("isSoftRecoverableToolError(%v) = %v, want %v", tt.err, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftRecoveryToolCallMiddleware_PassesThrough(t *testing.T) {
|
||||
mw := softRecoveryToolCallMiddleware()
|
||||
called := false
|
||||
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
called = true
|
||||
return &compose.ToolOutput{Result: "success"}, nil
|
||||
}
|
||||
wrapped := mw(next)
|
||||
out, err := wrapped(context.Background(), &compose.ToolInput{
|
||||
Name: "test_tool",
|
||||
Arguments: `{"key": "value"}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("next endpoint was not called")
|
||||
}
|
||||
if out.Result != "success" {
|
||||
t.Fatalf("expected 'success', got %q", out.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftRecoveryToolCallMiddleware_ConvertsJSONError(t *testing.T) {
|
||||
mw := softRecoveryToolCallMiddleware()
|
||||
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
return nil, errors.New("failed to unmarshal task tool input json: unexpected end of JSON input")
|
||||
}
|
||||
wrapped := mw(next)
|
||||
out, err := wrapped(context.Background(), &compose.ToolInput{
|
||||
Name: "task",
|
||||
Arguments: `{"subagent_type": "recon`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error (soft recovery), got: %v", err)
|
||||
}
|
||||
if out == nil || out.Result == "" {
|
||||
t.Fatal("expected non-empty recovery message")
|
||||
}
|
||||
if !containsAll(out.Result, "[Tool Error]", "task", "JSON") {
|
||||
t.Fatalf("recovery message missing expected content: %s", out.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftRecoveryToolCallMiddleware_PropagatesNonRecoverable(t *testing.T) {
|
||||
mw := softRecoveryToolCallMiddleware()
|
||||
origErr := errors.New("connection timeout to remote server")
|
||||
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
return nil, origErr
|
||||
}
|
||||
wrapped := mw(next)
|
||||
_, err := wrapped(context.Background(), &compose.ToolInput{
|
||||
Name: "test_tool",
|
||||
Arguments: `{}`,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error to propagate for non-recoverable errors")
|
||||
}
|
||||
if err != origErr {
|
||||
t.Fatalf("expected original error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func containsAll(s string, subs ...string) bool {
|
||||
for _, sub := range subs {
|
||||
if !contains(s, sub) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func contains(s, sub string) bool {
|
||||
return len(s) >= len(sub) && searchString(s, sub)
|
||||
}
|
||||
|
||||
func searchString(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
BIN
Binary file not shown.
+141
-101
@@ -1,6 +1,7 @@
|
||||
package burp;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@@ -10,6 +11,7 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
|
||||
|
||||
private CyberStrikeAITab tab;
|
||||
private final CyberStrikeAIClient client = new CyberStrikeAIClient();
|
||||
private String lastInstruction = HttpMessageFormatter.defaultInstruction();
|
||||
|
||||
@Override
|
||||
public void registerExtenderCallbacks(IBurpExtenderCallbacks callbacks) {
|
||||
@@ -36,111 +38,149 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
|
||||
if (selected == null || selected.length == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
CyberStrikeAIClient.Config cfg = tab.currentConfig();
|
||||
String token = tab.getToken();
|
||||
if (token == null || token.trim().isEmpty()) {
|
||||
JOptionPane.showMessageDialog(tab.getUiComponent(),
|
||||
"Please click Validate first to obtain a token.",
|
||||
"CyberStrikeAI", JOptionPane.WARNING_MESSAGE);
|
||||
return;
|
||||
}
|
||||
|
||||
String prompt = HttpMessageFormatter.toPrompt(helpers, selected[0]);
|
||||
String title = HttpMessageFormatter.getRequestTitle(helpers, selected[0]);
|
||||
String agentModeStr = (cfg.agentMode == CyberStrikeAIClient.AgentMode.MULTI) ? "Multi Agent" : "Single Agent";
|
||||
String runId = tab.startNewRun(title, agentModeStr, selected[0]);
|
||||
tab.appendProgressToRun(runId, "\n[server] " + cfg.baseUrl + "\n\n");
|
||||
|
||||
client.streamTest(cfg, token, prompt, new CyberStrikeAIClient.StreamListener() {
|
||||
@Override
|
||||
public void onEvent(String type, String message, String rawJson) {
|
||||
if (type == null) type = "";
|
||||
switch (type) {
|
||||
case "response_delta":
|
||||
case "eino_agent_reply_stream_delta":
|
||||
// delta chunk (content only)
|
||||
tab.appendFinalToRun(runId, message);
|
||||
break;
|
||||
case "response":
|
||||
// final response (full)
|
||||
tab.appendFinalToRun(runId, "\n\n--- Final Response ---\n");
|
||||
tab.appendFinalToRun(runId, message);
|
||||
tab.setFinalResponse(runId, message);
|
||||
break;
|
||||
case "progress":
|
||||
tab.appendProgressToRun(runId, "\n[progress] " + message + "\n");
|
||||
tab.setRunStatus(runId, "running");
|
||||
break;
|
||||
case "cancelled":
|
||||
tab.appendProgressToRun(runId, "\n[cancelled] " + message + "\n");
|
||||
tab.setRunStatus(runId, "cancelled");
|
||||
break;
|
||||
case "error":
|
||||
tab.appendProgressToRun(runId, "\n[error] " + message + "\n");
|
||||
tab.setRunStatus(runId, "error");
|
||||
break;
|
||||
case "thinking_stream_start":
|
||||
if (tab.isShowDebugEvents()) {
|
||||
tab.resetThinkingStream(runId);
|
||||
}
|
||||
break;
|
||||
case "thinking_stream_delta":
|
||||
case "tool_call":
|
||||
case "tool_result":
|
||||
case "tool_result_delta":
|
||||
// debug; hide by default
|
||||
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
||||
if ("thinking_stream_delta".equals(type)) {
|
||||
tab.appendThinkingDelta(runId, message);
|
||||
} else {
|
||||
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
||||
}
|
||||
}
|
||||
break;
|
||||
case "conversation":
|
||||
// Capture conversationId for stop/cancel.
|
||||
if (rawJson != null) {
|
||||
String convId = SimpleJson.extractStringField(rawJson, "conversationId");
|
||||
if (convId != null && !convId.trim().isEmpty()) {
|
||||
tab.setRunConversationId(runId, convId);
|
||||
}
|
||||
}
|
||||
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
||||
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
||||
}
|
||||
break;
|
||||
case "done":
|
||||
// handled in onDone too
|
||||
break;
|
||||
default:
|
||||
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
||||
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(String message, Exception e) {
|
||||
tab.appendProgressToRun(runId, "\n[error] " + message + "\n");
|
||||
tab.setRunStatus(runId, "error");
|
||||
callbacks.printError("CyberStrikeAI stream error: " + message);
|
||||
if (e != null) {
|
||||
callbacks.printError(e.toString());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onDone() {
|
||||
tab.appendProgressToRun(runId, "\n\n[done]\n");
|
||||
tab.setRunStatus(runId, "done");
|
||||
}
|
||||
});
|
||||
sendMessage(selected[0]);
|
||||
});
|
||||
|
||||
items.add(sendItem);
|
||||
return items;
|
||||
}
|
||||
|
||||
private void sendMessage(IHttpRequestResponse msg) {
|
||||
if (msg == null) return;
|
||||
CyberStrikeAIClient.Config cfg = tab.currentConfig();
|
||||
String token = tab.getToken();
|
||||
if (token == null || token.trim().isEmpty()) {
|
||||
JOptionPane.showMessageDialog(tab.getUiComponent(),
|
||||
"Please click Validate first to obtain a token.",
|
||||
"CyberStrikeAI", JOptionPane.WARNING_MESSAGE);
|
||||
return;
|
||||
}
|
||||
|
||||
String instruction = showInstructionEditor(tab.getUiComponent(), lastInstruction);
|
||||
if (instruction == null) {
|
||||
return;
|
||||
}
|
||||
lastInstruction = instruction;
|
||||
|
||||
String prompt = HttpMessageFormatter.toPrompt(helpers, msg, instruction);
|
||||
String title = HttpMessageFormatter.getRequestTitle(helpers, msg);
|
||||
String agentModeStr = (cfg.agentMode == CyberStrikeAIClient.AgentMode.MULTI) ? "Multi Agent" : "Single Agent";
|
||||
String runId = tab.startNewRun(title, agentModeStr, msg);
|
||||
tab.appendProgressToRun(runId, "\n[server] " + cfg.baseUrl + "\n\n");
|
||||
|
||||
client.streamTest(cfg, token, prompt, new CyberStrikeAIClient.StreamListener() {
|
||||
@Override
|
||||
public void onEvent(String type, String message, String rawJson) {
|
||||
if (type == null) type = "";
|
||||
switch (type) {
|
||||
case "response_delta":
|
||||
case "eino_agent_reply_stream_delta":
|
||||
tab.appendFinalToRun(runId, message);
|
||||
break;
|
||||
case "response":
|
||||
tab.appendFinalToRun(runId, "\n\n--- Final Response ---\n");
|
||||
tab.appendFinalToRun(runId, message);
|
||||
tab.setFinalResponse(runId, message);
|
||||
break;
|
||||
case "progress":
|
||||
tab.appendProgressToRun(runId, "\n[progress] " + message + "\n");
|
||||
tab.setRunStatus(runId, "running");
|
||||
break;
|
||||
case "cancelled":
|
||||
tab.appendProgressToRun(runId, "\n[cancelled] " + message + "\n");
|
||||
tab.setRunStatus(runId, "cancelled");
|
||||
break;
|
||||
case "error":
|
||||
tab.appendProgressToRun(runId, "\n[error] " + message + "\n");
|
||||
tab.setRunStatus(runId, "error");
|
||||
break;
|
||||
case "thinking_stream_start":
|
||||
if (tab.isShowDebugEvents()) {
|
||||
tab.resetThinkingStream(runId);
|
||||
}
|
||||
break;
|
||||
case "thinking_stream_delta":
|
||||
case "tool_call":
|
||||
case "tool_result":
|
||||
case "tool_result_delta":
|
||||
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
||||
if ("thinking_stream_delta".equals(type)) {
|
||||
tab.appendThinkingDelta(runId, message);
|
||||
} else {
|
||||
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
||||
}
|
||||
}
|
||||
break;
|
||||
case "conversation":
|
||||
if (rawJson != null) {
|
||||
String convId = SimpleJson.extractStringField(rawJson, "conversationId");
|
||||
if (convId != null && !convId.trim().isEmpty()) {
|
||||
tab.setRunConversationId(runId, convId);
|
||||
}
|
||||
}
|
||||
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
||||
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
||||
}
|
||||
break;
|
||||
case "done":
|
||||
break;
|
||||
default:
|
||||
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
||||
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(String message, Exception e) {
|
||||
tab.appendProgressToRun(runId, "\n[error] " + message + "\n");
|
||||
tab.setRunStatus(runId, "error");
|
||||
callbacks.printError("CyberStrikeAI stream error: " + message);
|
||||
if (e != null) {
|
||||
callbacks.printError(e.toString());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onDone() {
|
||||
tab.appendProgressToRun(runId, "\n\n[done]\n");
|
||||
tab.setRunStatus(runId, "done");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private static String showInstructionEditor(Component parent, String initialValue) {
|
||||
JTextArea editor = new JTextArea(
|
||||
initialValue == null || initialValue.trim().isEmpty()
|
||||
? HttpMessageFormatter.defaultInstruction()
|
||||
: initialValue,
|
||||
6,
|
||||
70
|
||||
);
|
||||
editor.setLineWrap(true);
|
||||
editor.setWrapStyleWord(true);
|
||||
editor.setFont(new Font(Font.SANS_SERIF, Font.PLAIN, 13));
|
||||
|
||||
JPanel panel = new JPanel(new BorderLayout(0, 8));
|
||||
panel.add(new JLabel("Edit instruction before sending:"), BorderLayout.NORTH);
|
||||
panel.add(new JScrollPane(editor), BorderLayout.CENTER);
|
||||
|
||||
int result = JOptionPane.showConfirmDialog(
|
||||
parent,
|
||||
panel,
|
||||
"Customize Prompt Instruction",
|
||||
JOptionPane.OK_CANCEL_OPTION,
|
||||
JOptionPane.PLAIN_MESSAGE
|
||||
);
|
||||
if (result != JOptionPane.OK_OPTION) {
|
||||
return null;
|
||||
}
|
||||
String value = editor.getText();
|
||||
if (value == null || value.trim().isEmpty()) {
|
||||
return HttpMessageFormatter.defaultInstruction();
|
||||
}
|
||||
return value.trim();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+15
-1
@@ -5,6 +5,8 @@ import java.util.List;
|
||||
|
||||
final class HttpMessageFormatter {
|
||||
private HttpMessageFormatter() {}
|
||||
private static final String DEFAULT_INSTRUCTION =
|
||||
"针对该流量做web渗透测试,并输出测试结果,要求:只针对该接口流量做测试,切勿拓展其他接口";
|
||||
|
||||
static String getRequestTitle(IExtensionHelpers helpers, IHttpRequestResponse msg) {
|
||||
IRequestInfo reqInfo = helpers.analyzeRequest(msg);
|
||||
@@ -22,7 +24,15 @@ final class HttpMessageFormatter {
|
||||
return method + " " + host + shortPath + q;
|
||||
}
|
||||
|
||||
static String defaultInstruction() {
|
||||
return DEFAULT_INSTRUCTION;
|
||||
}
|
||||
|
||||
static String toPrompt(IExtensionHelpers helpers, IHttpRequestResponse msg) {
|
||||
return toPrompt(helpers, msg, DEFAULT_INSTRUCTION);
|
||||
}
|
||||
|
||||
static String toPrompt(IExtensionHelpers helpers, IHttpRequestResponse msg, String instruction) {
|
||||
IRequestInfo reqInfo = helpers.analyzeRequest(msg);
|
||||
String method = reqInfo.getMethod();
|
||||
String url = reqInfo.getUrl() != null ? reqInfo.getUrl().toString() : "(unknown)";
|
||||
@@ -53,8 +63,12 @@ final class HttpMessageFormatter {
|
||||
+ respBody;
|
||||
}
|
||||
|
||||
String prefix = (instruction == null || instruction.trim().isEmpty())
|
||||
? DEFAULT_INSTRUCTION
|
||||
: instruction.trim();
|
||||
|
||||
return ""
|
||||
+ "针对该流量做web渗透测试,并输出测试结果,要求:只针对该接口流量做测试,切勿拓展其他接口\n\n"
|
||||
+ prefix + "\n\n"
|
||||
+ "[Target]\n"
|
||||
+ method + " " + url + "\n\n"
|
||||
+ "[Request]\n"
|
||||
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,293 @@
|
||||
name: "quake_search"
|
||||
command: "python3"
|
||||
args:
|
||||
- "-c"
|
||||
- |
|
||||
import sys
|
||||
import json
|
||||
import requests
|
||||
import os
|
||||
|
||||
# ==================== Quake配置 ====================
|
||||
# 请在此处配置您的Quake API Token
|
||||
# 您也可以在环境变量中设置:QUAKE_API_KEY
|
||||
# enable 默认为 false,需开启才能调用该MCP
|
||||
QUAKE_API_KEY = "" # 请填写您的Quake API Token
|
||||
# ==================================================
|
||||
|
||||
# Quake API基础URL
|
||||
base_url = "https://quake.360.cn/api/v3/search/quake_service"
|
||||
|
||||
# 解析参数(从JSON字符串或命令行参数)
|
||||
def parse_args():
|
||||
# 尝试从第一个参数读取JSON配置
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
arg1 = str(sys.argv[1])
|
||||
config = json.loads(arg1)
|
||||
if isinstance(config, dict):
|
||||
return config
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# 传统位置参数方式(向后兼容)
|
||||
# 参数位置:query=1, size=2, start=3, fields=4, latest=5
|
||||
config = {}
|
||||
if len(sys.argv) > 1:
|
||||
config["query"] = str(sys.argv[1])
|
||||
if len(sys.argv) > 2:
|
||||
try:
|
||||
config["size"] = int(sys.argv[2])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if len(sys.argv) > 3:
|
||||
try:
|
||||
config["start"] = int(sys.argv[3])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if len(sys.argv) > 4:
|
||||
config["fields"] = str(sys.argv[4])
|
||||
if len(sys.argv) > 5:
|
||||
val = sys.argv[5]
|
||||
if isinstance(val, str):
|
||||
config["latest"] = val.lower() in ("true", "1", "yes")
|
||||
else:
|
||||
config["latest"] = bool(val)
|
||||
return config
|
||||
|
||||
# 标准化 fields 参数:支持字符串和数组
|
||||
def normalize_fields(fields_value):
|
||||
if fields_value is None:
|
||||
return None
|
||||
|
||||
if isinstance(fields_value, str):
|
||||
raw = fields_value.strip()
|
||||
if not raw:
|
||||
return None
|
||||
return [x.strip() for x in raw.split(",") if x.strip()]
|
||||
|
||||
if isinstance(fields_value, list):
|
||||
output = []
|
||||
for item in fields_value:
|
||||
text = str(item).strip()
|
||||
if text:
|
||||
output.append(text)
|
||||
return output or None
|
||||
|
||||
return None
|
||||
|
||||
try:
|
||||
config = parse_args()
|
||||
|
||||
if not isinstance(config, dict):
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": f"参数解析错误: 期望字典类型,但得到 {type(config).__name__}",
|
||||
"type": "TypeError"
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
api_key = os.getenv("QUAKE_API_KEY", QUAKE_API_KEY).strip()
|
||||
query = str(config.get("query", "")).strip()
|
||||
|
||||
if not api_key:
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": "缺少Quake配置: api_key(Quake API Token)",
|
||||
"required_config": ["api_key"],
|
||||
"note": "请在YAML文件的QUAKE_API_KEY配置项中填写Token,或在环境变量QUAKE_API_KEY中设置。Token可在Quake用户中心获取。"
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
if not query:
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": "缺少必需参数: query(搜索查询语句)",
|
||||
"required_params": ["query"],
|
||||
"examples": [
|
||||
'domain:"example.com"',
|
||||
'ip:"1.1.1.1"',
|
||||
'port:443',
|
||||
'service.name:"http"',
|
||||
'port:22 AND country_cn:"中国"'
|
||||
]
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"query": query
|
||||
}
|
||||
|
||||
# 可选参数 size(通常最大100)
|
||||
if "size" in config and config["size"] is not None:
|
||||
try:
|
||||
size = int(config["size"])
|
||||
if size > 0:
|
||||
data["size"] = size
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# 可选参数 start(分页偏移,默认0)
|
||||
if "start" in config and config["start"] is not None:
|
||||
try:
|
||||
start = int(config["start"])
|
||||
if start >= 0:
|
||||
data["start"] = start
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# fields 映射到 Quake 的 include 字段
|
||||
include_fields = normalize_fields(config.get("fields"))
|
||||
if include_fields:
|
||||
data["include"] = include_fields
|
||||
|
||||
# latest 参数,默认 true(取最新索引结果)
|
||||
latest_value = config.get("latest", True)
|
||||
if isinstance(latest_value, bool):
|
||||
data["latest"] = latest_value
|
||||
elif isinstance(latest_value, str):
|
||||
data["latest"] = latest_value.lower() in ("true", "1", "yes")
|
||||
elif isinstance(latest_value, (int, float)):
|
||||
data["latest"] = latest_value != 0
|
||||
else:
|
||||
data["latest"] = True
|
||||
|
||||
headers = {
|
||||
"X-QuakeToken": api_key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(base_url, json=data, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
result_data = response.json()
|
||||
|
||||
# Quake API code==0 表示成功
|
||||
if result_data.get("code") != 0:
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": f"Quake API错误: {result_data.get('message', '未知错误')}",
|
||||
"error_code": result_data.get("code", "unknown"),
|
||||
"suggestion": "请检查API Token、查询语法和账户积分是否正常"
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
results = result_data.get("data", [])
|
||||
meta = result_data.get("meta", {})
|
||||
pagination = meta.get("pagination", {}) if isinstance(meta, dict) else {}
|
||||
|
||||
output = {
|
||||
"status": "success",
|
||||
"query": query,
|
||||
"size": data.get("size", pagination.get("size", len(results))),
|
||||
"start": data.get("start", pagination.get("page_index", 0)),
|
||||
"total": result_data.get("total_count", pagination.get("total", 0)),
|
||||
"results_count": len(results),
|
||||
"fields": include_fields or "all",
|
||||
"results": results,
|
||||
"message": f"成功获取 {len(results)} 条结果"
|
||||
}
|
||||
|
||||
print(json.dumps(output, ensure_ascii=False, indent=2))
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": f"请求失败: {str(e)}",
|
||||
"suggestion": "请检查网络连通性或Quake API服务状态"
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": f"执行出错: {str(e)}",
|
||||
"type": type(e).__name__
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
enabled: false
|
||||
short_description: "Quake网络空间搜索接口,支持自定义query、size、fields"
|
||||
description: |
|
||||
Quake(360 网络空间测绘)资产搜索工具,调用 Quake API v3 实时检索互联网资产。
|
||||
|
||||
**主要功能:**
|
||||
- 支持 Quake DSL 查询语法(query)
|
||||
- 支持返回数量控制(size)
|
||||
- 支持字段裁剪(fields,对应 Quake include)
|
||||
- 支持分页偏移(start)
|
||||
|
||||
**鉴权方式:**
|
||||
- Header 使用 `X-QuakeToken`
|
||||
- 可在本文件中填写 `QUAKE_API_KEY`,或通过环境变量 `QUAKE_API_KEY` 注入
|
||||
|
||||
**常见查询示例:**
|
||||
- `domain:"example.com"`
|
||||
- `ip:"1.1.1.1"`
|
||||
- `port:443`
|
||||
- `service.name:"http" AND country_cn:"中国"`
|
||||
|
||||
**注意事项:**
|
||||
- API 调用会消耗积分,请按需控制 `size`
|
||||
- `fields` 会映射到请求体 `include` 字段,多个字段用英文逗号分隔
|
||||
- 如遇语法报错,请先在 Quake 控制台验证 DSL
|
||||
parameters:
|
||||
- name: "query"
|
||||
type: "string"
|
||||
description: |
|
||||
Quake DSL 查询语句(必需)。
|
||||
|
||||
**示例:**
|
||||
- `domain:"example.com"`
|
||||
- `ip:"1.1.1.1"`
|
||||
- `port:443`
|
||||
- `service.name:"http" AND country_cn:"中国"`
|
||||
required: true
|
||||
position: 1
|
||||
format: "positional"
|
||||
- name: "size"
|
||||
type: "int"
|
||||
description: |
|
||||
返回结果数量(可选)。
|
||||
|
||||
建议范围:1-100(具体受账户权限/接口限制影响)。
|
||||
required: false
|
||||
position: 2
|
||||
format: "positional"
|
||||
default: 10
|
||||
- name: "start"
|
||||
type: "int"
|
||||
description: |
|
||||
分页起始偏移(可选),从 0 开始。
|
||||
required: false
|
||||
position: 3
|
||||
format: "positional"
|
||||
default: 0
|
||||
- name: "fields"
|
||||
type: "string"
|
||||
description: |
|
||||
返回字段(可选),多个字段用英文逗号分隔。
|
||||
|
||||
该参数会映射到 Quake 请求体中的 `include` 字段。
|
||||
**示例:**
|
||||
- `ip,port`
|
||||
- `ip,port,service.name,service.http.title,location.country_cn`
|
||||
required: false
|
||||
position: 4
|
||||
format: "positional"
|
||||
default: "ip,port"
|
||||
- name: "latest"
|
||||
type: "bool"
|
||||
description: |
|
||||
是否优先返回最新索引结果(可选)。
|
||||
默认 `true`。
|
||||
required: false
|
||||
position: 5
|
||||
format: "positional"
|
||||
default: true
|
||||
@@ -0,0 +1,403 @@
|
||||
name: "shodan_search"
|
||||
command: "python3"
|
||||
args:
|
||||
- "-c"
|
||||
- |
|
||||
import sys
|
||||
import json
|
||||
import requests
|
||||
import os
|
||||
import math
|
||||
|
||||
# ==================== Shodan配置 ====================
|
||||
# 请在此处配置您的Shodan API Key
|
||||
# 您也可以在环境变量中设置:SHODAN_API_KEY
|
||||
# enable 默认为 false,需开启才能调用该MCP
|
||||
SHODAN_API_KEY = "" # 请替换为您自己的Shodan API Key
|
||||
# ==================================================
|
||||
|
||||
# Shodan API基础URL
|
||||
base_url = "https://api.shodan.io"
|
||||
|
||||
# 解析参数(从JSON字符串或命令行参数)
|
||||
def parse_args():
|
||||
# 尝试从第一个参数读取JSON配置
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
arg1 = str(sys.argv[1])
|
||||
config = json.loads(arg1)
|
||||
if isinstance(config, dict):
|
||||
return config
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# 传统位置参数方式(向后兼容)
|
||||
# 兼容两种序列:
|
||||
# 1) query,page,facets,minify,fields,count_only,size
|
||||
# 2) query,page,minify,fields,count_only,size (facets省略时执行器会压缩参数)
|
||||
config = {}
|
||||
if len(sys.argv) > 1:
|
||||
config["query"] = str(sys.argv[1])
|
||||
if len(sys.argv) > 2:
|
||||
try:
|
||||
config["page"] = int(sys.argv[2])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
def is_bool_like(val):
|
||||
if isinstance(val, bool):
|
||||
return True
|
||||
if not isinstance(val, str):
|
||||
return False
|
||||
return val.strip().lower() in ("true", "false", "1", "0", "yes", "no")
|
||||
|
||||
remaining = [str(x) for x in sys.argv[3:]]
|
||||
if remaining:
|
||||
# facets 省略时,第一个剩余参数通常是 minify(布尔)
|
||||
first_is_bool = is_bool_like(remaining[0])
|
||||
idx = 0
|
||||
if not first_is_bool:
|
||||
config["facets"] = remaining[idx]
|
||||
idx += 1
|
||||
|
||||
if idx < len(remaining):
|
||||
val = remaining[idx]
|
||||
config["minify"] = val.lower() in ("true", "1", "yes")
|
||||
idx += 1
|
||||
|
||||
if idx < len(remaining):
|
||||
config["fields"] = remaining[idx]
|
||||
idx += 1
|
||||
|
||||
if idx < len(remaining):
|
||||
val = remaining[idx]
|
||||
config["count_only"] = val.lower() in ("true", "1", "yes")
|
||||
idx += 1
|
||||
|
||||
if idx < len(remaining):
|
||||
try:
|
||||
config["size"] = int(remaining[idx])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return config
|
||||
|
||||
def normalize_bool(value, default_value):
|
||||
if value is None:
|
||||
return default_value
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ("true", "1", "yes")
|
||||
if isinstance(value, (int, float)):
|
||||
return value != 0
|
||||
return default_value
|
||||
|
||||
try:
|
||||
config = parse_args()
|
||||
|
||||
if not isinstance(config, dict):
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": f"参数解析错误: 期望字典类型,但得到 {type(config).__name__}",
|
||||
"type": "TypeError"
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
api_key = os.getenv("SHODAN_API_KEY", SHODAN_API_KEY).strip()
|
||||
query = str(config.get("query", "")).strip()
|
||||
|
||||
if not api_key:
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": "缺少Shodan配置: api_key(Shodan API密钥)",
|
||||
"required_config": ["api_key"],
|
||||
"note": "请在YAML文件的SHODAN_API_KEY配置项中填写您的API密钥,或在环境变量SHODAN_API_KEY中设置。API密钥可在Shodan账户页面查看: https://account.shodan.io/"
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
if not query:
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": "缺少必需参数: query(搜索查询语句)",
|
||||
"required_params": ["query"],
|
||||
"examples": [
|
||||
"product:nginx",
|
||||
"apache country:DE",
|
||||
"port:22",
|
||||
"ssl.cert.subject.cn:example.com",
|
||||
"org:\"Amazon\" port:443"
|
||||
]
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
count_only = normalize_bool(config.get("count_only"), False)
|
||||
minify = normalize_bool(config.get("minify"), True)
|
||||
requested_size = config.get("size", None)
|
||||
if requested_size is not None:
|
||||
try:
|
||||
requested_size = int(requested_size)
|
||||
if requested_size <= 0:
|
||||
requested_size = None
|
||||
else:
|
||||
# 防止单次请求过大导致额度和响应时间问题
|
||||
requested_size = min(requested_size, 1000)
|
||||
except (ValueError, TypeError):
|
||||
requested_size = None
|
||||
|
||||
# 根据 count_only 选择搜索端点
|
||||
endpoint = "/shodan/host/count" if count_only else "/shodan/host/search"
|
||||
url = f"{base_url}{endpoint}"
|
||||
|
||||
params = {
|
||||
"key": api_key,
|
||||
"query": query
|
||||
}
|
||||
|
||||
# 可选参数 facets(search 和 count 都支持)
|
||||
if "facets" in config and config["facets"]:
|
||||
facets_value = str(config["facets"]).strip()
|
||||
if facets_value:
|
||||
params["facets"] = facets_value
|
||||
|
||||
# search 接口的可选参数
|
||||
if not count_only:
|
||||
if "page" in config and config["page"] is not None:
|
||||
try:
|
||||
page = int(config["page"])
|
||||
if page > 0:
|
||||
params["page"] = page
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
minify_effective = minify
|
||||
|
||||
if "fields" in config and config["fields"]:
|
||||
fields_value = str(config["fields"]).strip()
|
||||
if fields_value:
|
||||
params["fields"] = fields_value
|
||||
# Shodan API约束:fields 与 minify=true 互斥
|
||||
minify_effective = False
|
||||
|
||||
params["minify"] = "true" if minify_effective else "false"
|
||||
|
||||
try:
|
||||
if count_only:
|
||||
response = requests.get(url, params=params, timeout=30)
|
||||
response.raise_for_status()
|
||||
result_data = response.json()
|
||||
|
||||
if isinstance(result_data, dict) and result_data.get("error"):
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": f"Shodan API错误: {result_data.get('error', '未知错误')}",
|
||||
"suggestion": "请检查API密钥、查询语法和账户查询额度"
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
output = {
|
||||
"status": "success",
|
||||
"mode": "count",
|
||||
"query": query,
|
||||
"total": result_data.get("total", 0),
|
||||
"facets": result_data.get("facets", {}),
|
||||
"size": requested_size,
|
||||
"note": "count模式仅返回统计,不返回明细结果",
|
||||
"message": "统计查询完成(未返回资产明细)"
|
||||
}
|
||||
else:
|
||||
start_page = int(params.get("page", 1))
|
||||
# Shodan search 每页固定最多100条
|
||||
# 如果未指定 size,则保持原始行为(单页)
|
||||
target_size = requested_size if requested_size else 100
|
||||
pages_needed = 1 if not requested_size else max(1, int(math.ceil(target_size / 100.0)))
|
||||
|
||||
all_matches = []
|
||||
last_result_data = {}
|
||||
current_page = start_page
|
||||
pages_fetched = 0
|
||||
|
||||
for _ in range(pages_needed):
|
||||
page_params = dict(params)
|
||||
page_params["page"] = current_page
|
||||
|
||||
response = requests.get(url, params=page_params, timeout=30)
|
||||
response.raise_for_status()
|
||||
result_data = response.json()
|
||||
last_result_data = result_data if isinstance(result_data, dict) else {}
|
||||
pages_fetched += 1
|
||||
|
||||
if isinstance(last_result_data, dict) and last_result_data.get("error"):
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": f"Shodan API错误: {last_result_data.get('error', '未知错误')}",
|
||||
"suggestion": "请检查API密钥、查询语法和账户查询额度"
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
page_matches = last_result_data.get("matches", []) if isinstance(last_result_data, dict) else []
|
||||
if not page_matches:
|
||||
break
|
||||
|
||||
all_matches.extend(page_matches)
|
||||
if len(all_matches) >= target_size:
|
||||
break
|
||||
current_page += 1
|
||||
|
||||
matches = all_matches[:target_size]
|
||||
output = {
|
||||
"status": "success",
|
||||
"mode": "search",
|
||||
"query": query,
|
||||
"page": start_page,
|
||||
"size": target_size,
|
||||
"pages_fetched": pages_fetched,
|
||||
"total": last_result_data.get("total", 0),
|
||||
"results_count": len(matches),
|
||||
"facets": last_result_data.get("facets", {}),
|
||||
"results": matches,
|
||||
"message": f"成功获取 {len(matches)} 条结果"
|
||||
}
|
||||
|
||||
print(json.dumps(output, ensure_ascii=False, indent=2))
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
response_body = ""
|
||||
status_code = None
|
||||
if hasattr(e, "response") and e.response is not None:
|
||||
status_code = e.response.status_code
|
||||
try:
|
||||
response_body = e.response.text[:500]
|
||||
except Exception:
|
||||
response_body = ""
|
||||
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": f"请求失败: {str(e)}",
|
||||
"status_code": status_code,
|
||||
"response": response_body,
|
||||
"suggestion": "请检查网络连接、Shodan API状态、API密钥与查询额度"
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"message": f"执行出错: {str(e)}",
|
||||
"type": type(e).__name__
|
||||
}
|
||||
print(json.dumps(error_result, ensure_ascii=False, indent=2))
|
||||
sys.exit(1)
|
||||
enabled: false
|
||||
short_description: "Shodan网络空间搜索,支持search与count模式"
|
||||
description: |
|
||||
Shodan 资产搜索工具,基于官方 Developer API 实现,支持快速检索和统计分析。
|
||||
|
||||
**主要功能:**
|
||||
- 使用 `/shodan/host/search` 进行资产搜索
|
||||
- 使用 `/shodan/host/count` 进行无明细统计(节省查询信用)
|
||||
- 支持按 `size` 控制返回条数(自动翻页聚合)
|
||||
- 支持分页(page)
|
||||
- 支持分面统计(facets)
|
||||
- 支持结果字段裁剪(fields)
|
||||
- 支持 `minify` 控制返回数据体积
|
||||
|
||||
**鉴权方式:**
|
||||
- Query 参数使用 `key`
|
||||
- 可在本文件中填写 `SHODAN_API_KEY`,或通过环境变量 `SHODAN_API_KEY` 注入
|
||||
|
||||
**查询语法示例:**
|
||||
- `product:nginx`
|
||||
- `apache country:DE`
|
||||
- `port:22`
|
||||
- `org:"Amazon" port:443`
|
||||
- `ssl.cert.subject.cn:example.com`
|
||||
|
||||
**注意事项:**
|
||||
- 带过滤器的查询通常会消耗 query credits
|
||||
- 翻页(超过第1页)会额外消耗额度
|
||||
- `size` 大于 100 时会自动请求更多页(每页最多 100)
|
||||
- `size` 最大限制为 1000(防止过量请求)
|
||||
- `count_only=true` 使用统计接口,不返回 matches 明细
|
||||
parameters:
|
||||
- name: "query"
|
||||
type: "string"
|
||||
description: |
|
||||
Shodan 搜索语句(必需)。
|
||||
|
||||
支持 Shodan filter 语法(`filter:value`)与关键字组合。
|
||||
示例:
|
||||
- `product:nginx`
|
||||
- `apache country:DE`
|
||||
- `port:22`
|
||||
- `org:"Amazon" port:443`
|
||||
required: true
|
||||
position: 1
|
||||
format: "positional"
|
||||
- name: "page"
|
||||
type: "int"
|
||||
description: |
|
||||
页码(可选,仅 search 模式生效),从 1 开始,默认 1。
|
||||
required: false
|
||||
position: 2
|
||||
format: "positional"
|
||||
default: 1
|
||||
- name: "facets"
|
||||
type: "string"
|
||||
description: |
|
||||
分面统计字段(可选)。
|
||||
|
||||
多个字段用英文逗号分隔,也可指定数量:
|
||||
- `org,os`
|
||||
- `country:20,org:10`
|
||||
required: false
|
||||
position: 3
|
||||
format: "positional"
|
||||
- name: "minify"
|
||||
type: "bool"
|
||||
description: |
|
||||
是否精简返回字段(可选,仅 search 模式生效)。
|
||||
默认 `true`。
|
||||
required: false
|
||||
position: 4
|
||||
format: "positional"
|
||||
default: true
|
||||
- name: "fields"
|
||||
type: "string"
|
||||
description: |
|
||||
指定返回字段(可选,仅 search 模式生效)。
|
||||
|
||||
多个字段用英文逗号分隔,例如:
|
||||
- `ip_str,port,org,hostnames,http.title`
|
||||
- `tags,http.title,http.favicon.hash`
|
||||
required: false
|
||||
position: 5
|
||||
format: "positional"
|
||||
- name: "count_only"
|
||||
type: "bool"
|
||||
description: |
|
||||
是否仅统计总数(可选)。
|
||||
|
||||
- `false`(默认):调用 `/shodan/host/search` 返回明细
|
||||
- `true`:调用 `/shodan/host/count` 仅返回 total 和 facets
|
||||
required: false
|
||||
position: 6
|
||||
format: "positional"
|
||||
default: false
|
||||
- name: "size"
|
||||
type: "int"
|
||||
description: |
|
||||
返回结果数量(可选,仅 search 模式生效)。
|
||||
|
||||
- 支持 `10 / 20 / 100 / n`
|
||||
- Shodan 单页最多 100,超过 100 时会自动翻页拼接
|
||||
- 为避免额度和时延问题,最大值限制为 1000
|
||||
- 未传时默认返回单页结果(最多 100 条)
|
||||
required: false
|
||||
position: 7
|
||||
format: "positional"
|
||||
@@ -1302,7 +1302,13 @@
|
||||
"maxRetriesHint": "Retries on rate limit or server error",
|
||||
"retryDelay": "Retry delay (ms)",
|
||||
"retryDelayPlaceholder": "1000",
|
||||
"retryDelayHint": "Delay between retries (ms)"
|
||||
"retryDelayHint": "Delay between retries (ms)",
|
||||
"testConnection": "Test Connection",
|
||||
"testFillRequired": "Please fill in API Key and Model first",
|
||||
"testing": "Testing connection...",
|
||||
"testSuccess": "Connection successful",
|
||||
"testFailed": "Connection failed",
|
||||
"testError": "Test error"
|
||||
},
|
||||
"settingsTerminal": {
|
||||
"title": "Terminal",
|
||||
|
||||
@@ -1302,7 +1302,13 @@
|
||||
"maxRetriesHint": "最大重试次数(默认 3),遇到速率限制或服务器错误时自动重试",
|
||||
"retryDelay": "重试间隔(毫秒)",
|
||||
"retryDelayPlaceholder": "1000",
|
||||
"retryDelayHint": "重试间隔毫秒数(默认 1000),每次重试会递增延迟"
|
||||
"retryDelayHint": "重试间隔毫秒数(默认 1000),每次重试会递增延迟",
|
||||
"testConnection": "测试连接",
|
||||
"testFillRequired": "请先填写 API Key 和模型",
|
||||
"testing": "测试中...",
|
||||
"testSuccess": "连接成功",
|
||||
"testFailed": "连接失败",
|
||||
"testError": "测试出错"
|
||||
},
|
||||
"settingsTerminal": {
|
||||
"title": "终端",
|
||||
|
||||
+102
-78
@@ -1494,11 +1494,14 @@ function addMessage(role, content, mcpExecutionIds = null, progressId = null, cr
|
||||
mcpExecutionIds.forEach((execId, index) => {
|
||||
const detailBtn = document.createElement('button');
|
||||
detailBtn.className = 'mcp-detail-btn';
|
||||
detailBtn.dataset.execId = execId;
|
||||
detailBtn.dataset.execIndex = String(index + 1);
|
||||
detailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.callNumber', { n: index + 1 }) : '调用 #' + (index + 1)) + '</span>';
|
||||
detailBtn.onclick = () => showMCPDetail(execId);
|
||||
buttonsContainer.appendChild(detailBtn);
|
||||
updateButtonWithToolName(detailBtn, execId, index + 1);
|
||||
});
|
||||
// 使用批量 API 一次性获取所有工具名称(消除 N 次单独请求)
|
||||
batchUpdateButtonToolNames(buttonsContainer, mcpExecutionIds);
|
||||
|
||||
mcpSection.appendChild(buttonsContainer);
|
||||
contentWrapper.appendChild(mcpSection);
|
||||
@@ -1861,6 +1864,34 @@ async function updateButtonWithToolName(button, executionId, index) {
|
||||
}
|
||||
}
|
||||
|
||||
// 批量获取工具名称并更新按钮(消除 N 次单独 API 请求,合并为 1 次)
|
||||
async function batchUpdateButtonToolNames(buttonsContainer, executionIds) {
|
||||
if (!executionIds || executionIds.length === 0) return;
|
||||
try {
|
||||
const response = await apiFetch('/api/monitor/executions/names', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ ids: executionIds }),
|
||||
});
|
||||
if (!response.ok) return;
|
||||
const nameMap = await response.json(); // { execId: toolName }
|
||||
// 更新对应按钮的文本
|
||||
const buttons = buttonsContainer.querySelectorAll('.mcp-detail-btn[data-exec-id]');
|
||||
buttons.forEach(btn => {
|
||||
const execId = btn.dataset.execId;
|
||||
const index = btn.dataset.execIndex;
|
||||
const toolName = nameMap[execId];
|
||||
if (toolName) {
|
||||
const displayToolName = toolName.includes('::') ? toolName.split('::')[1] : toolName;
|
||||
const span = btn.querySelector('span');
|
||||
if (span) span.textContent = `${displayToolName} #${index}`;
|
||||
}
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('批量获取工具名称失败:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// 显示MCP调用详情
|
||||
async function showMCPDetail(executionId) {
|
||||
try {
|
||||
@@ -2380,15 +2411,14 @@ async function loadConversation(conversationId) {
|
||||
}
|
||||
|
||||
// 获取当前对话所属的分组ID(用于高亮显示)
|
||||
// 确保分组映射已加载
|
||||
// 确保分组映射已加载(使用缓存避免重复请求)
|
||||
if (Object.keys(conversationGroupMappingCache).length === 0) {
|
||||
await loadConversationGroupMapping();
|
||||
}
|
||||
currentConversationGroupId = conversationGroupMappingCache[conversationId] || null;
|
||||
|
||||
// 无论是否在分组详情页面,都刷新分组列表,确保高亮状态正确
|
||||
// 这样可以清除之前分组的高亮状态,确保UI状态一致
|
||||
await loadGroups();
|
||||
|
||||
// 异步刷新分组列表高亮状态(不阻塞消息渲染)
|
||||
loadGroups();
|
||||
|
||||
// 更新当前对话ID
|
||||
currentConversationId = conversationId;
|
||||
@@ -2430,13 +2460,15 @@ async function loadConversation(conversationId) {
|
||||
}
|
||||
}
|
||||
|
||||
// 加载消息
|
||||
// 加载消息 — 分批渲染避免长时间阻塞主线程
|
||||
if (conversation.messages && conversation.messages.length > 0) {
|
||||
conversation.messages.forEach(msg => {
|
||||
// 检查消息内容是否为"处理中...",如果是,检查processDetails中是否有错误或取消事件
|
||||
const FIRST_BATCH = 20; // 首批同步渲染(用户可见区域)
|
||||
const BATCH_SIZE = 10; // 后续每批条数
|
||||
|
||||
// 渲染单条消息的辅助函数
|
||||
const renderOneMessage = (msg) => {
|
||||
let displayContent = msg.content;
|
||||
if (msg.role === 'assistant' && msg.content === '处理中...' && msg.processDetails && msg.processDetails.length > 0) {
|
||||
// 查找最后一个error或cancelled事件
|
||||
for (let i = msg.processDetails.length - 1; i >= 0; i--) {
|
||||
const detail = msg.processDetails[i];
|
||||
if (detail.eventType === 'error' || detail.eventType === 'cancelled') {
|
||||
@@ -2445,47 +2477,63 @@ async function loadConversation(conversationId) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 传递消息的创建时间
|
||||
|
||||
const messageId = addMessage(msg.role, displayContent, msg.mcpExecutionIds || [], null, msg.createdAt);
|
||||
// 绑定后端 messageId,供按需加载过程详情使用
|
||||
const messageEl = document.getElementById(messageId);
|
||||
if (messageEl && msg && msg.id) {
|
||||
messageEl.dataset.backendMessageId = String(msg.id);
|
||||
attachDeleteTurnButton(messageEl);
|
||||
}
|
||||
// 对于助手消息,总是渲染过程详情(即使没有processDetails也要显示展开详情按钮)
|
||||
if (msg.role === 'assistant') {
|
||||
// 延迟一下,确保消息已经渲染
|
||||
setTimeout(() => {
|
||||
// 如果后端未返回 processDetails 字段,传 null 表示“尚未加载,点击展开时再请求”
|
||||
const hasField = msg && Object.prototype.hasOwnProperty.call(msg, 'processDetails');
|
||||
renderProcessDetails(messageId, hasField ? (msg.processDetails || []) : null);
|
||||
// 如果有过程详情,检查是否有错误或取消事件,如果有,确保详情默认折叠
|
||||
if (msg.processDetails && msg.processDetails.length > 0) {
|
||||
const hasErrorOrCancelled = msg.processDetails.some(d =>
|
||||
d.eventType === 'error' || d.eventType === 'cancelled'
|
||||
);
|
||||
if (hasErrorOrCancelled) {
|
||||
collapseAllProgressDetails(messageId, null);
|
||||
}
|
||||
const hasField = msg && Object.prototype.hasOwnProperty.call(msg, 'processDetails');
|
||||
renderProcessDetails(messageId, hasField ? (msg.processDetails || []) : null);
|
||||
if (msg.processDetails && msg.processDetails.length > 0) {
|
||||
const hasErrorOrCancelled = msg.processDetails.some(d =>
|
||||
d.eventType === 'error' || d.eventType === 'cancelled'
|
||||
);
|
||||
if (hasErrorOrCancelled) {
|
||||
collapseAllProgressDetails(messageId, null);
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const msgs = conversation.messages;
|
||||
const firstBatch = msgs.slice(0, FIRST_BATCH);
|
||||
const rest = msgs.slice(FIRST_BATCH);
|
||||
|
||||
// 首批同步渲染
|
||||
firstBatch.forEach(renderOneMessage);
|
||||
|
||||
// 剩余消息通过 requestAnimationFrame 分批渲染,避免阻塞 UI
|
||||
if (rest.length > 0) {
|
||||
const savedConvId = conversationId;
|
||||
let offset = 0;
|
||||
const renderNextBatch = () => {
|
||||
// 如果用户已经切换到其他对话,停止渲染
|
||||
if (currentConversationId !== savedConvId) return;
|
||||
const batch = rest.slice(offset, offset + BATCH_SIZE);
|
||||
batch.forEach(renderOneMessage);
|
||||
offset += BATCH_SIZE;
|
||||
if (offset < rest.length) {
|
||||
requestAnimationFrame(renderNextBatch);
|
||||
} else {
|
||||
// 所有消息渲染完毕,滚动到底部
|
||||
messagesDiv.scrollTop = messagesDiv.scrollHeight;
|
||||
}
|
||||
};
|
||||
requestAnimationFrame(renderNextBatch);
|
||||
}
|
||||
} else {
|
||||
const readyMsgEmpty = typeof window.t === 'function' ? window.t('chat.systemReadyMessage') : '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。';
|
||||
addMessage('assistant', readyMsgEmpty, null, null, null, { systemReadyMessage: true });
|
||||
}
|
||||
|
||||
// 滚动到底部
|
||||
|
||||
// 滚动到底部(首批渲染后立即滚动,剩余批次渲染后会再次滚动)
|
||||
messagesDiv.scrollTop = messagesDiv.scrollHeight;
|
||||
|
||||
|
||||
// 添加攻击链按钮
|
||||
addAttackChainButton(conversationId);
|
||||
|
||||
// 刷新对话列表
|
||||
loadConversations();
|
||||
} catch (error) {
|
||||
console.error('加载对话失败:', error);
|
||||
alert('加载对话失败: ' + error.message);
|
||||
@@ -4421,20 +4469,17 @@ async function loadGroups() {
|
||||
async function loadConversationsWithGroups(searchQuery = '') {
|
||||
const loadSeq = ++conversationsListLoadSeq;
|
||||
try {
|
||||
// 总是重新加载分组列表和分组映射,确保缓存是最新的
|
||||
// 这样可以正确处理分组被删除后的情况
|
||||
await loadGroups();
|
||||
if (loadSeq !== conversationsListLoadSeq) return;
|
||||
await loadConversationGroupMapping();
|
||||
if (loadSeq !== conversationsListLoadSeq) return;
|
||||
|
||||
// 如果有搜索关键词,使用更大的limit以获取所有匹配结果
|
||||
const limit = (searchQuery && searchQuery.trim()) ? 1000 : 100;
|
||||
// 并行加载分组列表、分组映射和对话列表(消除串行等待)
|
||||
const limit = (searchQuery && searchQuery.trim()) ? 100 : 100;
|
||||
let url = `/api/conversations?limit=${limit}`;
|
||||
if (searchQuery && searchQuery.trim()) {
|
||||
url += '&search=' + encodeURIComponent(searchQuery.trim());
|
||||
}
|
||||
const response = await apiFetch(url);
|
||||
const [,, response] = await Promise.all([
|
||||
loadGroups(),
|
||||
loadConversationGroupMapping(),
|
||||
apiFetch(url),
|
||||
]);
|
||||
if (loadSeq !== conversationsListLoadSeq) return;
|
||||
|
||||
const listContainer = document.getElementById('conversations-list');
|
||||
@@ -5432,48 +5477,27 @@ async function removeConversationFromGroup(convId, groupId) {
|
||||
// 加载对话分组映射
|
||||
async function loadConversationGroupMapping() {
|
||||
try {
|
||||
// 获取所有分组,然后获取每个分组的对话
|
||||
let groups;
|
||||
if (Array.isArray(groupsCache) && groupsCache.length > 0) {
|
||||
groups = groupsCache;
|
||||
} else {
|
||||
const response = await apiFetch('/api/groups');
|
||||
if (!response.ok) {
|
||||
// 如果API请求失败,使用空数组,不打印警告(这是正常错误处理)
|
||||
groups = [];
|
||||
} else {
|
||||
groups = await response.json();
|
||||
// 确保groups是有效数组,只在真正异常时才打印警告
|
||||
if (!Array.isArray(groups)) {
|
||||
// 只在返回的不是数组且不是null/undefined时才打印警告(可能是后端返回了错误格式)
|
||||
if (groups !== null && groups !== undefined) {
|
||||
console.warn('loadConversationGroupMapping: groups不是有效数组,使用空数组', groups);
|
||||
}
|
||||
groups = [];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 使用批量 API 一次性获取所有映射(消除 N+1 串行请求)
|
||||
const response = await apiFetch('/api/groups/mappings');
|
||||
|
||||
// 保存待保留的映射
|
||||
const preservedMappings = { ...pendingGroupMappings };
|
||||
|
||||
|
||||
conversationGroupMappingCache = {};
|
||||
|
||||
for (const group of groups) {
|
||||
const response = await apiFetch(`/api/groups/${group.id}/conversations`);
|
||||
const conversations = await response.json();
|
||||
// 确保conversations是有效数组
|
||||
if (Array.isArray(conversations)) {
|
||||
conversations.forEach(conv => {
|
||||
conversationGroupMappingCache[conv.id] = group.id;
|
||||
if (response.ok) {
|
||||
const mappings = await response.json();
|
||||
if (Array.isArray(mappings)) {
|
||||
mappings.forEach(m => {
|
||||
conversationGroupMappingCache[m.conversationId] = m.groupId;
|
||||
// 如果这个对话在待保留映射中,从待保留映射中移除(因为已经从后端加载了)
|
||||
if (preservedMappings[conv.id] === group.id) {
|
||||
delete pendingGroupMappings[conv.id];
|
||||
if (preservedMappings[m.conversationId] === m.groupId) {
|
||||
delete pendingGroupMappings[m.conversationId];
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 恢复待保留的映射(这些是后端API尚未同步的映射)
|
||||
Object.assign(conversationGroupMappingCache, preservedMappings);
|
||||
} catch (error) {
|
||||
|
||||
@@ -74,6 +74,17 @@ if (typeof window !== 'undefined') {
|
||||
// 存储工具调用ID到DOM元素的映射,用于更新执行状态
|
||||
const toolCallStatusMap = new Map();
|
||||
|
||||
function finalizeOutstandingToolCallsForProgress(progressId, finalStatus) {
|
||||
if (!progressId) return;
|
||||
const pid = String(progressId);
|
||||
for (const [toolCallId, mapping] of Array.from(toolCallStatusMap.entries())) {
|
||||
if (!mapping) continue;
|
||||
if (mapping.progressId != null && String(mapping.progressId) !== pid) continue;
|
||||
updateToolCallStatus(toolCallId, finalStatus);
|
||||
toolCallStatusMap.delete(toolCallId);
|
||||
}
|
||||
}
|
||||
|
||||
// 模型流式输出缓存:progressId -> { assistantId, buffer }
|
||||
const responseStreamStateByProgressId = new Map();
|
||||
|
||||
@@ -388,6 +399,11 @@ function integrateProgressToMCPSection(progressId, assistantMessageId, mcpExecut
|
||||
const progressElement = document.getElementById(progressId);
|
||||
if (!progressElement) return;
|
||||
|
||||
// Ensure any "running" tool_call badges are closed before we snapshot timeline HTML.
|
||||
// Otherwise, once the progress element is removed, later 'done' events may not be able
|
||||
// to update the original timeline DOM and the copied HTML would stay "执行中".
|
||||
finalizeOutstandingToolCallsForProgress(progressId, 'failed');
|
||||
|
||||
const mcpIds = Array.isArray(mcpExecutionIds) ? mcpExecutionIds : [];
|
||||
|
||||
// 获取时间线内容
|
||||
@@ -444,13 +460,16 @@ function integrateProgressToMCPSection(progressId, assistantMessageId, mcpExecut
|
||||
mcpIds.forEach((execId, index) => {
|
||||
const detailBtn = document.createElement('button');
|
||||
detailBtn.className = 'mcp-detail-btn';
|
||||
detailBtn.dataset.execId = execId;
|
||||
detailBtn.dataset.execIndex = String(index + 1);
|
||||
detailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.callNumber', { n: index + 1 }) : '调用 #' + (index + 1)) + '</span>';
|
||||
detailBtn.onclick = () => showMCPDetail(execId);
|
||||
buttonsContainer.appendChild(detailBtn);
|
||||
if (typeof updateButtonWithToolName === 'function') {
|
||||
updateButtonWithToolName(detailBtn, execId, index + 1);
|
||||
}
|
||||
});
|
||||
// 使用批量 API 一次性获取所有工具名称(消除 N 次单独请求)
|
||||
if (typeof batchUpdateButtonToolNames === 'function') {
|
||||
batchUpdateButtonToolNames(buttonsContainer, mcpIds);
|
||||
}
|
||||
}
|
||||
if (!buttonsContainer.querySelector('.process-detail-btn')) {
|
||||
const progressDetailBtn = document.createElement('button');
|
||||
@@ -937,6 +956,9 @@ function handleStreamEvent(event, progressElement, progressId,
|
||||
message: event.message || '',
|
||||
data: event.data
|
||||
});
|
||||
// If the backend triggers a recovery run, any "running" tool_call items in this progress
|
||||
// should be closed to avoid being stuck forever.
|
||||
finalizeOutstandingToolCallsForProgress(progressId, 'failed');
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -958,7 +980,8 @@ function handleStreamEvent(event, progressElement, progressId,
|
||||
if (toolCallId && toolCallItemId) {
|
||||
toolCallStatusMap.set(toolCallId, {
|
||||
itemId: toolCallItemId,
|
||||
timeline: timeline
|
||||
timeline: timeline,
|
||||
progressId: progressId
|
||||
});
|
||||
|
||||
// 添加执行中状态指示器
|
||||
@@ -1224,6 +1247,8 @@ function handleStreamEvent(event, progressElement, progressId,
|
||||
|
||||
// 立即刷新任务状态
|
||||
loadActiveTasks();
|
||||
// Close any remaining running tool calls for this progress.
|
||||
finalizeOutstandingToolCallsForProgress(progressId, 'failed');
|
||||
break;
|
||||
|
||||
case 'response_start': {
|
||||
@@ -1337,9 +1362,23 @@ function handleStreamEvent(event, progressElement, progressId,
|
||||
updateAssistantBubbleContent(assistantIdFinal, event.message, true);
|
||||
}
|
||||
|
||||
// 移除 response_start/response_delta 阶段创建的「规划中」占位条目。
|
||||
// 该条目属于 UI-only 的流式展示,不应被拷贝到最终的过程详情里;
|
||||
// 否则会出现“不刷新页面仍显示规划中,刷新后消失”的不一致。
|
||||
if (streamState && streamState.itemId) {
|
||||
const planningItem = document.getElementById(streamState.itemId);
|
||||
if (planningItem && planningItem.parentNode) {
|
||||
planningItem.parentNode.removeChild(planningItem);
|
||||
}
|
||||
}
|
||||
|
||||
// 最终回复时隐藏进度卡片(多代理模式下,迭代过程已完整展示)
|
||||
hideProgressMessageForFinalReply(progressId);
|
||||
|
||||
// Before integrating/removing the progress DOM, close any outstanding running tool calls
|
||||
// so the copied timeline HTML reflects the final status.
|
||||
finalizeOutstandingToolCallsForProgress(progressId, 'failed');
|
||||
|
||||
// 将进度详情集成到工具调用区域(放在最终 response 之后,保证时间线已完整)
|
||||
integrateProgressToMCPSection(progressId, assistantIdFinal, mcpIds);
|
||||
responseStreamStateByProgressId.delete(progressId);
|
||||
@@ -1403,6 +1442,8 @@ function handleStreamEvent(event, progressElement, progressId,
|
||||
|
||||
// 立即刷新任务状态(执行失败时任务状态会更新)
|
||||
loadActiveTasks();
|
||||
// Close any remaining running tool calls for this progress.
|
||||
finalizeOutstandingToolCallsForProgress(progressId, 'failed');
|
||||
break;
|
||||
|
||||
case 'done':
|
||||
@@ -1438,6 +1479,8 @@ function handleStreamEvent(event, progressElement, progressId,
|
||||
|
||||
// 立即刷新任务状态(确保任务状态同步)
|
||||
loadActiveTasks();
|
||||
// Close any remaining running tool calls for this progress (best-effort).
|
||||
finalizeOutstandingToolCallsForProgress(progressId, 'failed');
|
||||
|
||||
// 延迟再次刷新任务状态(确保后端已完成状态更新)
|
||||
setTimeout(() => {
|
||||
|
||||
@@ -959,6 +959,57 @@ async function applySettings() {
|
||||
}
|
||||
}
|
||||
|
||||
// 测试OpenAI连接
|
||||
async function testOpenAIConnection() {
|
||||
const btn = document.getElementById('test-openai-btn');
|
||||
const resultEl = document.getElementById('test-openai-result');
|
||||
|
||||
const baseUrl = document.getElementById('openai-base-url').value.trim();
|
||||
const apiKey = document.getElementById('openai-api-key').value.trim();
|
||||
const model = document.getElementById('openai-model').value.trim();
|
||||
|
||||
if (!apiKey || !model) {
|
||||
resultEl.style.color = 'var(--danger-color, #e53e3e)';
|
||||
resultEl.textContent = typeof window.t === 'function' ? window.t('settingsBasic.testFillRequired') : '请先填写 API Key 和模型';
|
||||
return;
|
||||
}
|
||||
|
||||
btn.style.pointerEvents = 'none';
|
||||
btn.style.opacity = '0.5';
|
||||
resultEl.style.color = 'var(--text-muted, #888)';
|
||||
resultEl.textContent = typeof window.t === 'function' ? window.t('settingsBasic.testing') : '测试中...';
|
||||
|
||||
try {
|
||||
const response = await apiFetch('/api/config/test-openai', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
base_url: baseUrl,
|
||||
api_key: apiKey,
|
||||
model: model
|
||||
})
|
||||
});
|
||||
|
||||
const result = await response.json();
|
||||
|
||||
if (result.success) {
|
||||
resultEl.style.color = 'var(--success-color, #38a169)';
|
||||
const latency = result.latency_ms ? ` (${result.latency_ms}ms)` : '';
|
||||
const modelInfo = result.model ? ` [${result.model}]` : '';
|
||||
resultEl.textContent = (typeof window.t === 'function' ? window.t('settingsBasic.testSuccess') : '连接成功') + modelInfo + latency;
|
||||
} else {
|
||||
resultEl.style.color = 'var(--danger-color, #e53e3e)';
|
||||
resultEl.textContent = (typeof window.t === 'function' ? window.t('settingsBasic.testFailed') : '连接失败') + ': ' + (result.error || '未知错误');
|
||||
}
|
||||
} catch (error) {
|
||||
resultEl.style.color = 'var(--danger-color, #e53e3e)';
|
||||
resultEl.textContent = (typeof window.t === 'function' ? window.t('settingsBasic.testError') : '测试出错') + ': ' + error.message;
|
||||
} finally {
|
||||
btn.style.pointerEvents = '';
|
||||
btn.style.opacity = '';
|
||||
}
|
||||
}
|
||||
|
||||
// 保存工具配置(独立函数,用于MCP管理页面)
|
||||
async function saveToolsConfig() {
|
||||
try {
|
||||
|
||||
@@ -1371,6 +1371,10 @@
|
||||
<label for="openai-model"><span data-i18n="settingsBasic.model">模型</span> <span style="color: red;">*</span></label>
|
||||
<input type="text" id="openai-model" data-i18n="settingsBasic.modelPlaceholder" data-i18n-attr="placeholder" placeholder="gpt-4" required />
|
||||
</div>
|
||||
<div style="display: flex; align-items: center; gap: 8px; margin-top: 2px;">
|
||||
<a href="javascript:void(0)" id="test-openai-btn" onclick="testOpenAIConnection()" style="font-size: 0.8125rem; color: var(--accent-color, #3182ce); text-decoration: none; cursor: pointer; user-select: none;" data-i18n="settingsBasic.testConnection">测试连接</a>
|
||||
<span id="test-openai-result" style="font-size: 0.8125rem;"></span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user