Add files via upload

This commit is contained in:
公明
2026-06-18 12:40:54 +08:00
committed by GitHub
parent 56faefaaf9
commit d5a0f93c6c
94 changed files with 24645 additions and 0 deletions
+132
View File
@@ -0,0 +1,132 @@
package security
import (
"errors"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
// Predefined errors for authentication operations.
var (
ErrInvalidPassword = errors.New("invalid password")
)
// Session represents an authenticated user session.
type Session struct {
Token string
ExpiresAt time.Time
}
// AuthManager manages password-based authentication and session lifecycle.
type AuthManager struct {
password string
sessionDuration time.Duration
mu sync.RWMutex
sessions map[string]Session
}
// NewAuthManager creates a new AuthManager instance.
func NewAuthManager(password string, sessionDurationHours int) (*AuthManager, error) {
if strings.TrimSpace(password) == "" {
return nil, errors.New("auth password must be configured")
}
if sessionDurationHours <= 0 {
sessionDurationHours = 12
}
return &AuthManager{
password: password,
sessionDuration: time.Duration(sessionDurationHours) * time.Hour,
sessions: make(map[string]Session),
}, nil
}
// Authenticate validates the password and creates a new session.
func (a *AuthManager) Authenticate(password string) (string, time.Time, error) {
if password != a.password {
return "", time.Time{}, ErrInvalidPassword
}
token := uuid.NewString()
expiresAt := time.Now().Add(a.sessionDuration)
a.mu.Lock()
a.sessions[token] = Session{
Token: token,
ExpiresAt: expiresAt,
}
a.mu.Unlock()
return token, expiresAt, nil
}
// ValidateToken checks whether the provided token is still valid.
func (a *AuthManager) ValidateToken(token string) (Session, bool) {
if strings.TrimSpace(token) == "" {
return Session{}, false
}
a.mu.RLock()
session, ok := a.sessions[token]
a.mu.RUnlock()
if !ok {
return Session{}, false
}
if time.Now().After(session.ExpiresAt) {
a.mu.Lock()
delete(a.sessions, token)
a.mu.Unlock()
return Session{}, false
}
return session, true
}
// CheckPassword verifies whether the provided password matches the current password.
func (a *AuthManager) CheckPassword(password string) bool {
a.mu.RLock()
defer a.mu.RUnlock()
return password == a.password
}
// RevokeToken invalidates the specified token.
func (a *AuthManager) RevokeToken(token string) {
if strings.TrimSpace(token) == "" {
return
}
a.mu.Lock()
delete(a.sessions, token)
a.mu.Unlock()
}
// SessionDurationHours returns the configured session duration in hours.
func (a *AuthManager) SessionDurationHours() int {
return int(a.sessionDuration / time.Hour)
}
// UpdateConfig updates the password and session duration, revoking existing sessions.
func (a *AuthManager) UpdateConfig(password string, sessionDurationHours int) error {
password = strings.TrimSpace(password)
if password == "" {
return errors.New("auth password must be configured")
}
if sessionDurationHours <= 0 {
sessionDurationHours = 12
}
a.mu.Lock()
defer a.mu.Unlock()
a.password = password
a.sessionDuration = time.Duration(sessionDurationHours) * time.Hour
a.sessions = make(map[string]Session)
return nil
}
+51
View File
@@ -0,0 +1,51 @@
package security
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
const (
ContextAuthTokenKey = "authToken"
ContextSessionExpiry = "authSessionExpiry"
)
// AuthMiddleware enforces authentication on protected routes.
func AuthMiddleware(manager *AuthManager) gin.HandlerFunc {
return func(c *gin.Context) {
token := extractTokenFromRequest(c)
session, ok := manager.ValidateToken(token)
if !ok {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "未授权访问,请先登录",
})
return
}
c.Set(ContextAuthTokenKey, session.Token)
c.Set(ContextSessionExpiry, session.ExpiresAt)
c.Next()
}
}
func extractTokenFromRequest(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
if len(authHeader) > 7 && strings.EqualFold(authHeader[0:7], "Bearer ") {
return strings.TrimSpace(authHeader[7:])
}
return strings.TrimSpace(authHeader)
}
if token := c.Query("token"); token != "" {
return strings.TrimSpace(token)
}
if cookie, err := c.Cookie("auth_token"); err == nil {
return strings.TrimSpace(cookie)
}
return ""
}
File diff suppressed because it is too large Load Diff
+128
View File
@@ -0,0 +1,128 @@
package security
import (
"context"
"strings"
"testing"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
// setupTestExecutor 创建测试用的执行器
func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
cfg := &config.SecurityConfig{
Tools: []config.ToolConfig{},
}
executor := NewExecutor(cfg, mcpServer, logger)
return executor, mcpServer
}
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
executor, _ := setupTestExecutor(t)
ctx := context.Background()
args := map[string]interface{}{
"test": "value",
}
// 测试未知的内部工具类型
toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args)
if err != nil {
t.Fatalf("执行内部工具失败: %v", err)
}
if !toolResult.IsError {
t.Fatal("未知的工具类型应该返回错误")
}
if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") {
t.Errorf("错误消息应该包含'未知的内部工具类型'")
}
}
func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T) {
executor, _ := setupTestExecutor(t)
// 子进程先向 stdout 写无换行字符再长时间 sleep;若与 echo $pid 共享管道且未重定向子进程 stdout,
// ReadString('\n') 会阻塞到子进程退出。后台包装须将子进程标准流与 PID 行分离。
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel()
args := map[string]interface{}{
"command": `(sh -c 'printf x; sleep 120') &`,
"shell": "sh",
}
res, err := executor.executeSystemCommand(ctx, args)
if err != nil {
t.Fatalf("executeSystemCommand: %v", err)
}
if res == nil || res.IsError {
t.Fatalf("expected success, got %+v", res)
}
txt := res.Content[0].Text
if !strings.Contains(txt, "后台命令已启动") {
t.Fatalf("unexpected body: %q", txt)
}
}
func TestBuildCommandArgs_NmapSkipsEmptyOptionalFlags(t *testing.T) {
pos1 := 1
executor, _ := setupTestExecutor(t)
toolConfig := &config.ToolConfig{
Name: "nmap",
Command: "nmap",
Args: []string{"-sT", "-sV", "-sC"},
Parameters: []config.ParameterConfig{
{Name: "target", Type: "string", Required: true, Position: &pos1, Format: "positional"},
{Name: "ports", Type: "string", Flag: "-p", Format: "flag"},
{Name: "timing", Type: "string", Template: "-T{value}", Format: "template"},
{Name: "nse_scripts", Type: "string", Flag: "--script", Format: "flag"},
{Name: "os_detection", Type: "bool", Flag: "-O", Format: "flag", Default: false},
{Name: "aggressive", Type: "bool", Flag: "-A", Format: "flag", Default: false},
{Name: "scan_type", Type: "string", Format: "template", Template: "{value}"},
{Name: "additional_args", Type: "string", Format: "positional"},
},
}
args := map[string]interface{}{
"target": "110.52.223.114",
"ports": "21, 22, 80, 443",
"timing": "4",
"nse_scripts": "",
"scan_type": "",
"os_detection": false,
"aggressive": false,
"additional_args": "-Pn",
}
cmdArgs := executor.buildCommandArgs("nmap", toolConfig, args)
joined := strings.Join(cmdArgs, " ")
if strings.Contains(joined, "--script") {
t.Fatalf("empty nse_scripts must not emit --script, got: %v", cmdArgs)
}
if !strings.Contains(joined, "110.52.223.114") {
t.Fatalf("target missing from args: %v", cmdArgs)
}
// target 应出现在 -Pn 之前,避免被误当作 --script 的参数
pnIdx := indexOf(cmdArgs, "-Pn")
targetIdx := indexOf(cmdArgs, "110.52.223.114")
if pnIdx < 0 || targetIdx < 0 || targetIdx >= pnIdx {
t.Fatalf("expected target before -Pn, got: %v", cmdArgs)
}
}
func indexOf(slice []string, s string) int {
for i, v := range slice {
if v == s {
return i
}
}
return -1
}
+31
View File
@@ -0,0 +1,31 @@
//go:build !windows
package security
import (
"os/exec"
"syscall"
)
// prepareShellCmdSession 让 shell 子进程在独立会话中运行,便于超时/取消时整组 SIGKILL(含子进程)。
func prepareShellCmdSession(cmd *exec.Cmd) error {
if cmd == nil {
return nil
}
if cmd.SysProcAttr == nil {
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
cmd.SysProcAttr.Setsid = true
return nil
}
// terminateCmdTree 尽力终止 cmd 及其进程组(Unix 下 Setsid 后 PGID == 首进程 PID)。
func terminateCmdTree(cmd *exec.Cmd) {
if cmd == nil || cmd.Process == nil {
return
}
pid := cmd.Process.Pid
if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil {
_ = cmd.Process.Kill()
}
}
+17
View File
@@ -0,0 +1,17 @@
//go:build windows
package security
import "os/exec"
func prepareShellCmdSession(cmd *exec.Cmd) error {
_ = cmd
return nil
}
func terminateCmdTree(cmd *exec.Cmd) {
if cmd == nil || cmd.Process == nil {
return
}
_ = cmd.Process.Kill()
}
+81
View File
@@ -0,0 +1,81 @@
package security
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// rateLimitEntry 记录某个 IP 的请求窗口信息
type rateLimitEntry struct {
count int
windowAt time.Time
}
// RateLimiter 基于 IP 的滑动窗口速率限制器
type RateLimiter struct {
mu sync.Mutex
entries map[string]*rateLimitEntry
limit int // 窗口内允许的最大请求数
window time.Duration // 窗口时长
}
// NewRateLimiter 创建速率限制器
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
entries: make(map[string]*rateLimitEntry),
limit: limit,
window: window,
}
// 后台定期清理过期条目,防止内存泄漏
go rl.cleanup()
return rl
}
// cleanup 每分钟清理一次过期条目
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
now := time.Now()
for ip, entry := range rl.entries {
if now.Sub(entry.windowAt) > rl.window {
delete(rl.entries, ip)
}
}
rl.mu.Unlock()
}
}
// allow 检查指定 IP 是否允许通过
func (rl *RateLimiter) allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
entry, ok := rl.entries[ip]
if !ok || now.Sub(entry.windowAt) > rl.window {
rl.entries[ip] = &rateLimitEntry{count: 1, windowAt: now}
return true
}
entry.count++
return entry.count <= rl.limit
}
// RateLimitMiddleware 返回 Gin 中间件,对超限请求返回 429
func RateLimitMiddleware(rl *RateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
if !rl.allow(ip) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded, please try again later",
})
return
}
c.Next()
}
}