diff --git a/internal/database/vulnerability.go b/internal/database/vulnerability.go index 20b2d8da..ea0328eb 100644 --- a/internal/database/vulnerability.go +++ b/internal/database/vulnerability.go @@ -244,18 +244,24 @@ func (db *DB) DeleteVulnerability(id string) error { return nil } -// GetVulnerabilityStats 获取漏洞统计 -func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface{}, error) { +// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致) +func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) { stats := make(map[string]interface{}) + where := "WHERE 1=1" + args := []interface{}{} + if conversationID != "" { + where += " AND conversation_id = ?" + args = append(args, conversationID) + } + if taskID != "" { + where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))" + args = append(args, taskID, taskID) + } + // 总漏洞数 var totalCount int - query := "SELECT COUNT(*) FROM vulnerabilities" - args := []interface{}{} - if conversationID != "" { - query += " WHERE conversation_id = ?" - args = append(args, conversationID) - } + query := "SELECT COUNT(*) FROM vulnerabilities " + where err := db.QueryRow(query, args...).Scan(&totalCount) if err != nil { return nil, fmt.Errorf("获取总漏洞数失败: %w", err) @@ -263,11 +269,7 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface stats["total"] = totalCount // 按严重程度统计 - severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities" - if conversationID != "" { - severityQuery += " WHERE conversation_id = ?" - } - severityQuery += " GROUP BY severity" + severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities " + where + " GROUP BY severity" rows, err := db.Query(severityQuery, args...) if err != nil { @@ -287,11 +289,7 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface stats["by_severity"] = severityStats // 按状态统计 - statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities" - if conversationID != "" { - statusQuery += " WHERE conversation_id = ?" - } - statusQuery += " GROUP BY status" + statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities " + where + " GROUP BY status" rows, err = db.Query(statusQuery, args...) if err != nil { diff --git a/internal/handler/vulnerability.go b/internal/handler/vulnerability.go index f7335a46..d9531976 100644 --- a/internal/handler/vulnerability.go +++ b/internal/handler/vulnerability.go @@ -268,8 +268,9 @@ func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) { // GetVulnerabilityStats 获取漏洞统计 func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) { conversationID := c.Query("conversation_id") + taskID := c.Query("task_id") - stats, err := h.db.GetVulnerabilityStats(conversationID) + stats, err := h.db.GetVulnerabilityStats(conversationID, taskID) if err != nil { h.logger.Error("获取漏洞统计失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})