package handler import ( "net/http" "strconv" "cyberstrike-ai/internal/database" "github.com/gin-gonic/gin" "go.uber.org/zap" ) // VulnerabilityHandler 漏洞处理器 type VulnerabilityHandler struct { db *database.DB logger *zap.Logger } // NewVulnerabilityHandler 创建新的漏洞处理器 func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *VulnerabilityHandler { return &VulnerabilityHandler{ db: db, logger: logger, } } // CreateVulnerabilityRequest 创建漏洞请求 type CreateVulnerabilityRequest struct { ConversationID string `json:"conversation_id" binding:"required"` Title string `json:"title" binding:"required"` Description string `json:"description"` Severity string `json:"severity" binding:"required"` Status string `json:"status"` Type string `json:"type"` Target string `json:"target"` Proof string `json:"proof"` Impact string `json:"impact"` Recommendation string `json:"recommendation"` } // CreateVulnerability 创建漏洞 func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) { var req CreateVulnerabilityRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } vuln := &database.Vulnerability{ ConversationID: req.ConversationID, Title: req.Title, Description: req.Description, Severity: req.Severity, Status: req.Status, Type: req.Type, Target: req.Target, Proof: req.Proof, Impact: req.Impact, Recommendation: req.Recommendation, } created, err := h.db.CreateVulnerability(vuln) if err != nil { h.logger.Error("创建漏洞失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, created) } // GetVulnerability 获取漏洞 func (h *VulnerabilityHandler) GetVulnerability(c *gin.Context) { id := c.Param("id") vuln, err := h.db.GetVulnerability(id) if err != nil { h.logger.Error("获取漏洞失败", zap.Error(err)) c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"}) return } c.JSON(http.StatusOK, vuln) } // ListVulnerabilitiesResponse 漏洞列表响应 type ListVulnerabilitiesResponse struct { Vulnerabilities []*database.Vulnerability `json:"vulnerabilities"` Total int `json:"total"` Page int `json:"page"` PageSize int `json:"page_size"` TotalPages int `json:"total_pages"` } // ListVulnerabilities 列出漏洞 func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { limitStr := c.DefaultQuery("limit", "20") offsetStr := c.DefaultQuery("offset", "0") pageStr := c.Query("page") id := c.Query("id") conversationID := c.Query("conversation_id") severity := c.Query("severity") status := c.Query("status") limit, _ := strconv.Atoi(limitStr) offset, _ := strconv.Atoi(offsetStr) page := 1 // 如果提供了page参数,优先使用page计算offset if pageStr != "" { if p, err := strconv.Atoi(pageStr); err == nil && p > 0 { page = p offset = (page - 1) * limit } } if limit <= 0 || limit > 100 { limit = 20 } if offset < 0 { offset = 0 } // 获取总数 total, err := h.db.CountVulnerabilities(id, conversationID, severity, status) if err != nil { h.logger.Error("获取漏洞总数失败", zap.Error(err)) // 继续执行,使用0作为总数 total = 0 } // 获取漏洞列表 vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status) if err != nil { h.logger.Error("获取漏洞列表失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // 计算总页数 totalPages := (total + limit - 1) / limit if totalPages == 0 { totalPages = 1 } // 如果使用offset计算page,需要重新计算 if pageStr == "" { page = (offset / limit) + 1 } response := ListVulnerabilitiesResponse{ Vulnerabilities: vulnerabilities, Total: total, Page: page, PageSize: limit, TotalPages: totalPages, } c.JSON(http.StatusOK, response) } // UpdateVulnerabilityRequest 更新漏洞请求 type UpdateVulnerabilityRequest struct { Title string `json:"title"` Description string `json:"description"` Severity string `json:"severity"` Status string `json:"status"` Type string `json:"type"` Target string `json:"target"` Proof string `json:"proof"` Impact string `json:"impact"` Recommendation string `json:"recommendation"` } // UpdateVulnerability 更新漏洞 func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) { id := c.Param("id") var req UpdateVulnerabilityRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // 获取现有漏洞 existing, err := h.db.GetVulnerability(id) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"}) return } // 更新字段 if req.Title != "" { existing.Title = req.Title } if req.Description != "" { existing.Description = req.Description } if req.Severity != "" { existing.Severity = req.Severity } if req.Status != "" { existing.Status = req.Status } if req.Type != "" { existing.Type = req.Type } if req.Target != "" { existing.Target = req.Target } if req.Proof != "" { existing.Proof = req.Proof } if req.Impact != "" { existing.Impact = req.Impact } if req.Recommendation != "" { existing.Recommendation = req.Recommendation } if err := h.db.UpdateVulnerability(id, existing); err != nil { h.logger.Error("更新漏洞失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // 返回更新后的漏洞 updated, err := h.db.GetVulnerability(id) if err != nil { h.logger.Error("获取更新后的漏洞失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, updated) } // DeleteVulnerability 删除漏洞 func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) { id := c.Param("id") if err := h.db.DeleteVulnerability(id); err != nil { h.logger.Error("删除漏洞失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) } // GetVulnerabilityStats 获取漏洞统计 func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) { conversationID := c.Query("conversation_id") stats, err := h.db.GetVulnerabilityStats(conversationID) if err != nil { h.logger.Error("获取漏洞统计失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, stats) }