package handler import ( "fmt" "net/http" "strconv" "strings" "time" "cyberstrike-ai/internal/database" "github.com/gin-gonic/gin" "go.uber.org/zap" ) // VulnerabilityHandler 漏洞处理器 type VulnerabilityHandler struct { db *database.DB logger *zap.Logger } // NewVulnerabilityHandler 创建新的漏洞处理器 func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *VulnerabilityHandler { return &VulnerabilityHandler{ db: db, logger: logger, } } // CreateVulnerabilityRequest 创建漏洞请求 type CreateVulnerabilityRequest struct { ConversationID string `json:"conversation_id" binding:"required"` ConversationTag string `json:"conversation_tag"` TaskTag string `json:"task_tag"` Title string `json:"title" binding:"required"` Description string `json:"description"` Severity string `json:"severity" binding:"required"` Status string `json:"status"` Type string `json:"type"` Target string `json:"target"` Proof string `json:"proof"` Impact string `json:"impact"` Recommendation string `json:"recommendation"` } // CreateVulnerability 创建漏洞 func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) { var req CreateVulnerabilityRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } vuln := &database.Vulnerability{ ConversationID: req.ConversationID, ConversationTag: req.ConversationTag, TaskTag: req.TaskTag, Title: req.Title, Description: req.Description, Severity: req.Severity, Status: req.Status, Type: req.Type, Target: req.Target, Proof: req.Proof, Impact: req.Impact, Recommendation: req.Recommendation, } created, err := h.db.CreateVulnerability(vuln) if err != nil { h.logger.Error("创建漏洞失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, created) } // GetVulnerability 获取漏洞 func (h *VulnerabilityHandler) GetVulnerability(c *gin.Context) { id := c.Param("id") vuln, err := h.db.GetVulnerability(id) if err != nil { h.logger.Error("获取漏洞失败", zap.Error(err)) c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"}) return } c.JSON(http.StatusOK, vuln) } // ListVulnerabilitiesResponse 漏洞列表响应 type ListVulnerabilitiesResponse struct { Vulnerabilities []*database.Vulnerability `json:"vulnerabilities"` Total int `json:"total"` Page int `json:"page"` PageSize int `json:"page_size"` TotalPages int `json:"total_pages"` } // ListVulnerabilities 列出漏洞 func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { limitStr := c.DefaultQuery("limit", "20") offsetStr := c.DefaultQuery("offset", "0") pageStr := c.Query("page") id := c.Query("id") conversationID := c.Query("conversation_id") severity := c.Query("severity") status := c.Query("status") taskID := c.Query("task_id") conversationTag := c.Query("conversation_tag") taskTag := c.Query("task_tag") limit, _ := strconv.Atoi(limitStr) offset, _ := strconv.Atoi(offsetStr) page := 1 // 如果提供了page参数,优先使用page计算offset if pageStr != "" { if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { page = p offset = (page - 1) * limit } } if limit <= 0 || limit > 100 { limit = 20 } if offset < 0 { offset = 0 } // 获取总数 total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag) if err != nil { h.logger.Error("获取漏洞总数失败", zap.Error(err)) // 继续执行,使用0作为总数 total = 0 } // 获取漏洞列表 vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status, taskID, conversationTag, taskTag) if err != nil { h.logger.Error("获取漏洞列表失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // 计算总页数 totalPages := (total + limit - 1) / limit if totalPages == 0 { totalPages = 1 } // 如果使用offset计算page,需要重新计算 if pageStr == "" { page = (offset / limit) + 1 } response := ListVulnerabilitiesResponse{ Vulnerabilities: vulnerabilities, Total: total, Page: page, PageSize: limit, TotalPages: totalPages, } c.JSON(http.StatusOK, response) } // UpdateVulnerabilityRequest 更新漏洞请求 type UpdateVulnerabilityRequest struct { ConversationTag string `json:"conversation_tag"` TaskTag string `json:"task_tag"` Title string `json:"title"` Description string `json:"description"` Severity string `json:"severity"` Status string `json:"status"` Type string `json:"type"` Target string `json:"target"` Proof string `json:"proof"` Impact string `json:"impact"` Recommendation string `json:"recommendation"` } // UpdateVulnerability 更新漏洞 func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) { id := c.Param("id") var req UpdateVulnerabilityRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // 获取现有漏洞 existing, err := h.db.GetVulnerability(id) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"}) return } // 更新字段 if req.ConversationTag != "" { existing.ConversationTag = req.ConversationTag } if req.TaskTag != "" { existing.TaskTag = req.TaskTag } if req.Title != "" { existing.Title = req.Title } if req.Description != "" { existing.Description = req.Description } if req.Severity != "" { existing.Severity = req.Severity } if req.Status != "" { existing.Status = req.Status } if req.Type != "" { existing.Type = req.Type } if req.Target != "" { existing.Target = req.Target } if req.Proof != "" { existing.Proof = req.Proof } if req.Impact != "" { existing.Impact = req.Impact } if req.Recommendation != "" { existing.Recommendation = req.Recommendation } if err := h.db.UpdateVulnerability(id, existing); err != nil { h.logger.Error("更新漏洞失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // 返回更新后的漏洞 updated, err := h.db.GetVulnerability(id) if err != nil { h.logger.Error("获取更新后的漏洞失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, updated) } // DeleteVulnerability 删除漏洞 func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) { id := c.Param("id") if err := h.db.DeleteVulnerability(id); err != nil { h.logger.Error("删除漏洞失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) } // GetVulnerabilityStats 获取漏洞统计 func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) { conversationID := c.Query("conversation_id") taskID := c.Query("task_id") stats, err := h.db.GetVulnerabilityStats(conversationID, taskID) if err != nil { h.logger.Error("获取漏洞统计失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, stats) } // GetVulnerabilityFilterOptions 获取漏洞筛选建议项 func (h *VulnerabilityHandler) GetVulnerabilityFilterOptions(c *gin.Context) { options, err := h.db.GetVulnerabilityFilterOptions() if err != nil { h.logger.Error("获取漏洞筛选建议失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, options) } // ExportVulnerabilities 导出漏洞(支持按对话/任务分组,汇总或拆分) func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) { groupBy := c.DefaultQuery("group_by", "conversation") mode := c.DefaultQuery("mode", "summary") if groupBy != "conversation" && groupBy != "task" { c.JSON(http.StatusBadRequest, gin.H{"error": "group_by 仅支持 conversation 或 task"}) return } if mode != "summary" && mode != "split" { c.JSON(http.StatusBadRequest, gin.H{"error": "mode 仅支持 summary 或 split"}) return } id := c.Query("id") conversationID := c.Query("conversation_id") severity := c.Query("severity") status := c.Query("status") taskID := c.Query("task_id") conversationTag := c.Query("conversation_tag") taskTag := c.Query("task_tag") total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } if total == 0 { c.JSON(http.StatusOK, gin.H{"mode": mode, "group_by": groupBy, "total": 0, "files": []any{}}) return } items, err := h.db.ListVulnerabilities(total, 0, id, conversationID, severity, status, taskID, conversationTag, taskTag) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } type exportFile struct { FileName string `json:"filename"` Content string `json:"content"` } grouped := map[string][]*database.Vulnerability{} for _, v := range items { key := v.ConversationID if groupBy == "conversation" { if strings.TrimSpace(v.ConversationTag) != "" { key = strings.TrimSpace(v.ConversationTag) } } else { key = firstNonEmpty(v.TaskTag, v.TaskID, v.TaskQueueID, "unassigned-task") } grouped[key] = append(grouped[key], v) } files := make([]exportFile, 0) nowStr := time.Now().Format("20060102-150405") if mode == "summary" { var b strings.Builder b.WriteString("# 漏洞批量导出报告\n\n") b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05"))) b.WriteString(fmt.Sprintf("- 分组维度: %s\n", groupBy)) b.WriteString(fmt.Sprintf("- 漏洞总数: %d\n", len(items))) b.WriteString(fmt.Sprintf("- 分组数: %d\n\n", len(grouped))) for group, list := range grouped { b.WriteString(fmt.Sprintf("## %s (%d)\n\n", group, len(list))) for _, v := range list { appendVulnerabilityMarkdown(&b, v, "###") } } files = append(files, exportFile{ FileName: fmt.Sprintf("vulnerability-report-%s-%s.md", groupBy, nowStr), Content: b.String(), }) } else { for group, list := range grouped { var b strings.Builder b.WriteString(fmt.Sprintf("# 漏洞报告 - %s\n\n", group)) b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05"))) b.WriteString(fmt.Sprintf("- 漏洞数量: %d\n\n", len(list))) for _, v := range list { appendVulnerabilityMarkdown(&b, v, "##") } files = append(files, exportFile{ FileName: fmt.Sprintf("vulnerability-%s-%s.md", sanitizeExportName(group), nowStr), Content: b.String(), }) } } c.JSON(http.StatusOK, gin.H{ "mode": mode, "group_by": groupBy, "total": len(items), "files": files, }) } // appendVulnerabilityMarkdown 单条漏洞的 Markdown 片段(与单文件下载字段对齐,缺省字段不写) func appendVulnerabilityMarkdown(b *strings.Builder, v *database.Vulnerability, titleHeading string) { b.WriteString(fmt.Sprintf("%s %s\n\n", titleHeading, v.Title)) b.WriteString(fmt.Sprintf("- 漏洞ID: `%s`\n", v.ID)) b.WriteString(fmt.Sprintf("- 严重程度: %s\n", v.Severity)) b.WriteString(fmt.Sprintf("- 状态: %s\n", v.Status)) if v.Type != "" { b.WriteString(fmt.Sprintf("- 类型: %s\n", v.Type)) } if v.Target != "" { b.WriteString(fmt.Sprintf("- 目标: %s\n", v.Target)) } b.WriteString(fmt.Sprintf("- 对话ID: `%s`\n", v.ConversationID)) if v.ConversationTag != "" { b.WriteString(fmt.Sprintf("- 对话标签: %s\n", v.ConversationTag)) } if v.TaskTag != "" { b.WriteString(fmt.Sprintf("- 任务标签: %s\n", v.TaskTag)) } if v.TaskID != "" { b.WriteString(fmt.Sprintf("- 任务ID: `%s`\n", v.TaskID)) } if v.TaskQueueID != "" { b.WriteString(fmt.Sprintf("- 任务队列ID: `%s`\n", v.TaskQueueID)) } if !v.CreatedAt.IsZero() { b.WriteString(fmt.Sprintf("- 创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05"))) } if !v.UpdatedAt.IsZero() { b.WriteString(fmt.Sprintf("- 更新时间: %s\n", v.UpdatedAt.Format("2006-01-02 15:04:05"))) } if v.Description != "" { b.WriteString("\n#### 描述\n\n") b.WriteString(v.Description) b.WriteString("\n") } if v.Proof != "" { b.WriteString("\n#### 证明(POC)\n\n```\n") b.WriteString(v.Proof) b.WriteString("\n```\n") } if v.Impact != "" { b.WriteString("\n#### 影响\n\n") b.WriteString(v.Impact) b.WriteString("\n") } if v.Recommendation != "" { b.WriteString("\n#### 修复建议\n\n") b.WriteString(v.Recommendation) b.WriteString("\n") } b.WriteString("\n") } func firstNonEmpty(values ...string) string { for _, v := range values { trimmed := strings.TrimSpace(v) if trimmed != "" { return trimmed } } return "" } func sanitizeExportName(raw string) string { name := strings.TrimSpace(raw) if name == "" { return "unknown" } replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-") return replacer.Replace(name) }