Add files via upload

This commit is contained in:
公明
2025-11-13 23:41:00 +08:00
committed by GitHub
parent b63fd24b18
commit 6f0044b6fd
11 changed files with 1163 additions and 97 deletions
+69 -23
View File
@@ -27,6 +27,7 @@ type App struct {
agent *agent.Agent
executor *security.Executor
db *database.DB
auth *security.AuthManager
}
// New 创建新应用
@@ -37,6 +38,12 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
// CORS中间件
router.Use(corsMiddleware())
// 认证管理器
authManager, err := security.NewAuthManager(cfg.Auth.Password, cfg.Auth.SessionDurationHours)
if err != nil {
return nil, fmt.Errorf("初始化认证失败: %w", err)
}
// 初始化数据库
dbPath := cfg.Database.Path
if dbPath == "" {
@@ -62,6 +69,13 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
// 注册工具
executor.RegisterTools(mcpServer)
if cfg.Auth.GeneratedPassword != "" {
config.PrintGeneratedPasswordWarning(cfg.Auth.GeneratedPassword, cfg.Auth.GeneratedPasswordPersisted, cfg.Auth.GeneratedPasswordPersistErr)
cfg.Auth.GeneratedPassword = ""
cfg.Auth.GeneratedPasswordPersisted = false
cfg.Auth.GeneratedPasswordPersistErr = ""
}
// 创建Agent
maxIterations := cfg.Agent.MaxIterations
if maxIterations <= 0 {
@@ -69,20 +83,30 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
}
agent := agent.NewAgent(&cfg.OpenAI, mcpServer, log.Logger, maxIterations)
// 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, log.Logger)
conversationHandler := handler.NewConversationHandler(db, log.Logger)
// 获取配置文件路径
configPath := "config.yaml"
if len(os.Args) > 1 {
configPath = os.Args[1]
}
// 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, log.Logger)
conversationHandler := handler.NewConversationHandler(db, log.Logger)
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, log.Logger)
// 设置路由
setupRoutes(router, agentHandler, monitorHandler, conversationHandler, configHandler, mcpServer)
setupRoutes(
router,
authHandler,
agentHandler,
monitorHandler,
conversationHandler,
configHandler,
mcpServer,
authManager,
)
return &App{
config: cfg,
@@ -92,6 +116,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
agent: agent,
executor: executor,
db: db,
auth: authManager,
}, nil
}
@@ -120,37 +145,58 @@ func (a *App) Run() error {
}
// setupRoutes 设置路由
func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitorHandler *handler.MonitorHandler, conversationHandler *handler.ConversationHandler, configHandler *handler.ConfigHandler, mcpServer *mcp.Server) {
func setupRoutes(
router *gin.Engine,
authHandler *handler.AuthHandler,
agentHandler *handler.AgentHandler,
monitorHandler *handler.MonitorHandler,
conversationHandler *handler.ConversationHandler,
configHandler *handler.ConfigHandler,
mcpServer *mcp.Server,
authManager *security.AuthManager,
) {
// API路由
api := router.Group("/api")
// 认证相关路由
authRoutes := api.Group("/auth")
{
authRoutes.POST("/login", authHandler.Login)
authRoutes.POST("/logout", security.AuthMiddleware(authManager), authHandler.Logout)
authRoutes.POST("/change-password", security.AuthMiddleware(authManager), authHandler.ChangePassword)
authRoutes.GET("/validate", security.AuthMiddleware(authManager), authHandler.Validate)
}
protected := api.Group("")
protected.Use(security.AuthMiddleware(authManager))
{
// Agent Loop
api.POST("/agent-loop", agentHandler.AgentLoop)
protected.POST("/agent-loop", agentHandler.AgentLoop)
// Agent Loop 流式输出
api.POST("/agent-loop/stream", agentHandler.AgentLoopStream)
protected.POST("/agent-loop/stream", agentHandler.AgentLoopStream)
// Agent Loop 取消与任务列表
api.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
api.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
protected.POST("/agent-loop/cancel", agentHandler.CancelAgentLoop)
protected.GET("/agent-loop/tasks", agentHandler.ListAgentTasks)
// 对话历史
api.POST("/conversations", conversationHandler.CreateConversation)
api.GET("/conversations", conversationHandler.ListConversations)
api.GET("/conversations/:id", conversationHandler.GetConversation)
api.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
protected.POST("/conversations", conversationHandler.CreateConversation)
protected.GET("/conversations", conversationHandler.ListConversations)
protected.GET("/conversations/:id", conversationHandler.GetConversation)
protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
// 监控
api.GET("/monitor", monitorHandler.Monitor)
api.GET("/monitor/execution/:id", monitorHandler.GetExecution)
api.GET("/monitor/stats", monitorHandler.GetStats)
api.GET("/monitor/vulnerabilities", monitorHandler.GetVulnerabilities)
protected.GET("/monitor", monitorHandler.Monitor)
protected.GET("/monitor/execution/:id", monitorHandler.GetExecution)
protected.GET("/monitor/stats", monitorHandler.GetStats)
protected.GET("/monitor/vulnerabilities", monitorHandler.GetVulnerabilities)
// 配置管理
api.GET("/config", configHandler.GetConfig)
api.PUT("/config", configHandler.UpdateConfig)
api.POST("/config/apply", configHandler.ApplyConfig)
protected.GET("/config", configHandler.GetConfig)
protected.PUT("/config", configHandler.UpdateConfig)
protected.POST("/config/apply", configHandler.ApplyConfig)
// MCP端点
api.POST("/mcp", func(c *gin.Context) {
protected.POST("/mcp", func(c *gin.Context) {
mcpServer.HandleHTTP(c.Writer, c.Request)
})
}
+167 -31
View File
@@ -1,6 +1,8 @@
package config
import (
"crypto/rand"
"encoding/base64"
"fmt"
"os"
"path/filepath"
@@ -17,6 +19,7 @@ type Config struct {
Agent AgentConfig `yaml:"agent"`
Security SecurityConfig `yaml:"security"`
Database DatabaseConfig `yaml:"database"`
Auth AuthConfig `yaml:"auth"`
}
type ServerConfig struct {
@@ -42,8 +45,8 @@ type OpenAIConfig struct {
}
type SecurityConfig struct {
Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具
ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式)
Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具
ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式)
}
type DatabaseConfig struct {
@@ -54,29 +57,36 @@ type AgentConfig struct {
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
}
type AuthConfig struct {
Password string `yaml:"password" json:"password"`
SessionDurationHours int `yaml:"session_duration_hours" json:"session_duration_hours"`
GeneratedPassword string `yaml:"-" json:"-"`
GeneratedPasswordPersisted bool `yaml:"-" json:"-"`
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
}
type ToolConfig struct {
Name string `yaml:"name"`
Command string `yaml:"command"`
Args []string `yaml:"args,omitempty"` // 固定参数(可选)
Args []string `yaml:"args,omitempty"` // 固定参数(可选)
ShortDescription string `yaml:"short_description,omitempty"` // 简短描述(用于工具列表,减少token消耗)
Description string `yaml:"description"` // 详细描述(用于工具文档)
Description string `yaml:"description"` // 详细描述(用于工具文档)
Enabled bool `yaml:"enabled"`
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
}
// ParameterConfig 参数配置
type ParameterConfig struct {
Name string `yaml:"name"` // 参数名称
Type string `yaml:"type"` // 参数类型: string, int, bool, array
Description string `yaml:"description"` // 参数描述
Required bool `yaml:"required,omitempty"` // 是否必需
Default interface{} `yaml:"default,omitempty"` // 默认值
Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p"
Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始)
Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template"
Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}"
Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举)
Name string `yaml:"name"` // 参数名称
Type string `yaml:"type"` // 参数类型: string, int, bool, array
Description string `yaml:"description"` // 参数描述
Required bool `yaml:"required,omitempty"` // 是否必需
Default interface{} `yaml:"default,omitempty"` // 默认值
Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p"
Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始)
Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template"
Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}"
Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举)
}
func Load(path string) (*Config, error) {
@@ -90,65 +100,189 @@ func Load(path string) (*Config, error) {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
if cfg.Auth.SessionDurationHours <= 0 {
cfg.Auth.SessionDurationHours = 12
}
if strings.TrimSpace(cfg.Auth.Password) == "" {
password, err := generateStrongPassword(24)
if err != nil {
return nil, fmt.Errorf("生成默认密码失败: %w", err)
}
cfg.Auth.Password = password
cfg.Auth.GeneratedPassword = password
if err := PersistAuthPassword(path, password); err != nil {
cfg.Auth.GeneratedPasswordPersisted = false
cfg.Auth.GeneratedPasswordPersistErr = err.Error()
} else {
cfg.Auth.GeneratedPasswordPersisted = true
}
}
// 如果配置了工具目录,从目录加载工具配置
if cfg.Security.ToolsDir != "" {
configDir := filepath.Dir(path)
toolsDir := cfg.Security.ToolsDir
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(toolsDir) {
toolsDir = filepath.Join(configDir, toolsDir)
}
tools, err := LoadToolsFromDir(toolsDir)
if err != nil {
return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err)
}
// 合并工具配置:目录中的工具优先,主配置中的工具作为补充
existingTools := make(map[string]bool)
for _, tool := range tools {
existingTools[tool.Name] = true
}
// 添加主配置中不存在于目录中的工具(向后兼容)
for _, tool := range cfg.Security.Tools {
if !existingTools[tool.Name] {
tools = append(tools, tool)
}
}
cfg.Security.Tools = tools
}
return &cfg, nil
}
func generateStrongPassword(length int) (string, error) {
if length <= 0 {
length = 24
}
bytesLen := length
randomBytes := make([]byte, bytesLen)
if _, err := rand.Read(randomBytes); err != nil {
return "", err
}
password := base64.RawURLEncoding.EncodeToString(randomBytes)
if len(password) > length {
password = password[:length]
}
return password, nil
}
func PersistAuthPassword(path, password string) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
lines := strings.Split(string(data), "\n")
inAuthBlock := false
authIndent := -1
for i, line := range lines {
trimmed := strings.TrimSpace(line)
if !inAuthBlock {
if strings.HasPrefix(trimmed, "auth:") {
inAuthBlock = true
authIndent = len(line) - len(strings.TrimLeft(line, " "))
}
continue
}
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
continue
}
leadingSpaces := len(line) - len(strings.TrimLeft(line, " "))
if leadingSpaces <= authIndent {
// 离开 auth 块
inAuthBlock = false
authIndent = -1
// 继续寻找其它 auth 块(理论上没有)
if strings.HasPrefix(trimmed, "auth:") {
inAuthBlock = true
authIndent = leadingSpaces
}
continue
}
if strings.HasPrefix(strings.TrimSpace(line), "password:") {
prefix := line[:len(line)-len(strings.TrimLeft(line, " "))]
comment := ""
if idx := strings.Index(line, "#"); idx >= 0 {
comment = strings.TrimRight(line[idx:], " ")
}
newLine := fmt.Sprintf("%spassword: %s", prefix, password)
if comment != "" {
if !strings.HasPrefix(comment, " ") {
newLine += " "
}
newLine += comment
}
lines[i] = newLine
break
}
}
return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644)
}
func PrintGeneratedPasswordWarning(password string, persisted bool, persistErr string) {
if strings.TrimSpace(password) == "" {
return
}
if persisted {
fmt.Println("[CyberStrikeAI] ✅ 已为您自动生成并写入 Web 登录密码。")
} else {
if persistErr != "" {
fmt.Printf("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码: %s\n", persistErr)
} else {
fmt.Println("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码。")
}
fmt.Println("请手动将以下随机密码写入 config.yaml 的 auth.password")
}
fmt.Println("----------------------------------------------------------------")
fmt.Println("CyberStrikeAI Auto-Generated Web Password")
fmt.Printf("Password: %s\n", password)
fmt.Println("WARNING: Anyone with this password can fully control CyberStrikeAI.")
fmt.Println("Please store it securely and change it in config.yaml as soon as possible.")
fmt.Println("警告:持有此密码的人将拥有对 CyberStrikeAI 的完全控制权限。")
fmt.Println("请妥善保管,并尽快在 config.yaml 中修改 auth.password")
fmt.Println("----------------------------------------------------------------")
}
// LoadToolsFromDir 从目录加载所有工具配置文件
func LoadToolsFromDir(dir string) ([]ToolConfig, error) {
var tools []ToolConfig
// 检查目录是否存在
if _, err := os.Stat(dir); os.IsNotExist(err) {
return tools, nil // 目录不存在时返回空列表,不报错
}
// 读取目录中的所有 .yaml 和 .yml 文件
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("读取工具目录失败: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") {
continue
}
filePath := filepath.Join(dir, name)
tool, err := LoadToolFromFile(filePath)
if err != nil {
@@ -156,10 +290,10 @@ func LoadToolsFromDir(dir string) ([]ToolConfig, error) {
fmt.Printf("警告: 加载工具配置文件 %s 失败: %v\n", filePath, err)
continue
}
tools = append(tools, *tool)
}
return tools, nil
}
@@ -169,12 +303,12 @@ func LoadToolFromFile(path string) (*ToolConfig, error) {
if err != nil {
return nil, fmt.Errorf("读取文件失败: %w", err)
}
var tool ToolConfig
if err := yaml.Unmarshal(data, &tool); err != nil {
return nil, fmt.Errorf("解析工具配置失败: %w", err)
}
// 验证必需字段
if tool.Name == "" {
return nil, fmt.Errorf("工具名称不能为空")
@@ -182,7 +316,7 @@ func LoadToolFromFile(path string) (*ToolConfig, error) {
if tool.Command == "" {
return nil, fmt.Errorf("工具命令不能为空")
}
return &tool, nil
}
@@ -215,6 +349,8 @@ func Default() *Config {
Database: DatabaseConfig{
Path: "data/conversations.db",
},
Auth: AuthConfig{
SessionDurationHours: 12,
},
}
}
+156
View File
@@ -0,0 +1,156 @@
package handler
import (
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AuthHandler handles authentication-related endpoints.
type AuthHandler struct {
manager *security.AuthManager
config *config.Config
configPath string
logger *zap.Logger
}
// NewAuthHandler creates a new AuthHandler.
func NewAuthHandler(manager *security.AuthManager, cfg *config.Config, configPath string, logger *zap.Logger) *AuthHandler {
return &AuthHandler{
manager: manager,
config: cfg,
configPath: configPath,
logger: logger,
}
}
type loginRequest struct {
Password string `json:"password" binding:"required"`
}
type changePasswordRequest struct {
OldPassword string `json:"oldPassword"`
NewPassword string `json:"newPassword"`
}
// Login verifies password and returns a session token.
func (h *AuthHandler) Login(c *gin.Context) {
var req loginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "密码不能为空"})
return
}
token, expiresAt, err := h.manager.Authenticate(req.Password)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
return
}
c.JSON(http.StatusOK, gin.H{
"token": token,
"expires_at": expiresAt.UTC().Format(time.RFC3339),
"session_duration_hr": h.manager.SessionDurationHours(),
})
}
// Logout revokes the current session token.
func (h *AuthHandler) Logout(c *gin.Context) {
token := c.GetString(security.ContextAuthTokenKey)
if token == "" {
authHeader := c.GetHeader("Authorization")
if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") {
token = strings.TrimSpace(authHeader[7:])
} else {
token = strings.TrimSpace(authHeader)
}
}
h.manager.RevokeToken(token)
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
}
// ChangePassword updates the login password.
func (h *AuthHandler) ChangePassword(c *gin.Context) {
var req changePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "参数无效"})
return
}
oldPassword := strings.TrimSpace(req.OldPassword)
newPassword := strings.TrimSpace(req.NewPassword)
if oldPassword == "" || newPassword == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码和新密码均不能为空"})
return
}
if len(newPassword) < 8 {
c.JSON(http.StatusBadRequest, gin.H{"error": "新密码长度至少需要 8 位"})
return
}
if oldPassword == newPassword {
c.JSON(http.StatusBadRequest, gin.H{"error": "新密码不能与旧密码相同"})
return
}
if !h.manager.CheckPassword(oldPassword) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "当前密码不正确"})
return
}
if err := config.PersistAuthPassword(h.configPath, newPassword); err != nil {
if h.logger != nil {
h.logger.Error("保存新密码失败", zap.Error(err))
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存新密码失败,请重试"})
return
}
if err := h.manager.UpdateConfig(newPassword, h.config.Auth.SessionDurationHours); err != nil {
if h.logger != nil {
h.logger.Error("更新认证配置失败", zap.Error(err))
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新认证配置失败"})
return
}
h.config.Auth.Password = newPassword
h.config.Auth.GeneratedPassword = ""
h.config.Auth.GeneratedPasswordPersisted = false
h.config.Auth.GeneratedPasswordPersistErr = ""
if h.logger != nil {
h.logger.Info("登录密码已更新,所有会话已失效")
}
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
}
// Validate returns the current session status.
func (h *AuthHandler) Validate(c *gin.Context) {
token := c.GetString(security.ContextAuthTokenKey)
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "会话无效"})
return
}
session, ok := h.manager.ValidateToken(token)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "会话已过期"})
return
}
c.JSON(http.StatusOK, gin.H{
"token": session.Token,
"expires_at": session.ExpiresAt.UTC().Format(time.RFC3339),
})
}
+132
View File
@@ -0,0 +1,132 @@
package security
import (
"errors"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
// Predefined errors for authentication operations.
var (
ErrInvalidPassword = errors.New("invalid password")
)
// Session represents an authenticated user session.
type Session struct {
Token string
ExpiresAt time.Time
}
// AuthManager manages password-based authentication and session lifecycle.
type AuthManager struct {
password string
sessionDuration time.Duration
mu sync.RWMutex
sessions map[string]Session
}
// NewAuthManager creates a new AuthManager instance.
func NewAuthManager(password string, sessionDurationHours int) (*AuthManager, error) {
if strings.TrimSpace(password) == "" {
return nil, errors.New("auth password must be configured")
}
if sessionDurationHours <= 0 {
sessionDurationHours = 12
}
return &AuthManager{
password: password,
sessionDuration: time.Duration(sessionDurationHours) * time.Hour,
sessions: make(map[string]Session),
}, nil
}
// Authenticate validates the password and creates a new session.
func (a *AuthManager) Authenticate(password string) (string, time.Time, error) {
if password != a.password {
return "", time.Time{}, ErrInvalidPassword
}
token := uuid.NewString()
expiresAt := time.Now().Add(a.sessionDuration)
a.mu.Lock()
a.sessions[token] = Session{
Token: token,
ExpiresAt: expiresAt,
}
a.mu.Unlock()
return token, expiresAt, nil
}
// ValidateToken checks whether the provided token is still valid.
func (a *AuthManager) ValidateToken(token string) (Session, bool) {
if strings.TrimSpace(token) == "" {
return Session{}, false
}
a.mu.RLock()
session, ok := a.sessions[token]
a.mu.RUnlock()
if !ok {
return Session{}, false
}
if time.Now().After(session.ExpiresAt) {
a.mu.Lock()
delete(a.sessions, token)
a.mu.Unlock()
return Session{}, false
}
return session, true
}
// CheckPassword verifies whether the provided password matches the current password.
func (a *AuthManager) CheckPassword(password string) bool {
a.mu.RLock()
defer a.mu.RUnlock()
return password == a.password
}
// RevokeToken invalidates the specified token.
func (a *AuthManager) RevokeToken(token string) {
if strings.TrimSpace(token) == "" {
return
}
a.mu.Lock()
delete(a.sessions, token)
a.mu.Unlock()
}
// SessionDurationHours returns the configured session duration in hours.
func (a *AuthManager) SessionDurationHours() int {
return int(a.sessionDuration / time.Hour)
}
// UpdateConfig updates the password and session duration, revoking existing sessions.
func (a *AuthManager) UpdateConfig(password string, sessionDurationHours int) error {
password = strings.TrimSpace(password)
if password == "" {
return errors.New("auth password must be configured")
}
if sessionDurationHours <= 0 {
sessionDurationHours = 12
}
a.mu.Lock()
defer a.mu.Unlock()
a.password = password
a.sessionDuration = time.Duration(sessionDurationHours) * time.Hour
a.sessions = make(map[string]Session)
return nil
}
+51
View File
@@ -0,0 +1,51 @@
package security
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
const (
ContextAuthTokenKey = "authToken"
ContextSessionExpiry = "authSessionExpiry"
)
// AuthMiddleware enforces authentication on protected routes.
func AuthMiddleware(manager *AuthManager) gin.HandlerFunc {
return func(c *gin.Context) {
token := extractTokenFromRequest(c)
session, ok := manager.ValidateToken(token)
if !ok {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "未授权访问,请先登录",
})
return
}
c.Set(ContextAuthTokenKey, session.Token)
c.Set(ContextSessionExpiry, session.ExpiresAt)
c.Next()
}
}
func extractTokenFromRequest(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
if len(authHeader) > 7 && strings.EqualFold(authHeader[0:7], "Bearer ") {
return strings.TrimSpace(authHeader[7:])
}
return strings.TrimSpace(authHeader)
}
if token := c.Query("token"); token != "" {
return strings.TrimSpace(token)
}
if cookie, err := c.Cookie("auth_token"); err == nil {
return strings.TrimSpace(cookie)
}
return ""
}