From 9f862ce7218f32a481c3c5f59129854405fc2e72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Mon, 17 Nov 2025 02:43:36 +0800 Subject: [PATCH] Add files via upload --- README.md | 2 + README_CN.md | 2 + internal/app/app.go | 7 + internal/attackchain/builder.go | 1032 +++++++++++++++++++++++++++++ internal/database/attackchain.go | 168 +++++ internal/database/database.go | 42 ++ internal/handler/attackchain.go | 152 +++++ web/static/css/style.css | 138 ++++ web/static/js/app.js | 1052 +++++++++++++++++++++++++++++- web/templates/index.html | 63 ++ 10 files changed, 2646 insertions(+), 12 deletions(-) create mode 100644 internal/attackchain/builder.go create mode 100644 internal/database/attackchain.go create mode 100644 internal/handler/attackchain.go diff --git a/README.md b/README.md index cc3e6e02..654c4d2b 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ ![Preview](./img/外部MCP接入.png) ## Changelog +- 2025.11.15 Added attack chain visualization feature: automatically build attack chains from conversations using AI analysis, visualize tool execution flows, vulnerability discovery paths, and relationships between nodes, support interactive graph exploration with risk scoring - 2025.11.15 Added large result pagination feature: when tool execution results exceed the threshold (default 200KB), automatically save to file and return execution ID, support paginated queries, keyword search, conditional filtering, and regex matching through query_execution_result tool, effectively solving the problem of overly long single responses and improving large file processing capabilities - 2025.11.15 Added external MCP integration feature: support for integrating external MCP servers to extend tool capabilities, supports both stdio and HTTP transport modes, tool-level enable/disable control, complete configuration guide and management APIs - 2025.11.14 Performance optimizations: optimized tool lookup from O(n) to O(1) using index map, added automatic cleanup mechanism for execution records to prevent memory leaks, and added pagination support for database queries @@ -36,6 +37,7 @@ - 📊 **Conversation History Management** - Complete conversation history records, supports viewing, deletion, and management - ⚙️ **Visual Configuration Management** - Web interface for system settings, supports real-time loading and saving configurations with required field validation - 📄 **Large Result Pagination** - When tool execution results exceed the threshold, automatically save to file, support paginated queries, keyword search, conditional filtering, and regex matching, effectively solving the problem of overly long single responses, with examples for various tools (head, tail, grep, sed, etc.) for segmented reading +- 🔗 **Attack Chain Visualization** - Automatically build and visualize attack chains from conversations, showing tool execution flows, vulnerability discovery paths, and relationships between targets, tools, vulnerabilities, and discoveries, with AI-powered analysis and interactive graph exploration ### Tool Integration - 🔌 **MCP Protocol Support** - Complete MCP protocol implementation, supports tool registration, invocation, and monitoring diff --git a/README_CN.md b/README_CN.md index 8252cfed..b273b9b9 100644 --- a/README_CN.md +++ b/README_CN.md @@ -15,6 +15,7 @@ ![详情预览](./img/外部MCP接入.png) ## 更新日志 +- 2025.11.17 新增攻击链可视化功能:基于AI分析自动从对话中构建攻击链,可视化展示工具执行流程、漏洞发现路径和节点间关联关系,支持交互式图谱探索和风险评分 - 2025.11.15 新增大结果分段读取功能:当工具执行结果超过阈值(默认200KB)时,自动保存到文件并返回执行ID,支持通过 query_execution_result 工具进行分页查询、关键词搜索、条件过滤和正则表达式匹配,有效解决单次返回过长的问题,提升大文件处理能力 - 2025.11.15 新增外部 MCP 接入功能:支持接入外部 MCP 服务器扩展工具能力,支持 stdio 和 HTTP 两种传输模式,支持工具级别的启用/禁用控制,提供完整的配置指南和管理接口 - 2025.11.14 性能优化:工具查找从 O(n) 优化为 O(1)(使用索引映射),添加执行记录自动清理机制防止内存泄漏,数据库查询支持分页加载 @@ -36,6 +37,7 @@ - 📊 **对话历史管理** - 完整的对话历史记录,支持查看、删除和管理 - ⚙️ **可视化配置管理** - Web界面配置系统设置,支持实时加载和保存配置,必填项验证 - 📄 **大结果分段读取** - 当工具执行结果超过阈值时自动保存,支持分页查询、关键词搜索、条件过滤和正则表达式匹配,有效解决单次返回过长问题,提供多种工具(head、tail、grep、sed等)的分段读取示例 +- 🔗 **攻击链可视化** - 基于AI分析自动从对话中构建攻击链,可视化展示工具执行流程、漏洞发现路径以及目标、工具、漏洞、发现之间的关联关系,支持交互式图谱探索和风险评分 ### 工具集成 - 🔌 **MCP协议支持** - 完整实现MCP协议,支持工具注册、调用、监控 diff --git a/internal/app/app.go b/internal/app/app.go index addef393..b7963094 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -130,6 +130,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger) configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, externalMCPMgr, log.Logger) externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) + attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger) // 设置路由 setupRoutes( @@ -140,6 +141,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { conversationHandler, configHandler, externalMCPHandler, + attackChainHandler, mcpServer, authManager, ) @@ -198,6 +200,7 @@ func setupRoutes( conversationHandler *handler.ConversationHandler, configHandler *handler.ConfigHandler, externalMCPHandler *handler.ExternalMCPHandler, + attackChainHandler *handler.AttackChainHandler, mcpServer *mcp.Server, authManager *security.AuthManager, ) { @@ -251,6 +254,10 @@ func setupRoutes( protected.POST("/external-mcp/:name/start", externalMCPHandler.StartExternalMCP) protected.POST("/external-mcp/:name/stop", externalMCPHandler.StopExternalMCP) + // 攻击链可视化 + protected.GET("/attack-chain/:conversationId", attackChainHandler.GetAttackChain) + protected.POST("/attack-chain/:conversationId/regenerate", attackChainHandler.RegenerateAttackChain) + // MCP端点 protected.POST("/mcp", func(c *gin.Context) { mcpServer.HandleHTTP(c.Writer, c.Request) diff --git a/internal/attackchain/builder.go b/internal/attackchain/builder.go new file mode 100644 index 00000000..10f7bf05 --- /dev/null +++ b/internal/attackchain/builder.go @@ -0,0 +1,1032 @@ +package attackchain + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sort" + "strings" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/mcp" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Builder 攻击链构建器 +type Builder struct { + db *database.DB + logger *zap.Logger + openAIClient *http.Client + openAIConfig *config.OpenAIConfig +} + +// Node 攻击链节点(使用database包的类型) +type Node = database.AttackChainNode + +// Edge 攻击链边(使用database包的类型) +type Edge = database.AttackChainEdge + +// Chain 完整的攻击链 +type Chain struct { + Nodes []Node `json:"nodes"` + Edges []Edge `json:"edges"` +} + +// NewBuilder 创建新的攻击链构建器 +func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *Builder { + transport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + } + + return &Builder{ + db: db, + logger: logger, + openAIClient: &http.Client{Timeout: 5 * time.Minute, Transport: transport}, + openAIConfig: openAIConfig, + } +} + +// BuildChainFromConversation 从对话构建攻击链(一次性生成整个图) +func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) { + b.logger.Info("开始构建攻击链(一次性生成)", zap.String("conversationId", conversationID)) + + // 1. 获取对话消息和工具执行记录 + messages, err := b.db.GetMessages(conversationID) + if err != nil { + return nil, fmt.Errorf("获取对话消息失败: %w", err) + } + + executions, err := b.getToolExecutionsByConversation(conversationID) + if err != nil { + return nil, fmt.Errorf("获取工具执行记录失败: %w", err) + } + + // 获取过程详情 + processDetailsMap, err := b.db.GetProcessDetailsByConversation(conversationID) + if err != nil { + b.logger.Warn("获取过程详情失败", zap.Error(err)) + processDetailsMap = make(map[string][]database.ProcessDetail) + } + + if len(executions) == 0 && len(messages) == 0 { + b.logger.Info("对话中没有数据", zap.String("conversationId", conversationID)) + return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil + } + + // 2. 准备上下文数据 + contextData, err := b.prepareContextData(messages, executions, processDetailsMap) + if err != nil { + return nil, fmt.Errorf("准备上下文数据失败: %w", err) + } + + // 3. 一次性生成攻击链(带重试和压缩机制) + chain, err := b.generateChainWithRetry(ctx, contextData, 5) + if err != nil { + return nil, fmt.Errorf("生成攻击链失败: %w", err) + } + + // 4. 保存到数据库 + if err := b.saveChain(conversationID, chain.Nodes, chain.Edges); err != nil { + b.logger.Warn("保存攻击链失败", zap.Error(err)) + // 不返回错误,继续返回结果 + } + + b.logger.Info("攻击链构建完成", + zap.String("conversationId", conversationID), + zap.Int("nodes", len(chain.Nodes)), + zap.Int("edges", len(chain.Edges))) + + return chain, nil +} + +// getToolExecutionsByConversation 获取对话的工具执行记录 +func (b *Builder) getToolExecutionsByConversation(conversationID string) ([]*mcp.ToolExecution, error) { + // 通过conversation_id关联messages,再通过mcp_execution_ids关联tool_executions + // 简化实现:直接查询所有工具执行记录,然后过滤(实际应该优化查询) + allExecutions, err := b.db.LoadToolExecutions() + if err != nil { + return nil, err + } + + // 获取对话的消息,提取mcp_execution_ids + messages, err := b.db.GetMessages(conversationID) + if err != nil { + return nil, err + } + + // 收集所有execution IDs + executionIDSet := make(map[string]bool) + for _, msg := range messages { + if len(msg.MCPExecutionIDs) > 0 { + for _, id := range msg.MCPExecutionIDs { + executionIDSet[id] = true + } + } + } + + // 过滤执行记录 + var filteredExecutions []*mcp.ToolExecution + for _, exec := range allExecutions { + if executionIDSet[exec.ID] { + filteredExecutions = append(filteredExecutions, exec) + } + } + + // 按时间排序 + sort.Slice(filteredExecutions, func(i, j int) bool { + return filteredExecutions[i].StartTime.Before(filteredExecutions[j].StartTime) + }) + + return filteredExecutions, nil +} + +// saveChain 保存攻击链到数据库 +func (b *Builder) saveChain(conversationID string, nodes []Node, edges []Edge) error { + // 先删除旧的攻击链数据 + if err := b.db.DeleteAttackChain(conversationID); err != nil { + b.logger.Warn("删除旧攻击链失败", zap.Error(err)) + } + + // 保存节点 + for _, node := range nodes { + metadataJSON, _ := json.Marshal(node.Metadata) + if err := b.db.SaveAttackChainNode(conversationID, node.ID, node.Type, node.Label, node.ToolExecutionID, string(metadataJSON), node.RiskScore); err != nil { + b.logger.Warn("保存攻击链节点失败", zap.String("nodeId", node.ID), zap.Error(err)) + } + } + + // 保存边 + for _, edge := range edges { + if err := b.db.SaveAttackChainEdge(conversationID, edge.ID, edge.Source, edge.Target, edge.Type, edge.Weight); err != nil { + b.logger.Warn("保存攻击链边失败", zap.String("edgeId", edge.ID), zap.Error(err)) + } + } + + return nil +} + +// LoadChainFromDatabase 从数据库加载攻击链 +func (b *Builder) LoadChainFromDatabase(conversationID string) (*Chain, error) { + nodes, err := b.db.LoadAttackChainNodes(conversationID) + if err != nil { + return nil, fmt.Errorf("加载攻击链节点失败: %w", err) + } + + edges, err := b.db.LoadAttackChainEdges(conversationID) + if err != nil { + return nil, fmt.Errorf("加载攻击链边失败: %w", err) + } + + return &Chain{ + Nodes: nodes, + Edges: edges, + }, nil +} + +// ContextData 上下文数据(用于一次性生成攻击链) +type ContextData struct { + Messages []database.Message `json:"messages"` + Executions []*mcp.ToolExecution `json:"executions"` + ProcessDetails map[string][]database.ProcessDetail `json:"process_details"` + SummarizedItems map[string]string `json:"summarized_items"` // 已总结的项目(key: 原始ID, value: 总结内容) +} + +// prepareContextData 准备上下文数据 +func (b *Builder) prepareContextData(messages []database.Message, executions []*mcp.ToolExecution, processDetails map[string][]database.ProcessDetail) (*ContextData, error) { + return &ContextData{ + Messages: messages, + Executions: executions, + ProcessDetails: processDetails, + SummarizedItems: make(map[string]string), + }, nil +} + +// generateChainWithRetry 生成攻击链(带重试和压缩机制) +func (b *Builder) generateChainWithRetry(ctx context.Context, contextData *ContextData, maxRetries int) (*Chain, error) { + for attempt := 0; attempt < maxRetries; attempt++ { + b.logger.Info("尝试生成攻击链", + zap.Int("attempt", attempt+1), + zap.Int("maxRetries", maxRetries)) + + // 构建提示词 + prompt, err := b.buildChainGenerationPrompt(contextData) + if err != nil { + return nil, fmt.Errorf("构建提示词失败: %w", err) + } + + // 调用AI生成攻击链 + chainJSON, err := b.callAIForChainGeneration(ctx, prompt) + if err != nil { + // 检查是否是上下文过长错误 + if strings.Contains(err.Error(), "context length") || strings.Contains(err.Error(), "too long") || strings.Contains(err.Error(), "context length exceeded") { + b.logger.Warn("上下文过长,尝试压缩", + zap.Int("attempt", attempt+1), + zap.Error(err)) + + // 压缩最长的子节点 + if err := b.compressLongestItem(ctx, contextData); err != nil { + return nil, fmt.Errorf("压缩上下文失败: %w", err) + } + + // 重试 + continue + } + + return nil, fmt.Errorf("AI生成失败: %w", err) + } + + // 解析JSON(传入executions用于ID映射) + chain, err := b.parseChainJSON(chainJSON, contextData.Executions) + if err != nil { + return nil, fmt.Errorf("解析攻击链JSON失败: %w", err) + } + + return chain, nil + } + + return nil, fmt.Errorf("生成攻击链失败:超过最大重试次数 %d", maxRetries) +} + +// buildChainGenerationPrompt 构建攻击链生成提示词 +func (b *Builder) buildChainGenerationPrompt(contextData *ContextData) (string, error) { + var promptBuilder strings.Builder + + promptBuilder.WriteString(`你是一个专业的安全测试分析师。请根据以下对话和工具执行记录,生成清晰、有教育意义的攻击链图。 + +## 核心原则 + +**目标:让不懂渗透测试的同学可以通过这个攻击链路学习到知识,而不是无数个节点看花眼。** + +## 任务要求 + +1. **节点类型(简化,只保留3种)**: + - **target(目标)**:从用户输入中提取测试目标(IP、域名、URL等) + - **重要:如果对话中测试了多个不同的目标(如先测试A网页,后测试B网页),必须为每个不同的目标创建独立的target节点** + - 每个target节点只关联属于它的action节点(通过工具执行参数中的目标来判断) + - 不同目标的action节点之间**不应该**建立关联关系 + - **action(行动)**:**工具执行 + AI分析结果 = 一个action节点** + - 将每个工具执行和AI对该工具结果的分析合并为一个action节点 + - 节点标签应该清晰描述"做了什么"和"发现了什么"(例如:"使用Nmap扫描端口,发现22、80、443端口开放") + - 只包含**有效的、成功的**工具执行(忽略错误、失败、无效的执行) + - **重要:action节点必须关联到正确的target节点(通过工具执行参数判断目标)** + - **vulnerability(漏洞)**:从工具执行结果和AI分析中提取的**真实漏洞**(不是所有发现都是漏洞) + +2. **过滤规则(重要!)**: + - **忽略所有错误/失败的节点**: + - 工具执行错误(Error字段不为空,或Result.IsError为true) + - 工具执行结果为空或无效 + - AI分析中明确标记为"失败"、"错误"、"无效"的内容 + - **只保留有价值的节点**: + - 成功执行的工具 + - 有实际发现的工具执行 + - 真实存在的漏洞 + +3. **建立清晰的关联关系**: + - target → action:目标指向属于它的所有行动(通过工具执行参数判断目标) + - action → action:行动之间的逻辑顺序(按时间顺序,但只连接有逻辑关系的) + - **重要:只连接属于同一目标的action节点,不同目标的action节点之间不应该连接** + - action → vulnerability:行动发现的漏洞 + - vulnerability → vulnerability:漏洞间的因果关系(如SQL注入 → 信息泄露) + - **重要:只连接属于同一目标的漏洞,不同目标的漏洞之间不应该连接** + +4. **节点属性**: + - 每个节点需要:id, type, label, risk_score, metadata + - action节点需要: + - tool_name: 工具名称 + - tool_intent: 工具调用意图(如"端口扫描"、"漏洞扫描") + - ai_analysis: AI对工具结果的分析总结(简洁,不超过100字) + - findings: 关键发现(列表) + - vulnerability节点需要:type, description, severity, location + +## 对话数据 + +`) + + // 添加消息 + promptBuilder.WriteString("\n### 对话消息\n\n") + for i, msg := range contextData.Messages { + promptBuilder.WriteString(fmt.Sprintf("消息%d [%s]:\n", i+1, msg.Role)) + + // 检查是否已总结 + if summary, ok := contextData.SummarizedItems[msg.ID]; ok { + promptBuilder.WriteString(fmt.Sprintf("[已总结] %s\n\n", summary)) + } else { + content := msg.Content + if len(content) > 5000 { + content = content[:5000] + "..." + } + promptBuilder.WriteString(fmt.Sprintf("%s\n\n", content)) + } + + // 添加过程详情 + if details, ok := contextData.ProcessDetails[msg.ID]; ok { + for _, detail := range details { + if detail.EventType == "thinking" { + thinkingText := detail.Message + if summary, ok := contextData.SummarizedItems[detail.ID]; ok { + thinkingText = "[已总结] " + summary + } else if len(thinkingText) > 2000 { + thinkingText = thinkingText[:2000] + "..." + } + promptBuilder.WriteString(fmt.Sprintf("思考过程: %s\n", thinkingText)) + } + } + } + promptBuilder.WriteString("\n") + } + + // 添加工具执行记录(关联对应的AI回复) + promptBuilder.WriteString("\n### 工具执行记录(包含对应的AI分析)\n\n") + + // 构建工具执行ID到消息的映射(找到工具执行后AI的回复) + execToMessageMap := b.buildExecutionToMessageMap(contextData) + + for i, exec := range contextData.Executions { + // 检查是否是错误/失败的执行 + isError := exec.Error != "" || (exec.Result != nil && exec.Result.IsError) + if isError { + promptBuilder.WriteString(fmt.Sprintf("执行%d [%s] (ID: %s) - **已忽略(执行失败/错误)**\n\n", i+1, exec.ToolName, exec.ID)) + continue + } + + promptBuilder.WriteString(fmt.Sprintf("执行%d [%s] (ID: %s):\n", i+1, exec.ToolName, exec.ID)) + promptBuilder.WriteString(fmt.Sprintf("参数: %s\n", b.formatArguments(exec.Arguments))) + + // 检查是否已总结 + var resultText string + if exec.Result != nil { + for _, content := range exec.Result.Content { + if content.Type == "text" { + resultText += content.Text + "\n" + } + } + } + + // 检查结果是否为空或无效 + if resultText == "" || strings.TrimSpace(resultText) == "" { + promptBuilder.WriteString("结果: **已忽略(结果为空)**\n\n") + continue + } + + if summary, ok := contextData.SummarizedItems[exec.ID]; ok { + promptBuilder.WriteString(fmt.Sprintf("工具执行结果: [已总结] %s\n", summary)) + } else { + if len(resultText) > 5000 { + resultText = resultText[:5000] + "..." + } + promptBuilder.WriteString(fmt.Sprintf("工具执行结果: %s\n", resultText)) + } + + // 添加对应的AI分析(工具执行后AI的回复) + if aiMessage, ok := execToMessageMap[exec.ID]; ok { + aiContent := aiMessage.Content + if len(aiContent) > 2000 { + aiContent = aiContent[:2000] + "..." + } + promptBuilder.WriteString(fmt.Sprintf("AI分析: %s\n", aiContent)) + } + + promptBuilder.WriteString("\n") + } + + promptBuilder.WriteString(` + +## 输出格式 + +请以JSON格式返回攻击链,格式如下: + +{ + "nodes": [ + { + "id": "node_1", + "type": "target|action|vulnerability", + "label": "节点标签(清晰、简洁,action节点要描述"做了什么"和"发现了什么")", + "risk_score": 0-100, + "tool_execution_id": "执行记录的真实ID(action节点必须使用上面执行记录中的ID字段)", + "metadata": { + "target": "目标(target节点)", + "tool_name": "工具名称(action节点)", + "tool_intent": "工具调用意图(action节点,如"端口扫描"、"漏洞扫描")", + "ai_analysis": "AI对工具结果的分析总结(action节点,不超过100字)", + "findings": ["发现1", "发现2"](action节点,关键发现列表), + "vulnerability_type": "漏洞类型(vulnerability节点)", + "description": "描述(vulnerability节点)", + "severity": "critical|high|medium|low(vulnerability节点)", + "location": "漏洞位置(vulnerability节点)" + } + } + ], + "edges": [ + { + "source": "node_1", + "target": "node_2", + "type": "leads_to|discovers|enables", + "weight": 1-5 + } + ] +} + +## 重要要求 + +1. **节点合并**: + - 每个工具执行和对应的AI分析必须合并为一个action节点 + - action节点的label要清晰描述"做了什么"和"发现了什么" + - 例如:"使用Nmap扫描192.168.1.1,发现22、80、443端口开放" + +2. **过滤无效节点**: + - **必须忽略**所有错误/失败的执行(已在上面标记为"已忽略"的) + - **必须忽略**结果为空或无效的执行 + - 只保留有价值的、成功的节点 + +3. **简化结构**: + - 只创建target、action、vulnerability三种节点 + - 不要创建discovery、decision等节点 + - 让攻击链清晰、有教育意义 + +4. **关联关系**: + - target → action:目标指向属于它的所有行动(通过工具执行参数判断目标) + - action → action:按时间顺序连接,但只连接有逻辑关系的 + - **重要:只连接属于同一目标的action节点,不同目标的action节点之间不应该连接** + - action → vulnerability:行动发现的漏洞 + - vulnerability → vulnerability:漏洞间的因果关系 + - **重要:只连接属于同一目标的漏洞,不同目标的漏洞之间不应该连接** + +5. **多目标处理(重要!)**: + - 如果对话中测试了多个不同的目标(如先测试A网页,后测试B网页),必须: + - 为每个不同的目标创建独立的target节点 + - 每个target节点只关联属于它的action和vulnerability节点 + - 不同目标的节点之间**不应该**建立任何关联关系 + - 这样会形成多个独立的攻击链分支,每个分支对应一个测试目标 + +6. **节点数量控制**: + - 如果节点太多(>20个),优先保留最重要的节点 + - 合并相似的action节点(如同一工具的连续调用,如果结果相似) + +只返回JSON,不要包含其他解释文字。`) + + return promptBuilder.String(), nil +} + +// buildExecutionToMessageMap 构建工具执行ID到AI消息的映射 +// 找到每个工具执行后AI的回复消息 +func (b *Builder) buildExecutionToMessageMap(contextData *ContextData) map[string]database.Message { + execToMessageMap := make(map[string]database.Message) + + // 遍历消息,找到包含工具执行ID的消息(通常是assistant消息) + for _, msg := range contextData.Messages { + if msg.Role != "assistant" { + continue + } + + // 检查消息中是否引用了工具执行ID + // 通常工具执行后,AI会在回复中引用这些执行ID + for _, execID := range msg.MCPExecutionIDs { + // 找到对应的工具执行 + for _, exec := range contextData.Executions { + if exec.ID == execID { + // 如果这个执行还没有关联的消息,或者当前消息时间更晚,则更新 + if existingMsg, exists := execToMessageMap[execID]; !exists || msg.CreatedAt.After(existingMsg.CreatedAt) { + execToMessageMap[execID] = msg + } + break + } + } + } + } + + // 如果通过MCPExecutionIDs找不到,尝试按时间顺序匹配 + // 找到每个工具执行后最近的assistant消息 + for _, exec := range contextData.Executions { + if _, exists := execToMessageMap[exec.ID]; exists { + continue + } + + // 找到执行时间之后最近的assistant消息 + var closestMsg *database.Message + for i := range contextData.Messages { + msg := &contextData.Messages[i] + if msg.Role == "assistant" && msg.CreatedAt.After(exec.StartTime) { + if closestMsg == nil || msg.CreatedAt.Before(closestMsg.CreatedAt) { + closestMsg = msg + } + } + } + + if closestMsg != nil { + execToMessageMap[exec.ID] = *closestMsg + } + } + + return execToMessageMap +} + +// formatArguments 格式化工具参数 +func (b *Builder) formatArguments(args map[string]interface{}) string { + if args == nil { + return "{}" + } + jsonData, _ := json.Marshal(args) + return string(jsonData) +} + +// compressLongestItem 压缩最长的子节点 +func (b *Builder) compressLongestItem(ctx context.Context, contextData *ContextData) error { + var longestID string + var longestType string + var longestContent string + maxLength := 0 + + // 查找最长的消息 + for _, msg := range contextData.Messages { + if _, alreadySummarized := contextData.SummarizedItems[msg.ID]; alreadySummarized { + continue + } + length := len(msg.Content) + if length > maxLength { + maxLength = length + longestID = msg.ID + longestType = "message" + longestContent = msg.Content + } + } + + // 查找最长的工具执行结果 + for _, exec := range contextData.Executions { + if _, alreadySummarized := contextData.SummarizedItems[exec.ID]; alreadySummarized { + continue + } + if exec.Result != nil { + var resultText string + for _, content := range exec.Result.Content { + if content.Type == "text" { + resultText += content.Text + "\n" + } + } + length := len(resultText) + if length > maxLength { + maxLength = length + longestID = exec.ID + longestType = "execution" + longestContent = resultText + } + } + } + + // 查找最长的思考过程 + for _, details := range contextData.ProcessDetails { + for _, detail := range details { + if detail.EventType == "thinking" { + if _, alreadySummarized := contextData.SummarizedItems[detail.ID]; alreadySummarized { + continue + } + length := len(detail.Message) + if length > maxLength { + maxLength = length + longestID = detail.ID + longestType = "thinking" + longestContent = detail.Message + } + } + } + } + + if longestID == "" { + return fmt.Errorf("没有找到需要压缩的内容") + } + + b.logger.Info("压缩最长子节点", + zap.String("id", longestID), + zap.String("type", longestType), + zap.Int("length", maxLength)) + + // 使用AI总结 + summary, err := b.summarizeContent(ctx, longestType, longestContent) + if err != nil { + return fmt.Errorf("总结内容失败: %w", err) + } + + // 保存总结 + contextData.SummarizedItems[longestID] = summary + + b.logger.Info("压缩完成", + zap.String("id", longestID), + zap.Int("originalLength", maxLength), + zap.Int("summaryLength", len(summary))) + + return nil +} + +// summarizeContent 总结内容 +func (b *Builder) summarizeContent(ctx context.Context, contentType, content string) (string, error) { + var prompt string + switch contentType { + case "message": + prompt = fmt.Sprintf(`请总结以下AI回复的关键信息,保留所有重要的安全发现、漏洞信息和测试结果。用简洁的中文总结,不超过500字。 + +AI回复: +%s + +总结:`, content) + case "execution": + prompt = fmt.Sprintf(`请总结以下工具执行结果的关键信息,保留所有发现的漏洞、重要发现和测试结果。用简洁的中文总结,不超过500字。 + +工具执行结果: +%s + +总结:`, content) + case "thinking": + prompt = fmt.Sprintf(`请总结以下AI思考过程的关键决策和思路,保留所有重要的决策点和测试策略。用简洁的中文总结,不超过300字。 + +思考过程: +%s + +总结:`, content) + default: + return "", fmt.Errorf("未知的内容类型: %s", contentType) + } + + requestBody := map[string]interface{}{ + "model": b.openAIConfig.Model, + "messages": []map[string]interface{}{ + { + "role": "system", + "content": "你是一个专业的安全测试分析师,擅长总结安全测试相关的信息。请用简洁的中文总结关键信息。", + }, + { + "role": "user", + "content": prompt, + }, + }, + "temperature": 0.3, + "max_tokens": 1000, + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return "", fmt.Errorf("序列化请求失败: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("创建请求失败: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey) + + resp, err := b.openAIClient.Do(req) + if err != nil { + return "", fmt.Errorf("请求失败: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, string(body)) + } + + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + + if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil { + return "", fmt.Errorf("解析响应失败: %w", err) + } + + if len(apiResponse.Choices) == 0 { + return "", fmt.Errorf("API未返回有效响应") + } + + return strings.TrimSpace(apiResponse.Choices[0].Message.Content), nil +} + +// callAIForChainGeneration 调用AI生成攻击链 +func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (string, error) { + requestBody := map[string]interface{}{ + "model": b.openAIConfig.Model, + "messages": []map[string]interface{}{ + { + "role": "system", + "content": "你是一个专业的安全测试分析师,擅长构建攻击链图。请严格按照JSON格式返回攻击链数据。", + }, + { + "role": "user", + "content": prompt, + }, + }, + "temperature": 0.3, + "max_tokens": 8000, + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return "", fmt.Errorf("序列化请求失败: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", b.openAIConfig.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("创建请求失败: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+b.openAIConfig.APIKey) + + resp, err := b.openAIClient.Do(req) + if err != nil { + // 检查是否是上下文过长错误 + if strings.Contains(err.Error(), "context") || strings.Contains(err.Error(), "length") { + return "", fmt.Errorf("context length exceeded") + } + return "", fmt.Errorf("请求失败: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + bodyStr := string(body) + // 检查是否是上下文过长错误 + if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") { + return "", fmt.Errorf("context length exceeded") + } + return "", fmt.Errorf("API返回错误: %d, %s", resp.StatusCode, bodyStr) + } + + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + + if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil { + return "", fmt.Errorf("解析响应失败: %w", err) + } + + if len(apiResponse.Choices) == 0 { + return "", fmt.Errorf("API未返回有效响应") + } + + content := strings.TrimSpace(apiResponse.Choices[0].Message.Content) + // 尝试提取JSON(可能包含markdown代码块) + content = strings.TrimPrefix(content, "```json") + content = strings.TrimPrefix(content, "```") + content = strings.TrimSuffix(content, "```") + content = strings.TrimSpace(content) + + return content, nil +} + +// ChainJSON 攻击链JSON结构 +type ChainJSON struct { + Nodes []struct { + ID string `json:"id"` + Type string `json:"type"` + Label string `json:"label"` + RiskScore int `json:"risk_score"` + ToolExecutionID string `json:"tool_execution_id,omitempty"` + Metadata map[string]interface{} `json:"metadata"` + } `json:"nodes"` + Edges []struct { + Source string `json:"source"` + Target string `json:"target"` + Type string `json:"type"` + Weight int `json:"weight"` + } `json:"edges"` +} + +// parseChainJSON 解析攻击链JSON +func (b *Builder) parseChainJSON(chainJSON string, executions []*mcp.ToolExecution) (*Chain, error) { + var chainData ChainJSON + if err := json.Unmarshal([]byte(chainJSON), &chainData); err != nil { + return nil, fmt.Errorf("解析JSON失败: %w", err) + } + + // 创建execution ID映射(AI可能返回简单的索引或ID,需要映射到真实的execution ID) + executionMap := make(map[string]string) // AI返回的ID -> 真实execution ID + for i, exec := range executions { + // 支持多种可能的AI返回格式 + executionMap[fmt.Sprintf("exec_%d", i+1)] = exec.ID + executionMap[fmt.Sprintf("execution_%d", i+1)] = exec.ID + executionMap[exec.ID] = exec.ID // 如果AI直接返回真实ID + executionMap[fmt.Sprintf("tool_%d", i+1)] = exec.ID // AI可能用tool_1格式 + executionMap[fmt.Sprintf("执行%d", i+1)] = exec.ID // 中文格式 + executionMap[fmt.Sprintf("执行_%d", i+1)] = exec.ID + } + + // 创建节点ID映射(AI返回的ID -> 新的UUID) + nodeIDMap := make(map[string]string) + + // 转换为Chain结构,并过滤无效节点 + nodes := make([]Node, 0, len(chainData.Nodes)) + for _, n := range chainData.Nodes { + // 过滤无效节点 + if b.shouldFilterNode(n, executions) { + b.logger.Info("过滤无效节点", + zap.String("nodeID", n.ID), + zap.String("nodeType", n.Type), + zap.String("label", n.Label)) + continue + } + + // 生成新的UUID节点ID + newNodeID := fmt.Sprintf("node_%s", uuid.New().String()) + nodeIDMap[n.ID] = newNodeID + + node := Node{ + ID: newNodeID, + Type: n.Type, + Label: n.Label, + RiskScore: n.RiskScore, + Metadata: n.Metadata, + } + if node.Metadata == nil { + node.Metadata = make(map[string]interface{}) + } + + // 处理tool_execution_id:如果是action或vulnerability节点,需要映射到真实的execution ID + if n.ToolExecutionID != "" { + if realExecID, ok := executionMap[n.ToolExecutionID]; ok { + node.ToolExecutionID = realExecID + } else { + // 检查是否是真实的execution ID(UUID格式) + // 如果是,直接使用;如果不是,尝试从节点ID推断 + if len(n.ToolExecutionID) > 20 { // UUID通常很长 + node.ToolExecutionID = n.ToolExecutionID + } else { + // 可能是简单的ID,尝试从节点ID推断 + if realExecID, ok := executionMap[n.ID]; ok { + node.ToolExecutionID = realExecID + } else { + b.logger.Warn("无法映射tool_execution_id", + zap.String("nodeID", n.ID), + zap.String("toolExecutionID", n.ToolExecutionID)) + // 对于action节点,如果没有有效的execution ID,清空它(避免外键约束失败) + if n.Type == "action" { + node.ToolExecutionID = "" + } + } + } + } + } else if n.Type == "action" || n.Type == "vulnerability" { + // 如果AI没有提供tool_execution_id,尝试从节点ID推断 + // 例如:tool_1 -> 查找exec_1 + if realExecID, ok := executionMap[n.ID]; ok { + node.ToolExecutionID = realExecID + } else { + b.logger.Warn("action/vulnerability节点缺少tool_execution_id", + zap.String("nodeID", n.ID), + zap.String("nodeType", n.Type)) + } + } + + nodes = append(nodes, node) + } + + // 转换边,更新source和target为新的节点ID + edges := make([]Edge, 0, len(chainData.Edges)) + for _, e := range chainData.Edges { + sourceID, ok := nodeIDMap[e.Source] + if !ok { + b.logger.Warn("边的源节点ID未找到", zap.String("source", e.Source)) + continue + } + + targetID, ok := nodeIDMap[e.Target] + if !ok { + b.logger.Warn("边的目标节点ID未找到", zap.String("target", e.Target)) + continue + } + + edge := Edge{ + ID: fmt.Sprintf("edge_%s", uuid.New().String()), + Source: sourceID, + Target: targetID, + Type: e.Type, + Weight: e.Weight, + } + edges = append(edges, edge) + } + + // 过滤掉指向已删除节点的边 + filteredEdges := make([]Edge, 0, len(edges)) + for _, edge := range edges { + // 检查source和target节点是否都存在 + sourceExists := false + targetExists := false + for _, node := range nodes { + if node.ID == edge.Source { + sourceExists = true + } + if node.ID == edge.Target { + targetExists = true + } + } + + if sourceExists && targetExists { + filteredEdges = append(filteredEdges, edge) + } else { + b.logger.Warn("过滤无效边", + zap.String("edgeID", edge.ID), + zap.String("source", edge.Source), + zap.String("target", edge.Target), + zap.Bool("sourceExists", sourceExists), + zap.Bool("targetExists", targetExists)) + } + } + + return &Chain{ + Nodes: nodes, + Edges: filteredEdges, + }, nil +} + +// shouldFilterNode 判断是否应该过滤掉这个节点 +func (b *Builder) shouldFilterNode(n struct { + ID string `json:"id"` + Type string `json:"type"` + Label string `json:"label"` + RiskScore int `json:"risk_score"` + ToolExecutionID string `json:"tool_execution_id,omitempty"` + Metadata map[string]interface{} `json:"metadata"` +}, executions []*mcp.ToolExecution) bool { + // 只允许target、action、vulnerability三种节点类型 + if n.Type != "target" && n.Type != "action" && n.Type != "vulnerability" { + return true + } + + // 对于action节点,检查对应的工具执行是否有效 + if n.Type == "action" { + if n.ToolExecutionID == "" { + // 没有关联工具执行的action节点,可能是无效的 + return true + } + + // 查找对应的工具执行 + var exec *mcp.ToolExecution + for _, e := range executions { + if e.ID == n.ToolExecutionID { + exec = e + break + } + } + + if exec == nil { + // 找不到对应的工具执行,可能是无效的 + return true + } + + // 检查工具执行是否错误或失败 + if exec.Error != "" || (exec.Result != nil && exec.Result.IsError) { + return true + } + + // 检查工具执行结果是否为空 + if exec.Result == nil || len(exec.Result.Content) == 0 { + return true + } + + // 检查结果文本是否为空 + var resultText string + for _, content := range exec.Result.Content { + if content.Type == "text" { + resultText += content.Text + } + } + if strings.TrimSpace(resultText) == "" { + return true + } + } + + // 检查节点标签是否为空或无效 + if strings.TrimSpace(n.Label) == "" { + return true + } + + // 检查标签中是否包含错误/失败的关键词 + labelLower := strings.ToLower(n.Label) + errorKeywords := []string{"错误", "失败", "无效", "error", "failed", "invalid", "empty", "空"} + for _, keyword := range errorKeywords { + if strings.Contains(labelLower, keyword) { + // 如果标签明确表示错误,但节点类型不是vulnerability,则过滤 + if n.Type != "vulnerability" { + return true + } + } + } + + return false +} diff --git a/internal/database/attackchain.go b/internal/database/attackchain.go new file mode 100644 index 00000000..c8529e70 --- /dev/null +++ b/internal/database/attackchain.go @@ -0,0 +1,168 @@ +package database + +import ( + "database/sql" + "encoding/json" + "fmt" + + "go.uber.org/zap" +) + +// AttackChainNode 攻击链节点 +type AttackChainNode struct { + ID string `json:"id"` + Type string `json:"type"` // tool, vulnerability, target, exploit + Label string `json:"label"` + ToolExecutionID string `json:"tool_execution_id,omitempty"` + Metadata map[string]interface{} `json:"metadata"` + RiskScore int `json:"risk_score"` +} + +// AttackChainEdge 攻击链边 +type AttackChainEdge struct { + ID string `json:"id"` + Source string `json:"source"` + Target string `json:"target"` + Type string `json:"type"` // leads_to, exploits, enables, depends_on + Weight int `json:"weight"` +} + +// SaveAttackChainNode 保存攻击链节点 +func (db *DB) SaveAttackChainNode(conversationID, nodeID, nodeType, nodeName, toolExecutionID, metadata string, riskScore int) error { + var toolExecID sql.NullString + if toolExecutionID != "" { + toolExecID = sql.NullString{String: toolExecutionID, Valid: true} + } + + var metadataJSON sql.NullString + if metadata != "" { + metadataJSON = sql.NullString{String: metadata, Valid: true} + } + + query := ` + INSERT OR REPLACE INTO attack_chain_nodes + (id, conversation_id, node_type, node_name, tool_execution_id, metadata, risk_score, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) + ` + + _, err := db.Exec(query, nodeID, conversationID, nodeType, nodeName, toolExecID, metadataJSON, riskScore) + if err != nil { + db.logger.Error("保存攻击链节点失败", zap.Error(err), zap.String("nodeId", nodeID)) + return err + } + + return nil +} + +// SaveAttackChainEdge 保存攻击链边 +func (db *DB) SaveAttackChainEdge(conversationID, edgeID, sourceNodeID, targetNodeID, edgeType string, weight int) error { + query := ` + INSERT OR REPLACE INTO attack_chain_edges + (id, conversation_id, source_node_id, target_node_id, edge_type, weight, created_at) + VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) + ` + + _, err := db.Exec(query, edgeID, conversationID, sourceNodeID, targetNodeID, edgeType, weight) + if err != nil { + db.logger.Error("保存攻击链边失败", zap.Error(err), zap.String("edgeId", edgeID)) + return err + } + + return nil +} + +// LoadAttackChainNodes 加载攻击链节点 +func (db *DB) LoadAttackChainNodes(conversationID string) ([]AttackChainNode, error) { + query := ` + SELECT id, node_type, node_name, tool_execution_id, metadata, risk_score + FROM attack_chain_nodes + WHERE conversation_id = ? + ORDER BY created_at ASC + ` + + rows, err := db.Query(query, conversationID) + if err != nil { + return nil, fmt.Errorf("查询攻击链节点失败: %w", err) + } + defer rows.Close() + + var nodes []AttackChainNode + for rows.Next() { + var node AttackChainNode + var toolExecID sql.NullString + var metadataJSON sql.NullString + + err := rows.Scan(&node.ID, &node.Type, &node.Label, &toolExecID, &metadataJSON, &node.RiskScore) + if err != nil { + db.logger.Warn("扫描攻击链节点失败", zap.Error(err)) + continue + } + + if toolExecID.Valid { + node.ToolExecutionID = toolExecID.String + } + + if metadataJSON.Valid && metadataJSON.String != "" { + if err := json.Unmarshal([]byte(metadataJSON.String), &node.Metadata); err != nil { + db.logger.Warn("解析节点元数据失败", zap.Error(err)) + node.Metadata = make(map[string]interface{}) + } + } else { + node.Metadata = make(map[string]interface{}) + } + + nodes = append(nodes, node) + } + + return nodes, nil +} + +// LoadAttackChainEdges 加载攻击链边 +func (db *DB) LoadAttackChainEdges(conversationID string) ([]AttackChainEdge, error) { + query := ` + SELECT id, source_node_id, target_node_id, edge_type, weight + FROM attack_chain_edges + WHERE conversation_id = ? + ORDER BY created_at ASC + ` + + rows, err := db.Query(query, conversationID) + if err != nil { + return nil, fmt.Errorf("查询攻击链边失败: %w", err) + } + defer rows.Close() + + var edges []AttackChainEdge + for rows.Next() { + var edge AttackChainEdge + + err := rows.Scan(&edge.ID, &edge.Source, &edge.Target, &edge.Type, &edge.Weight) + if err != nil { + db.logger.Warn("扫描攻击链边失败", zap.Error(err)) + continue + } + + edges = append(edges, edge) + } + + return edges, nil +} + +// DeleteAttackChain 删除对话的攻击链数据 +func (db *DB) DeleteAttackChain(conversationID string) error { + // 先删除边(因为有外键约束) + _, err := db.Exec("DELETE FROM attack_chain_edges WHERE conversation_id = ?", conversationID) + if err != nil { + db.logger.Warn("删除攻击链边失败", zap.Error(err)) + } + + // 再删除节点 + _, err = db.Exec("DELETE FROM attack_chain_nodes WHERE conversation_id = ?", conversationID) + if err != nil { + db.logger.Error("删除攻击链节点失败", zap.Error(err), zap.String("conversationId", conversationID)) + return err + } + + return nil +} + diff --git a/internal/database/database.go b/internal/database/database.go index c782e6de..55676264 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -101,6 +101,36 @@ func (db *DB) initTables() error { updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP );` + // 创建攻击链节点表 + createAttackChainNodesTable := ` + CREATE TABLE IF NOT EXISTS attack_chain_nodes ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + node_type TEXT NOT NULL, + node_name TEXT NOT NULL, + tool_execution_id TEXT, + metadata TEXT, + risk_score INTEGER DEFAULT 0, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (tool_execution_id) REFERENCES tool_executions(id) ON DELETE SET NULL + );` + + // 创建攻击链边表 + createAttackChainEdgesTable := ` + CREATE TABLE IF NOT EXISTS attack_chain_edges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + source_node_id TEXT NOT NULL, + target_node_id TEXT NOT NULL, + edge_type TEXT NOT NULL, + weight INTEGER DEFAULT 1, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE, + FOREIGN KEY (source_node_id) REFERENCES attack_chain_nodes(id) ON DELETE CASCADE, + FOREIGN KEY (target_node_id) REFERENCES attack_chain_nodes(id) ON DELETE CASCADE + );` + // 创建索引 createIndexes := ` CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id); @@ -110,6 +140,10 @@ func (db *DB) initTables() error { CREATE INDEX IF NOT EXISTS idx_tool_executions_tool_name ON tool_executions(tool_name); CREATE INDEX IF NOT EXISTS idx_tool_executions_start_time ON tool_executions(start_time); CREATE INDEX IF NOT EXISTS idx_tool_executions_status ON tool_executions(status); + CREATE INDEX IF NOT EXISTS idx_chain_nodes_conversation ON attack_chain_nodes(conversation_id); + CREATE INDEX IF NOT EXISTS idx_chain_edges_conversation ON attack_chain_edges(conversation_id); + CREATE INDEX IF NOT EXISTS idx_chain_edges_source ON attack_chain_edges(source_node_id); + CREATE INDEX IF NOT EXISTS idx_chain_edges_target ON attack_chain_edges(target_node_id); ` if _, err := db.Exec(createConversationsTable); err != nil { @@ -132,6 +166,14 @@ func (db *DB) initTables() error { return fmt.Errorf("创建tool_stats表失败: %w", err) } + if _, err := db.Exec(createAttackChainNodesTable); err != nil { + return fmt.Errorf("创建attack_chain_nodes表失败: %w", err) + } + + if _, err := db.Exec(createAttackChainEdgesTable); err != nil { + return fmt.Errorf("创建attack_chain_edges表失败: %w", err) + } + if _, err := db.Exec(createIndexes); err != nil { return fmt.Errorf("创建索引失败: %w", err) } diff --git a/internal/handler/attackchain.go b/internal/handler/attackchain.go new file mode 100644 index 00000000..e018c004 --- /dev/null +++ b/internal/handler/attackchain.go @@ -0,0 +1,152 @@ +package handler + +import ( + "context" + "net/http" + "sync" + "time" + + "cyberstrike-ai/internal/attackchain" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// AttackChainHandler 攻击链处理器 +type AttackChainHandler struct { + db *database.DB + logger *zap.Logger + openAIConfig *config.OpenAIConfig + // 用于防止同一对话的并发生成 + generatingLocks sync.Map // map[string]*sync.Mutex +} + +// NewAttackChainHandler 创建新的攻击链处理器 +func NewAttackChainHandler(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *AttackChainHandler { + return &AttackChainHandler{ + db: db, + logger: logger, + openAIConfig: openAIConfig, + } +} + +// GetAttackChain 获取攻击链(按需生成) +// GET /api/attack-chain/:conversationId +func (h *AttackChainHandler) GetAttackChain(c *gin.Context) { + conversationID := c.Param("conversationId") + if conversationID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"}) + return + } + + // 检查对话是否存在 + _, err := h.db.GetConversation(conversationID) + if err != nil { + h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) + return + } + + // 先尝试从数据库加载(如果已生成过) + builder := attackchain.NewBuilder(h.db, h.openAIConfig, h.logger) + chain, err := builder.LoadChainFromDatabase(conversationID) + if err == nil && len(chain.Nodes) > 0 { + // 如果已存在,直接返回 + h.logger.Info("返回已存在的攻击链", zap.String("conversationId", conversationID)) + c.JSON(http.StatusOK, chain) + return + } + + // 如果不存在,则生成新的攻击链(按需生成) + // 使用锁机制防止同一对话的并发生成 + lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{}) + lock := lockInterface.(*sync.Mutex) + + // 尝试获取锁,如果正在生成则返回错误 + acquired := lock.TryLock() + if !acquired { + h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID)) + c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"}) + return + } + defer lock.Unlock() + + // 再次检查是否已生成(可能在等待锁的过程中已经生成完成) + chain, err = builder.LoadChainFromDatabase(conversationID) + if err == nil && len(chain.Nodes) > 0 { + h.logger.Info("返回已存在的攻击链(在锁等待期间已生成)", zap.String("conversationId", conversationID)) + c.JSON(http.StatusOK, chain) + return + } + + h.logger.Info("开始生成攻击链", zap.String("conversationId", conversationID)) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + chain, err = builder.BuildChainFromConversation(ctx, conversationID) + if err != nil { + h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()}) + return + } + + // 生成完成后,从锁映射中删除(可选,保留也可以用于防止短时间内重复生成) + // h.generatingLocks.Delete(conversationID) + + c.JSON(http.StatusOK, chain) +} + +// RegenerateAttackChain 重新生成攻击链 +// POST /api/attack-chain/:conversationId/regenerate +func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) { + conversationID := c.Param("conversationId") + if conversationID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"}) + return + } + + // 检查对话是否存在 + _, err := h.db.GetConversation(conversationID) + if err != nil { + h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) + return + } + + // 删除旧的攻击链 + if err := h.db.DeleteAttackChain(conversationID); err != nil { + h.logger.Warn("删除旧攻击链失败", zap.Error(err)) + } + + // 使用锁机制防止并发生成 + lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{}) + lock := lockInterface.(*sync.Mutex) + + acquired := lock.TryLock() + if !acquired { + h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID)) + c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"}) + return + } + defer lock.Unlock() + + // 生成新的攻击链 + h.logger.Info("重新生成攻击链", zap.String("conversationId", conversationID)) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + builder := attackchain.NewBuilder(h.db, h.openAIConfig, h.logger) + chain, err := builder.BuildChainFromConversation(ctx, conversationID) + if err != nil { + h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()}) + return + } + + c.JSON(http.StatusOK, chain) +} + diff --git a/web/static/css/style.css b/web/static/css/style.css index bb845c16..07ed93f5 100644 --- a/web/static/css/style.css +++ b/web/static/css/style.css @@ -2579,3 +2579,141 @@ header { grid-template-columns: 1fr; } } + +/* ==================== 攻击链可视化样式 ==================== */ + +.attack-chain-modal-content { + max-width: 95vw; + width: 95vw; + height: 90vh; + max-height: 90vh; + display: flex; + flex-direction: column; +} + +.attack-chain-body { + display: flex; + flex-direction: column; + flex: 1; + overflow: hidden; + padding: 0; +} + +.attack-chain-controls { + padding: 16px; + border-bottom: 1px solid var(--border-color); + display: flex; + justify-content: space-between; + align-items: center; + flex-wrap: wrap; + gap: 16px; + background: var(--bg-secondary); +} + +.attack-chain-info { + font-size: 0.875rem; + color: var(--text-secondary); +} + +.attack-chain-legend { + display: flex; + gap: 16px; + flex-wrap: wrap; +} + +.legend-item { + display: flex; + align-items: center; + gap: 6px; + font-size: 0.875rem; +} + +.legend-color { + width: 16px; + height: 16px; + border-radius: 4px; + border: 1px solid var(--border-color); +} + +.attack-chain-container { + flex: 1; + min-height: 0; + background: var(--bg-primary); + border: 1px solid var(--border-color); + position: relative; +} + +.attack-chain-details { + padding: 16px; + border-top: 1px solid var(--border-color); + background: var(--bg-secondary); + max-height: 200px; + overflow-y: auto; +} + +.attack-chain-details h3 { + margin: 0 0 12px 0; + font-size: 1rem; + color: var(--text-primary); +} + +.node-detail-item { + margin-bottom: 8px; + font-size: 0.875rem; +} + +.node-detail-item strong { + color: var(--text-primary); + margin-right: 8px; +} + +.node-detail-item code { + background: var(--bg-tertiary); + padding: 2px 6px; + border-radius: 3px; + font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; + font-size: 0.8125rem; +} + +.metadata-pre { + background: var(--bg-tertiary); + padding: 8px; + border-radius: 4px; + font-size: 0.8125rem; + overflow-x: auto; + margin-top: 4px; +} + +.modal-header-actions { + display: flex; + gap: 8px; + align-items: center; +} + +.loading-spinner { + display: flex; + align-items: center; + justify-content: center; + height: 100%; + color: var(--text-secondary); + font-size: 1rem; +} + +.empty-message { + display: flex; + align-items: center; + justify-content: center; + height: 100%; + color: var(--text-secondary); + font-size: 1rem; +} + +.error-message { + display: flex; + align-items: center; + justify-content: center; + height: 100%; + color: var(--error-color); + font-size: 1rem; + padding: 20px; +} diff --git a/web/static/js/app.js b/web/static/js/app.js index 8b9b94e3..b50d3bef 100644 --- a/web/static/js/app.js +++ b/web/static/js/app.js @@ -1750,6 +1750,9 @@ async function loadConversation(conversationId) { // 滚动到底部 messagesDiv.scrollTop = messagesDiv.scrollHeight; + // 添加攻击链按钮 + addAttackChainButton(conversationId); + // 刷新对话列表 loadConversations(); } catch (error) { @@ -2879,27 +2882,24 @@ function renderMonitorPagination() { const { page, totalPages, total, pageSize } = monitorState.pagination; - // 如果只有一页或没有数据,不显示分页 - if (totalPages <= 1 || total === 0) { - return; - } - + // 始终显示分页控件 const pagination = document.createElement('div'); pagination.className = 'monitor-pagination'; - const startItem = (page - 1) * pageSize + 1; - const endItem = Math.min(page * pageSize, total); + // 处理没有数据的情况 + const startItem = total === 0 ? 0 : (page - 1) * pageSize + 1; + const endItem = total === 0 ? 0 : Math.min(page * pageSize, total); pagination.innerHTML = `
显示 ${startItem}-${endItem} / 共 ${total} 条记录
- - - 第 ${page} / ${totalPages} 页 - - + + + 第 ${page} / ${totalPages || 1} 页 + +
`; @@ -3481,3 +3481,1031 @@ openSettings = async function() { await originalOpenSettings(); await loadExternalMCPs(); }; + +// ==================== 攻击链可视化功能 ==================== + +let attackChainCytoscape = null; +let currentAttackChainConversationId = null; +let isAttackChainLoading = false; // 防止重复加载 + +// 添加攻击链按钮 +function addAttackChainButton(conversationId) { + // 检查是否已存在按钮 + let attackChainBtn = document.getElementById('attack-chain-btn'); + if (!attackChainBtn) { + attackChainBtn = document.createElement('button'); + attackChainBtn.id = 'attack-chain-btn'; + attackChainBtn.className = 'btn-secondary'; + attackChainBtn.style.marginLeft = '10px'; + attackChainBtn.innerHTML = '🔗 攻击链'; + attackChainBtn.onclick = () => showAttackChain(conversationId); + + // 在消息区域上方添加按钮容器 + const chatMessages = document.getElementById('chat-messages'); + if (chatMessages) { + // 检查是否已有按钮容器 + let btnContainer = document.getElementById('attack-chain-btn-container'); + if (!btnContainer) { + btnContainer = document.createElement('div'); + btnContainer.id = 'attack-chain-btn-container'; + btnContainer.style.padding = '10px'; + btnContainer.style.borderBottom = '1px solid var(--border-color)'; + btnContainer.style.background = 'var(--bg-secondary)'; + chatMessages.parentNode.insertBefore(btnContainer, chatMessages); + } + btnContainer.innerHTML = ''; + btnContainer.appendChild(attackChainBtn); + } + } else { + attackChainBtn.onclick = () => showAttackChain(conversationId); + } +} + +// 显示攻击链模态框 +async function showAttackChain(conversationId) { + // 防止重复点击 + if (isAttackChainLoading) { + console.log('攻击链正在加载中,请稍候...'); + return; + } + + currentAttackChainConversationId = conversationId; + const modal = document.getElementById('attack-chain-modal'); + if (!modal) { + console.error('攻击链模态框未找到'); + return; + } + + modal.style.display = 'block'; + + // 清空容器 + const container = document.getElementById('attack-chain-container'); + if (container) { + container.innerHTML = '
加载中...
'; + } + + // 隐藏详情面板 + const detailsPanel = document.getElementById('attack-chain-details'); + if (detailsPanel) { + detailsPanel.style.display = 'none'; + } + + // 禁用重新生成按钮 + const regenerateBtn = document.querySelector('button[onclick="regenerateAttackChain()"]'); + if (regenerateBtn) { + regenerateBtn.disabled = true; + regenerateBtn.style.opacity = '0.5'; + regenerateBtn.style.cursor = 'not-allowed'; + } + + // 加载攻击链数据 + await loadAttackChain(conversationId); +} + +// 加载攻击链数据 +async function loadAttackChain(conversationId) { + if (isAttackChainLoading) { + return; // 防止重复调用 + } + + isAttackChainLoading = true; + + try { + const response = await apiFetch(`/api/attack-chain/${conversationId}`); + + if (!response.ok) { + // 处理 409 Conflict(正在生成中) + if (response.status === 409) { + const error = await response.json(); + const container = document.getElementById('attack-chain-container'); + if (container) { + container.innerHTML = ` +
+
⏳ 攻击链正在生成中...
+
+ 请稍候,生成完成后将自动显示 +
+ +
+ `; + } + // 5秒后自动刷新(允许刷新,但保持加载状态防止重复点击) + setTimeout(() => { + refreshAttackChain(); + }, 5000); + // 在 409 情况下,保持 isAttackChainLoading = true,防止重复点击 + // 但允许 refreshAttackChain 调用 loadAttackChain 来检查状态 + // 注意:不重置 isAttackChainLoading,保持加载状态 + // 恢复按钮状态(虽然保持加载状态,但允许用户手动刷新) + const regenerateBtn = document.querySelector('button[onclick="regenerateAttackChain()"]'); + if (regenerateBtn) { + regenerateBtn.disabled = false; + regenerateBtn.style.opacity = '1'; + regenerateBtn.style.cursor = 'pointer'; + } + return; // 提前返回,不执行 finally 块中的 isAttackChainLoading = false + } + + const error = await response.json(); + throw new Error(error.error || '加载攻击链失败'); + } + + const chainData = await response.json(); + + // 渲染攻击链 + renderAttackChain(chainData); + + // 更新统计信息 + updateAttackChainStats(chainData); + + // 成功加载后,重置加载状态 + isAttackChainLoading = false; + + } catch (error) { + console.error('加载攻击链失败:', error); + const container = document.getElementById('attack-chain-container'); + if (container) { + container.innerHTML = `
加载失败: ${error.message}
`; + } + // 错误时也重置加载状态 + isAttackChainLoading = false; + } finally { + // 恢复重新生成按钮 + const regenerateBtn = document.querySelector('button[onclick="regenerateAttackChain()"]'); + if (regenerateBtn) { + regenerateBtn.disabled = false; + regenerateBtn.style.opacity = '1'; + regenerateBtn.style.cursor = 'pointer'; + } + } +} + +// 渲染攻击链 +function renderAttackChain(chainData) { + const container = document.getElementById('attack-chain-container'); + if (!container) { + return; + } + + // 清空容器 + container.innerHTML = ''; + + if (!chainData.nodes || chainData.nodes.length === 0) { + container.innerHTML = '
暂无攻击链数据
'; + return; + } + + // 计算图的复杂度(用于动态调整布局和样式) + const nodeCount = chainData.nodes.length; + const edgeCount = chainData.edges.length; + const isComplexGraph = nodeCount > 20 || edgeCount > 30; + + // 准备Cytoscape数据 + const elements = []; + + // 添加节点,并预计算文字颜色和边框颜色 + chainData.nodes.forEach(node => { + const riskScore = node.risk_score || 0; + // 根据风险分数计算文字颜色和边框颜色 + let textColor, borderColor, textOutlineWidth, textOutlineColor; + if (riskScore >= 80) { + // 红色背景:白色文字,白色边框 + textColor = '#fff'; + borderColor = '#fff'; + textOutlineWidth = 1; + textOutlineColor = '#333'; + } else if (riskScore >= 60) { + // 橙色背景:白色文字,白色边框 + textColor = '#fff'; + borderColor = '#fff'; + textOutlineWidth = 1; + textOutlineColor = '#333'; + } else if (riskScore >= 40) { + // 黄色背景:深色文字,深色边框 + textColor = '#333'; + borderColor = '#cc9900'; + textOutlineWidth = 2; + textOutlineColor = '#fff'; + } else { + // 绿色背景:深绿色文字,深色边框 + textColor = '#1a5a1a'; + borderColor = '#5a8a5a'; + textOutlineWidth = 2; + textOutlineColor = '#fff'; + } + + elements.push({ + data: { + id: node.id, + label: node.label, + type: node.type, + riskScore: riskScore, + toolExecutionId: node.tool_execution_id || '', + metadata: node.metadata || {}, + textColor: textColor, + borderColor: borderColor, + textOutlineWidth: textOutlineWidth, + textOutlineColor: textOutlineColor + } + }); + }); + + // 添加边 + chainData.edges.forEach(edge => { + elements.push({ + data: { + id: edge.id, + source: edge.source, + target: edge.target, + type: edge.type || 'leads_to', + weight: edge.weight || 1 + } + }); + }); + + // 初始化Cytoscape + attackChainCytoscape = cytoscape({ + container: container, + elements: elements, + style: [ + { + selector: 'node', + style: { + 'label': 'data(label)', + // 统一节点大小,减少布局混乱(根据复杂度调整) + 'width': nodeCount > 20 ? 60 : 'mapData(riskScore, 0, 100, 45, 75)', + 'height': nodeCount > 20 ? 60 : 'mapData(riskScore, 0, 100, 45, 75)', + 'shape': function(ele) { + const type = ele.data('type'); + if (type === 'vulnerability') return 'diamond'; + if (type === 'action') return 'round-rectangle'; + if (type === 'target') return 'star'; + return 'ellipse'; + }, + 'background-color': function(ele) { + const riskScore = ele.data('riskScore') || 0; + if (riskScore >= 80) return '#ff4444'; // 红色 + if (riskScore >= 60) return '#ff8800'; // 橙色 + if (riskScore >= 40) return '#ffbb00'; // 黄色 + return '#88cc00'; // 绿色 + }, + // 使用预计算的颜色数据 + 'color': 'data(textColor)', + 'font-size': nodeCount > 20 ? '11px' : '12px', // 复杂图使用更小字体 + 'font-weight': 'bold', + 'text-valign': 'center', + 'text-halign': 'center', + 'text-wrap': 'wrap', + 'text-max-width': nodeCount > 20 ? '80px' : '100px', // 复杂图限制文本宽度 + 'border-width': 2, + 'border-color': 'data(borderColor)', + 'overlay-padding': '4px', + 'text-outline-width': 'data(textOutlineWidth)', + 'text-outline-color': 'data(textOutlineColor)' + } + }, + { + selector: 'edge', + style: { + 'width': 'mapData(weight, 1, 5, 1.5, 3)', + 'line-color': function(ele) { + const type = ele.data('type'); + if (type === 'discovers') return '#3498db'; // 浅蓝:action发现vulnerability + if (type === 'targets') return '#0066ff'; // 蓝色:target指向action + if (type === 'enables') return '#e74c3c'; // 深红:vulnerability间的因果关系 + if (type === 'leads_to') return '#666'; // 灰色:action之间的逻辑顺序 + return '#999'; + }, + 'target-arrow-color': function(ele) { + const type = ele.data('type'); + if (type === 'discovers') return '#3498db'; + if (type === 'targets') return '#0066ff'; + if (type === 'enables') return '#e74c3c'; + if (type === 'leads_to') return '#666'; + return '#999'; + }, + 'target-arrow-shape': 'triangle', + 'target-arrow-size': 8, + // 对于复杂图,使用straight样式减少交叉;简单图使用bezier更美观 + 'curve-style': isComplexGraph ? 'straight' : 'bezier', + 'control-point-step-size': isComplexGraph ? 40 : 60, // bezier控制点间距 + 'control-point-distance': isComplexGraph ? 30 : 50, // bezier控制点距离 + 'opacity': isComplexGraph ? 0.5 : 0.7, // 复杂图降低不透明度,减少视觉混乱 + 'line-style': 'solid' + } + }, + { + selector: 'node:selected', + style: { + 'border-width': 4, + 'border-color': '#0066ff' + } + } + ], + userPanningEnabled: true, + userZoomingEnabled: true, + boxSelectionEnabled: true + }); + + // 注册dagre布局(确保依赖已加载) + let layoutName = 'breadthfirst'; // 默认布局 + let layoutOptions = { + name: 'breadthfirst', + directed: true, + spacingFactor: isComplexGraph ? 2.5 : 2.0, + padding: 30 + }; + + if (typeof cytoscape !== 'undefined' && typeof cytoscapeDagre !== 'undefined') { + try { + cytoscape.use(cytoscapeDagre); + layoutName = 'dagre'; + // 根据图的复杂度调整布局参数 + layoutOptions = { + name: 'dagre', + rankDir: 'TB', // 从上到下 + spacingFactor: isComplexGraph ? 2.5 : 2.0, // 增加整体间距 + nodeSep: isComplexGraph ? 80 : 60, // 增加节点间距 + edgeSep: isComplexGraph ? 40 : 30, // 增加边间距 + rankSep: isComplexGraph ? 120 : 100, // 增加层级间距 + nodeDimensionsIncludeLabels: true, // 考虑标签大小 + animate: false, + padding: 40 // 增加边距 + }; + } catch (e) { + console.warn('dagre布局注册失败,使用默认布局:', e); + } + } else { + console.warn('dagre布局插件未加载,使用默认布局'); + } + + // 应用布局 + attackChainCytoscape.layout(layoutOptions).run(); + + // 布局完成后,调整视图以适应所有节点 + attackChainCytoscape.fit(undefined, 50); // 50px padding + + // 添加点击事件 + attackChainCytoscape.on('tap', 'node', function(evt) { + const node = evt.target; + showNodeDetails(node.data()); + }); + + // 添加悬停效果 + attackChainCytoscape.on('mouseover', 'node', function(evt) { + const node = evt.target; + node.style('opacity', 0.8); + }); + + attackChainCytoscape.on('mouseout', 'node', function(evt) { + const node = evt.target; + node.style('opacity', 1); + }); +} + +// 显示节点详情 +function showNodeDetails(nodeData) { + const detailsPanel = document.getElementById('attack-chain-details'); + const detailsContent = document.getElementById('attack-chain-details-content'); + + if (!detailsPanel || !detailsContent) { + return; + } + + detailsPanel.style.display = 'block'; + + let html = ` +
+ 节点ID: ${nodeData.id} +
+
+ 类型: ${getNodeTypeLabel(nodeData.type)} +
+
+ 标签: ${escapeHtml(nodeData.label)} +
+
+ 风险评分: ${nodeData.riskScore}/100 +
+ `; + + // 显示action节点信息(工具执行 + AI分析) + if (nodeData.type === 'action' && nodeData.metadata) { + if (nodeData.metadata.tool_name) { + html += ` +
+ 工具名称: ${escapeHtml(nodeData.metadata.tool_name)} +
+ `; + } + if (nodeData.metadata.tool_intent) { + html += ` +
+ 工具意图: ${escapeHtml(nodeData.metadata.tool_intent)} +
+ `; + } + if (nodeData.metadata.ai_analysis) { + html += ` +
+ AI分析:
${escapeHtml(nodeData.metadata.ai_analysis)}
+
+ `; + } + if (nodeData.metadata.findings && Array.isArray(nodeData.metadata.findings) && nodeData.metadata.findings.length > 0) { + html += ` +
+ 关键发现: + +
+ `; + } + } + + // 显示目标信息(如果是目标节点) + if (nodeData.type === 'target' && nodeData.metadata && nodeData.metadata.target) { + html += ` +
+ 测试目标: ${escapeHtml(nodeData.metadata.target)} +
+ `; + } + + // 显示漏洞信息(如果是漏洞节点) + if (nodeData.type === 'vulnerability' && nodeData.metadata) { + if (nodeData.metadata.vulnerability_type) { + html += ` +
+ 漏洞类型: ${escapeHtml(nodeData.metadata.vulnerability_type)} +
+ `; + } + if (nodeData.metadata.description) { + html += ` +
+ 描述: ${escapeHtml(nodeData.metadata.description)} +
+ `; + } + if (nodeData.metadata.severity) { + html += ` +
+ 严重程度: ${escapeHtml(nodeData.metadata.severity)} +
+ `; + } + if (nodeData.metadata.location) { + html += ` +
+ 位置: ${escapeHtml(nodeData.metadata.location)} +
+ `; + } + } + + if (nodeData.toolExecutionId) { + html += ` +
+ 工具执行ID: ${nodeData.toolExecutionId} +
+ `; + } + + if (nodeData.metadata && Object.keys(nodeData.metadata).length > 0) { + html += ` +
+ 完整元数据: +
${JSON.stringify(nodeData.metadata, null, 2)}
+
+ `; + } + + detailsContent.innerHTML = html; +} + +// 转义HTML +function escapeHtml(text) { + const div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; +} + +// 获取严重程度颜色 +function getSeverityColor(severity) { + const colors = { + 'critical': '#ff0000', + 'high': '#ff4444', + 'medium': '#ff8800', + 'low': '#ffbb00' + }; + return colors[severity.toLowerCase()] || '#666'; +} + +// 获取节点类型标签 +function getNodeTypeLabel(type) { + const labels = { + 'action': '行动', + 'vulnerability': '漏洞', + 'target': '目标' + }; + return labels[type] || type; +} + +// 更新统计信息 +function updateAttackChainStats(chainData) { + const statsElement = document.getElementById('attack-chain-stats'); + if (statsElement) { + const nodeCount = chainData.nodes ? chainData.nodes.length : 0; + const edgeCount = chainData.edges ? chainData.edges.length : 0; + statsElement.textContent = `节点: ${nodeCount} | 边: ${edgeCount}`; + } +} + +// 关闭攻击链模态框 +function closeAttackChainModal() { + const modal = document.getElementById('attack-chain-modal'); + if (modal) { + modal.style.display = 'none'; + } + + // 清理Cytoscape实例 + if (attackChainCytoscape) { + attackChainCytoscape.destroy(); + attackChainCytoscape = null; + } + + currentAttackChainConversationId = null; +} + +// 刷新攻击链(重新加载) +// 注意:此函数允许在加载过程中调用,用于检查生成状态 +function refreshAttackChain() { + if (currentAttackChainConversationId) { + // 临时允许刷新,即使正在加载中(用于检查生成状态) + const wasLoading = isAttackChainLoading; + isAttackChainLoading = false; // 临时重置,允许刷新 + loadAttackChain(currentAttackChainConversationId).finally(() => { + // 如果之前正在加载(409 情况),恢复加载状态 + // 否则保持 false(正常完成) + if (wasLoading) { + // 检查是否仍然需要保持加载状态(如果还是 409,会在 loadAttackChain 中处理) + // 这里我们假设如果成功加载,则重置状态 + // 如果还是 409,loadAttackChain 会保持 isAttackChainLoading = true + } + }); + } +} + +// 重新生成攻击链 +async function regenerateAttackChain() { + if (!currentAttackChainConversationId) { + return; + } + + // 防止重复点击 + if (isAttackChainLoading) { + console.log('攻击链正在生成中,请稍候...'); + return; + } + + isAttackChainLoading = true; + + const container = document.getElementById('attack-chain-container'); + if (container) { + container.innerHTML = '
重新生成中...
'; + } + + // 禁用重新生成按钮 + const regenerateBtn = document.querySelector('button[onclick="regenerateAttackChain()"]'); + if (regenerateBtn) { + regenerateBtn.disabled = true; + regenerateBtn.style.opacity = '0.5'; + regenerateBtn.style.cursor = 'not-allowed'; + } + + try { + // 调用重新生成接口 + const response = await apiFetch(`/api/attack-chain/${currentAttackChainConversationId}/regenerate`, { + method: 'POST' + }); + + if (!response.ok) { + // 处理 409 Conflict(正在生成中) + if (response.status === 409) { + const error = await response.json(); + if (container) { + container.innerHTML = ` +
+
⏳ 攻击链正在生成中...
+
+ 请稍候,生成完成后将自动显示 +
+ +
+ `; + } + // 5秒后自动刷新 + setTimeout(() => { + if (isAttackChainLoading) { + refreshAttackChain(); + } + }, 5000); + return; + } + + const error = await response.json(); + throw new Error(error.error || '重新生成攻击链失败'); + } + + const chainData = await response.json(); + + // 渲染攻击链 + renderAttackChain(chainData); + + // 更新统计信息 + updateAttackChainStats(chainData); + + } catch (error) { + console.error('重新生成攻击链失败:', error); + if (container) { + container.innerHTML = `
重新生成失败: ${error.message}
`; + } + } finally { + isAttackChainLoading = false; + + // 恢复重新生成按钮 + if (regenerateBtn) { + regenerateBtn.disabled = false; + regenerateBtn.style.opacity = '1'; + regenerateBtn.style.cursor = 'pointer'; + } + } +} + +// 导出攻击链 +function exportAttackChain(format) { + if (!attackChainCytoscape) { + alert('请先加载攻击链'); + return; + } + + // 确保图形已经渲染完成(使用小延迟) + setTimeout(() => { + try { + if (format === 'png') { + try { + const pngPromise = attackChainCytoscape.png({ + output: 'blob', + bg: 'white', + full: true, + scale: 1 + }); + + // 处理 Promise + if (pngPromise && typeof pngPromise.then === 'function') { + pngPromise.then(blob => { + if (!blob) { + throw new Error('PNG导出返回空数据'); + } + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `attack-chain-${currentAttackChainConversationId || 'export'}-${Date.now()}.png`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + setTimeout(() => URL.revokeObjectURL(url), 100); + }).catch(err => { + console.error('导出PNG失败:', err); + alert('导出PNG失败: ' + (err.message || '未知错误')); + }); + } else { + // 如果不是 Promise,直接使用 + const url = URL.createObjectURL(pngPromise); + const a = document.createElement('a'); + a.href = url; + a.download = `attack-chain-${currentAttackChainConversationId || 'export'}-${Date.now()}.png`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + setTimeout(() => URL.revokeObjectURL(url), 100); + } + } catch (err) { + console.error('PNG导出错误:', err); + alert('导出PNG失败: ' + (err.message || '未知错误')); + } + } else if (format === 'svg') { + try { + // Cytoscape.js 3.x 不直接支持 .svg() 方法 + // 使用替代方案:从 Cytoscape 数据手动构建 SVG + const container = attackChainCytoscape.container(); + if (!container) { + throw new Error('无法获取容器元素'); + } + + // 获取所有节点和边 + const nodes = attackChainCytoscape.nodes(); + const edges = attackChainCytoscape.edges(); + + if (nodes.length === 0) { + throw new Error('没有节点可导出'); + } + + // 计算所有节点的实际边界(包括节点大小) + let minX = Infinity, minY = Infinity, maxX = -Infinity, maxY = -Infinity; + nodes.forEach(node => { + const pos = node.position(); + const nodeWidth = node.width(); + const nodeHeight = node.height(); + const size = Math.max(nodeWidth, nodeHeight) / 2; + + minX = Math.min(minX, pos.x - size); + minY = Math.min(minY, pos.y - size); + maxX = Math.max(maxX, pos.x + size); + maxY = Math.max(maxY, pos.y + size); + }); + + // 也考虑边的范围 + edges.forEach(edge => { + const sourcePos = edge.source().position(); + const targetPos = edge.target().position(); + minX = Math.min(minX, sourcePos.x, targetPos.x); + minY = Math.min(minY, sourcePos.y, targetPos.y); + maxX = Math.max(maxX, sourcePos.x, targetPos.x); + maxY = Math.max(maxY, sourcePos.y, targetPos.y); + }); + + // 添加边距 + const padding = 50; + minX -= padding; + minY -= padding; + maxX += padding; + maxY += padding; + + const width = maxX - minX; + const height = maxY - minY; + + // 创建 SVG 元素 + const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg'); + svg.setAttribute('width', width.toString()); + svg.setAttribute('height', height.toString()); + svg.setAttribute('xmlns', 'http://www.w3.org/2000/svg'); + svg.setAttribute('viewBox', `${minX} ${minY} ${width} ${height}`); + + // 添加白色背景矩形 + const bgRect = document.createElementNS('http://www.w3.org/2000/svg', 'rect'); + bgRect.setAttribute('x', minX.toString()); + bgRect.setAttribute('y', minY.toString()); + bgRect.setAttribute('width', width.toString()); + bgRect.setAttribute('height', height.toString()); + bgRect.setAttribute('fill', 'white'); + svg.appendChild(bgRect); + + // 创建 defs 用于箭头标记 + const defs = document.createElementNS('http://www.w3.org/2000/svg', 'defs'); + + // 添加边的箭头标记(为不同类型的边创建不同的箭头) + const edgeTypes = ['discovers', 'targets', 'enables', 'leads_to']; + edgeTypes.forEach((type, index) => { + let color = '#999'; + if (type === 'discovers') color = '#3498db'; + else if (type === 'targets') color = '#0066ff'; + else if (type === 'enables') color = '#e74c3c'; + else if (type === 'leads_to') color = '#666'; + + const marker = document.createElementNS('http://www.w3.org/2000/svg', 'marker'); + marker.setAttribute('id', `arrowhead-${type}`); + marker.setAttribute('markerWidth', '10'); + marker.setAttribute('markerHeight', '10'); + marker.setAttribute('refX', '9'); + marker.setAttribute('refY', '3'); + marker.setAttribute('orient', 'auto'); + const polygon = document.createElementNS('http://www.w3.org/2000/svg', 'polygon'); + polygon.setAttribute('points', '0 0, 10 3, 0 6'); + polygon.setAttribute('fill', color); + marker.appendChild(polygon); + defs.appendChild(marker); + }); + svg.appendChild(defs); + + // 添加边(先绘制,这样节点会在上面) + edges.forEach(edge => { + const sourcePos = edge.source().position(); + const targetPos = edge.target().position(); + const edgeData = edge.data(); + const edgeType = edgeData.type || 'leads_to'; + + // 获取边的样式 + let lineColor = '#999'; + if (edgeType === 'discovers') lineColor = '#3498db'; + else if (edgeType === 'targets') lineColor = '#0066ff'; + else if (edgeType === 'enables') lineColor = '#e74c3c'; + else if (edgeType === 'leads_to') lineColor = '#666'; + + // 创建路径(支持曲线) + const path = document.createElementNS('http://www.w3.org/2000/svg', 'path'); + // 简单的直线路径(可以改进为曲线) + const midX = (sourcePos.x + targetPos.x) / 2; + const midY = (sourcePos.y + targetPos.y) / 2; + const dx = targetPos.x - sourcePos.x; + const dy = targetPos.y - sourcePos.y; + const offset = Math.min(30, Math.sqrt(dx * dx + dy * dy) * 0.3); + + // 使用二次贝塞尔曲线 + const controlX = midX + (dy > 0 ? -offset : offset); + const controlY = midY + (dx > 0 ? offset : -offset); + path.setAttribute('d', `M ${sourcePos.x} ${sourcePos.y} Q ${controlX} ${controlY} ${targetPos.x} ${targetPos.y}`); + path.setAttribute('stroke', lineColor); + path.setAttribute('stroke-width', '2'); + path.setAttribute('fill', 'none'); + path.setAttribute('marker-end', `url(#arrowhead-${edgeType})`); + svg.appendChild(path); + }); + + // 添加节点 + nodes.forEach(node => { + const pos = node.position(); + const nodeData = node.data(); + const riskScore = nodeData.riskScore || 0; + const nodeWidth = node.width(); + const nodeHeight = node.height(); + const size = Math.max(nodeWidth, nodeHeight) / 2; + + // 确定节点颜色 + let bgColor = '#88cc00'; + let textColor = '#1a5a1a'; + let borderColor = '#5a8a5a'; + if (riskScore >= 80) { + bgColor = '#ff4444'; + textColor = '#fff'; + borderColor = '#fff'; + } else if (riskScore >= 60) { + bgColor = '#ff8800'; + textColor = '#fff'; + borderColor = '#fff'; + } else if (riskScore >= 40) { + bgColor = '#ffbb00'; + textColor = '#333'; + borderColor = '#cc9900'; + } + + // 确定节点形状 + const nodeType = nodeData.type; + let shapeElement; + if (nodeType === 'vulnerability') { + // 菱形 + shapeElement = document.createElementNS('http://www.w3.org/2000/svg', 'polygon'); + const points = [ + `${pos.x},${pos.y - size}`, + `${pos.x + size},${pos.y}`, + `${pos.x},${pos.y + size}`, + `${pos.x - size},${pos.y}` + ].join(' '); + shapeElement.setAttribute('points', points); + } else if (nodeType === 'target') { + // 星形(五角星) + shapeElement = document.createElementNS('http://www.w3.org/2000/svg', 'polygon'); + const points = []; + for (let i = 0; i < 5; i++) { + const angle = (i * 4 * Math.PI / 5) - Math.PI / 2; + const x = pos.x + size * Math.cos(angle); + const y = pos.y + size * Math.sin(angle); + points.push(`${x},${y}`); + } + shapeElement.setAttribute('points', points.join(' ')); + } else { + // 圆角矩形 + shapeElement = document.createElementNS('http://www.w3.org/2000/svg', 'rect'); + shapeElement.setAttribute('x', (pos.x - size).toString()); + shapeElement.setAttribute('y', (pos.y - size).toString()); + shapeElement.setAttribute('width', (size * 2).toString()); + shapeElement.setAttribute('height', (size * 2).toString()); + shapeElement.setAttribute('rx', '5'); + shapeElement.setAttribute('ry', '5'); + } + + shapeElement.setAttribute('fill', bgColor); + shapeElement.setAttribute('stroke', borderColor); + shapeElement.setAttribute('stroke-width', '2'); + svg.appendChild(shapeElement); + + // 添加文本标签(使用文本描边提高可读性) + const label = (nodeData.label || nodeData.id || '').toString(); + const maxLength = 15; + + // 创建文本组,包含描边和填充 + const textGroup = document.createElementNS('http://www.w3.org/2000/svg', 'g'); + textGroup.setAttribute('text-anchor', 'middle'); + textGroup.setAttribute('dominant-baseline', 'middle'); + + // 处理长文本(简单换行) + let lines = []; + if (label.length > maxLength) { + const words = label.split(' '); + let currentLine = ''; + words.forEach(word => { + if ((currentLine + word).length <= maxLength) { + currentLine += (currentLine ? ' ' : '') + word; + } else { + if (currentLine) lines.push(currentLine); + currentLine = word; + } + }); + if (currentLine) lines.push(currentLine); + lines = lines.slice(0, 2); // 最多两行 + } else { + lines = [label]; + } + + // 确定文本描边颜色(与原始渲染一致) + let textOutlineColor = '#fff'; + let textOutlineWidth = 2; + if (riskScore >= 80 || riskScore >= 60) { + // 红色/橙色背景:白色文字,白色描边,深色轮廓 + textOutlineColor = '#333'; + textOutlineWidth = 1; + } else if (riskScore >= 40) { + // 黄色背景:深色文字,白色描边 + textOutlineColor = '#fff'; + textOutlineWidth = 2; + } else { + // 绿色背景:深绿色文字,白色描边 + textOutlineColor = '#fff'; + textOutlineWidth = 2; + } + + // 为每行文本创建描边和填充 + lines.forEach((line, i) => { + const textY = pos.y + (i - (lines.length - 1) / 2) * 16; + + // 描边文本(用于提高对比度,模拟text-outline效果) + const strokeText = document.createElementNS('http://www.w3.org/2000/svg', 'text'); + strokeText.setAttribute('x', pos.x.toString()); + strokeText.setAttribute('y', textY.toString()); + strokeText.setAttribute('fill', 'none'); + strokeText.setAttribute('stroke', textOutlineColor); + strokeText.setAttribute('stroke-width', textOutlineWidth.toString()); + strokeText.setAttribute('stroke-linejoin', 'round'); + strokeText.setAttribute('stroke-linecap', 'round'); + strokeText.setAttribute('font-size', '14px'); + strokeText.setAttribute('font-weight', 'bold'); + strokeText.setAttribute('font-family', 'Arial, sans-serif'); + strokeText.setAttribute('text-anchor', 'middle'); + strokeText.setAttribute('dominant-baseline', 'middle'); + strokeText.textContent = line; + textGroup.appendChild(strokeText); + + // 填充文本(实际可见的文本) + const fillText = document.createElementNS('http://www.w3.org/2000/svg', 'text'); + fillText.setAttribute('x', pos.x.toString()); + fillText.setAttribute('y', textY.toString()); + fillText.setAttribute('fill', textColor); + fillText.setAttribute('font-size', '14px'); + fillText.setAttribute('font-weight', 'bold'); + fillText.setAttribute('font-family', 'Arial, sans-serif'); + fillText.setAttribute('text-anchor', 'middle'); + fillText.setAttribute('dominant-baseline', 'middle'); + fillText.textContent = line; + textGroup.appendChild(fillText); + }); + + svg.appendChild(textGroup); + }); + + // 将 SVG 转换为字符串 + const serializer = new XMLSerializer(); + let svgString = serializer.serializeToString(svg); + + // 确保有 XML 声明 + if (!svgString.startsWith('\n' + svgString; + } + + const blob = new Blob([svgString], { type: 'image/svg+xml;charset=utf-8' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `attack-chain-${currentAttackChainConversationId || 'export'}-${Date.now()}.svg`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + setTimeout(() => URL.revokeObjectURL(url), 100); + } catch (err) { + console.error('SVG导出错误:', err); + alert('导出SVG失败: ' + (err.message || '未知错误')); + } + } else { + alert('不支持的导出格式: ' + format); + } + } catch (error) { + console.error('导出失败:', error); + alert('导出失败: ' + (error.message || '未知错误')); + } + }, 100); // 小延迟确保图形已渲染 +} diff --git a/web/templates/index.html b/web/templates/index.html index e921f76d..326098d9 100644 --- a/web/templates/index.html +++ b/web/templates/index.html @@ -304,10 +304,73 @@ + + + + + + + + + +