mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-02 18:55:52 +02:00
Add files via upload
This commit is contained in:
+69
-23
@@ -27,6 +27,7 @@ type App struct {
|
||||
agent *agent.Agent
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
auth *security.AuthManager
|
||||
}
|
||||
|
||||
// New 创建新应用
|
||||
@@ -37,6 +38,12 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
// CORS中间件
|
||||
router.Use(corsMiddleware())
|
||||
|
||||
// 认证管理器
|
||||
authManager, err := security.NewAuthManager(cfg.Auth.Password, cfg.Auth.SessionDurationHours)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化认证失败: %w", err)
|
||||
}
|
||||
|
||||
// 初始化数据库
|
||||
dbPath := cfg.Database.Path
|
||||
if dbPath == "" {
|
||||
@@ -62,6 +69,13 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
// 注册工具
|
||||
executor.RegisterTools(mcpServer)
|
||||
|
||||
if cfg.Auth.GeneratedPassword != "" {
|
||||
config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr)
|
||||
cfg.Auth.GeneratedPassword = ""
|
||||
cfg.Auth.GeneratedPasswordPersisted = false
|
||||
cfg.Auth.GeneratedPasswordPersistErr = ""
|
||||
}
|
||||
|
||||
// 创建Agent
|
||||
maxIterations := cfg.Agent.MaxIterations
|
||||
if maxIterations <= 0 {
|
||||
@@ -69,20 +83,30 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
}
|
||||
agent := agent.NewAgent(&cfg.OpenAI, mcpServer, log.Logger, maxIterations)
|
||||
|
||||
// 创建处理器
|
||||
agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
|
||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, log.Logger)
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
|
||||
// 获取配置文件路径
|
||||
configPath := "config.yaml"
|
||||
if len(os.Args) > 1 {
|
||||
configPath = os.Args[1]
|
||||
}
|
||||
|
||||
// 创建处理器
|
||||
agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
|
||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, log.Logger)
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, log.Logger)
|
||||
|
||||
// 设置路由
|
||||
setupRoutes(router, agentHandler, monitorHandler, conversationHandler, configHandler, mcpServer)
|
||||
setupRoutes(
|
||||
router,
|
||||
authHandler,
|
||||
agentHandler,
|
||||
monitorHandler,
|
||||
conversationHandler,
|
||||
configHandler,
|
||||
mcpServer,
|
||||
authManager,
|
||||
)
|
||||
|
||||
return &App{
|
||||
config: cfg,
|
||||
@@ -92,6 +116,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
agent: agent,
|
||||
executor: executor,
|
||||
db: db,
|
||||
auth: authManager,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -120,37 +145,58 @@ func (a *App) Run() error {
|
||||
}
|
||||
|
||||
// setupRoutes 设置路由
|
||||
func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitorHandler *handler.MonitorHandler, conversationHandler *handler.ConversationHandler, configHandler *handler.ConfigHandler, mcpServer *mcp.Server) {
|
||||
func setupRoutes(
|
||||
router *gin.Engine,
|
||||
authHandler *handler.AuthHandler,
|
||||
agentHandler *handler.AgentHandler,
|
||||
monitorHandler *handler.MonitorHandler,
|
||||
conversationHandler *handler.ConversationHandler,
|
||||
configHandler *handler.ConfigHandler,
|
||||
mcpServer *mcp.Server,
|
||||
authManager *security.AuthManager,
|
||||
) {
|
||||
// API路由
|
||||
api := router.Group("/api")
|
||||
|
||||
// 认证相关路由
|
||||
authRoutes := api.Group("/auth")
|
||||
{
|
||||
authRoutes.POST("/login", authHandler.Login)
|
||||
authRoutes.POST("/logout", security.AuthMiddleware(authManager), authHandler.Logout)
|
||||
authRoutes.POST("/change-password", security.AuthMiddleware(authManager), authHandler.ChangePassword)
|
||||
authRoutes.GET("/validate", security.AuthMiddleware(authManager), authHandler.Validate)
|
||||
}
|
||||
|
||||
protected := api.Group("")
|
||||
protected.Use(security.AuthMiddleware(authManager))
|
||||
{
|
||||
// Agent Loop
|
||||
api.POST("/agent-loop", agentHandler.AgentLoop)
|
||||
protected.POST("/agent-loop", agentHandler.AgentLoop)
|
||||
// Agent Loop 流式输出
|
||||
api.POST("/agent-loop/stream", agentHandler.AgentLoopStream)
|
||||
protected.POST("/agent-loop/stream", agentHandler.AgentLoopStream)
|
||||
// Agent Loop 取消与任务列表
|
||||
api.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
|
||||
api.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
|
||||
protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
|
||||
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
|
||||
|
||||
// 对话历史
|
||||
api.POST("/conversations", conversationHandler.CreateConversation)
|
||||
api.GET("/conversations", conversationHandler.ListConversations)
|
||||
api.GET("/conversations/:id", conversationHandler.GetConversation)
|
||||
api.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
|
||||
protected.POST("/conversations", conversationHandler.CreateConversation)
|
||||
protected.GET("/conversations", conversationHandler.ListConversations)
|
||||
protected.GET("/conversations/:id", conversationHandler.GetConversation)
|
||||
protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
|
||||
|
||||
// 监控
|
||||
api.GET("/monitor", monitorHandler.Monitor)
|
||||
api.GET("/monitor/execution/:id", monitorHandler.GetExecution)
|
||||
api.GET("/monitor/stats", monitorHandler.GetStats)
|
||||
api.GET("/monitor/vulnerabilities", monitorHandler.GetVulnerabilities)
|
||||
protected.GET("/monitor", monitorHandler.Monitor)
|
||||
protected.GET("/monitor/execution/:id", monitorHandler.GetExecution)
|
||||
protected.GET("/monitor/stats", monitorHandler.GetStats)
|
||||
protected.GET("/monitor/vulnerabilities", monitorHandler.GetVulnerabilities)
|
||||
|
||||
// 配置管理
|
||||
api.GET("/config", configHandler.GetConfig)
|
||||
api.PUT("/config", configHandler.UpdateConfig)
|
||||
api.POST("/config/apply", configHandler.ApplyConfig)
|
||||
protected.GET("/config", configHandler.GetConfig)
|
||||
protected.PUT("/config", configHandler.UpdateConfig)
|
||||
protected.POST("/config/apply", configHandler.ApplyConfig)
|
||||
|
||||
// MCP端点
|
||||
api.POST("/mcp", func(c *gin.Context) {
|
||||
protected.POST("/mcp", func(c *gin.Context) {
|
||||
mcpServer.HandleHTTP(c.Writer, c.Request)
|
||||
})
|
||||
}
|
||||
|
||||
+167
-31
@@ -1,6 +1,8 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -17,6 +19,7 @@ type Config struct {
|
||||
Agent AgentConfig `yaml:"agent"`
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@@ -42,8 +45,8 @@ type OpenAIConfig struct {
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具
|
||||
ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式)
|
||||
Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具
|
||||
ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式)
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
@@ -54,29 +57,36 @@ type AgentConfig struct {
|
||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
Password string `yaml:"password" json:"password"`
|
||||
SessionDurationHours int `yaml:"session_duration_hours" json:"session_duration_hours"`
|
||||
GeneratedPassword string `yaml:"-" json:"-"`
|
||||
GeneratedPasswordPersisted bool `yaml:"-" json:"-"`
|
||||
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
|
||||
}
|
||||
type ToolConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Command string `yaml:"command"`
|
||||
Args []string `yaml:"args,omitempty"` // 固定参数(可选)
|
||||
Args []string `yaml:"args,omitempty"` // 固定参数(可选)
|
||||
ShortDescription string `yaml:"short_description,omitempty"` // 简短描述(用于工具列表,减少token消耗)
|
||||
Description string `yaml:"description"` // 详细描述(用于工具文档)
|
||||
Description string `yaml:"description"` // 详细描述(用于工具文档)
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
|
||||
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
|
||||
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
|
||||
}
|
||||
|
||||
// ParameterConfig 参数配置
|
||||
type ParameterConfig struct {
|
||||
Name string `yaml:"name"` // 参数名称
|
||||
Type string `yaml:"type"` // 参数类型: string, int, bool, array
|
||||
Description string `yaml:"description"` // 参数描述
|
||||
Required bool `yaml:"required,omitempty"` // 是否必需
|
||||
Default interface{} `yaml:"default,omitempty"` // 默认值
|
||||
Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p"
|
||||
Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始)
|
||||
Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template"
|
||||
Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}"
|
||||
Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举)
|
||||
Name string `yaml:"name"` // 参数名称
|
||||
Type string `yaml:"type"` // 参数类型: string, int, bool, array
|
||||
Description string `yaml:"description"` // 参数描述
|
||||
Required bool `yaml:"required,omitempty"` // 是否必需
|
||||
Default interface{} `yaml:"default,omitempty"` // 默认值
|
||||
Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p"
|
||||
Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始)
|
||||
Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template"
|
||||
Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}"
|
||||
Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举)
|
||||
}
|
||||
|
||||
func Load(path string) (*Config, error) {
|
||||
@@ -90,65 +100,189 @@ func Load(path string) (*Config, error) {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Auth.SessionDurationHours <= 0 {
|
||||
cfg.Auth.SessionDurationHours = 12
|
||||
}
|
||||
|
||||
if strings.TrimSpace(cfg.Auth.Password) == "" {
|
||||
password, err := generateStrongPassword(24)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成默认密码失败: %w", err)
|
||||
}
|
||||
|
||||
cfg.Auth.Password = password
|
||||
cfg.Auth.GeneratedPassword = password
|
||||
|
||||
if err := PersistAuthPassword(path, password); err != nil {
|
||||
cfg.Auth.GeneratedPasswordPersisted = false
|
||||
cfg.Auth.GeneratedPasswordPersistErr = err.Error()
|
||||
} else {
|
||||
cfg.Auth.GeneratedPasswordPersisted = true
|
||||
}
|
||||
}
|
||||
|
||||
// 如果配置了工具目录,从目录加载工具配置
|
||||
if cfg.Security.ToolsDir != "" {
|
||||
configDir := filepath.Dir(path)
|
||||
toolsDir := cfg.Security.ToolsDir
|
||||
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(toolsDir) {
|
||||
toolsDir = filepath.Join(configDir, toolsDir)
|
||||
}
|
||||
|
||||
|
||||
tools, err := LoadToolsFromDir(toolsDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// 合并工具配置:目录中的工具优先,主配置中的工具作为补充
|
||||
existingTools := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
existingTools[tool.Name] = true
|
||||
}
|
||||
|
||||
|
||||
// 添加主配置中不存在于目录中的工具(向后兼容)
|
||||
for _, tool := range cfg.Security.Tools {
|
||||
if !existingTools[tool.Name] {
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
cfg.Security.Tools = tools
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func generateStrongPassword(length int) (string, error) {
|
||||
if length <= 0 {
|
||||
length = 24
|
||||
}
|
||||
|
||||
bytesLen := length
|
||||
randomBytes := make([]byte, bytesLen)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
password := base64.RawURLEncoding.EncodeToString(randomBytes)
|
||||
if len(password) > length {
|
||||
password = password[:length]
|
||||
}
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func PersistAuthPassword(path, password string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
inAuthBlock := false
|
||||
authIndent := -1
|
||||
|
||||
for i, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if !inAuthBlock {
|
||||
if strings.HasPrefix(trimmed, "auth:") {
|
||||
inAuthBlock = true
|
||||
authIndent = len(line) - len(strings.TrimLeft(line, " "))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
leadingSpaces := len(line) - len(strings.TrimLeft(line, " "))
|
||||
if leadingSpaces <= authIndent {
|
||||
// 离开 auth 块
|
||||
inAuthBlock = false
|
||||
authIndent = -1
|
||||
// 继续寻找其它 auth 块(理论上没有)
|
||||
if strings.HasPrefix(trimmed, "auth:") {
|
||||
inAuthBlock = true
|
||||
authIndent = leadingSpaces
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(strings.TrimSpace(line), "password:") {
|
||||
prefix := line[:len(line)-len(strings.TrimLeft(line, " "))]
|
||||
comment := ""
|
||||
if idx := strings.Index(line, "#"); idx >= 0 {
|
||||
comment = strings.TrimRight(line[idx:], " ")
|
||||
}
|
||||
|
||||
newLine := fmt.Sprintf("%spassword: %s", prefix, password)
|
||||
if comment != "" {
|
||||
if !strings.HasPrefix(comment, " ") {
|
||||
newLine += " "
|
||||
}
|
||||
newLine += comment
|
||||
}
|
||||
lines[i] = newLine
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644)
|
||||
}
|
||||
|
||||
func PrintGeneratedPasswordWarning(password string, persisted bool, persistErr string) {
|
||||
if strings.TrimSpace(password) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if persisted {
|
||||
fmt.Println("[CyberStrikeAI] ✅ 已为您自动生成并写入 Web 登录密码。")
|
||||
} else {
|
||||
if persistErr != "" {
|
||||
fmt.Printf("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码: %s\n", persistErr)
|
||||
} else {
|
||||
fmt.Println("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码。")
|
||||
}
|
||||
fmt.Println("请手动将以下随机密码写入 config.yaml 的 auth.password:")
|
||||
}
|
||||
|
||||
fmt.Println("----------------------------------------------------------------")
|
||||
fmt.Println("CyberStrikeAI Auto-Generated Web Password")
|
||||
fmt.Printf("Password: %s\n", password)
|
||||
fmt.Println("WARNING: Anyone with this password can fully control CyberStrikeAI.")
|
||||
fmt.Println("Please store it securely and change it in config.yaml as soon as possible.")
|
||||
fmt.Println("警告:持有此密码的人将拥有对 CyberStrikeAI 的完全控制权限。")
|
||||
fmt.Println("请妥善保管,并尽快在 config.yaml 中修改 auth.password!")
|
||||
fmt.Println("----------------------------------------------------------------")
|
||||
}
|
||||
|
||||
// LoadToolsFromDir 从目录加载所有工具配置文件
|
||||
func LoadToolsFromDir(dir string) ([]ToolConfig, error) {
|
||||
var tools []ToolConfig
|
||||
|
||||
|
||||
// 检查目录是否存在
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
return tools, nil // 目录不存在时返回空列表,不报错
|
||||
}
|
||||
|
||||
|
||||
// 读取目录中的所有 .yaml 和 .yml 文件
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取工具目录失败: %w", err)
|
||||
}
|
||||
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
name := entry.Name()
|
||||
if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
filePath := filepath.Join(dir, name)
|
||||
tool, err := LoadToolFromFile(filePath)
|
||||
if err != nil {
|
||||
@@ -156,10 +290,10 @@ func LoadToolsFromDir(dir string) ([]ToolConfig, error) {
|
||||
fmt.Printf("警告: 加载工具配置文件 %s 失败: %v\n", filePath, err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
tools = append(tools, *tool)
|
||||
}
|
||||
|
||||
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
@@ -169,12 +303,12 @@ func LoadToolFromFile(path string) (*ToolConfig, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取文件失败: %w", err)
|
||||
}
|
||||
|
||||
|
||||
var tool ToolConfig
|
||||
if err := yaml.Unmarshal(data, &tool); err != nil {
|
||||
return nil, fmt.Errorf("解析工具配置失败: %w", err)
|
||||
}
|
||||
|
||||
|
||||
// 验证必需字段
|
||||
if tool.Name == "" {
|
||||
return nil, fmt.Errorf("工具名称不能为空")
|
||||
@@ -182,7 +316,7 @@ func LoadToolFromFile(path string) (*ToolConfig, error) {
|
||||
if tool.Command == "" {
|
||||
return nil, fmt.Errorf("工具命令不能为空")
|
||||
}
|
||||
|
||||
|
||||
return &tool, nil
|
||||
}
|
||||
|
||||
@@ -215,6 +349,8 @@ func Default() *Config {
|
||||
Database: DatabaseConfig{
|
||||
Path: "data/conversations.db",
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
SessionDurationHours: 12,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication-related endpoints.
|
||||
type AuthHandler struct {
|
||||
manager *security.AuthManager
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler.
|
||||
func NewAuthHandler(manager *security.AuthManager, cfg *config.Config, configPath string, logger *zap.Logger) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
manager: manager,
|
||||
config: cfg,
|
||||
configPath: configPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
type changePasswordRequest struct {
|
||||
OldPassword string `json:"oldPassword"`
|
||||
NewPassword string `json:"newPassword"`
|
||||
}
|
||||
|
||||
// Login verifies password and returns a session token.
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req loginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "密码不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
token, expiresAt, err := h.manager.Authenticate(req.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": token,
|
||||
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
||||
"session_duration_hr": h.manager.SessionDurationHours(),
|
||||
})
|
||||
}
|
||||
|
||||
// Logout revokes the current session token.
|
||||
func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
token := c.GetString(security.ContextAuthTokenKey)
|
||||
if token == "" {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") {
|
||||
token = strings.TrimSpace(authHeader[7:])
|
||||
} else {
|
||||
token = strings.TrimSpace(authHeader)
|
||||
}
|
||||
}
|
||||
|
||||
h.manager.RevokeToken(token)
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
|
||||
}
|
||||
|
||||
// ChangePassword updates the login password.
|
||||
func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
var req changePasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "参数无效"})
|
||||
return
|
||||
}
|
||||
|
||||
oldPassword := strings.TrimSpace(req.OldPassword)
|
||||
newPassword := strings.TrimSpace(req.NewPassword)
|
||||
|
||||
if oldPassword == "" || newPassword == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码和新密码均不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if len(newPassword) < 8 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "新密码长度至少需要 8 位"})
|
||||
return
|
||||
}
|
||||
|
||||
if oldPassword == newPassword {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "新密码不能与旧密码相同"})
|
||||
return
|
||||
}
|
||||
|
||||
if !h.manager.CheckPassword(oldPassword) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "当前密码不正确"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := config.PersistAuthPassword(h.configPath, newPassword); err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Error("保存新密码失败", zap.Error(err))
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存新密码失败,请重试"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.manager.UpdateConfig(newPassword, h.config.Auth.SessionDurationHours); err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Error("更新认证配置失败", zap.Error(err))
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新认证配置失败"})
|
||||
return
|
||||
}
|
||||
|
||||
h.config.Auth.Password = newPassword
|
||||
h.config.Auth.GeneratedPassword = ""
|
||||
h.config.Auth.GeneratedPasswordPersisted = false
|
||||
h.config.Auth.GeneratedPasswordPersistErr = ""
|
||||
|
||||
if h.logger != nil {
|
||||
h.logger.Info("登录密码已更新,所有会话已失效")
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
|
||||
}
|
||||
|
||||
// Validate returns the current session status.
|
||||
func (h *AuthHandler) Validate(c *gin.Context) {
|
||||
token := c.GetString(security.ContextAuthTokenKey)
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "会话无效"})
|
||||
return
|
||||
}
|
||||
|
||||
session, ok := h.manager.ValidateToken(token)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "会话已过期"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": session.Token,
|
||||
"expires_at": session.ExpiresAt.UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Predefined errors for authentication operations.
|
||||
var (
|
||||
ErrInvalidPassword = errors.New("invalid password")
|
||||
)
|
||||
|
||||
// Session represents an authenticated user session.
|
||||
type Session struct {
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// AuthManager manages password-based authentication and session lifecycle.
|
||||
type AuthManager struct {
|
||||
password string
|
||||
sessionDuration time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
sessions map[string]Session
|
||||
}
|
||||
|
||||
// NewAuthManager creates a new AuthManager instance.
|
||||
func NewAuthManager(password string, sessionDurationHours int) (*AuthManager, error) {
|
||||
if strings.TrimSpace(password) == "" {
|
||||
return nil, errors.New("auth password must be configured")
|
||||
}
|
||||
|
||||
if sessionDurationHours <= 0 {
|
||||
sessionDurationHours = 12
|
||||
}
|
||||
|
||||
return &AuthManager{
|
||||
password: password,
|
||||
sessionDuration: time.Duration(sessionDurationHours) * time.Hour,
|
||||
sessions: make(map[string]Session),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Authenticate validates the password and creates a new session.
|
||||
func (a *AuthManager) Authenticate(password string) (string, time.Time, error) {
|
||||
if password != a.password {
|
||||
return "", time.Time{}, ErrInvalidPassword
|
||||
}
|
||||
|
||||
token := uuid.NewString()
|
||||
expiresAt := time.Now().Add(a.sessionDuration)
|
||||
|
||||
a.mu.Lock()
|
||||
a.sessions[token] = Session{
|
||||
Token: token,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
return token, expiresAt, nil
|
||||
}
|
||||
|
||||
// ValidateToken checks whether the provided token is still valid.
|
||||
func (a *AuthManager) ValidateToken(token string) (Session, bool) {
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return Session{}, false
|
||||
}
|
||||
|
||||
a.mu.RLock()
|
||||
session, ok := a.sessions[token]
|
||||
a.mu.RUnlock()
|
||||
if !ok {
|
||||
return Session{}, false
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
a.mu.Lock()
|
||||
delete(a.sessions, token)
|
||||
a.mu.Unlock()
|
||||
return Session{}, false
|
||||
}
|
||||
|
||||
return session, true
|
||||
}
|
||||
|
||||
// CheckPassword verifies whether the provided password matches the current password.
|
||||
func (a *AuthManager) CheckPassword(password string) bool {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
return password == a.password
|
||||
}
|
||||
|
||||
// RevokeToken invalidates the specified token.
|
||||
func (a *AuthManager) RevokeToken(token string) {
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
delete(a.sessions, token)
|
||||
a.mu.Unlock()
|
||||
}
|
||||
|
||||
// SessionDurationHours returns the configured session duration in hours.
|
||||
func (a *AuthManager) SessionDurationHours() int {
|
||||
return int(a.sessionDuration / time.Hour)
|
||||
}
|
||||
|
||||
// UpdateConfig updates the password and session duration, revoking existing sessions.
|
||||
func (a *AuthManager) UpdateConfig(password string, sessionDurationHours int) error {
|
||||
password = strings.TrimSpace(password)
|
||||
if password == "" {
|
||||
return errors.New("auth password must be configured")
|
||||
}
|
||||
|
||||
if sessionDurationHours <= 0 {
|
||||
sessionDurationHours = 12
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
a.password = password
|
||||
a.sessionDuration = time.Duration(sessionDurationHours) * time.Hour
|
||||
a.sessions = make(map[string]Session)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
ContextAuthTokenKey = "authToken"
|
||||
ContextSessionExpiry = "authSessionExpiry"
|
||||
)
|
||||
|
||||
// AuthMiddleware enforces authentication on protected routes.
|
||||
func AuthMiddleware(manager *AuthManager) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := extractTokenFromRequest(c)
|
||||
session, ok := manager.ValidateToken(token)
|
||||
if !ok {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "未授权访问,请先登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(ContextAuthTokenKey, session.Token)
|
||||
c.Set(ContextSessionExpiry, session.ExpiresAt)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func extractTokenFromRequest(c *gin.Context) string {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
if len(authHeader) > 7 && strings.EqualFold(authHeader[0:7], "Bearer ") {
|
||||
return strings.TrimSpace(authHeader[7:])
|
||||
}
|
||||
return strings.TrimSpace(authHeader)
|
||||
}
|
||||
|
||||
if token := c.Query("token"); token != "" {
|
||||
return strings.TrimSpace(token)
|
||||
}
|
||||
|
||||
if cookie, err := c.Cookie("auth_token"); err == nil {
|
||||
return strings.TrimSpace(cookie)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
Reference in New Issue
Block a user