mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-21 10:16:32 +02:00
264 lines
6.9 KiB
Go
264 lines
6.9 KiB
Go
package handler
|
|
|
|
import (
|
|
"net/http"
|
|
"strconv"
|
|
|
|
"cyberstrike-ai/internal/database"
|
|
"github.com/gin-gonic/gin"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// VulnerabilityHandler 漏洞处理器
|
|
type VulnerabilityHandler struct {
|
|
db *database.DB
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewVulnerabilityHandler 创建新的漏洞处理器
|
|
func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *VulnerabilityHandler {
|
|
return &VulnerabilityHandler{
|
|
db: db,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// CreateVulnerabilityRequest 创建漏洞请求
|
|
type CreateVulnerabilityRequest struct {
|
|
ConversationID string `json:"conversation_id" binding:"required"`
|
|
Title string `json:"title" binding:"required"`
|
|
Description string `json:"description"`
|
|
Severity string `json:"severity" binding:"required"`
|
|
Status string `json:"status"`
|
|
Type string `json:"type"`
|
|
Target string `json:"target"`
|
|
Proof string `json:"proof"`
|
|
Impact string `json:"impact"`
|
|
Recommendation string `json:"recommendation"`
|
|
}
|
|
|
|
// CreateVulnerability 创建漏洞
|
|
func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
|
var req CreateVulnerabilityRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
vuln := &database.Vulnerability{
|
|
ConversationID: req.ConversationID,
|
|
Title: req.Title,
|
|
Description: req.Description,
|
|
Severity: req.Severity,
|
|
Status: req.Status,
|
|
Type: req.Type,
|
|
Target: req.Target,
|
|
Proof: req.Proof,
|
|
Impact: req.Impact,
|
|
Recommendation: req.Recommendation,
|
|
}
|
|
|
|
created, err := h.db.CreateVulnerability(vuln)
|
|
if err != nil {
|
|
h.logger.Error("创建漏洞失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, created)
|
|
}
|
|
|
|
// GetVulnerability 获取漏洞
|
|
func (h *VulnerabilityHandler) GetVulnerability(c *gin.Context) {
|
|
id := c.Param("id")
|
|
|
|
vuln, err := h.db.GetVulnerability(id)
|
|
if err != nil {
|
|
h.logger.Error("获取漏洞失败", zap.Error(err))
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, vuln)
|
|
}
|
|
|
|
// ListVulnerabilitiesResponse 漏洞列表响应
|
|
type ListVulnerabilitiesResponse struct {
|
|
Vulnerabilities []*database.Vulnerability `json:"vulnerabilities"`
|
|
Total int `json:"total"`
|
|
Page int `json:"page"`
|
|
PageSize int `json:"page_size"`
|
|
TotalPages int `json:"total_pages"`
|
|
}
|
|
|
|
// ListVulnerabilities 列出漏洞
|
|
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
|
limitStr := c.DefaultQuery("limit", "20")
|
|
offsetStr := c.DefaultQuery("offset", "0")
|
|
pageStr := c.Query("page")
|
|
id := c.Query("id")
|
|
conversationID := c.Query("conversation_id")
|
|
severity := c.Query("severity")
|
|
status := c.Query("status")
|
|
|
|
limit, _ := strconv.Atoi(limitStr)
|
|
offset, _ := strconv.Atoi(offsetStr)
|
|
page := 1
|
|
|
|
// 如果提供了page参数,优先使用page计算offset
|
|
if pageStr != "" {
|
|
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
|
|
page = p
|
|
offset = (page - 1) * limit
|
|
}
|
|
}
|
|
|
|
if limit <= 0 || limit > 100 {
|
|
limit = 20
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
|
|
// 获取总数
|
|
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status)
|
|
if err != nil {
|
|
h.logger.Error("获取漏洞总数失败", zap.Error(err))
|
|
// 继续执行,使用0作为总数
|
|
total = 0
|
|
}
|
|
|
|
// 获取漏洞列表
|
|
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status)
|
|
if err != nil {
|
|
h.logger.Error("获取漏洞列表失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// 计算总页数
|
|
totalPages := (total + limit - 1) / limit
|
|
if totalPages == 0 {
|
|
totalPages = 1
|
|
}
|
|
|
|
// 如果使用offset计算page,需要重新计算
|
|
if pageStr == "" {
|
|
page = (offset / limit) + 1
|
|
}
|
|
|
|
response := ListVulnerabilitiesResponse{
|
|
Vulnerabilities: vulnerabilities,
|
|
Total: total,
|
|
Page: page,
|
|
PageSize: limit,
|
|
TotalPages: totalPages,
|
|
}
|
|
|
|
c.JSON(http.StatusOK, response)
|
|
}
|
|
|
|
// UpdateVulnerabilityRequest 更新漏洞请求
|
|
type UpdateVulnerabilityRequest struct {
|
|
Title string `json:"title"`
|
|
Description string `json:"description"`
|
|
Severity string `json:"severity"`
|
|
Status string `json:"status"`
|
|
Type string `json:"type"`
|
|
Target string `json:"target"`
|
|
Proof string `json:"proof"`
|
|
Impact string `json:"impact"`
|
|
Recommendation string `json:"recommendation"`
|
|
}
|
|
|
|
// UpdateVulnerability 更新漏洞
|
|
func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
|
id := c.Param("id")
|
|
|
|
var req UpdateVulnerabilityRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// 获取现有漏洞
|
|
existing, err := h.db.GetVulnerability(id)
|
|
if err != nil {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"})
|
|
return
|
|
}
|
|
|
|
// 更新字段
|
|
if req.Title != "" {
|
|
existing.Title = req.Title
|
|
}
|
|
if req.Description != "" {
|
|
existing.Description = req.Description
|
|
}
|
|
if req.Severity != "" {
|
|
existing.Severity = req.Severity
|
|
}
|
|
if req.Status != "" {
|
|
existing.Status = req.Status
|
|
}
|
|
if req.Type != "" {
|
|
existing.Type = req.Type
|
|
}
|
|
if req.Target != "" {
|
|
existing.Target = req.Target
|
|
}
|
|
if req.Proof != "" {
|
|
existing.Proof = req.Proof
|
|
}
|
|
if req.Impact != "" {
|
|
existing.Impact = req.Impact
|
|
}
|
|
if req.Recommendation != "" {
|
|
existing.Recommendation = req.Recommendation
|
|
}
|
|
|
|
if err := h.db.UpdateVulnerability(id, existing); err != nil {
|
|
h.logger.Error("更新漏洞失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// 返回更新后的漏洞
|
|
updated, err := h.db.GetVulnerability(id)
|
|
if err != nil {
|
|
h.logger.Error("获取更新后的漏洞失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, updated)
|
|
}
|
|
|
|
// DeleteVulnerability 删除漏洞
|
|
func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
|
|
id := c.Param("id")
|
|
|
|
if err := h.db.DeleteVulnerability(id); err != nil {
|
|
h.logger.Error("删除漏洞失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
|
}
|
|
|
|
// GetVulnerabilityStats 获取漏洞统计
|
|
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
|
|
conversationID := c.Query("conversation_id")
|
|
|
|
stats, err := h.db.GetVulnerabilityStats(conversationID)
|
|
if err != nil {
|
|
h.logger.Error("获取漏洞统计失败", zap.Error(err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, stats)
|
|
}
|
|
|