mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-03 19:17:55 +02:00
567 lines
16 KiB
Go
567 lines
16 KiB
Go
package handler
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"cyberstrike-ai/internal/audit"
|
|
"cyberstrike-ai/internal/database"
|
|
"github.com/gin-gonic/gin"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// VulnerabilityHandler 漏洞处理器
|
|
type VulnerabilityHandler struct {
|
|
db *database.DB
|
|
logger *zap.Logger
|
|
audit *audit.Service
|
|
}
|
|
|
|
// SetAudit wires platform audit logging.
|
|
func (h *VulnerabilityHandler) SetAudit(s *audit.Service) {
|
|
h.audit = s
|
|
}
|
|
|
|
// NewVulnerabilityHandler 创建新的漏洞处理器
|
|
func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *VulnerabilityHandler {
|
|
return &VulnerabilityHandler{
|
|
db: db,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// CreateVulnerabilityRequest 创建漏洞请求
|
|
type CreateVulnerabilityRequest struct {
|
|
ConversationID string `json:"conversation_id" binding:"required"`
|
|
ProjectID string `json:"project_id"`
|
|
ConversationTag string `json:"conversation_tag"`
|
|
TaskTag string `json:"task_tag"`
|
|
Title string `json:"title" binding:"required"`
|
|
Description string `json:"description"`
|
|
Severity string `json:"severity" binding:"required"`
|
|
Status string `json:"status"`
|
|
Type string `json:"type"`
|
|
Target string `json:"target"`
|
|
Preconditions string `json:"preconditions"`
|
|
ReproSteps string `json:"reproduction_steps"`
|
|
Evidence string `json:"evidence"`
|
|
Impact string `json:"impact"`
|
|
Recommendation string `json:"recommendation"`
|
|
RetestNotes string `json:"retest_notes"`
|
|
}
|
|
|
|
// CreateVulnerability 创建漏洞
|
|
func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
|
var req CreateVulnerabilityRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
vuln := &database.Vulnerability{
|
|
ConversationID: req.ConversationID,
|
|
ProjectID: strings.TrimSpace(req.ProjectID),
|
|
ConversationTag: req.ConversationTag,
|
|
TaskTag: req.TaskTag,
|
|
Title: req.Title,
|
|
Description: req.Description,
|
|
Severity: req.Severity,
|
|
Status: req.Status,
|
|
Type: req.Type,
|
|
Target: req.Target,
|
|
Preconditions: req.Preconditions,
|
|
ReproSteps: req.ReproSteps,
|
|
Evidence: req.Evidence,
|
|
Impact: req.Impact,
|
|
Recommendation: req.Recommendation,
|
|
RetestNotes: req.RetestNotes,
|
|
}
|
|
|
|
created, err := h.db.CreateVulnerability(vuln)
|
|
if err != nil {
|
|
h.logger.Error("创建漏洞失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if h.audit != nil {
|
|
h.audit.RecordOK(c, "vulnerability", "create", "创建漏洞记录", "vulnerability", created.ID, map[string]interface{}{
|
|
"severity": created.Severity, "title": created.Title,
|
|
})
|
|
}
|
|
c.JSON(http.StatusOK, created)
|
|
}
|
|
|
|
// GetVulnerability 获取漏洞
|
|
func (h *VulnerabilityHandler) GetVulnerability(c *gin.Context) {
|
|
id := c.Param("id")
|
|
|
|
vuln, err := h.db.GetVulnerability(id)
|
|
if err != nil {
|
|
h.logger.Error("获取漏洞失败", zap.Error(err))
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, vuln)
|
|
}
|
|
|
|
// ListVulnerabilitiesResponse 漏洞列表响应
|
|
type ListVulnerabilitiesResponse struct {
|
|
Vulnerabilities []*database.Vulnerability `json:"vulnerabilities"`
|
|
Total int `json:"total"`
|
|
Page int `json:"page"`
|
|
PageSize int `json:"page_size"`
|
|
TotalPages int `json:"total_pages"`
|
|
}
|
|
|
|
func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilter {
|
|
q := strings.TrimSpace(c.Query("q"))
|
|
if q == "" {
|
|
q = strings.TrimSpace(c.Query("search"))
|
|
}
|
|
return database.VulnerabilityListFilter{
|
|
ProjectID: c.Query("project_id"),
|
|
ID: c.Query("id"),
|
|
Search: q,
|
|
ConversationID: c.Query("conversation_id"),
|
|
Severity: c.Query("severity"),
|
|
Status: c.Query("status"),
|
|
TaskID: c.Query("task_id"),
|
|
ConversationTag: c.Query("conversation_tag"),
|
|
TaskTag: c.Query("task_tag"),
|
|
}
|
|
}
|
|
|
|
// ListVulnerabilities 列出漏洞
|
|
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
|
limitStr := c.DefaultQuery("limit", "20")
|
|
offsetStr := c.DefaultQuery("offset", "0")
|
|
pageStr := c.Query("page")
|
|
filter := parseVulnerabilityListFilter(c)
|
|
|
|
limit, _ := strconv.Atoi(limitStr)
|
|
offset, _ := strconv.Atoi(offsetStr)
|
|
page := 1
|
|
|
|
// 如果提供了page参数,优先使用page计算offset
|
|
if pageStr != "" {
|
|
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
|
|
page = p
|
|
offset = (page - 1) * limit
|
|
}
|
|
}
|
|
|
|
if limit <= 0 || limit > 100 {
|
|
limit = 20
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
// 获取总数
|
|
total, err := h.db.CountVulnerabilities(filter)
|
|
if err != nil {
|
|
h.logger.Error("获取漏洞总数失败", zap.Error(err))
|
|
// 继续执行,使用0作为总数
|
|
total = 0
|
|
}
|
|
|
|
// 获取漏洞列表
|
|
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, filter)
|
|
if err != nil {
|
|
h.logger.Error("获取漏洞列表失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// 计算总页数
|
|
totalPages := (total + limit - 1) / limit
|
|
if totalPages == 0 {
|
|
totalPages = 1
|
|
}
|
|
|
|
// 如果使用offset计算page,需要重新计算
|
|
if pageStr == "" {
|
|
page = (offset / limit) + 1
|
|
}
|
|
|
|
response := ListVulnerabilitiesResponse{
|
|
Vulnerabilities: vulnerabilities,
|
|
Total: total,
|
|
Page: page,
|
|
PageSize: limit,
|
|
TotalPages: totalPages,
|
|
}
|
|
|
|
c.JSON(http.StatusOK, response)
|
|
}
|
|
|
|
// UpdateVulnerabilityRequest 更新漏洞请求
|
|
type UpdateVulnerabilityRequest struct {
|
|
ProjectID *string `json:"project_id"`
|
|
ConversationTag *string `json:"conversation_tag"`
|
|
TaskTag *string `json:"task_tag"`
|
|
Title *string `json:"title"`
|
|
Description *string `json:"description"`
|
|
Severity *string `json:"severity"`
|
|
Status *string `json:"status"`
|
|
Type *string `json:"type"`
|
|
Target *string `json:"target"`
|
|
Preconditions *string `json:"preconditions"`
|
|
ReproSteps *string `json:"reproduction_steps"`
|
|
Evidence *string `json:"evidence"`
|
|
Impact *string `json:"impact"`
|
|
Recommendation *string `json:"recommendation"`
|
|
RetestNotes *string `json:"retest_notes"`
|
|
}
|
|
|
|
// UpdateVulnerability 更新漏洞
|
|
func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
|
id := c.Param("id")
|
|
|
|
var req UpdateVulnerabilityRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// 获取现有漏洞
|
|
existing, err := h.db.GetVulnerability(id)
|
|
if err != nil {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"})
|
|
return
|
|
}
|
|
|
|
// 更新字段
|
|
if req.ProjectID != nil {
|
|
existing.ProjectID = strings.TrimSpace(*req.ProjectID)
|
|
}
|
|
if req.ConversationTag != nil {
|
|
existing.ConversationTag = *req.ConversationTag
|
|
}
|
|
if req.TaskTag != nil {
|
|
existing.TaskTag = *req.TaskTag
|
|
}
|
|
if req.Title != nil {
|
|
existing.Title = *req.Title
|
|
}
|
|
if req.Description != nil {
|
|
existing.Description = *req.Description
|
|
}
|
|
if req.Severity != nil {
|
|
existing.Severity = *req.Severity
|
|
}
|
|
if req.Status != nil {
|
|
existing.Status = *req.Status
|
|
}
|
|
if req.Type != nil {
|
|
existing.Type = *req.Type
|
|
}
|
|
if req.Target != nil {
|
|
existing.Target = *req.Target
|
|
}
|
|
if req.Preconditions != nil {
|
|
existing.Preconditions = *req.Preconditions
|
|
}
|
|
if req.ReproSteps != nil {
|
|
existing.ReproSteps = *req.ReproSteps
|
|
}
|
|
if req.Evidence != nil {
|
|
existing.Evidence = *req.Evidence
|
|
}
|
|
if req.Impact != nil {
|
|
existing.Impact = *req.Impact
|
|
}
|
|
if req.Recommendation != nil {
|
|
existing.Recommendation = *req.Recommendation
|
|
}
|
|
if req.RetestNotes != nil {
|
|
existing.RetestNotes = *req.RetestNotes
|
|
}
|
|
|
|
if err := h.db.UpdateVulnerability(id, existing); err != nil {
|
|
h.logger.Error("更新漏洞失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// 返回更新后的漏洞
|
|
updated, err := h.db.GetVulnerability(id)
|
|
if err != nil {
|
|
h.logger.Error("获取更新后的漏洞失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if h.audit != nil {
|
|
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
|
|
"severity": updated.Severity, "status": updated.Status, "project_id": updated.ProjectID,
|
|
})
|
|
}
|
|
c.JSON(http.StatusOK, updated)
|
|
}
|
|
|
|
// DeleteVulnerability 删除漏洞
|
|
func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
|
|
id := c.Param("id")
|
|
|
|
if err := h.db.DeleteVulnerability(id); err != nil {
|
|
h.logger.Error("删除漏洞失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if h.audit != nil {
|
|
h.audit.Record(c, audit.Entry{
|
|
Category: "vulnerability",
|
|
Action: "delete",
|
|
Result: "success",
|
|
ResourceType: "vulnerability",
|
|
ResourceID: id,
|
|
Message: "删除漏洞记录",
|
|
})
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
|
}
|
|
|
|
// BatchDeleteVulnerabilities 按当前筛选条件批量删除漏洞
|
|
func (h *VulnerabilityHandler) BatchDeleteVulnerabilities(c *gin.Context) {
|
|
filter := parseVulnerabilityListFilter(c)
|
|
|
|
total, err := h.db.CountVulnerabilities(filter)
|
|
if err != nil {
|
|
h.logger.Error("统计待删除漏洞失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
if total == 0 {
|
|
c.JSON(http.StatusOK, gin.H{"message": "当前筛选条件下没有可删除的漏洞", "deleted": 0})
|
|
return
|
|
}
|
|
|
|
deleted, err := h.db.DeleteVulnerabilitiesByFilter(filter)
|
|
if err != nil {
|
|
h.logger.Error("批量删除漏洞失败", zap.Error(err), zap.Int("count", total))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if h.audit != nil {
|
|
h.audit.RecordOK(c, "vulnerability", "delete_batch", "批量删除漏洞记录", "vulnerability", "", map[string]interface{}{
|
|
"deleted": deleted,
|
|
"filter": filter,
|
|
})
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{"message": "批量删除成功", "deleted": deleted})
|
|
}
|
|
|
|
// GetVulnerabilityStats 获取漏洞统计
|
|
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
|
|
filter := parseVulnerabilityListFilter(c)
|
|
|
|
stats, err := h.db.GetVulnerabilityStats(filter)
|
|
if err != nil {
|
|
h.logger.Error("获取漏洞统计失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, stats)
|
|
}
|
|
|
|
// GetVulnerabilityFilterOptions 获取漏洞筛选建议项
|
|
func (h *VulnerabilityHandler) GetVulnerabilityFilterOptions(c *gin.Context) {
|
|
options, err := h.db.GetVulnerabilityFilterOptions()
|
|
if err != nil {
|
|
h.logger.Error("获取漏洞筛选建议失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, options)
|
|
}
|
|
|
|
// ExportVulnerabilities 导出漏洞(支持按对话/任务分组,汇总或拆分)
|
|
func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
|
|
groupBy := c.DefaultQuery("group_by", "conversation")
|
|
mode := c.DefaultQuery("mode", "summary")
|
|
if groupBy != "conversation" && groupBy != "task" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "group_by 仅支持 conversation 或 task"})
|
|
return
|
|
}
|
|
if mode != "summary" && mode != "split" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "mode 仅支持 summary 或 split"})
|
|
return
|
|
}
|
|
|
|
filter := parseVulnerabilityListFilter(c)
|
|
|
|
total, err := h.db.CountVulnerabilities(filter)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
if total == 0 {
|
|
c.JSON(http.StatusOK, gin.H{"mode": mode, "group_by": groupBy, "total": 0, "files": []any{}})
|
|
return
|
|
}
|
|
|
|
items, err := h.db.ListVulnerabilities(total, 0, filter)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
type exportFile struct {
|
|
FileName string `json:"filename"`
|
|
Content string `json:"content"`
|
|
}
|
|
grouped := map[string][]*database.Vulnerability{}
|
|
for _, v := range items {
|
|
key := v.ConversationID
|
|
if groupBy == "conversation" {
|
|
if strings.TrimSpace(v.ConversationTag) != "" {
|
|
key = strings.TrimSpace(v.ConversationTag)
|
|
}
|
|
} else {
|
|
key = firstNonEmpty(v.TaskTag, v.TaskID, v.TaskQueueID, "unassigned-task")
|
|
}
|
|
grouped[key] = append(grouped[key], v)
|
|
}
|
|
|
|
files := make([]exportFile, 0)
|
|
nowStr := time.Now().Format("20060102-150405")
|
|
if mode == "summary" {
|
|
var b strings.Builder
|
|
b.WriteString("# 漏洞批量导出报告\n\n")
|
|
b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05")))
|
|
b.WriteString(fmt.Sprintf("- 分组维度: %s\n", groupBy))
|
|
b.WriteString(fmt.Sprintf("- 漏洞总数: %d\n", len(items)))
|
|
b.WriteString(fmt.Sprintf("- 分组数: %d\n\n", len(grouped)))
|
|
for group, list := range grouped {
|
|
b.WriteString(fmt.Sprintf("## %s (%d)\n\n", group, len(list)))
|
|
for _, v := range list {
|
|
appendVulnerabilityMarkdown(&b, v, "###")
|
|
}
|
|
}
|
|
files = append(files, exportFile{
|
|
FileName: fmt.Sprintf("vulnerability-report-%s-%s.md", groupBy, nowStr),
|
|
Content: b.String(),
|
|
})
|
|
} else {
|
|
for group, list := range grouped {
|
|
var b strings.Builder
|
|
b.WriteString(fmt.Sprintf("# 漏洞报告 - %s\n\n", group))
|
|
b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05")))
|
|
b.WriteString(fmt.Sprintf("- 漏洞数量: %d\n\n", len(list)))
|
|
for _, v := range list {
|
|
appendVulnerabilityMarkdown(&b, v, "##")
|
|
}
|
|
files = append(files, exportFile{
|
|
FileName: fmt.Sprintf("vulnerability-%s-%s.md", sanitizeExportName(group), nowStr),
|
|
Content: b.String(),
|
|
})
|
|
}
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"mode": mode,
|
|
"group_by": groupBy,
|
|
"total": len(items),
|
|
"files": files,
|
|
})
|
|
}
|
|
|
|
// appendVulnerabilityMarkdown 单条漏洞的 Markdown 片段(与单文件下载字段对齐,缺省字段不写)
|
|
func appendVulnerabilityMarkdown(b *strings.Builder, v *database.Vulnerability, titleHeading string) {
|
|
b.WriteString(fmt.Sprintf("%s %s\n\n", titleHeading, v.Title))
|
|
b.WriteString(fmt.Sprintf("- 漏洞ID: `%s`\n", v.ID))
|
|
b.WriteString(fmt.Sprintf("- 严重程度: %s\n", v.Severity))
|
|
b.WriteString(fmt.Sprintf("- 状态: %s\n", v.Status))
|
|
if v.Type != "" {
|
|
b.WriteString(fmt.Sprintf("- 类型: %s\n", v.Type))
|
|
}
|
|
if v.Target != "" {
|
|
b.WriteString(fmt.Sprintf("- 目标: %s\n", v.Target))
|
|
}
|
|
b.WriteString(fmt.Sprintf("- 对话ID: `%s`\n", v.ConversationID))
|
|
if v.ConversationTag != "" {
|
|
b.WriteString(fmt.Sprintf("- 对话标签: %s\n", v.ConversationTag))
|
|
}
|
|
if v.TaskTag != "" {
|
|
b.WriteString(fmt.Sprintf("- 任务标签: %s\n", v.TaskTag))
|
|
}
|
|
if v.TaskID != "" {
|
|
b.WriteString(fmt.Sprintf("- 任务ID: `%s`\n", v.TaskID))
|
|
}
|
|
if v.TaskQueueID != "" {
|
|
b.WriteString(fmt.Sprintf("- 任务队列ID: `%s`\n", v.TaskQueueID))
|
|
}
|
|
if !v.CreatedAt.IsZero() {
|
|
b.WriteString(fmt.Sprintf("- 创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05")))
|
|
}
|
|
if !v.UpdatedAt.IsZero() {
|
|
b.WriteString(fmt.Sprintf("- 更新时间: %s\n", v.UpdatedAt.Format("2006-01-02 15:04:05")))
|
|
}
|
|
if v.Description != "" {
|
|
b.WriteString("\n#### 描述\n\n")
|
|
b.WriteString(v.Description)
|
|
b.WriteString("\n")
|
|
}
|
|
if v.Preconditions != "" {
|
|
b.WriteString("\n#### 前置条件\n\n")
|
|
b.WriteString(v.Preconditions)
|
|
b.WriteString("\n")
|
|
}
|
|
if v.ReproSteps != "" {
|
|
b.WriteString("\n#### 复现步骤\n\n")
|
|
b.WriteString(v.ReproSteps)
|
|
b.WriteString("\n")
|
|
}
|
|
if v.Evidence != "" {
|
|
b.WriteString("\n#### 证据 / POC\n\n```\n")
|
|
b.WriteString(v.Evidence)
|
|
b.WriteString("\n```\n")
|
|
}
|
|
if v.Impact != "" {
|
|
b.WriteString("\n#### 影响\n\n")
|
|
b.WriteString(v.Impact)
|
|
b.WriteString("\n")
|
|
}
|
|
if v.Recommendation != "" {
|
|
b.WriteString("\n#### 修复建议\n\n")
|
|
b.WriteString(v.Recommendation)
|
|
b.WriteString("\n")
|
|
}
|
|
if v.RetestNotes != "" {
|
|
b.WriteString("\n#### 复测方式\n\n")
|
|
b.WriteString(v.RetestNotes)
|
|
b.WriteString("\n")
|
|
}
|
|
b.WriteString("\n")
|
|
}
|
|
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, v := range values {
|
|
trimmed := strings.TrimSpace(v)
|
|
if trimmed != "" {
|
|
return trimmed
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func sanitizeExportName(raw string) string {
|
|
name := strings.TrimSpace(raw)
|
|
if name == "" {
|
|
return "unknown"
|
|
}
|
|
replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-")
|
|
return replacer.Replace(name)
|
|
}
|