From 6685076dfb28dd82af2862dc660d9ad8926292f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Fri, 22 May 2026 11:35:02 +0800 Subject: [PATCH] Add files via upload --- internal/handler/openapi.go | 2 +- internal/handler/vulnerability.go | 46 +++++++++++++++++-------------- internal/handler/webshell.go | 9 ++++-- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/internal/handler/openapi.go b/internal/handler/openapi.go index da785f0c..15de9ab1 100644 --- a/internal/handler/openapi.go +++ b/internal/handler/openapi.go @@ -6254,7 +6254,7 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) { } // 获取漏洞列表 - vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "", "", "", "") + vulnList, err := h.db.ListVulnerabilities(1000, 0, database.VulnerabilityListFilter{ConversationID: conversationID}) if err != nil { h.logger.Warn("获取漏洞列表失败", zap.Error(err)) vulnList = []*database.Vulnerability{} diff --git a/internal/handler/vulnerability.go b/internal/handler/vulnerability.go index fd8f7819..f9e83395 100644 --- a/internal/handler/vulnerability.go +++ b/internal/handler/vulnerability.go @@ -110,18 +110,29 @@ type ListVulnerabilitiesResponse struct { TotalPages int `json:"total_pages"` } +func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilter { + q := strings.TrimSpace(c.Query("q")) + if q == "" { + q = strings.TrimSpace(c.Query("search")) + } + return database.VulnerabilityListFilter{ + ID: c.Query("id"), + Search: q, + ConversationID: c.Query("conversation_id"), + Severity: c.Query("severity"), + Status: c.Query("status"), + TaskID: c.Query("task_id"), + ConversationTag: c.Query("conversation_tag"), + TaskTag: c.Query("task_tag"), + } +} + // ListVulnerabilities 列出漏洞 func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { limitStr := c.DefaultQuery("limit", "20") offsetStr := c.DefaultQuery("offset", "0") pageStr := c.Query("page") - id := c.Query("id") - conversationID := c.Query("conversation_id") - severity := c.Query("severity") - status := c.Query("status") - taskID := c.Query("task_id") - conversationTag := c.Query("conversation_tag") - taskTag := c.Query("task_tag") + filter := parseVulnerabilityListFilter(c) limit, _ := strconv.Atoi(limitStr) offset, _ := strconv.Atoi(offsetStr) @@ -143,7 +154,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { } // 获取总数 - total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag) + total, err := h.db.CountVulnerabilities(filter) if err != nil { h.logger.Error("获取漏洞总数失败", zap.Error(err)) // 继续执行,使用0作为总数 @@ -151,7 +162,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) { } // 获取漏洞列表 - vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status, taskID, conversationTag, taskTag) + vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, filter) if err != nil { h.logger.Error("获取漏洞列表失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -295,10 +306,9 @@ func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) { // GetVulnerabilityStats 获取漏洞统计 func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) { - conversationID := c.Query("conversation_id") - taskID := c.Query("task_id") + filter := parseVulnerabilityListFilter(c) - stats, err := h.db.GetVulnerabilityStats(conversationID, taskID) + stats, err := h.db.GetVulnerabilityStats(filter) if err != nil { h.logger.Error("获取漏洞统计失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -332,15 +342,9 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) { return } - id := c.Query("id") - conversationID := c.Query("conversation_id") - severity := c.Query("severity") - status := c.Query("status") - taskID := c.Query("task_id") - conversationTag := c.Query("conversation_tag") - taskTag := c.Query("task_tag") + filter := parseVulnerabilityListFilter(c) - total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag) + total, err := h.db.CountVulnerabilities(filter) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -350,7 +354,7 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) { return } - items, err := h.db.ListVulnerabilities(total, 0, id, conversationID, severity, status, taskID, conversationTag, taskTag) + items, err := h.db.ListVulnerabilities(total, 0, filter) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return diff --git a/internal/handler/webshell.go b/internal/handler/webshell.go index 3b95b896..031f3075 100644 --- a/internal/handler/webshell.go +++ b/internal/handler/webshell.go @@ -2,6 +2,7 @@ package handler import ( "bytes" + "crypto/tls" "database/sql" "encoding/base64" "encoding/json" @@ -318,8 +319,12 @@ func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler { return &WebShellHandler{ logger: logger, client: &http.Client{ - Timeout: 30 * time.Second, - Transport: &http.Transport{DisableKeepAlives: false}, + Timeout: 30 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: false, + // WebShell 场景常见自签证书或 IP 访问(证书无 IP SAN);默认跳过校验,与蚁剑等客户端一致。 + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // intentional for webshell proxy + }, }, db: db, }