Add files via upload

This commit is contained in:
公明
2025-12-25 22:12:41 +08:00
committed by GitHub
parent 99cf5e78a9
commit 025704cbf7
11 changed files with 1683 additions and 28 deletions
+54 -16
View File
@@ -20,18 +20,19 @@ import (
// Agent AI代理
type Agent struct {
openAIClient *openai.Client
config *config.OpenAIConfig
agentConfig *config.AgentConfig
memoryCompressor *MemoryCompressor
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
logger *zap.Logger
maxIterations int
resultStorage ResultStorage // 结果存储
largeResultThreshold int // 大结果阈值(字节)
mu sync.RWMutex // 添加互斥锁以支持并发更新
toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具)
openAIClient *openai.Client
config *config.OpenAIConfig
agentConfig *config.AgentConfig
memoryCompressor *MemoryCompressor
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
logger *zap.Logger
maxIterations int
resultStorage ResultStorage // 结果存储
largeResultThreshold int // 大结果阈值(字节)
mu sync.RWMutex // 添加互斥锁以支持并发更新
toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具)
currentConversationID string // 当前对话ID(用于自动传递给工具)
}
// ResultStorage 结果存储接口(直接使用 storage 包的类型)
@@ -301,11 +302,20 @@ type ProgressCallback func(eventType, message string, data interface{})
// AgentLoop 执行Agent循环
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, nil)
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil)
}
// AgentLoopWithProgress 执行Agent循环(带进度回调
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, callback ProgressCallback) (*AgentLoopResult, error) {
// AgentLoopWithConversationID 执行Agent循环(带对话ID
func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil)
}
// AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID)
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback) (*AgentLoopResult, error) {
// 设置当前对话ID
a.mu.Lock()
a.currentConversationID = conversationID
a.mu.Unlock()
// 发送进度更新
sendProgress := func(eventType, message string, data interface{}) {
if callback != nil {
@@ -388,7 +398,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
5. 如果确实无法使用某个工具,向用户说明问题,并建议替代方案或手动操作
6. 不要因为单个工具失败就停止整个测试流程,尝试其他方法继续完成任务
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。`
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
漏洞记录要求:
- 当你发现有效漏洞时,必须使用 record_vulnerability 工具记录漏洞详情
- 漏洞记录应包含:标题、描述、严重程度、类型、目标、证明(POC)、影响和修复建议
- 严重程度评估标准:
* critical(严重):可导致系统完全被控制、数据泄露、服务中断等
* high(高):可导致敏感信息泄露、权限提升、重要功能被绕过等
* medium(中):可导致部分信息泄露、功能受限、需要特定条件才能利用等
* low(低):影响较小,难以利用或影响范围有限
* info(信息):安全配置问题、信息泄露但不直接可利用等
- 确保漏洞证明(proof)包含足够的证据,如请求/响应、截图、命令输出等
- 在记录漏洞后,继续测试以发现更多问题`
messages := []ChatMessage{
{
@@ -1112,6 +1134,22 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
zap.Any("args", args),
)
// 如果是record_vulnerability工具,自动添加conversation_id
if toolName == "record_vulnerability" {
a.mu.RLock()
conversationID := a.currentConversationID
a.mu.RUnlock()
if conversationID != "" {
args["conversation_id"] = conversationID
a.logger.Debug("自动添加conversation_id到record_vulnerability工具",
zap.String("conversation_id", conversationID),
)
} else {
a.logger.Warn("record_vulnerability工具调用时conversation_id为空")
}
}
var result *mcp.ToolResult
var executionID string
var err error
+203
View File
@@ -77,6 +77,9 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
// 注册工具
executor.RegisterTools(mcpServer)
// 注册漏洞记录工具
registerVulnerabilityTool(mcpServer, db, log.Logger)
if cfg.Auth.GeneratedPassword != "" {
config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr)
cfg.Auth.GeneratedPassword = ""
@@ -237,6 +240,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
groupHandler := handler.NewGroupHandler(db, log.Logger)
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
// 如果知识库已启用,设置知识库工具注册器,以便在ApplyConfig时重新注册知识库工具
if cfg.Knowledge.Enabled && knowledgeRetriever != nil && knowledgeManager != nil {
@@ -261,6 +265,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
externalMCPHandler,
attackChainHandler,
knowledgeHandler,
vulnerabilityHandler,
mcpServer,
authManager,
)
@@ -330,6 +335,7 @@ func setupRoutes(
externalMCPHandler *handler.ExternalMCPHandler,
attackChainHandler *handler.AttackChainHandler,
knowledgeHandler *handler.KnowledgeHandler,
vulnerabilityHandler *handler.VulnerabilityHandler,
mcpServer *mcp.Server,
authManager *security.AuthManager,
) {
@@ -417,6 +423,14 @@ func setupRoutes(
protected.POST("/knowledge/search", knowledgeHandler.Search)
}
// 漏洞管理
protected.GET("/vulnerabilities", vulnerabilityHandler.ListVulnerabilities)
protected.GET("/vulnerabilities/stats", vulnerabilityHandler.GetVulnerabilityStats)
protected.GET("/vulnerabilities/:id", vulnerabilityHandler.GetVulnerability)
protected.POST("/vulnerabilities", vulnerabilityHandler.CreateVulnerability)
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
// MCP端点
protected.POST("/mcp", func(c *gin.Context) {
mcpServer.HandleHTTP(c.Writer, c.Request)
@@ -433,6 +447,195 @@ func setupRoutes(
})
}
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
tool := mcp.Tool{
Name: "record_vulnerability",
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"title": map[string]interface{}{
"type": "string",
"description": "漏洞标题(必需)",
},
"description": map[string]interface{}{
"type": "string",
"description": "漏洞详细描述",
},
"severity": map[string]interface{}{
"type": "string",
"description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)",
"enum": []string{"critical", "high", "medium", "low", "info"},
},
"vulnerability_type": map[string]interface{}{
"type": "string",
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等",
},
"target": map[string]interface{}{
"type": "string",
"description": "受影响的目标(URL、IP地址、服务等)",
},
"proof": map[string]interface{}{
"type": "string",
"description": "漏洞证明(POC、截图、请求/响应等)",
},
"impact": map[string]interface{}{
"type": "string",
"description": "漏洞影响说明",
},
"recommendation": map[string]interface{}{
"type": "string",
"description": "修复建议",
},
},
"required": []string{"title", "severity"},
},
}
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
// 从参数中获取conversation_id(由Agent自动添加)
conversationID, _ := args["conversation_id"].(string)
if conversationID == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: conversation_id 未设置。这是系统错误,请重试。",
},
},
IsError: true,
}, nil
}
title, ok := args["title"].(string)
if !ok || title == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: title 参数必需且不能为空",
},
},
IsError: true,
}, nil
}
severity, ok := args["severity"].(string)
if !ok || severity == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: severity 参数必需且不能为空",
},
},
IsError: true,
}, nil
}
// 验证严重程度
validSeverities := map[string]bool{
"critical": true,
"high": true,
"medium": true,
"low": true,
"info": true,
}
if !validSeverities[severity] {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity),
},
},
IsError: true,
}, nil
}
// 获取可选参数
description := ""
if d, ok := args["description"].(string); ok {
description = d
}
vulnType := ""
if t, ok := args["vulnerability_type"].(string); ok {
vulnType = t
}
target := ""
if t, ok := args["target"].(string); ok {
target = t
}
proof := ""
if p, ok := args["proof"].(string); ok {
proof = p
}
impact := ""
if i, ok := args["impact"].(string); ok {
impact = i
}
recommendation := ""
if r, ok := args["recommendation"].(string); ok {
recommendation = r
}
// 创建漏洞记录
vuln := &database.Vulnerability{
ConversationID: conversationID,
Title: title,
Description: description,
Severity: severity,
Status: "open",
Type: vulnType,
Target: target,
Proof: proof,
Impact: impact,
Recommendation: recommendation,
}
created, err := db.CreateVulnerability(vuln)
if err != nil {
logger.Error("记录漏洞失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("记录漏洞失败: %v", err),
},
},
IsError: true,
}, nil
}
logger.Info("漏洞记录成功",
zap.String("id", created.ID),
zap.String("title", created.Title),
zap.String("severity", created.Severity),
zap.String("conversation_id", conversationID),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n你可以在漏洞管理页面查看和管理此漏洞。", created.ID, created.Title, created.Severity, created.Status),
},
},
IsError: false,
}, nil
}
mcpServer.RegisterTool(tool, handler)
logger.Info("漏洞记录工具注册成功")
}
// corsMiddleware CORS中间件
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
+27
View File
@@ -170,6 +170,25 @@ func (db *DB) initTables() error {
UNIQUE(conversation_id, group_id)
);`
// 创建漏洞表
createVulnerabilitiesTable := `
CREATE TABLE IF NOT EXISTS vulnerabilities (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
title TEXT NOT NULL,
description TEXT,
severity TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'open',
vulnerability_type TEXT,
target TEXT,
proof TEXT,
impact TEXT,
recommendation TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
);`
// 创建索引
createIndexes := `
CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
@@ -189,6 +208,10 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_conversation ON conversation_group_mappings(conversation_id);
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id);
CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
`
if _, err := db.Exec(createConversationsTable); err != nil {
@@ -231,6 +254,10 @@ func (db *DB) initTables() error {
return fmt.Errorf("创建conversation_group_mappings表失败: %w", err)
}
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
return fmt.Errorf("创建vulnerabilities表失败: %w", err)
}
// 为已有表添加新字段(如果不存在)- 必须在创建索引之前
if err := db.migrateConversationsTable(); err != nil {
db.logger.Warn("迁移conversations表失败", zap.Error(err))
+246
View File
@@ -0,0 +1,246 @@
package database
import (
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Vulnerability 漏洞
type Vulnerability struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"` // critical, high, medium, low, info
Status string `json:"status"` // open, confirmed, fixed, false_positive
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// CreateVulnerability 创建漏洞
func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
if vuln.ID == "" {
vuln.ID = uuid.New().String()
}
if vuln.Status == "" {
vuln.Status = "open"
}
now := time.Now()
if vuln.CreatedAt.IsZero() {
vuln.CreatedAt = now
}
vuln.UpdatedAt = now
query := `
INSERT INTO vulnerabilities (
id, conversation_id, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(
query,
vuln.ID, vuln.ConversationID, vuln.Title, vuln.Description,
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
vuln.Proof, vuln.Impact, vuln.Recommendation,
vuln.CreatedAt, vuln.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("创建漏洞失败: %w", err)
}
return vuln, nil
}
// GetVulnerability 获取漏洞
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
var vuln Vulnerability
query := `
SELECT id, conversation_id, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at
FROM vulnerabilities
WHERE id = ?
`
err := db.QueryRow(query, id).Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.CreatedAt, &vuln.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("漏洞不存在")
}
return nil, fmt.Errorf("获取漏洞失败: %w", err)
}
return &vuln, nil
}
// ListVulnerabilities 列出漏洞
func (db *DB) ListVulnerabilities(limit, offset int, conversationID, severity, status string) ([]*Vulnerability, error) {
query := `
SELECT id, conversation_id, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at
FROM vulnerabilities
WHERE 1=1
`
args := []interface{}{}
if conversationID != "" {
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
}
if status != "" {
query += " AND status = ?"
args = append(args, status)
}
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询漏洞列表失败: %w", err)
}
defer rows.Close()
var vulnerabilities []*Vulnerability
for rows.Next() {
var vuln Vulnerability
err := rows.Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.CreatedAt, &vuln.UpdatedAt,
)
if err != nil {
db.logger.Warn("扫描漏洞记录失败", zap.Error(err))
continue
}
vulnerabilities = append(vulnerabilities, &vuln)
}
return vulnerabilities, nil
}
// UpdateVulnerability 更新漏洞
func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
vuln.UpdatedAt = time.Now()
query := `
UPDATE vulnerabilities
SET title = ?, description = ?, severity = ?, status = ?,
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
recommendation = ?, updated_at = ?
WHERE id = ?
`
_, err := db.Exec(
query,
vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
vuln.Recommendation, vuln.UpdatedAt, id,
)
if err != nil {
return fmt.Errorf("更新漏洞失败: %w", err)
}
return nil
}
// DeleteVulnerability 删除漏洞
func (db *DB) DeleteVulnerability(id string) error {
_, err := db.Exec("DELETE FROM vulnerabilities WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除漏洞失败: %w", err)
}
return nil
}
// GetVulnerabilityStats 获取漏洞统计
func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface{}, error) {
stats := make(map[string]interface{})
// 总漏洞数
var totalCount int
query := "SELECT COUNT(*) FROM vulnerabilities"
args := []interface{}{}
if conversationID != "" {
query += " WHERE conversation_id = ?"
args = append(args, conversationID)
}
err := db.QueryRow(query, args...).Scan(&totalCount)
if err != nil {
return nil, fmt.Errorf("获取总漏洞数失败: %w", err)
}
stats["total"] = totalCount
// 按严重程度统计
severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities"
if conversationID != "" {
severityQuery += " WHERE conversation_id = ?"
}
severityQuery += " GROUP BY severity"
rows, err := db.Query(severityQuery, args...)
if err != nil {
return nil, fmt.Errorf("获取严重程度统计失败: %w", err)
}
defer rows.Close()
severityStats := make(map[string]int)
for rows.Next() {
var severity string
var count int
if err := rows.Scan(&severity, &count); err != nil {
continue
}
severityStats[severity] = count
}
stats["by_severity"] = severityStats
// 按状态统计
statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities"
if conversationID != "" {
statusQuery += " WHERE conversation_id = ?"
}
statusQuery += " GROUP BY status"
rows, err = db.Query(statusQuery, args...)
if err != nil {
return nil, fmt.Errorf("获取状态统计失败: %w", err)
}
defer rows.Close()
statusStats := make(map[string]int)
for rows.Next() {
var status string
var count int
if err := rows.Scan(&status, &count); err != nil {
continue
}
statusStats[status] = count
}
stats["by_status"] = statusStats
return stats, nil
}
+3 -3
View File
@@ -117,8 +117,8 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
h.logger.Error("保存用户消息失败", zap.Error(err))
}
// 执行Agent Loop,传入历史消息
result, err := h.agent.AgentLoop(c.Request.Context(), req.Message, agentHistoryMessages)
// 执行Agent Loop,传入历史消息和对话ID
result, err := h.agent.AgentLoopWithConversationID(c.Request.Context(), req.Message, agentHistoryMessages, conversationID)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
@@ -492,7 +492,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断
sendEvent("progress", "正在分析您的请求...", nil)
result, err := h.agent.AgentLoopWithProgress(taskCtx, req.Message, agentHistoryMessages, progressCallback)
result, err := h.agent.AgentLoopWithProgress(taskCtx, req.Message, agentHistoryMessages, conversationID, progressCallback)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
cause := context.Cause(baseCtx)
+212
View File
@@ -0,0 +1,212 @@
package handler
import (
"net/http"
"strconv"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// VulnerabilityHandler 漏洞处理器
type VulnerabilityHandler struct {
db *database.DB
logger *zap.Logger
}
// NewVulnerabilityHandler 创建新的漏洞处理器
func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *VulnerabilityHandler {
return &VulnerabilityHandler{
db: db,
logger: logger,
}
}
// CreateVulnerabilityRequest 创建漏洞请求
type CreateVulnerabilityRequest struct {
ConversationID string `json:"conversation_id" binding:"required"`
Title string `json:"title" binding:"required"`
Description string `json:"description"`
Severity string `json:"severity" binding:"required"`
Status string `json:"status"`
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
}
// CreateVulnerability 创建漏洞
func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
var req CreateVulnerabilityRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
vuln := &database.Vulnerability{
ConversationID: req.ConversationID,
Title: req.Title,
Description: req.Description,
Severity: req.Severity,
Status: req.Status,
Type: req.Type,
Target: req.Target,
Proof: req.Proof,
Impact: req.Impact,
Recommendation: req.Recommendation,
}
created, err := h.db.CreateVulnerability(vuln)
if err != nil {
h.logger.Error("创建漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, created)
}
// GetVulnerability 获取漏洞
func (h *VulnerabilityHandler) GetVulnerability(c *gin.Context) {
id := c.Param("id")
vuln, err := h.db.GetVulnerability(id)
if err != nil {
h.logger.Error("获取漏洞失败", zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"})
return
}
c.JSON(http.StatusOK, vuln)
}
// ListVulnerabilities 列出漏洞
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "50")
offsetStr := c.DefaultQuery("offset", "0")
conversationID := c.Query("conversation_id")
severity := c.Query("severity")
status := c.Query("status")
limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr)
if limit <= 0 || limit > 100 {
limit = 50
}
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, conversationID, severity, status)
if err != nil {
h.logger.Error("获取漏洞列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, vulnerabilities)
}
// UpdateVulnerabilityRequest 更新漏洞请求
type UpdateVulnerabilityRequest struct {
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"`
Status string `json:"status"`
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
}
// UpdateVulnerability 更新漏洞
func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
id := c.Param("id")
var req UpdateVulnerabilityRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 获取现有漏洞
existing, err := h.db.GetVulnerability(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"})
return
}
// 更新字段
if req.Title != "" {
existing.Title = req.Title
}
if req.Description != "" {
existing.Description = req.Description
}
if req.Severity != "" {
existing.Severity = req.Severity
}
if req.Status != "" {
existing.Status = req.Status
}
if req.Type != "" {
existing.Type = req.Type
}
if req.Target != "" {
existing.Target = req.Target
}
if req.Proof != "" {
existing.Proof = req.Proof
}
if req.Impact != "" {
existing.Impact = req.Impact
}
if req.Recommendation != "" {
existing.Recommendation = req.Recommendation
}
if err := h.db.UpdateVulnerability(id, existing); err != nil {
h.logger.Error("更新漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 返回更新后的漏洞
updated, err := h.db.GetVulnerability(id)
if err != nil {
h.logger.Error("获取更新后的漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, updated)
}
// DeleteVulnerability 删除漏洞
func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
id := c.Param("id")
if err := h.db.DeleteVulnerability(id); err != nil {
h.logger.Error("删除漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// GetVulnerabilityStats 获取漏洞统计
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
conversationID := c.Query("conversation_id")
stats, err := h.db.GetVulnerabilityStats(conversationID)
if err != nil {
h.logger.Error("获取漏洞统计失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, stats)
}