Files
CyberStrikeAI/internal/app/app.go
T
2025-12-25 22:12:41 +08:00

655 lines
22 KiB
Go

package app
import (
"context"
"database/sql"
"fmt"
"net/http"
"os"
"path/filepath"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/handler"
"cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/storage"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// App 应用
type App struct {
config *config.Config
logger *logger.Logger
router *gin.Engine
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager
agent *agent.Agent
executor *security.Executor
db *database.DB
knowledgeDB *database.DB // 知识库数据库连接(如果使用独立数据库)
auth *security.AuthManager
}
// New 创建新应用
func New(cfg *config.Config, log *logger.Logger) (*App, error) {
gin.SetMode(gin.ReleaseMode)
router := gin.Default()
// CORS中间件
router.Use(corsMiddleware())
// 认证管理器
authManager, err := security.NewAuthManager(cfg.Auth.Password, cfg.Auth.SessionDurationHours)
if err != nil {
return nil, fmt.Errorf("初始化认证失败: %w", err)
}
// 初始化数据库
dbPath := cfg.Database.Path
if dbPath == "" {
dbPath = "data/conversations.db"
}
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
}
db, err := database.NewDB(dbPath, log.Logger)
if err != nil {
return nil, fmt.Errorf("初始化数据库失败: %w", err)
}
// 创建MCP服务器(带数据库持久化)
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
// 创建安全工具执行器
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
// 注册工具
executor.RegisterTools(mcpServer)
// 注册漏洞记录工具
registerVulnerabilityTool(mcpServer, db, log.Logger)
if cfg.Auth.GeneratedPassword != "" {
config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr)
cfg.Auth.GeneratedPassword = ""
cfg.Auth.GeneratedPasswordPersisted = false
cfg.Auth.GeneratedPasswordPersistErr = ""
}
// 创建外部MCP管理器(使用与内部MCP服务器相同的存储)
externalMCPMgr := mcp.NewExternalMCPManagerWithStorage(log.Logger, db)
if cfg.ExternalMCP.Servers != nil {
externalMCPMgr.LoadConfigs(&cfg.ExternalMCP)
// 启动所有启用的外部MCP客户端
externalMCPMgr.StartAllEnabled()
}
// 初始化结果存储
resultStorageDir := "tmp"
if cfg.Agent.ResultStorageDir != "" {
resultStorageDir = cfg.Agent.ResultStorageDir
}
// 确保存储目录存在
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
return nil, fmt.Errorf("创建结果存储目录失败: %w", err)
}
// 创建结果存储实例
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
if err != nil {
return nil, fmt.Errorf("初始化结果存储失败: %w", err)
}
// 创建Agent
maxIterations := cfg.Agent.MaxIterations
if maxIterations <= 0 {
maxIterations = 30 // 默认值
}
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
// 设置结果存储到Agent
agent.SetResultStorage(resultStorage)
// 设置结果存储到Executor(用于查询工具)
executor.SetResultStorage(resultStorage)
// 初始化知识库模块(如果启用)
var knowledgeManager *knowledge.Manager
var knowledgeRetriever *knowledge.Retriever
var knowledgeIndexer *knowledge.Indexer
var knowledgeHandler *handler.KnowledgeHandler
var knowledgeDBConn *database.DB
log.Logger.Info("检查知识库配置", zap.Bool("enabled", cfg.Knowledge.Enabled))
if cfg.Knowledge.Enabled {
// 确定知识库数据库路径
knowledgeDBPath := cfg.Database.KnowledgeDBPath
var knowledgeDB *sql.DB
if knowledgeDBPath != "" {
// 使用独立的知识库数据库
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(knowledgeDBPath), 0755); err != nil {
return nil, fmt.Errorf("创建知识库数据库目录失败: %w", err)
}
var err error
knowledgeDBConn, err = database.NewKnowledgeDB(knowledgeDBPath, log.Logger)
if err != nil {
return nil, fmt.Errorf("初始化知识库数据库失败: %w", err)
}
knowledgeDB = knowledgeDBConn.DB
log.Logger.Info("使用独立的知识库数据库", zap.String("path", knowledgeDBPath))
} else {
// 向后兼容:使用会话数据库
knowledgeDB = db.DB
log.Logger.Info("使用会话数据库存储知识库数据(建议配置knowledge_db_path以分离数据)")
}
// 创建知识库管理器
knowledgeManager = knowledge.NewManager(knowledgeDB, cfg.Knowledge.BasePath, log.Logger)
// 创建嵌入器
// 使用OpenAI配置的API Key(如果知识库配置中没有指定)
if cfg.Knowledge.Embedding.APIKey == "" {
cfg.Knowledge.Embedding.APIKey = cfg.OpenAI.APIKey
}
if cfg.Knowledge.Embedding.BaseURL == "" {
cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL
}
httpClient := &http.Client{
Timeout: 30 * time.Minute,
}
openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, log.Logger)
embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, log.Logger)
// 创建检索器
retrievalConfig := &knowledge.RetrievalConfig{
TopK: cfg.Knowledge.Retrieval.TopK,
SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold,
HybridWeight: cfg.Knowledge.Retrieval.HybridWeight,
}
knowledgeRetriever = knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, log.Logger)
// 创建索引器
knowledgeIndexer = knowledge.NewIndexer(knowledgeDB, embedder, log.Logger)
// 注册知识检索工具到MCP服务器
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger)
// 创建知识库API处理器
knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger)
log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
// 扫描知识库并建立索引(异步)
go func() {
if err := knowledgeManager.ScanKnowledgeBase(); err != nil {
log.Logger.Warn("扫描知识库失败", zap.Error(err))
return
}
// 检查是否已有索引,如果有则跳过自动重建
hasIndex, err := knowledgeIndexer.HasIndex()
if err != nil {
log.Logger.Warn("检查索引状态失败", zap.Error(err))
return
}
if hasIndex {
log.Logger.Info("检测到已有知识库索引,跳过自动重建。如需重建,请手动点击重建索引按钮")
return
}
// 只有在没有索引时才自动重建
log.Logger.Info("未检测到知识库索引,开始自动构建索引")
ctx := context.Background()
if err := knowledgeIndexer.RebuildIndex(ctx); err != nil {
log.Logger.Warn("重建知识库索引失败", zap.Error(err))
}
}()
}
// 获取配置文件路径
configPath := "config.yaml"
if len(os.Args) > 1 {
configPath = os.Args[1]
}
// 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
if knowledgeManager != nil {
agentHandler.SetKnowledgeManager(knowledgeManager)
}
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
conversationHandler := handler.NewConversationHandler(db, log.Logger)
groupHandler := handler.NewGroupHandler(db, log.Logger)
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
// 如果知识库已启用,设置知识库工具注册器,以便在ApplyConfig时重新注册知识库工具
if cfg.Knowledge.Enabled && knowledgeRetriever != nil && knowledgeManager != nil {
// 创建闭包,捕获knowledgeRetriever和knowledgeManager的引用
registrar := func() error {
knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger)
return nil
}
configHandler.SetKnowledgeToolRegistrar(registrar)
}
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
// 设置路由
setupRoutes(
router,
authHandler,
agentHandler,
monitorHandler,
conversationHandler,
groupHandler,
configHandler,
externalMCPHandler,
attackChainHandler,
knowledgeHandler,
vulnerabilityHandler,
mcpServer,
authManager,
)
return &App{
config: cfg,
logger: log,
router: router,
mcpServer: mcpServer,
externalMCPMgr: externalMCPMgr,
agent: agent,
executor: executor,
db: db,
knowledgeDB: knowledgeDBConn,
auth: authManager,
}, nil
}
// Run 启动应用
func (a *App) Run() error {
// 启动MCP服务器(如果启用)
if a.config.MCP.Enabled {
go func() {
mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port)
a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr))
mux := http.NewServeMux()
mux.HandleFunc("/mcp", a.mcpServer.HandleHTTP)
if err := http.ListenAndServe(mcpAddr, mux); err != nil {
a.logger.Error("MCP服务器启动失败", zap.Error(err))
}
}()
}
// 启动主服务器
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
a.logger.Info("启动HTTP服务器", zap.String("address", addr))
return a.router.Run(addr)
}
// Shutdown 关闭应用
func (a *App) Shutdown() {
// 停止所有外部MCP客户端
if a.externalMCPMgr != nil {
a.externalMCPMgr.StopAll()
}
// 关闭知识库数据库连接(如果使用独立数据库)
if a.knowledgeDB != nil {
if err := a.knowledgeDB.Close(); err != nil {
a.logger.Logger.Warn("关闭知识库数据库连接失败", zap.Error(err))
}
}
}
// setupRoutes 设置路由
func setupRoutes(
router *gin.Engine,
authHandler *handler.AuthHandler,
agentHandler *handler.AgentHandler,
monitorHandler *handler.MonitorHandler,
conversationHandler *handler.ConversationHandler,
groupHandler *handler.GroupHandler,
configHandler *handler.ConfigHandler,
externalMCPHandler *handler.ExternalMCPHandler,
attackChainHandler *handler.AttackChainHandler,
knowledgeHandler *handler.KnowledgeHandler,
vulnerabilityHandler *handler.VulnerabilityHandler,
mcpServer *mcp.Server,
authManager *security.AuthManager,
) {
// API路由
api := router.Group("/api")
// 认证相关路由
authRoutes := api.Group("/auth")
{
authRoutes.POST("/login", authHandler.Login)
authRoutes.POST("/logout", security.AuthMiddleware(authManager), authHandler.Logout)
authRoutes.POST("/change-password", security.AuthMiddleware(authManager), authHandler.ChangePassword)
authRoutes.GET("/validate", security.AuthMiddleware(authManager), authHandler.Validate)
}
protected := api.Group("")
protected.Use(security.AuthMiddleware(authManager))
{
// Agent Loop
protected.POST("/agent-loop", agentHandler.AgentLoop)
// Agent Loop 流式输出
protected.POST("/agent-loop/stream", agentHandler.AgentLoopStream)
// Agent Loop 取消与任务列表
protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
// 对话历史
protected.POST("/conversations", conversationHandler.CreateConversation)
protected.GET("/conversations", conversationHandler.ListConversations)
protected.GET("/conversations/:id", conversationHandler.GetConversation)
protected.PUT("/conversations/:id", conversationHandler.UpdateConversation)
protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
protected.PUT("/conversations/:id/pinned", groupHandler.UpdateConversationPinned)
// 对话分组
protected.POST("/groups", groupHandler.CreateGroup)
protected.GET("/groups", groupHandler.ListGroups)
protected.GET("/groups/:id", groupHandler.GetGroup)
protected.PUT("/groups/:id", groupHandler.UpdateGroup)
protected.DELETE("/groups/:id", groupHandler.DeleteGroup)
protected.PUT("/groups/:id/pinned", groupHandler.UpdateGroupPinned)
protected.GET("/groups/:id/conversations", groupHandler.GetGroupConversations)
protected.POST("/groups/conversations", groupHandler.AddConversationToGroup)
protected.DELETE("/groups/:id/conversations/:conversationId", groupHandler.RemoveConversationFromGroup)
protected.PUT("/groups/:id/conversations/:conversationId/pinned", groupHandler.UpdateConversationPinnedInGroup)
// 监控
protected.GET("/monitor", monitorHandler.Monitor)
protected.GET("/monitor/execution/:id", monitorHandler.GetExecution)
protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution)
protected.GET("/monitor/stats", monitorHandler.GetStats)
// 配置管理
protected.GET("/config", configHandler.GetConfig)
protected.GET("/config/tools", configHandler.GetTools)
protected.PUT("/config", configHandler.UpdateConfig)
protected.POST("/config/apply", configHandler.ApplyConfig)
// 外部MCP管理
protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs)
protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats)
protected.GET("/external-mcp/:name", externalMCPHandler.GetExternalMCP)
protected.PUT("/external-mcp/:name", externalMCPHandler.AddOrUpdateExternalMCP)
protected.DELETE("/external-mcp/:name", externalMCPHandler.DeleteExternalMCP)
protected.POST("/external-mcp/:name/start", externalMCPHandler.StartExternalMCP)
protected.POST("/external-mcp/:name/stop", externalMCPHandler.StopExternalMCP)
// 攻击链可视化
protected.GET("/attack-chain/:conversationId", attackChainHandler.GetAttackChain)
protected.POST("/attack-chain/:conversationId/regenerate", attackChainHandler.RegenerateAttackChain)
// 知识库管理(如果启用)
if knowledgeHandler != nil {
protected.GET("/knowledge/categories", knowledgeHandler.GetCategories)
protected.GET("/knowledge/items", knowledgeHandler.GetItems)
protected.GET("/knowledge/items/:id", knowledgeHandler.GetItem)
protected.POST("/knowledge/items", knowledgeHandler.CreateItem)
protected.PUT("/knowledge/items/:id", knowledgeHandler.UpdateItem)
protected.DELETE("/knowledge/items/:id", knowledgeHandler.DeleteItem)
protected.GET("/knowledge/index-status", knowledgeHandler.GetIndexStatus)
protected.POST("/knowledge/index", knowledgeHandler.RebuildIndex)
protected.POST("/knowledge/scan", knowledgeHandler.ScanKnowledgeBase)
protected.GET("/knowledge/retrieval-logs", knowledgeHandler.GetRetrievalLogs)
protected.DELETE("/knowledge/retrieval-logs/:id", knowledgeHandler.DeleteRetrievalLog)
protected.POST("/knowledge/search", knowledgeHandler.Search)
}
// 漏洞管理
protected.GET("/vulnerabilities", vulnerabilityHandler.ListVulnerabilities)
protected.GET("/vulnerabilities/stats", vulnerabilityHandler.GetVulnerabilityStats)
protected.GET("/vulnerabilities/:id", vulnerabilityHandler.GetVulnerability)
protected.POST("/vulnerabilities", vulnerabilityHandler.CreateVulnerability)
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
// MCP端点
protected.POST("/mcp", func(c *gin.Context) {
mcpServer.HandleHTTP(c.Writer, c.Request)
})
}
// 静态文件
router.Static("/static", "./web/static")
router.LoadHTMLGlob("web/templates/*")
// 前端页面
router.GET("/", func(c *gin.Context) {
c.HTML(http.StatusOK, "index.html", nil)
})
}
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
tool := mcp.Tool{
Name: "record_vulnerability",
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"title": map[string]interface{}{
"type": "string",
"description": "漏洞标题(必需)",
},
"description": map[string]interface{}{
"type": "string",
"description": "漏洞详细描述",
},
"severity": map[string]interface{}{
"type": "string",
"description": "漏洞严重程度:critical(严重)、high(高)、medium(中)、low(低)、info(信息)",
"enum": []string{"critical", "high", "medium", "low", "info"},
},
"vulnerability_type": map[string]interface{}{
"type": "string",
"description": "漏洞类型,如:SQL注入、XSS、CSRF、命令注入等",
},
"target": map[string]interface{}{
"type": "string",
"description": "受影响的目标(URL、IP地址、服务等)",
},
"proof": map[string]interface{}{
"type": "string",
"description": "漏洞证明(POC、截图、请求/响应等)",
},
"impact": map[string]interface{}{
"type": "string",
"description": "漏洞影响说明",
},
"recommendation": map[string]interface{}{
"type": "string",
"description": "修复建议",
},
},
"required": []string{"title", "severity"},
},
}
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
// 从参数中获取conversation_id(由Agent自动添加)
conversationID, _ := args["conversation_id"].(string)
if conversationID == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: conversation_id 未设置。这是系统错误,请重试。",
},
},
IsError: true,
}, nil
}
title, ok := args["title"].(string)
if !ok || title == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: title 参数必需且不能为空",
},
},
IsError: true,
}, nil
}
severity, ok := args["severity"].(string)
if !ok || severity == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: severity 参数必需且不能为空",
},
},
IsError: true,
}, nil
}
// 验证严重程度
validSeverities := map[string]bool{
"critical": true,
"high": true,
"medium": true,
"low": true,
"info": true,
}
if !validSeverities[severity] {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("错误: severity 必须是 critical、high、medium、low 或 info 之一,当前值: %s", severity),
},
},
IsError: true,
}, nil
}
// 获取可选参数
description := ""
if d, ok := args["description"].(string); ok {
description = d
}
vulnType := ""
if t, ok := args["vulnerability_type"].(string); ok {
vulnType = t
}
target := ""
if t, ok := args["target"].(string); ok {
target = t
}
proof := ""
if p, ok := args["proof"].(string); ok {
proof = p
}
impact := ""
if i, ok := args["impact"].(string); ok {
impact = i
}
recommendation := ""
if r, ok := args["recommendation"].(string); ok {
recommendation = r
}
// 创建漏洞记录
vuln := &database.Vulnerability{
ConversationID: conversationID,
Title: title,
Description: description,
Severity: severity,
Status: "open",
Type: vulnType,
Target: target,
Proof: proof,
Impact: impact,
Recommendation: recommendation,
}
created, err := db.CreateVulnerability(vuln)
if err != nil {
logger.Error("记录漏洞失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("记录漏洞失败: %v", err),
},
},
IsError: true,
}, nil
}
logger.Info("漏洞记录成功",
zap.String("id", created.ID),
zap.String("title", created.Title),
zap.String("severity", created.Severity),
zap.String("conversation_id", conversationID),
)
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("漏洞已成功记录!\n\n漏洞ID: %s\n标题: %s\n严重程度: %s\n状态: %s\n\n你可以在漏洞管理页面查看和管理此漏洞。", created.ID, created.Title, created.Severity, created.Status),
},
},
IsError: false,
}, nil
}
mcpServer.RegisterTool(tool, handler)
logger.Info("漏洞记录工具注册成功")
}
// corsMiddleware CORS中间件
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}