diff --git a/README.md b/README.md index b700a19c..5c486ec1 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ ![Preview](./img/mcp-stdio2.png) ## Changelog +- 2025.11.13 Added authentication for the web mode, including automatic password generation and in-app password change - 2025.11.13 Added `Settings` feature in the frontend - 2025.11.13 Added MCP Stdio mode support, now seamlessly integrated and usable in code editors, CLI, and automation scripts - 2025.11.12 Added task stop functionality, optimized frontend @@ -38,6 +39,7 @@ - 💾 **Data Persistence** - SQLite database stores conversation history and process details - 📝 **Detailed Logging** - Structured logging for easy debugging and troubleshooting - 🔒 **Secure Execution** - Tool execution isolation, error handling, and timeout control +- 🔐 **Password-Protected Web Interface** - Unified authentication middleware secures every API call with configurable session duration ## 📁 Project Structure @@ -138,6 +140,11 @@ openai: base_url: "https://api.openai.com/v1" # Or use other compatible API addresses model: "gpt-4" # Or "deepseek-chat", "gpt-3.5-turbo", etc. +# Authentication configuration +auth: + password: "" # Leave empty to auto-generate a strong password on first launch + session_duration_hours: 12 # Login validity (hours) + # Server configuration server: host: "0.0.0.0" @@ -216,6 +223,8 @@ You will see: - API Key, Base URL, and Model are required fields (marked with *), must be filled for normal use - Configuration is automatically saved to the `config.yaml` file - Opening settings automatically loads the latest configuration from the current configuration file +- If `auth.password` is empty, the server generates a random strong password on first launch, writes it back to `config.yaml`, and prints it in the terminal with a security warning +- The web UI prompts for this password when you first open it; you can change it anytime in **Settings → Security** ## ⚙️ Configuration @@ -239,6 +248,11 @@ The system provides a visual configuration management interface. You can access ### Complete Configuration Example ```yaml +# Authentication +auth: + password: "change-me" # Web login password + session_duration_hours: 12 # Session validity (hours) + # Server configuration server: host: "0.0.0.0" # Listen address @@ -318,6 +332,14 @@ Define tool configurations directly in `config.yaml` under `security.tools`. **Note**: If both `tools_dir` and `tools` are configured, tools in `tools_dir` take priority. +### Authentication & Security + +- **Login Workflow**: Every web/API request (except `/api/auth/login`) is protected by a unified middleware. Obtain a token through `/api/auth/login` with the configured password, then include `Authorization: Bearer ` in subsequent requests. +- **Automatic Password Generation**: When `auth.password` is empty, the server generates a 24-character strong password on startup, writes it back to `config.yaml`, and prints the password with bilingual security warnings in the terminal. +- **Session Control**: Sessions expire according to `auth.session_duration_hours`. After expiration or password change, clients must log in again. +- **Password Rotation**: Use **Settings → Security** in the web UI (or call `/api/auth/change-password`) to update the password. The change revokes all existing sessions instantly. +- **MCP Port**: The standalone MCP server (default `8081`) remains authentication-free for IDE integrations. Restrict network access to this port if required. + ## 🚀 Usage Examples ### Conversational Penetration Testing @@ -630,6 +652,9 @@ CyberStrikeAI supports two MCP transport modes: - Suitable for web applications and other HTTP clients - Default listen address: `0.0.0.0:8081/mcp` - Accessible via `/api/mcp` endpoint +- 🌐 Remote-friendly: expose a single endpoint that IDEs, web apps, or automation running on other machines can reach over the network. +- 🧩 Easy reuse: no extra binaries—just point any HTTP-capable client (curl, Postman, cloud automations) to the service. +- 🔁 Always-on workflow: runs together with the main web server, so the same deployment handles UI, API, and MCP traffic. #### MCP HTTP Mode (IDE Integration) @@ -661,6 +686,21 @@ You can connect IDEs such as Cursor or Claude Desktop directly to the built-in H - Fully compliant with JSON-RPC 2.0 specification - Supports string, number, and null types for id field - Properly handles notification messages +- 🔒 Isolated execution: the stdio binary is built and launched separately, so you can run it with least-privilege policies and tighter filesystem/network permissions. +- 🪟 No network exposure: data stays inside the local process boundary—perfect when you do not want an HTTP port listening on your machine. +- 🧰 Editor-first experience: Cursor, Claude Desktop, and other IDEs expect stdio transports for local tooling, enabling plug-and-play integration with minimal setup. +- 🧱 Defense in depth: using both transports in parallel lets you pick the safest option per workflow—stdio for local, HTTP for remote or shared deployments. + +#### Mode comparison: pick what fits your workflow + +| Aspect | `mcp-http` | `mcp-stdio` | +|---------------------|-----------------------------------------------|------------------------------------------------------------------| +| Transport | HTTP/HTTPS over the network | Standard input/output streams | +| Deployment | Runs inside the main server process | Compiled as a standalone binary | +| Isolation & safety | Depends on server hardening (firewall, auth) | Sandboxed by OS process boundaries, no socket exposure | +| Remote access | ✅ Accessible across machines | ❌ Local only (unless tunneled manually) | +| IDE integration | Works with HTTP-capable clients | Native fit for Cursor/Claude Desktop stdio connectors | +| Best use case | Remote automations, shared services | Local development, high-trust / locked-down environments | ### Supported Methods diff --git a/README_CN.md b/README_CN.md index a36b1cee..aaf35cfb 100644 --- a/README_CN.md +++ b/README_CN.md @@ -7,9 +7,10 @@ ![详情预览](./img/mcp-stdio2.png) ## 更新日志 - - 2025.11.13 在前端新增`设置`功能; - - 2025.11.13 新增 MCP Stdio 模式支持,现可在代码编辑器、CLI 及自动化脚本等多种场景下,无缝集成并使用全套安全工具; - - 2025.11.12 增加了任务停止功能,优化前端; +- 2025.11.13 Web 端新增统一鉴权,支持自动生成强密码与前端修改密码; +- 2025.11.13 在前端新增`设置`功能; +- 2025.11.13 新增 MCP Stdio 模式支持,现可在代码编辑器、CLI 及自动化脚本等多种场景下,无缝集成并使用全套安全工具; +- 2025.11.12 增加了任务停止功能,优化前端; ## ✨ 功能特性 @@ -36,6 +37,7 @@ - 💾 **数据持久化** - SQLite数据库存储对话历史和过程详情 - 📝 **详细日志** - 结构化日志记录,便于调试和问题排查 - 🔒 **安全执行** - 工具执行隔离,错误处理和超时控制 +- 🔐 **登录鉴权保护** - Web 端与 API 统一鉴权,中间件校验会话,并支持可配置会话有效期 ## 📁 项目结构 @@ -130,6 +132,11 @@ go mod download 编辑 `config.yaml` 文件,设置您的API配置: ```yaml +# 身份认证配置 +auth: + password: "" # 可留空,首次启动自动生成强密码 + session_duration_hours: 12 # 登录有效期(小时) + # OpenAI兼容API配置(支持OpenAI、DeepSeek、Claude等) openai: api_key: "sk-your-api-key-here" # 替换为您的API Key @@ -214,6 +221,8 @@ go run cmd/server/main.go -config /path/to/config.yaml - API Key、Base URL 和模型是必填项(标记为 *),必须填写才能正常使用 - 配置会自动保存到 `config.yaml` 文件中 - 打开设置时会自动加载当前配置文件中的最新配置 +- 如果 `auth.password` 留空,程序首次启动会自动生成 24 位强密码,写回 `config.yaml` 并在终端输出中英文安全提示 +- Web 首次访问会弹出登录框,请使用该密码登录;可在 **设置 → 安全设置** 中随时修改密码 ## ⚙️ 配置说明 @@ -237,6 +246,11 @@ go run cmd/server/main.go -config /path/to/config.yaml ### 完整配置示例 ```yaml +# 身份认证 +auth: + password: "change-me" # Web 登录密码 + session_duration_hours: 12 # 会话有效期(小时) + # 服务器配置 server: host: "0.0.0.0" # 监听地址 @@ -316,6 +330,14 @@ parameters: **注意:** 如果同时配置了 `tools_dir` 和 `tools`,`tools_dir` 中的工具优先。 +### 身份认证与安全 + +- **登录流程**:除 `/api/auth/login` 外,所有 `/api` 接口均需携带 `Authorization: Bearer ` 请求头。登录成功后返回的 token 会由前端自动缓存。 +- **自动生成密码**:`auth.password` 为空时,启动会生成 24 位随机强密码,写回配置文件并在终端输出中英文提示,请务必妥善保管。 +- **会话控制**:会话有效期由 `auth.session_duration_hours` 控制。过期或修改密码后需重新登录。 +- **密码修改**:在 **设置 → 安全设置** 中即可修改密码,或直接调用 `/api/auth/change-password` 接口;修改会立即使所有旧会话失效。 +- **MCP 端口**:独立 MCP 服务器(默认 `8081`)为了兼容 IDE 插件暂未启用鉴权,建议通过网络层限制访问范围。 + ## 🚀 使用示例 ### 对话式渗透测试 @@ -629,6 +651,9 @@ CyberStrikeAI 支持两种 MCP 传输模式: - 适用于 Web 应用和其他 HTTP 客户端 - 默认监听地址:`0.0.0.0:8081/mcp` - 可通过 `/api/mcp` 端点访问 +- 🌐 便于远程:可对外暴露单个 HTTP 端口,IDE、Web 应用或其他机器上的自动化流程都能直接访问。 +- 🧩 易于复用:无需额外二进制,只要支持 HTTP 的客户端(例如 curl、Postman、云端任务)都能复用同一个服务。 +- 🔁 持续服务:与主 Web 服务同进程运行,部署一次即可同时提供 UI、API 和 MCP 能力。 #### MCP HTTP 模式(IDE 集成) @@ -660,6 +685,21 @@ CyberStrikeAI 支持两种 MCP 传输模式: - 完全符合 JSON-RPC 2.0 规范 - 支持字符串、数字和 null 类型的 id 字段 - 正确处理通知(notification)消息 +- 🔒 更强隔离:以独立二进制方式运行,可结合最小权限策略、独立运行账号来限制文件/网络访问,安全性更高。 +- 🪟 无需暴露端口:所有通信都在本地进程内完成,适合不希望在本机开启额外 HTTP 监听端口的场景。 +- 🧰 IDE 优先体验:Cursor、Claude Desktop 等 IDE 的自定义 MCP 首选 stdio 传输,配置简单即插即用。 +- 🧱 多层防护:HTTP 适合远程共享场景,stdio 适合本地高安全场景,同时保留可根据工作流自由选择。 + +#### 模式对比:按需选择 + +| 对比维度 | `mcp-http`(HTTP 模式) | `mcp-stdio`(stdio 模式) | +|------------------|-----------------------------------------------|-------------------------------------------------------------------| +| 传输协议 | 基于网络的 HTTP/HTTPS | 标准输入输出流 | +| 部署方式 | 与主服务器同进程运行 | 独立编译为单独可执行文件 | +| 隔离与安全 | 依赖服务端加固(防火墙、认证、网络策略) | 借助操作系统进程隔离,无需暴露监听端口 | +| 远程访问 | ✅ 可跨机器访问 | ❌ 仅限本地(除非手动隧道转发) | +| IDE 集成 | 适用于支持 HTTP 的客户端 | 原生适配 Cursor / Claude Desktop 等 stdio 连接器 | +| 最佳使用场景 | 远程自动化、共享服务、云端部署 | 本地开发、对安全隔离要求较高的环境 | ### 支持的方法 diff --git a/config.yaml b/config.yaml index 3e1b9cc4..1a003fd6 100644 --- a/config.yaml +++ b/config.yaml @@ -9,6 +9,10 @@ server: host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口 port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080 +# 认证配置 +auth: + password: # Web 登录密码,请修改为强密码 + session_duration_hours: 12 # 登录有效期(小时),超时后需重新登录 # 日志配置 log: level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误) @@ -16,7 +20,7 @@ log: # MCP 协议配置 # MCP (Model Context Protocol) 用于工具注册和调用 mcp: - enabled: true # 是否启用 MCP 服务器 + enabled: false # 是否启用 MCP 服务器 host: 0.0.0.0 # MCP 服务器监听地址 port: 8081 # MCP 服务器端口 # AI 模型配置(支持 OpenAI 兼容 API) diff --git a/internal/app/app.go b/internal/app/app.go index 8be46ea2..43be259f 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -27,6 +27,7 @@ type App struct { agent *agent.Agent executor *security.Executor db *database.DB + auth *security.AuthManager } // New 创建新应用 @@ -37,6 +38,12 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { // CORS中间件 router.Use(corsMiddleware()) + // 认证管理器 + authManager, err := security.NewAuthManager(cfg.Auth.Password, cfg.Auth.SessionDurationHours) + if err != nil { + return nil, fmt.Errorf("初始化认证失败: %w", err) + } + // 初始化数据库 dbPath := cfg.Database.Path if dbPath == "" { @@ -62,6 +69,13 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { // 注册工具 executor.RegisterTools(mcpServer) + if cfg.Auth.GeneratedPassword != "" { + config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr) + cfg.Auth.GeneratedPassword = "" + cfg.Auth.GeneratedPasswordPersisted = false + cfg.Auth.GeneratedPasswordPersistErr = "" + } + // 创建Agent maxIterations := cfg.Agent.MaxIterations if maxIterations <= 0 { @@ -69,20 +83,30 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { } agent := agent.NewAgent(&cfg.OpenAI, mcpServer, log.Logger, maxIterations) - // 创建处理器 - agentHandler := handler.NewAgentHandler(agent, db, log.Logger) - monitorHandler := handler.NewMonitorHandler(mcpServer, executor, log.Logger) - conversationHandler := handler.NewConversationHandler(db, log.Logger) - // 获取配置文件路径 configPath := "config.yaml" if len(os.Args) > 1 { configPath = os.Args[1] } + + // 创建处理器 + agentHandler := handler.NewAgentHandler(agent, db, log.Logger) + monitorHandler := handler.NewMonitorHandler(mcpServer, executor, log.Logger) + conversationHandler := handler.NewConversationHandler(db, log.Logger) + authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger) configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, log.Logger) // 设置路由 - setupRoutes(router, agentHandler, monitorHandler, conversationHandler, configHandler, mcpServer) + setupRoutes( + router, + authHandler, + agentHandler, + monitorHandler, + conversationHandler, + configHandler, + mcpServer, + authManager, + ) return &App{ config: cfg, @@ -92,6 +116,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) { agent: agent, executor: executor, db: db, + auth: authManager, }, nil } @@ -120,37 +145,58 @@ func (a *App) Run() error { } // setupRoutes 设置路由 -func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitorHandler *handler.MonitorHandler, conversationHandler *handler.ConversationHandler, configHandler *handler.ConfigHandler, mcpServer *mcp.Server) { +func setupRoutes( + router *gin.Engine, + authHandler *handler.AuthHandler, + agentHandler *handler.AgentHandler, + monitorHandler *handler.MonitorHandler, + conversationHandler *handler.ConversationHandler, + configHandler *handler.ConfigHandler, + mcpServer *mcp.Server, + authManager *security.AuthManager, +) { // API路由 api := router.Group("/api") + + // 认证相关路由 + authRoutes := api.Group("/auth") + { + authRoutes.POST("/login", authHandler.Login) + authRoutes.POST("/logout", security.AuthMiddleware(authManager), authHandler.Logout) + authRoutes.POST("/change-password", security.AuthMiddleware(authManager), authHandler.ChangePassword) + authRoutes.GET("/validate", security.AuthMiddleware(authManager), authHandler.Validate) + } + + protected := api.Group("") + protected.Use(security.AuthMiddleware(authManager)) { // Agent Loop - api.POST("/agent-loop", agentHandler.AgentLoop) + protected.POST("/agent-loop", agentHandler.AgentLoop) // Agent Loop 流式输出 - api.POST("/agent-loop/stream", agentHandler.AgentLoopStream) + protected.POST("/agent-loop/stream", agentHandler.AgentLoopStream) // Agent Loop 取消与任务列表 - api.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop) - api.GET("/agent-loop/tasks", agentHandler.ListAgentTasks) + protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop) + protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks) // 对话历史 - api.POST("/conversations", conversationHandler.CreateConversation) - api.GET("/conversations", conversationHandler.ListConversations) - api.GET("/conversations/:id", conversationHandler.GetConversation) - api.DELETE("/conversations/:id", conversationHandler.DeleteConversation) + protected.POST("/conversations", conversationHandler.CreateConversation) + protected.GET("/conversations", conversationHandler.ListConversations) + protected.GET("/conversations/:id", conversationHandler.GetConversation) + protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation) // 监控 - api.GET("/monitor", monitorHandler.Monitor) - api.GET("/monitor/execution/:id", monitorHandler.GetExecution) - api.GET("/monitor/stats", monitorHandler.GetStats) - api.GET("/monitor/vulnerabilities", monitorHandler.GetVulnerabilities) + protected.GET("/monitor", monitorHandler.Monitor) + protected.GET("/monitor/execution/:id", monitorHandler.GetExecution) + protected.GET("/monitor/stats", monitorHandler.GetStats) + protected.GET("/monitor/vulnerabilities", monitorHandler.GetVulnerabilities) // 配置管理 - api.GET("/config", configHandler.GetConfig) - api.PUT("/config", configHandler.UpdateConfig) - api.POST("/config/apply", configHandler.ApplyConfig) + protected.GET("/config", configHandler.GetConfig) + protected.PUT("/config", configHandler.UpdateConfig) + protected.POST("/config/apply", configHandler.ApplyConfig) // MCP端点 - api.POST("/mcp", func(c *gin.Context) { + protected.POST("/mcp", func(c *gin.Context) { mcpServer.HandleHTTP(c.Writer, c.Request) }) } diff --git a/internal/config/config.go b/internal/config/config.go index 424eb4eb..d73b4c1b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,8 @@ package config import ( + "crypto/rand" + "encoding/base64" "fmt" "os" "path/filepath" @@ -17,6 +19,7 @@ type Config struct { Agent AgentConfig `yaml:"agent"` Security SecurityConfig `yaml:"security"` Database DatabaseConfig `yaml:"database"` + Auth AuthConfig `yaml:"auth"` } type ServerConfig struct { @@ -42,8 +45,8 @@ type OpenAIConfig struct { } type SecurityConfig struct { - Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具 - ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式) + Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具 + ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式) } type DatabaseConfig struct { @@ -54,29 +57,36 @@ type AgentConfig struct { MaxIterations int `yaml:"max_iterations" json:"max_iterations"` } +type AuthConfig struct { + Password string `yaml:"password" json:"password"` + SessionDurationHours int `yaml:"session_duration_hours" json:"session_duration_hours"` + GeneratedPassword string `yaml:"-" json:"-"` + GeneratedPasswordPersisted bool `yaml:"-" json:"-"` + GeneratedPasswordPersistErr string `yaml:"-" json:"-"` +} type ToolConfig struct { Name string `yaml:"name"` Command string `yaml:"command"` - Args []string `yaml:"args,omitempty"` // 固定参数(可选) + Args []string `yaml:"args,omitempty"` // 固定参数(可选) ShortDescription string `yaml:"short_description,omitempty"` // 简短描述(用于工具列表,减少token消耗) - Description string `yaml:"description"` // 详细描述(用于工具文档) + Description string `yaml:"description"` // 详细描述(用于工具文档) Enabled bool `yaml:"enabled"` - Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选) + Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选) ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选) } // ParameterConfig 参数配置 type ParameterConfig struct { - Name string `yaml:"name"` // 参数名称 - Type string `yaml:"type"` // 参数类型: string, int, bool, array - Description string `yaml:"description"` // 参数描述 - Required bool `yaml:"required,omitempty"` // 是否必需 - Default interface{} `yaml:"default,omitempty"` // 默认值 - Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p" - Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始) - Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template" - Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}" - Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举) + Name string `yaml:"name"` // 参数名称 + Type string `yaml:"type"` // 参数类型: string, int, bool, array + Description string `yaml:"description"` // 参数描述 + Required bool `yaml:"required,omitempty"` // 是否必需 + Default interface{} `yaml:"default,omitempty"` // 默认值 + Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p" + Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始) + Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template" + Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}" + Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举) } func Load(path string) (*Config, error) { @@ -90,65 +100,189 @@ func Load(path string) (*Config, error) { return nil, fmt.Errorf("解析配置文件失败: %w", err) } + if cfg.Auth.SessionDurationHours <= 0 { + cfg.Auth.SessionDurationHours = 12 + } + + if strings.TrimSpace(cfg.Auth.Password) == "" { + password, err := generateStrongPassword(24) + if err != nil { + return nil, fmt.Errorf("生成默认密码失败: %w", err) + } + + cfg.Auth.Password = password + cfg.Auth.GeneratedPassword = password + + if err := PersistAuthPassword(path, password); err != nil { + cfg.Auth.GeneratedPasswordPersisted = false + cfg.Auth.GeneratedPasswordPersistErr = err.Error() + } else { + cfg.Auth.GeneratedPasswordPersisted = true + } + } + // 如果配置了工具目录,从目录加载工具配置 if cfg.Security.ToolsDir != "" { configDir := filepath.Dir(path) toolsDir := cfg.Security.ToolsDir - + // 如果是相对路径,相对于配置文件所在目录 if !filepath.IsAbs(toolsDir) { toolsDir = filepath.Join(configDir, toolsDir) } - + tools, err := LoadToolsFromDir(toolsDir) if err != nil { return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err) } - + // 合并工具配置:目录中的工具优先,主配置中的工具作为补充 existingTools := make(map[string]bool) for _, tool := range tools { existingTools[tool.Name] = true } - + // 添加主配置中不存在于目录中的工具(向后兼容) for _, tool := range cfg.Security.Tools { if !existingTools[tool.Name] { tools = append(tools, tool) } } - + cfg.Security.Tools = tools } return &cfg, nil } +func generateStrongPassword(length int) (string, error) { + if length <= 0 { + length = 24 + } + + bytesLen := length + randomBytes := make([]byte, bytesLen) + if _, err := rand.Read(randomBytes); err != nil { + return "", err + } + + password := base64.RawURLEncoding.EncodeToString(randomBytes) + if len(password) > length { + password = password[:length] + } + return password, nil +} + +func PersistAuthPassword(path, password string) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + + lines := strings.Split(string(data), "\n") + inAuthBlock := false + authIndent := -1 + + for i, line := range lines { + trimmed := strings.TrimSpace(line) + if !inAuthBlock { + if strings.HasPrefix(trimmed, "auth:") { + inAuthBlock = true + authIndent = len(line) - len(strings.TrimLeft(line, " ")) + } + continue + } + + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + + leadingSpaces := len(line) - len(strings.TrimLeft(line, " ")) + if leadingSpaces <= authIndent { + // 离开 auth 块 + inAuthBlock = false + authIndent = -1 + // 继续寻找其它 auth 块(理论上没有) + if strings.HasPrefix(trimmed, "auth:") { + inAuthBlock = true + authIndent = leadingSpaces + } + continue + } + + if strings.HasPrefix(strings.TrimSpace(line), "password:") { + prefix := line[:len(line)-len(strings.TrimLeft(line, " "))] + comment := "" + if idx := strings.Index(line, "#"); idx >= 0 { + comment = strings.TrimRight(line[idx:], " ") + } + + newLine := fmt.Sprintf("%spassword: %s", prefix, password) + if comment != "" { + if !strings.HasPrefix(comment, " ") { + newLine += " " + } + newLine += comment + } + lines[i] = newLine + break + } + } + + return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644) +} + +func PrintGeneratedPasswordWarning(password string, persisted bool, persistErr string) { + if strings.TrimSpace(password) == "" { + return + } + + if persisted { + fmt.Println("[CyberStrikeAI] ✅ 已为您自动生成并写入 Web 登录密码。") + } else { + if persistErr != "" { + fmt.Printf("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码: %s\n", persistErr) + } else { + fmt.Println("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码。") + } + fmt.Println("请手动将以下随机密码写入 config.yaml 的 auth.password:") + } + + fmt.Println("----------------------------------------------------------------") + fmt.Println("CyberStrikeAI Auto-Generated Web Password") + fmt.Printf("Password: %s\n", password) + fmt.Println("WARNING: Anyone with this password can fully control CyberStrikeAI.") + fmt.Println("Please store it securely and change it in config.yaml as soon as possible.") + fmt.Println("警告:持有此密码的人将拥有对 CyberStrikeAI 的完全控制权限。") + fmt.Println("请妥善保管,并尽快在 config.yaml 中修改 auth.password!") + fmt.Println("----------------------------------------------------------------") +} + // LoadToolsFromDir 从目录加载所有工具配置文件 func LoadToolsFromDir(dir string) ([]ToolConfig, error) { var tools []ToolConfig - + // 检查目录是否存在 if _, err := os.Stat(dir); os.IsNotExist(err) { return tools, nil // 目录不存在时返回空列表,不报错 } - + // 读取目录中的所有 .yaml 和 .yml 文件 entries, err := os.ReadDir(dir) if err != nil { return nil, fmt.Errorf("读取工具目录失败: %w", err) } - + for _, entry := range entries { if entry.IsDir() { continue } - + name := entry.Name() if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") { continue } - + filePath := filepath.Join(dir, name) tool, err := LoadToolFromFile(filePath) if err != nil { @@ -156,10 +290,10 @@ func LoadToolsFromDir(dir string) ([]ToolConfig, error) { fmt.Printf("警告: 加载工具配置文件 %s 失败: %v\n", filePath, err) continue } - + tools = append(tools, *tool) } - + return tools, nil } @@ -169,12 +303,12 @@ func LoadToolFromFile(path string) (*ToolConfig, error) { if err != nil { return nil, fmt.Errorf("读取文件失败: %w", err) } - + var tool ToolConfig if err := yaml.Unmarshal(data, &tool); err != nil { return nil, fmt.Errorf("解析工具配置失败: %w", err) } - + // 验证必需字段 if tool.Name == "" { return nil, fmt.Errorf("工具名称不能为空") @@ -182,7 +316,7 @@ func LoadToolFromFile(path string) (*ToolConfig, error) { if tool.Command == "" { return nil, fmt.Errorf("工具命令不能为空") } - + return &tool, nil } @@ -215,6 +349,8 @@ func Default() *Config { Database: DatabaseConfig{ Path: "data/conversations.db", }, + Auth: AuthConfig{ + SessionDurationHours: 12, + }, } } - diff --git a/internal/handler/auth.go b/internal/handler/auth.go new file mode 100644 index 00000000..6af09008 --- /dev/null +++ b/internal/handler/auth.go @@ -0,0 +1,156 @@ +package handler + +import ( + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/security" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// AuthHandler handles authentication-related endpoints. +type AuthHandler struct { + manager *security.AuthManager + config *config.Config + configPath string + logger *zap.Logger +} + +// NewAuthHandler creates a new AuthHandler. +func NewAuthHandler(manager *security.AuthManager, cfg *config.Config, configPath string, logger *zap.Logger) *AuthHandler { + return &AuthHandler{ + manager: manager, + config: cfg, + configPath: configPath, + logger: logger, + } +} + +type loginRequest struct { + Password string `json:"password" binding:"required"` +} + +type changePasswordRequest struct { + OldPassword string `json:"oldPassword"` + NewPassword string `json:"newPassword"` +} + +// Login verifies password and returns a session token. +func (h *AuthHandler) Login(c *gin.Context) { + var req loginRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "密码不能为空"}) + return + } + + token, expiresAt, err := h.manager.Authenticate(req.Password) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "token": token, + "expires_at": expiresAt.UTC().Format(time.RFC3339), + "session_duration_hr": h.manager.SessionDurationHours(), + }) +} + +// Logout revokes the current session token. +func (h *AuthHandler) Logout(c *gin.Context) { + token := c.GetString(security.ContextAuthTokenKey) + if token == "" { + authHeader := c.GetHeader("Authorization") + if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") { + token = strings.TrimSpace(authHeader[7:]) + } else { + token = strings.TrimSpace(authHeader) + } + } + + h.manager.RevokeToken(token) + c.JSON(http.StatusOK, gin.H{"message": "已退出登录"}) +} + +// ChangePassword updates the login password. +func (h *AuthHandler) ChangePassword(c *gin.Context) { + var req changePasswordRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "参数无效"}) + return + } + + oldPassword := strings.TrimSpace(req.OldPassword) + newPassword := strings.TrimSpace(req.NewPassword) + + if oldPassword == "" || newPassword == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码和新密码均不能为空"}) + return + } + + if len(newPassword) < 8 { + c.JSON(http.StatusBadRequest, gin.H{"error": "新密码长度至少需要 8 位"}) + return + } + + if oldPassword == newPassword { + c.JSON(http.StatusBadRequest, gin.H{"error": "新密码不能与旧密码相同"}) + return + } + + if !h.manager.CheckPassword(oldPassword) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "当前密码不正确"}) + return + } + + if err := config.PersistAuthPassword(h.configPath, newPassword); err != nil { + if h.logger != nil { + h.logger.Error("保存新密码失败", zap.Error(err)) + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "保存新密码失败,请重试"}) + return + } + + if err := h.manager.UpdateConfig(newPassword, h.config.Auth.SessionDurationHours); err != nil { + if h.logger != nil { + h.logger.Error("更新认证配置失败", zap.Error(err)) + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "更新认证配置失败"}) + return + } + + h.config.Auth.Password = newPassword + h.config.Auth.GeneratedPassword = "" + h.config.Auth.GeneratedPasswordPersisted = false + h.config.Auth.GeneratedPasswordPersistErr = "" + + if h.logger != nil { + h.logger.Info("登录密码已更新,所有会话已失效") + } + + c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"}) +} + +// Validate returns the current session status. +func (h *AuthHandler) Validate(c *gin.Context) { + token := c.GetString(security.ContextAuthTokenKey) + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "会话无效"}) + return + } + + session, ok := h.manager.ValidateToken(token) + if !ok { + c.JSON(http.StatusUnauthorized, gin.H{"error": "会话已过期"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "token": session.Token, + "expires_at": session.ExpiresAt.UTC().Format(time.RFC3339), + }) +} diff --git a/internal/security/auth_manager.go b/internal/security/auth_manager.go new file mode 100644 index 00000000..3b9bd17b --- /dev/null +++ b/internal/security/auth_manager.go @@ -0,0 +1,132 @@ +package security + +import ( + "errors" + "strings" + "sync" + "time" + + "github.com/google/uuid" +) + +// Predefined errors for authentication operations. +var ( + ErrInvalidPassword = errors.New("invalid password") +) + +// Session represents an authenticated user session. +type Session struct { + Token string + ExpiresAt time.Time +} + +// AuthManager manages password-based authentication and session lifecycle. +type AuthManager struct { + password string + sessionDuration time.Duration + + mu sync.RWMutex + sessions map[string]Session +} + +// NewAuthManager creates a new AuthManager instance. +func NewAuthManager(password string, sessionDurationHours int) (*AuthManager, error) { + if strings.TrimSpace(password) == "" { + return nil, errors.New("auth password must be configured") + } + + if sessionDurationHours <= 0 { + sessionDurationHours = 12 + } + + return &AuthManager{ + password: password, + sessionDuration: time.Duration(sessionDurationHours) * time.Hour, + sessions: make(map[string]Session), + }, nil +} + +// Authenticate validates the password and creates a new session. +func (a *AuthManager) Authenticate(password string) (string, time.Time, error) { + if password != a.password { + return "", time.Time{}, ErrInvalidPassword + } + + token := uuid.NewString() + expiresAt := time.Now().Add(a.sessionDuration) + + a.mu.Lock() + a.sessions[token] = Session{ + Token: token, + ExpiresAt: expiresAt, + } + a.mu.Unlock() + + return token, expiresAt, nil +} + +// ValidateToken checks whether the provided token is still valid. +func (a *AuthManager) ValidateToken(token string) (Session, bool) { + if strings.TrimSpace(token) == "" { + return Session{}, false + } + + a.mu.RLock() + session, ok := a.sessions[token] + a.mu.RUnlock() + if !ok { + return Session{}, false + } + + if time.Now().After(session.ExpiresAt) { + a.mu.Lock() + delete(a.sessions, token) + a.mu.Unlock() + return Session{}, false + } + + return session, true +} + +// CheckPassword verifies whether the provided password matches the current password. +func (a *AuthManager) CheckPassword(password string) bool { + a.mu.RLock() + defer a.mu.RUnlock() + return password == a.password +} + +// RevokeToken invalidates the specified token. +func (a *AuthManager) RevokeToken(token string) { + if strings.TrimSpace(token) == "" { + return + } + + a.mu.Lock() + delete(a.sessions, token) + a.mu.Unlock() +} + +// SessionDurationHours returns the configured session duration in hours. +func (a *AuthManager) SessionDurationHours() int { + return int(a.sessionDuration / time.Hour) +} + +// UpdateConfig updates the password and session duration, revoking existing sessions. +func (a *AuthManager) UpdateConfig(password string, sessionDurationHours int) error { + password = strings.TrimSpace(password) + if password == "" { + return errors.New("auth password must be configured") + } + + if sessionDurationHours <= 0 { + sessionDurationHours = 12 + } + + a.mu.Lock() + defer a.mu.Unlock() + + a.password = password + a.sessionDuration = time.Duration(sessionDurationHours) * time.Hour + a.sessions = make(map[string]Session) + return nil +} diff --git a/internal/security/auth_middleware.go b/internal/security/auth_middleware.go new file mode 100644 index 00000000..e7924a7a --- /dev/null +++ b/internal/security/auth_middleware.go @@ -0,0 +1,51 @@ +package security + +import ( + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) + +const ( + ContextAuthTokenKey = "authToken" + ContextSessionExpiry = "authSessionExpiry" +) + +// AuthMiddleware enforces authentication on protected routes. +func AuthMiddleware(manager *AuthManager) gin.HandlerFunc { + return func(c *gin.Context) { + token := extractTokenFromRequest(c) + session, ok := manager.ValidateToken(token) + if !ok { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "error": "未授权访问,请先登录", + }) + return + } + + c.Set(ContextAuthTokenKey, session.Token) + c.Set(ContextSessionExpiry, session.ExpiresAt) + c.Next() + } +} + +func extractTokenFromRequest(c *gin.Context) string { + authHeader := c.GetHeader("Authorization") + if authHeader != "" { + if len(authHeader) > 7 && strings.EqualFold(authHeader[0:7], "Bearer ") { + return strings.TrimSpace(authHeader[7:]) + } + return strings.TrimSpace(authHeader) + } + + if token := c.Query("token"); token != "" { + return strings.TrimSpace(token) + } + + if cookie, err := c.Cookie("auth_token"); err == nil { + return strings.TrimSpace(cookie) + } + + return "" +} diff --git a/web/static/css/style.css b/web/static/css/style.css index 7d6500a4..c0009fa7 100644 --- a/web/static/css/style.css +++ b/web/static/css/style.css @@ -56,7 +56,7 @@ body { header { background: var(--primary-color); color: white; - padding: 24px 32px; + padding: 16px 24px; border-bottom: 1px solid rgba(255, 255, 255, 0.1); flex-shrink: 0; } @@ -289,7 +289,7 @@ header { flex-direction: column; flex: 1; min-width: 0; - background: var(--bg-primary); + background: var(--bg-secondary); overflow: hidden; height: 100%; } @@ -755,6 +755,67 @@ header { transform: translateY(0); } +/* 登录遮罩 */ +.login-overlay { + position: fixed; + inset: 0; + background: rgba(0, 0, 0, 0.6); + backdrop-filter: blur(6px); + display: none; + align-items: center; + justify-content: center; + z-index: 1200; + padding: 24px; +} + +.login-card { + width: 100%; + max-width: 360px; + background: var(--bg-primary); + border-radius: 12px; + padding: 32px 28px; + box-shadow: var(--shadow-lg); + border: 1px solid var(--border-color); + display: flex; + flex-direction: column; + gap: 20px; +} + +.login-header h2 { + margin: 0; + font-size: 1.5rem; + color: var(--text-primary); +} + +.login-subtitle { + margin: 8px 0 0 0; + font-size: 0.9375rem; + color: var(--text-secondary); +} + +.login-form { + display: flex; + flex-direction: column; + gap: 16px; +} + +.login-error { + color: var(--error-color); + background: rgba(220, 53, 69, 0.08); + border: 1px solid rgba(220, 53, 69, 0.4); + border-radius: 6px; + padding: 10px 12px; + font-size: 0.875rem; +} + +.login-submit { + width: 100%; + justify-content: center; + display: inline-flex; + align-items: center; + gap: 8px; +} + /* 模态框样式 */ .modal { display: none; @@ -1251,9 +1312,12 @@ header { display: none; align-items: center; gap: 12px; - padding: 12px 20px; - background: rgba(0, 102, 255, 0.06); - border-bottom: 1px solid rgba(0, 102, 255, 0.15); + padding: 10px 16px; + margin: 12px 0; + background: var(--bg-primary); + border: 1px solid rgba(0, 102, 255, 0.15); + border-radius: 10px; + box-shadow: var(--shadow-sm); color: var(--text-primary); } @@ -1261,20 +1325,22 @@ header { display: flex; align-items: center; justify-content: space-between; - gap: 16px; + gap: 12px; background: var(--bg-primary); border: 1px solid rgba(0, 102, 255, 0.2); border-radius: 8px; padding: 8px 12px; flex: 1; min-width: 0; + box-shadow: inset 0 1px 1px rgba(0, 0, 0, 0.03); } .active-task-info { display: flex; align-items: center; - gap: 8px; + gap: 6px; min-width: 0; + flex-wrap: wrap; } .active-task-status { @@ -1288,7 +1354,7 @@ header { } .active-task-message { - font-size: 0.875rem; + font-size: 0.85rem; color: var(--text-primary); overflow: hidden; text-overflow: ellipsis; @@ -1299,7 +1365,7 @@ header { .active-task-actions { display: flex; align-items: center; - gap: 10px; + gap: 8px; flex-shrink: 0; } @@ -1409,6 +1475,19 @@ header { box-shadow: 0 0 0 3px rgba(220, 53, 69, 0.2); } +.form-actions { + display: flex; + justify-content: flex-end; + gap: 12px; + margin-top: 4px; +} + +.password-hint { + font-size: 0.8125rem; + color: var(--text-muted); + margin-top: 8px; +} + .tools-controls { display: flex; flex-direction: column; diff --git a/web/static/js/app.js b/web/static/js/app.js index b52d5511..5680abbd 100644 --- a/web/static/js/app.js +++ b/web/static/js/app.js @@ -1,3 +1,9 @@ +const AUTH_STORAGE_KEY = 'cyberstrike-auth'; +let authToken = null; +let authTokenExpiry = null; +let authPromise = null; +let authPromiseResolvers = []; +let isAppInitialized = false; // 当前对话ID let currentConversationId = null; @@ -7,6 +13,280 @@ const progressTaskState = new Map(); let activeTaskInterval = null; const ACTIVE_TASK_REFRESH_INTERVAL = 10000; // 10秒检查一次,提供更实时的任务状态反馈 +function isTokenValid() { + return !!authToken && authTokenExpiry instanceof Date && authTokenExpiry.getTime() > Date.now(); +} + +function saveAuth(token, expiresAt) { + const expiry = expiresAt instanceof Date ? expiresAt : new Date(expiresAt); + authToken = token; + authTokenExpiry = expiry; + try { + localStorage.setItem(AUTH_STORAGE_KEY, JSON.stringify({ + token, + expiresAt: expiry.toISOString(), + })); + } catch (error) { + console.warn('无法持久化认证信息:', error); + } +} + +function clearAuthStorage() { + authToken = null; + authTokenExpiry = null; + try { + localStorage.removeItem(AUTH_STORAGE_KEY); + } catch (error) { + console.warn('无法清除认证信息:', error); + } +} + +function loadAuthFromStorage() { + try { + const raw = localStorage.getItem(AUTH_STORAGE_KEY); + if (!raw) { + return false; + } + const stored = JSON.parse(raw); + if (!stored.token || !stored.expiresAt) { + clearAuthStorage(); + return false; + } + const expiry = new Date(stored.expiresAt); + if (Number.isNaN(expiry.getTime())) { + clearAuthStorage(); + return false; + } + authToken = stored.token; + authTokenExpiry = expiry; + return isTokenValid(); + } catch (error) { + console.error('读取认证信息失败:', error); + clearAuthStorage(); + return false; + } +} + +function resolveAuthPromises(success) { + authPromiseResolvers.forEach(resolve => resolve(success)); + authPromiseResolvers = []; + authPromise = null; +} + +function showLoginOverlay(message = '') { + const overlay = document.getElementById('login-overlay'); + const errorBox = document.getElementById('login-error'); + const passwordInput = document.getElementById('login-password'); + if (!overlay) { + return; + } + overlay.style.display = 'flex'; + if (errorBox) { + if (message) { + errorBox.textContent = message; + errorBox.style.display = 'block'; + } else { + errorBox.textContent = ''; + errorBox.style.display = 'none'; + } + } + setTimeout(() => { + if (passwordInput) { + passwordInput.focus(); + } + }, 100); +} + +function hideLoginOverlay() { + const overlay = document.getElementById('login-overlay'); + const errorBox = document.getElementById('login-error'); + const passwordInput = document.getElementById('login-password'); + if (overlay) { + overlay.style.display = 'none'; + } + if (errorBox) { + errorBox.textContent = ''; + errorBox.style.display = 'none'; + } + if (passwordInput) { + passwordInput.value = ''; + } +} + +function ensureAuthPromise() { + if (!authPromise) { + authPromise = new Promise(resolve => { + authPromiseResolvers.push(resolve); + }); + } + return authPromise; +} + +async function ensureAuthenticated() { + if (isTokenValid()) { + return true; + } + showLoginOverlay(); + await ensureAuthPromise(); + return true; +} + +function handleUnauthorized({ message = '认证已过期,请重新登录', silent = false } = {}) { + clearAuthStorage(); + authPromise = null; + authPromiseResolvers = []; + if (!silent) { + showLoginOverlay(message); + } else { + showLoginOverlay(); + } + return false; +} + +async function apiFetch(url, options = {}) { + await ensureAuthenticated(); + const opts = { ...options }; + const headers = new Headers(options && options.headers ? options.headers : undefined); + if (authToken && !headers.has('Authorization')) { + headers.set('Authorization', `Bearer ${authToken}`); + } + opts.headers = headers; + + const response = await fetch(url, opts); + if (response.status === 401) { + handleUnauthorized(); + throw new Error('未授权访问'); + } + return response; +} + +async function submitLogin(event) { + event.preventDefault(); + const passwordInput = document.getElementById('login-password'); + const errorBox = document.getElementById('login-error'); + const submitBtn = document.querySelector('.login-submit'); + + if (!passwordInput) { + return; + } + + const password = passwordInput.value.trim(); + if (!password) { + if (errorBox) { + errorBox.textContent = '请输入密码'; + errorBox.style.display = 'block'; + } + return; + } + + if (submitBtn) { + submitBtn.disabled = true; + } + + try { + const response = await fetch('/api/auth/login', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ password }), + }); + const result = await response.json().catch(() => ({})); + if (!response.ok || !result.token) { + if (errorBox) { + errorBox.textContent = result.error || '登录失败,请检查密码'; + errorBox.style.display = 'block'; + } + return; + } + + saveAuth(result.token, result.expires_at); + hideLoginOverlay(); + resolveAuthPromises(true); + if (!isAppInitialized) { + await bootstrapApp(); + } else { + await refreshAppData(); + } + } catch (error) { + console.error('登录失败:', error); + if (errorBox) { + errorBox.textContent = '登录失败,请稍后重试'; + errorBox.style.display = 'block'; + } + } finally { + if (submitBtn) { + submitBtn.disabled = false; + } + } +} + +async function refreshAppData(showTaskErrors = false) { + await Promise.allSettled([ + loadConversations(), + loadActiveTasks(showTaskErrors), + ]); +} + +async function bootstrapApp() { + if (!isAppInitialized) { + initializeChatUI(); + isAppInitialized = true; + } + await refreshAppData(); +} + +function initializeChatUI() { + const chatInputEl = document.getElementById('chat-input'); + if (chatInputEl) { + chatInputEl.style.height = '44px'; + } + + const messagesDiv = document.getElementById('chat-messages'); + if (messagesDiv && messagesDiv.childElementCount === 0) { + addMessage('assistant', '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。'); + } + + loadActiveTasks(true); + if (activeTaskInterval) { + clearInterval(activeTaskInterval); + } + activeTaskInterval = setInterval(() => loadActiveTasks(), ACTIVE_TASK_REFRESH_INTERVAL); +} + +function setupLoginUI() { + const loginForm = document.getElementById('login-form'); + if (loginForm) { + loginForm.addEventListener('submit', submitLogin); + } +} + +async function initializeApp() { + setupLoginUI(); + const hasStoredAuth = loadAuthFromStorage(); + if (hasStoredAuth && isTokenValid()) { + try { + const response = await apiFetch('/api/auth/validate', { + method: 'GET', + }); + if (response.ok) { + hideLoginOverlay(); + resolveAuthPromises(true); + await bootstrapApp(); + return; + } + } catch (error) { + console.warn('本地会话已失效,需重新登录'); + } + } + + clearAuthStorage(); + showLoginOverlay(); +} + +document.addEventListener('DOMContentLoaded', initializeApp); + + function registerProgressTask(progressId, conversationId = null) { const state = progressTaskState.get(progressId) || {}; state.conversationId = conversationId !== undefined && conversationId !== null @@ -45,7 +325,7 @@ function finalizeProgressTask(progressId, finalLabel = '已完成') { } async function requestCancel(conversationId) { - const response = await fetch('/api/agent-loop/cancel', { + const response = await apiFetch('/api/agent-loop/cancel', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -81,7 +361,7 @@ async function sendMessage() { let mcpExecutionIds = []; try { - const response = await fetch('/api/agent-loop/stream', { + const response = await apiFetch('/api/agent-loop/stream', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -950,7 +1230,7 @@ chatInput.addEventListener('keydown', function(e) { // 显示MCP调用详情 async function showMCPDetail(executionId) { try { - const response = await fetch(`/api/monitor/execution/${executionId}`); + const response = await apiFetch(`/api/monitor/execution/${executionId}`); const exec = await response.json(); if (response.ok) { @@ -1064,7 +1344,7 @@ function startNewConversation() { // 加载对话列表 async function loadConversations() { try { - const response = await fetch('/api/conversations?limit=50'); + const response = await apiFetch('/api/conversations?limit=50'); const conversations = await response.json(); const listContainer = document.getElementById('conversations-list'); @@ -1179,7 +1459,7 @@ async function loadConversations() { // 加载对话 async function loadConversation(conversationId) { try { - const response = await fetch(`/api/conversations/${conversationId}`); + const response = await apiFetch(`/api/conversations/${conversationId}`); const conversation = await response.json(); if (!response.ok) { @@ -1230,7 +1510,7 @@ async function deleteConversation(conversationId) { } try { - const response = await fetch(`/api/conversations/${conversationId}`, { + const response = await apiFetch(`/api/conversations/${conversationId}`, { method: 'DELETE' }); @@ -1268,7 +1548,7 @@ function updateActiveConversation() { async function loadActiveTasks(showErrors = false) { const bar = document.getElementById('active-tasks-bar'); try { - const response = await fetch('/api/agent-loop/tasks'); + const response = await apiFetch('/api/agent-loop/tasks'); const result = await response.json().catch(() => ({})); if (!response.ok) { @@ -1386,7 +1666,7 @@ window.onclick = function(event) { // 加载配置 async function loadConfig() { try { - const response = await fetch('/api/config'); + const response = await apiFetch('/api/config'); if (!response.ok) { throw new Error('获取配置失败'); } @@ -1520,7 +1800,7 @@ async function applySettings() { }); // 更新配置 - const updateResponse = await fetch('/api/config', { + const updateResponse = await apiFetch('/api/config', { method: 'PUT', headers: { 'Content-Type': 'application/json' @@ -1534,7 +1814,7 @@ async function applySettings() { } // 应用配置 - const applyResponse = await fetch('/api/config/apply', { + const applyResponse = await apiFetch('/api/config/apply', { method: 'POST' }); @@ -1551,24 +1831,85 @@ async function applySettings() { } } -// 页面加载时初始化 -document.addEventListener('DOMContentLoaded', function() { - // 加载对话列表 - loadConversations(); - - // 初始化 textarea 高度 - const chatInput = document.getElementById('chat-input'); - if (chatInput) { - chatInput.style.height = '44px'; - } - - // 添加欢迎消息 - addMessage('assistant', '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。'); - - loadActiveTasks(true); - if (activeTaskInterval) { - clearInterval(activeTaskInterval); - } - activeTaskInterval = setInterval(() => loadActiveTasks(), ACTIVE_TASK_REFRESH_INTERVAL); -}); +function resetPasswordForm() { + const currentInput = document.getElementById('auth-current-password'); + const newInput = document.getElementById('auth-new-password'); + const confirmInput = document.getElementById('auth-confirm-password'); + + [currentInput, newInput, confirmInput].forEach(input => { + if (input) { + input.value = ''; + input.classList.remove('error'); + } + }); +} + +async function changePassword() { + const currentInput = document.getElementById('auth-current-password'); + const newInput = document.getElementById('auth-new-password'); + const confirmInput = document.getElementById('auth-confirm-password'); + const submitBtn = document.querySelector('.change-password-submit'); + + [currentInput, newInput, confirmInput].forEach(input => input && input.classList.remove('error')); + + const currentPassword = currentInput?.value.trim() || ''; + const newPassword = newInput?.value.trim() || ''; + const confirmPassword = confirmInput?.value.trim() || ''; + + let hasError = false; + + if (!currentPassword) { + currentInput?.classList.add('error'); + hasError = true; + } + + if (!newPassword || newPassword.length < 8) { + newInput?.classList.add('error'); + hasError = true; + } + + if (newPassword !== confirmPassword) { + confirmInput?.classList.add('error'); + hasError = true; + } + + if (hasError) { + alert('请正确填写当前密码和新密码,新密码至少 8 位且需要两次输入一致。'); + return; + } + + if (submitBtn) { + submitBtn.disabled = true; + } + + try { + const response = await apiFetch('/api/auth/change-password', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + oldPassword: currentPassword, + newPassword: newPassword + }) + }); + + const result = await response.json().catch(() => ({})); + if (!response.ok) { + throw new Error(result.error || '修改密码失败'); + } + + alert('密码已更新,请使用新密码重新登录。'); + resetPasswordForm(); + handleUnauthorized({ message: '密码已更新,请使用新密码重新登录。', silent: false }); + closeSettings(); + } catch (error) { + console.error('修改密码失败:', error); + alert('修改密码失败: ' + error.message); + } finally { + if (submitBtn) { + submitBtn.disabled = false; + } + } +} diff --git a/web/templates/index.html b/web/templates/index.html index 75ad0bd4..2eff8884 100644 --- a/web/templates/index.html +++ b/web/templates/index.html @@ -7,6 +7,23 @@ + +
@@ -106,6 +123,30 @@
+ + +
+

安全设置

+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+

修改密码后,需要使用新密码重新登录。

+
+