package app import ( "context" "database/sql" "fmt" "net/http" "os" "path/filepath" "time" "cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/config" "cyberstrike-ai/internal/database" "cyberstrike-ai/internal/handler" "cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/logger" "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/security" "cyberstrike-ai/internal/storage" "github.com/gin-gonic/gin" "go.uber.org/zap" ) // App 应用 type App struct { config *config.Config logger *logger.Logger router *gin.Engine mcpServer *mcp.Server externalMCPMgr *mcp.ExternalMCPManager agent *agent.Agent executor *security.Executor db *database.DB knowledgeDB *database.DB // 知识库数据库连接(如果使用独立数据库) auth *security.AuthManager } // New 创建新应用 func New(cfg *config.Config, log *logger.Logger) (*App, error) { gin.SetMode(gin.ReleaseMode) router := gin.Default() // CORS中间件 router.Use(corsMiddleware()) // 认证管理器 authManager, err := security.NewAuthManager(cfg.Auth.Password, cfg.Auth.SessionDurationHours) if err != nil { return nil, fmt.Errorf("初始化认证失败: %w", err) } // 初始化数据库 dbPath := cfg.Database.Path if dbPath == "" { dbPath = "data/conversations.db" } // 确保目录存在 if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil { return nil, fmt.Errorf("创建数据库目录失败: %w", err) } db, err := database.NewDB(dbPath, log.Logger) if err != nil { return nil, fmt.Errorf("初始化数据库失败: %w", err) } // 创建MCP服务器(带数据库持久化) mcpServer := mcp.NewServerWithStorage(log.Logger, db) // 创建安全工具执行器 executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger) // 注册工具 executor.RegisterTools(mcpServer) if cfg.Auth.GeneratedPassword != "" { config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr) cfg.Auth.GeneratedPassword = "" cfg.Auth.GeneratedPasswordPersisted = false cfg.Auth.GeneratedPasswordPersistErr = "" } // 创建外部MCP管理器(使用与内部MCP服务器相同的存储) externalMCPMgr := mcp.NewExternalMCPManagerWithStorage(log.Logger, db) if cfg.ExternalMCP.Servers != nil { externalMCPMgr.LoadConfigs(&cfg.ExternalMCP) // 启动所有启用的外部MCP客户端 externalMCPMgr.StartAllEnabled() } // 初始化结果存储 resultStorageDir := "tmp" if cfg.Agent.ResultStorageDir != "" { resultStorageDir = cfg.Agent.ResultStorageDir } // 确保存储目录存在 if err := os.MkdirAll(resultStorageDir, 0755); err != nil { return nil, fmt.Errorf("创建结果存储目录失败: %w", err) } // 创建结果存储实例 resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger) if err != nil { return nil, fmt.Errorf("初始化结果存储失败: %w", err) } // 创建Agent maxIterations := cfg.Agent.MaxIterations if maxIterations <= 0 { maxIterations = 30 // 默认值 } agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations) // 设置结果存储到Agent agent.SetResultStorage(resultStorage) // 设置结果存储到Executor(用于查询工具) executor.SetResultStorage(resultStorage) // 初始化知识库模块(如果启用) var knowledgeManager *knowledge.Manager var knowledgeRetriever *knowledge.Retriever var knowledgeIndexer *knowledge.Indexer var knowledgeHandler *handler.KnowledgeHandler var knowledgeDBConn *database.DB log.Logger.Info("检查知识库配置", zap.Bool("enabled", cfg.Knowledge.Enabled)) if cfg.Knowledge.Enabled { // 确定知识库数据库路径 knowledgeDBPath := cfg.Database.KnowledgeDBPath var knowledgeDB *sql.DB if knowledgeDBPath != "" { // 使用独立的知识库数据库 // 确保目录存在 if err := os.MkdirAll(filepath.Dir(knowledgeDBPath), 0755); err != nil { return nil, fmt.Errorf("创建知识库数据库目录失败: %w", err) } var err error knowledgeDBConn, err = database.NewKnowledgeDB(knowledgeDBPath, log.Logger) if err != nil { return nil, fmt.Errorf("初始化知识库数据库失败: %w", err) } knowledgeDB = knowledgeDBConn.DB log.Logger.Info("使用独立的知识库数据库", zap.String("path", knowledgeDBPath)) } else { // 向后兼容:使用会话数据库 knowledgeDB = db.DB log.Logger.Info("使用会话数据库存储知识库数据(建议配置knowledge_db_path以分离数据)") } // 创建知识库管理器 knowledgeManager = knowledge.NewManager(knowledgeDB, cfg.Knowledge.BasePath, log.Logger) // 创建嵌入器 // 使用OpenAI配置的API Key(如果知识库配置中没有指定) if cfg.Knowledge.Embedding.APIKey == "" { cfg.Knowledge.Embedding.APIKey = cfg.OpenAI.APIKey } if cfg.Knowledge.Embedding.BaseURL == "" { cfg.Knowledge.Embedding.BaseURL = cfg.OpenAI.BaseURL } httpClient := &http.Client{ Timeout: 30 * time.Minute, } openAIClient := openai.NewClient(&cfg.OpenAI, httpClient, log.Logger) embedder := knowledge.NewEmbedder(&cfg.Knowledge, &cfg.OpenAI, openAIClient, log.Logger) // 创建检索器 retrievalConfig := &knowledge.RetrievalConfig{ TopK: cfg.Knowledge.Retrieval.TopK, SimilarityThreshold: cfg.Knowledge.Retrieval.SimilarityThreshold, HybridWeight: cfg.Knowledge.Retrieval.HybridWeight, } knowledgeRetriever = knowledge.NewRetriever(knowledgeDB, embedder, retrievalConfig, log.Logger) // 创建索引器 knowledgeIndexer = knowledge.NewIndexer(knowledgeDB, embedder, log.Logger) // 注册知识检索工具到MCP服务器 knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger) // 创建知识库API处理器 knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger) log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil)) // 扫描知识库并建立索引(异步) go func() { if err := knowledgeManager.ScanKnowledgeBase(); err != nil { log.Logger.Warn("扫描知识库失败", zap.Error(err)) return } // 检查是否已有索引,如果有则跳过自动重建 hasIndex, err := knowledgeIndexer.HasIndex() if err != nil { log.Logger.Warn("检查索引状态失败", zap.Error(err)) return } if hasIndex { log.Logger.Info("检测到已有知识库索引,跳过自动重建。如需重建,请手动点击重建索引按钮") return } // 只有在没有索引时才自动重建 log.Logger.Info("未检测到知识库索引,开始自动构建索引") ctx := context.Background() if err := knowledgeIndexer.RebuildIndex(ctx); err != nil { log.Logger.Warn("重建知识库索引失败", zap.Error(err)) } }() } // 获取配置文件路径 configPath := "config.yaml" if len(os.Args) > 1 { configPath = os.Args[1] } // 创建处理器 agentHandler := handler.NewAgentHandler(agent, db, log.Logger) // 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志 if knowledgeManager != nil { agentHandler.SetKnowledgeManager(knowledgeManager) } monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger) monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录 conversationHandler := handler.NewConversationHandler(db, log.Logger) authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger) attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger) configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger) // 如果知识库已启用,设置知识库工具注册器,以便在ApplyConfig时重新注册知识库工具 if cfg.Knowledge.Enabled && knowledgeRetriever != nil && knowledgeManager != nil { // 创建闭包,捕获knowledgeRetriever和knowledgeManager的引用 registrar := func() error { knowledge.RegisterKnowledgeTool(mcpServer, knowledgeRetriever, knowledgeManager, log.Logger) return nil } configHandler.SetKnowledgeToolRegistrar(registrar) } externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) // 设置路由 setupRoutes( router, authHandler, agentHandler, monitorHandler, conversationHandler, configHandler, externalMCPHandler, attackChainHandler, knowledgeHandler, mcpServer, authManager, ) return &App{ config: cfg, logger: log, router: router, mcpServer: mcpServer, externalMCPMgr: externalMCPMgr, agent: agent, executor: executor, db: db, knowledgeDB: knowledgeDBConn, auth: authManager, }, nil } // Run 启动应用 func (a *App) Run() error { // 启动MCP服务器(如果启用) if a.config.MCP.Enabled { go func() { mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port) a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr)) mux := http.NewServeMux() mux.HandleFunc("/mcp", a.mcpServer.HandleHTTP) if err := http.ListenAndServe(mcpAddr, mux); err != nil { a.logger.Error("MCP服务器启动失败", zap.Error(err)) } }() } // 启动主服务器 addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port) a.logger.Info("启动HTTP服务器", zap.String("address", addr)) return a.router.Run(addr) } // Shutdown 关闭应用 func (a *App) Shutdown() { // 停止所有外部MCP客户端 if a.externalMCPMgr != nil { a.externalMCPMgr.StopAll() } // 关闭知识库数据库连接(如果使用独立数据库) if a.knowledgeDB != nil { if err := a.knowledgeDB.Close(); err != nil { a.logger.Logger.Warn("关闭知识库数据库连接失败", zap.Error(err)) } } } // setupRoutes 设置路由 func setupRoutes( router *gin.Engine, authHandler *handler.AuthHandler, agentHandler *handler.AgentHandler, monitorHandler *handler.MonitorHandler, conversationHandler *handler.ConversationHandler, configHandler *handler.ConfigHandler, externalMCPHandler *handler.ExternalMCPHandler, attackChainHandler *handler.AttackChainHandler, knowledgeHandler *handler.KnowledgeHandler, mcpServer *mcp.Server, authManager *security.AuthManager, ) { // API路由 api := router.Group("/api") // 认证相关路由 authRoutes := api.Group("/auth") { authRoutes.POST("/login", authHandler.Login) authRoutes.POST("/logout", security.AuthMiddleware(authManager), authHandler.Logout) authRoutes.POST("/change-password", security.AuthMiddleware(authManager), authHandler.ChangePassword) authRoutes.GET("/validate", security.AuthMiddleware(authManager), authHandler.Validate) } protected := api.Group("") protected.Use(security.AuthMiddleware(authManager)) { // Agent Loop protected.POST("/agent-loop", agentHandler.AgentLoop) // Agent Loop 流式输出 protected.POST("/agent-loop/stream", agentHandler.AgentLoopStream) // Agent Loop 取消与任务列表 protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop) protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks) // 对话历史 protected.POST("/conversations", conversationHandler.CreateConversation) protected.GET("/conversations", conversationHandler.ListConversations) protected.GET("/conversations/:id", conversationHandler.GetConversation) protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation) // 监控 protected.GET("/monitor", monitorHandler.Monitor) protected.GET("/monitor/execution/:id", monitorHandler.GetExecution) protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution) protected.GET("/monitor/stats", monitorHandler.GetStats) // 配置管理 protected.GET("/config", configHandler.GetConfig) protected.GET("/config/tools", configHandler.GetTools) protected.PUT("/config", configHandler.UpdateConfig) protected.POST("/config/apply", configHandler.ApplyConfig) // 外部MCP管理 protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs) protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats) protected.GET("/external-mcp/:name", externalMCPHandler.GetExternalMCP) protected.PUT("/external-mcp/:name", externalMCPHandler.AddOrUpdateExternalMCP) protected.DELETE("/external-mcp/:name", externalMCPHandler.DeleteExternalMCP) protected.POST("/external-mcp/:name/start", externalMCPHandler.StartExternalMCP) protected.POST("/external-mcp/:name/stop", externalMCPHandler.StopExternalMCP) // 攻击链可视化 protected.GET("/attack-chain/:conversationId", attackChainHandler.GetAttackChain) protected.POST("/attack-chain/:conversationId/regenerate", attackChainHandler.RegenerateAttackChain) // 知识库管理(如果启用) if knowledgeHandler != nil { protected.GET("/knowledge/categories", knowledgeHandler.GetCategories) protected.GET("/knowledge/items", knowledgeHandler.GetItems) protected.GET("/knowledge/items/:id", knowledgeHandler.GetItem) protected.POST("/knowledge/items", knowledgeHandler.CreateItem) protected.PUT("/knowledge/items/:id", knowledgeHandler.UpdateItem) protected.DELETE("/knowledge/items/:id", knowledgeHandler.DeleteItem) protected.GET("/knowledge/index-status", knowledgeHandler.GetIndexStatus) protected.POST("/knowledge/index", knowledgeHandler.RebuildIndex) protected.POST("/knowledge/scan", knowledgeHandler.ScanKnowledgeBase) protected.GET("/knowledge/retrieval-logs", knowledgeHandler.GetRetrievalLogs) protected.POST("/knowledge/search", knowledgeHandler.Search) } // MCP端点 protected.POST("/mcp", func(c *gin.Context) { mcpServer.HandleHTTP(c.Writer, c.Request) }) } // 静态文件 router.Static("/static", "./web/static") router.LoadHTMLGlob("web/templates/*") // 前端页面 router.GET("/", func(c *gin.Context) { c.HTML(http.StatusOK, "index.html", nil) }) } // corsMiddleware CORS中间件 func corsMiddleware() gin.HandlerFunc { return func(c *gin.Context) { c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE") if c.Request.Method == "OPTIONS" { c.AbortWithStatus(204) return } c.Next() } }