From 0597838217ef34cc90d95c0f64bc1d61c8c415cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Tue, 28 Apr 2026 01:04:58 +0800 Subject: [PATCH] Add files via upload --- internal/app/app.go | 2 + internal/handler/openapi.go | 2 +- internal/handler/vulnerability.go | 203 +++++++++++++++++++++++++++++- 3 files changed, 204 insertions(+), 3 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index 0f4f154a..b0b25178 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -901,6 +901,8 @@ func setupRoutes( // 漏洞管理 protected.GET("/vulnerabilities", vulnerabilityHandler.ListVulnerabilities) + protected.GET("/vulnerabilities/export", vulnerabilityHandler.ExportVulnerabilities) + protected.GET("/vulnerabilities/filter-options", vulnerabilityHandler.GetVulnerabilityFilterOptions) protected.GET("/vulnerabilities/stats", vulnerabilityHandler.GetVulnerabilityStats) protected.GET("/vulnerabilities/:id", vulnerabilityHandler.GetVulnerability) protected.POST("/vulnerabilities", vulnerabilityHandler.CreateVulnerability) diff --git a/internal/handler/openapi.go b/internal/handler/openapi.go index cf082d9d..45216daa 100644 --- a/internal/handler/openapi.go +++ b/internal/handler/openapi.go @@ -6197,7 +6197,7 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) { } // 获取漏洞列表 - vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "") + vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "", "", "", "") if err != nil { h.logger.Warn("获取漏洞列表失败", zap.Error(err)) vulnList = []*database.Vulnerability{} diff --git a/internal/handler/vulnerability.go b/internal/handler/vulnerability.go index 9975efa7..f7335a46 100644 --- a/internal/handler/vulnerability.go +++ b/internal/handler/vulnerability.go @@ -1,8 +1,11 @@ package handler import ( + "fmt" "net/http" "strconv" + "strings" + "time" "cyberstrike-ai/internal/database" "github.com/gin-gonic/gin" @@ -26,6 +29,8 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability // CreateVulnerabilityRequest 创建漏洞请求 type CreateVulnerabilityRequest struct { ConversationID string `json:"conversation_id" binding:"required"` + ConversationTag string `json:"conversation_tag"` + TaskTag string `json:"task_tag"` Title string `json:"title" binding:"required"` Description string `json:"description"` Severity string `json:"severity" binding:"required"` @@ -47,6 +52,8 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) { vuln := &database.Vulnerability{ ConversationID: req.ConversationID, + ConversationTag: req.ConversationTag, + TaskTag: req.TaskTag, Title: req.Title, Description: req.Description, Severity: req.Severity, @@ -100,6 +107,9 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { conversationID := c.Query("conversation_id") severity := c.Query("severity") status := c.Query("status") + taskID := c.Query("task_id") + conversationTag := c.Query("conversation_tag") + taskTag := c.Query("task_tag") limit, _ := strconv.Atoi(limitStr) offset, _ := strconv.Atoi(offsetStr) @@ -121,7 +131,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { } // 获取总数 - total, err := h.db.CountVulnerabilities(id, conversationID, severity, status) + total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag) if err != nil { h.logger.Error("获取漏洞总数失败", zap.Error(err)) // 继续执行,使用0作为总数 @@ -129,7 +139,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { } // 获取漏洞列表 - vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status) + vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status, taskID, conversationTag, taskTag) if err != nil { h.logger.Error("获取漏洞列表失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -160,6 +170,8 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { // UpdateVulnerabilityRequest 更新漏洞请求 type UpdateVulnerabilityRequest struct { + ConversationTag string `json:"conversation_tag"` + TaskTag string `json:"task_tag"` Title string `json:"title"` Description string `json:"description"` Severity string `json:"severity"` @@ -189,6 +201,12 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) { } // 更新字段 + if req.ConversationTag != "" { + existing.ConversationTag = req.ConversationTag + } + if req.TaskTag != "" { + existing.TaskTag = req.TaskTag + } if req.Title != "" { existing.Title = req.Title } @@ -261,3 +279,184 @@ func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) { c.JSON(http.StatusOK, stats) } +// GetVulnerabilityFilterOptions 获取漏洞筛选建议项 +func (h *VulnerabilityHandler) GetVulnerabilityFilterOptions(c *gin.Context) { + options, err := h.db.GetVulnerabilityFilterOptions() + if err != nil { + h.logger.Error("获取漏洞筛选建议失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, options) +} + +// ExportVulnerabilities 导出漏洞(支持按对话/任务分组,汇总或拆分) +func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) { + groupBy := c.DefaultQuery("group_by", "conversation") + mode := c.DefaultQuery("mode", "summary") + if groupBy != "conversation" && groupBy != "task" { + c.JSON(http.StatusBadRequest, gin.H{"error": "group_by 仅支持 conversation 或 task"}) + return + } + if mode != "summary" && mode != "split" { + c.JSON(http.StatusBadRequest, gin.H{"error": "mode 仅支持 summary 或 split"}) + return + } + + id := c.Query("id") + conversationID := c.Query("conversation_id") + severity := c.Query("severity") + status := c.Query("status") + taskID := c.Query("task_id") + conversationTag := c.Query("conversation_tag") + taskTag := c.Query("task_tag") + + total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if total == 0 { + c.JSON(http.StatusOK, gin.H{"mode": mode, "group_by": groupBy, "total": 0, "files": []any{}}) + return + } + + items, err := h.db.ListVulnerabilities(total, 0, id, conversationID, severity, status, taskID, conversationTag, taskTag) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + type exportFile struct { + FileName string `json:"filename"` + Content string `json:"content"` + } + grouped := map[string][]*database.Vulnerability{} + for _, v := range items { + key := v.ConversationID + if groupBy == "conversation" { + if strings.TrimSpace(v.ConversationTag) != "" { + key = strings.TrimSpace(v.ConversationTag) + } + } else { + key = firstNonEmpty(v.TaskTag, v.TaskID, v.TaskQueueID, "unassigned-task") + } + grouped[key] = append(grouped[key], v) + } + + files := make([]exportFile, 0) + nowStr := time.Now().Format("20060102-150405") + if mode == "summary" { + var b strings.Builder + b.WriteString("# 漏洞批量导出报告\n\n") + b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05"))) + b.WriteString(fmt.Sprintf("- 分组维度: %s\n", groupBy)) + b.WriteString(fmt.Sprintf("- 漏洞总数: %d\n", len(items))) + b.WriteString(fmt.Sprintf("- 分组数: %d\n\n", len(grouped))) + for group, list := range grouped { + b.WriteString(fmt.Sprintf("## %s (%d)\n\n", group, len(list))) + for _, v := range list { + appendVulnerabilityMarkdown(&b, v, "###") + } + } + files = append(files, exportFile{ + FileName: fmt.Sprintf("vulnerability-report-%s-%s.md", groupBy, nowStr), + Content: b.String(), + }) + } else { + for group, list := range grouped { + var b strings.Builder + b.WriteString(fmt.Sprintf("# 漏洞报告 - %s\n\n", group)) + b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05"))) + b.WriteString(fmt.Sprintf("- 漏洞数量: %d\n\n", len(list))) + for _, v := range list { + appendVulnerabilityMarkdown(&b, v, "##") + } + files = append(files, exportFile{ + FileName: fmt.Sprintf("vulnerability-%s-%s.md", sanitizeExportName(group), nowStr), + Content: b.String(), + }) + } + } + + c.JSON(http.StatusOK, gin.H{ + "mode": mode, + "group_by": groupBy, + "total": len(items), + "files": files, + }) +} + +// appendVulnerabilityMarkdown 单条漏洞的 Markdown 片段(与单文件下载字段对齐,缺省字段不写) +func appendVulnerabilityMarkdown(b *strings.Builder, v *database.Vulnerability, titleHeading string) { + b.WriteString(fmt.Sprintf("%s %s\n\n", titleHeading, v.Title)) + b.WriteString(fmt.Sprintf("- 漏洞ID: `%s`\n", v.ID)) + b.WriteString(fmt.Sprintf("- 严重程度: %s\n", v.Severity)) + b.WriteString(fmt.Sprintf("- 状态: %s\n", v.Status)) + if v.Type != "" { + b.WriteString(fmt.Sprintf("- 类型: %s\n", v.Type)) + } + if v.Target != "" { + b.WriteString(fmt.Sprintf("- 目标: %s\n", v.Target)) + } + b.WriteString(fmt.Sprintf("- 对话ID: `%s`\n", v.ConversationID)) + if v.ConversationTag != "" { + b.WriteString(fmt.Sprintf("- 对话标签: %s\n", v.ConversationTag)) + } + if v.TaskTag != "" { + b.WriteString(fmt.Sprintf("- 任务标签: %s\n", v.TaskTag)) + } + if v.TaskID != "" { + b.WriteString(fmt.Sprintf("- 任务ID: `%s`\n", v.TaskID)) + } + if v.TaskQueueID != "" { + b.WriteString(fmt.Sprintf("- 任务队列ID: `%s`\n", v.TaskQueueID)) + } + if !v.CreatedAt.IsZero() { + b.WriteString(fmt.Sprintf("- 创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05"))) + } + if !v.UpdatedAt.IsZero() { + b.WriteString(fmt.Sprintf("- 更新时间: %s\n", v.UpdatedAt.Format("2006-01-02 15:04:05"))) + } + if v.Description != "" { + b.WriteString("\n#### 描述\n\n") + b.WriteString(v.Description) + b.WriteString("\n") + } + if v.Proof != "" { + b.WriteString("\n#### 证明(POC)\n\n```\n") + b.WriteString(v.Proof) + b.WriteString("\n```\n") + } + if v.Impact != "" { + b.WriteString("\n#### 影响\n\n") + b.WriteString(v.Impact) + b.WriteString("\n") + } + if v.Recommendation != "" { + b.WriteString("\n#### 修复建议\n\n") + b.WriteString(v.Recommendation) + b.WriteString("\n") + } + b.WriteString("\n") +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + trimmed := strings.TrimSpace(v) + if trimmed != "" { + return trimmed + } + } + return "" +} + +func sanitizeExportName(raw string) string { + name := strings.TrimSpace(raw) + if name == "" { + return "unknown" + } + replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-") + return replacer.Replace(name) +} +