mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-05 12:07:52 +02:00
Add files via upload
This commit is contained in:
@@ -0,0 +1,132 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Predefined errors for authentication operations.
|
||||
var (
|
||||
ErrInvalidPassword = errors.New("invalid password")
|
||||
)
|
||||
|
||||
// Session represents an authenticated user session.
|
||||
type Session struct {
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// AuthManager manages password-based authentication and session lifecycle.
|
||||
type AuthManager struct {
|
||||
password string
|
||||
sessionDuration time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
sessions map[string]Session
|
||||
}
|
||||
|
||||
// NewAuthManager creates a new AuthManager instance.
|
||||
func NewAuthManager(password string, sessionDurationHours int) (*AuthManager, error) {
|
||||
if strings.TrimSpace(password) == "" {
|
||||
return nil, errors.New("auth password must be configured")
|
||||
}
|
||||
|
||||
if sessionDurationHours <= 0 {
|
||||
sessionDurationHours = 12
|
||||
}
|
||||
|
||||
return &AuthManager{
|
||||
password: password,
|
||||
sessionDuration: time.Duration(sessionDurationHours) * time.Hour,
|
||||
sessions: make(map[string]Session),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Authenticate validates the password and creates a new session.
|
||||
func (a *AuthManager) Authenticate(password string) (string, time.Time, error) {
|
||||
if password != a.password {
|
||||
return "", time.Time{}, ErrInvalidPassword
|
||||
}
|
||||
|
||||
token := uuid.NewString()
|
||||
expiresAt := time.Now().Add(a.sessionDuration)
|
||||
|
||||
a.mu.Lock()
|
||||
a.sessions[token] = Session{
|
||||
Token: token,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
return token, expiresAt, nil
|
||||
}
|
||||
|
||||
// ValidateToken checks whether the provided token is still valid.
|
||||
func (a *AuthManager) ValidateToken(token string) (Session, bool) {
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return Session{}, false
|
||||
}
|
||||
|
||||
a.mu.RLock()
|
||||
session, ok := a.sessions[token]
|
||||
a.mu.RUnlock()
|
||||
if !ok {
|
||||
return Session{}, false
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
a.mu.Lock()
|
||||
delete(a.sessions, token)
|
||||
a.mu.Unlock()
|
||||
return Session{}, false
|
||||
}
|
||||
|
||||
return session, true
|
||||
}
|
||||
|
||||
// CheckPassword verifies whether the provided password matches the current password.
|
||||
func (a *AuthManager) CheckPassword(password string) bool {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
return password == a.password
|
||||
}
|
||||
|
||||
// RevokeToken invalidates the specified token.
|
||||
func (a *AuthManager) RevokeToken(token string) {
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
delete(a.sessions, token)
|
||||
a.mu.Unlock()
|
||||
}
|
||||
|
||||
// SessionDurationHours returns the configured session duration in hours.
|
||||
func (a *AuthManager) SessionDurationHours() int {
|
||||
return int(a.sessionDuration / time.Hour)
|
||||
}
|
||||
|
||||
// UpdateConfig updates the password and session duration, revoking existing sessions.
|
||||
func (a *AuthManager) UpdateConfig(password string, sessionDurationHours int) error {
|
||||
password = strings.TrimSpace(password)
|
||||
if password == "" {
|
||||
return errors.New("auth password must be configured")
|
||||
}
|
||||
|
||||
if sessionDurationHours <= 0 {
|
||||
sessionDurationHours = 12
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
a.password = password
|
||||
a.sessionDuration = time.Duration(sessionDurationHours) * time.Hour
|
||||
a.sessions = make(map[string]Session)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
ContextAuthTokenKey = "authToken"
|
||||
ContextSessionExpiry = "authSessionExpiry"
|
||||
)
|
||||
|
||||
// AuthMiddleware enforces authentication on protected routes.
|
||||
func AuthMiddleware(manager *AuthManager) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := extractTokenFromRequest(c)
|
||||
session, ok := manager.ValidateToken(token)
|
||||
if !ok {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "未授权访问,请先登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(ContextAuthTokenKey, session.Token)
|
||||
c.Set(ContextSessionExpiry, session.ExpiresAt)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func extractTokenFromRequest(c *gin.Context) string {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
if len(authHeader) > 7 && strings.EqualFold(authHeader[0:7], "Bearer ") {
|
||||
return strings.TrimSpace(authHeader[7:])
|
||||
}
|
||||
return strings.TrimSpace(authHeader)
|
||||
}
|
||||
|
||||
if token := c.Query("token"); token != "" {
|
||||
return strings.TrimSpace(token)
|
||||
}
|
||||
|
||||
if cookie, err := c.Cookie("auth_token"); err == nil {
|
||||
return strings.TrimSpace(cookie)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,128 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// setupTestExecutor 创建测试用的执行器
|
||||
func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
|
||||
logger := zap.NewNop()
|
||||
mcpServer := mcp.NewServer(logger)
|
||||
|
||||
cfg := &config.SecurityConfig{
|
||||
Tools: []config.ToolConfig{},
|
||||
}
|
||||
|
||||
executor := NewExecutor(cfg, mcpServer, logger)
|
||||
return executor, mcpServer
|
||||
}
|
||||
|
||||
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"test": "value",
|
||||
}
|
||||
|
||||
// 测试未知的内部工具类型
|
||||
toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args)
|
||||
if err != nil {
|
||||
t.Fatalf("执行内部工具失败: %v", err)
|
||||
}
|
||||
|
||||
if !toolResult.IsError {
|
||||
t.Fatal("未知的工具类型应该返回错误")
|
||||
}
|
||||
|
||||
if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") {
|
||||
t.Errorf("错误消息应该包含'未知的内部工具类型'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
// 子进程先向 stdout 写无换行字符再长时间 sleep;若与 echo $pid 共享管道且未重定向子进程 stdout,
|
||||
// ReadString('\n') 会阻塞到子进程退出。后台包装须将子进程标准流与 PID 行分离。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
|
||||
defer cancel()
|
||||
args := map[string]interface{}{
|
||||
"command": `(sh -c 'printf x; sleep 120') &`,
|
||||
"shell": "sh",
|
||||
}
|
||||
res, err := executor.executeSystemCommand(ctx, args)
|
||||
if err != nil {
|
||||
t.Fatalf("executeSystemCommand: %v", err)
|
||||
}
|
||||
if res == nil || res.IsError {
|
||||
t.Fatalf("expected success, got %+v", res)
|
||||
}
|
||||
txt := res.Content[0].Text
|
||||
if !strings.Contains(txt, "后台命令已启动") {
|
||||
t.Fatalf("unexpected body: %q", txt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandArgs_NmapSkipsEmptyOptionalFlags(t *testing.T) {
|
||||
pos1 := 1
|
||||
executor, _ := setupTestExecutor(t)
|
||||
toolConfig := &config.ToolConfig{
|
||||
Name: "nmap",
|
||||
Command: "nmap",
|
||||
Args: []string{"-sT", "-sV", "-sC"},
|
||||
Parameters: []config.ParameterConfig{
|
||||
{Name: "target", Type: "string", Required: true, Position: &pos1, Format: "positional"},
|
||||
{Name: "ports", Type: "string", Flag: "-p", Format: "flag"},
|
||||
{Name: "timing", Type: "string", Template: "-T{value}", Format: "template"},
|
||||
{Name: "nse_scripts", Type: "string", Flag: "--script", Format: "flag"},
|
||||
{Name: "os_detection", Type: "bool", Flag: "-O", Format: "flag", Default: false},
|
||||
{Name: "aggressive", Type: "bool", Flag: "-A", Format: "flag", Default: false},
|
||||
{Name: "scan_type", Type: "string", Format: "template", Template: "{value}"},
|
||||
{Name: "additional_args", Type: "string", Format: "positional"},
|
||||
},
|
||||
}
|
||||
|
||||
args := map[string]interface{}{
|
||||
"target": "110.52.223.114",
|
||||
"ports": "21, 22, 80, 443",
|
||||
"timing": "4",
|
||||
"nse_scripts": "",
|
||||
"scan_type": "",
|
||||
"os_detection": false,
|
||||
"aggressive": false,
|
||||
"additional_args": "-Pn",
|
||||
}
|
||||
|
||||
cmdArgs := executor.buildCommandArgs("nmap", toolConfig, args)
|
||||
joined := strings.Join(cmdArgs, " ")
|
||||
|
||||
if strings.Contains(joined, "--script") {
|
||||
t.Fatalf("empty nse_scripts must not emit --script, got: %v", cmdArgs)
|
||||
}
|
||||
if !strings.Contains(joined, "110.52.223.114") {
|
||||
t.Fatalf("target missing from args: %v", cmdArgs)
|
||||
}
|
||||
// target 应出现在 -Pn 之前,避免被误当作 --script 的参数
|
||||
pnIdx := indexOf(cmdArgs, "-Pn")
|
||||
targetIdx := indexOf(cmdArgs, "110.52.223.114")
|
||||
if pnIdx < 0 || targetIdx < 0 || targetIdx >= pnIdx {
|
||||
t.Fatalf("expected target before -Pn, got: %v", cmdArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func indexOf(slice []string, s string) int {
|
||||
for i, v := range slice {
|
||||
if v == s {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
//go:build !windows
|
||||
|
||||
package security
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// prepareShellCmdSession 让 shell 子进程在独立会话中运行,便于超时/取消时整组 SIGKILL(含子进程)。
|
||||
func prepareShellCmdSession(cmd *exec.Cmd) error {
|
||||
if cmd == nil {
|
||||
return nil
|
||||
}
|
||||
if cmd.SysProcAttr == nil {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||
}
|
||||
cmd.SysProcAttr.Setsid = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// terminateCmdTree 尽力终止 cmd 及其进程组(Unix 下 Setsid 后 PGID == 首进程 PID)。
|
||||
func terminateCmdTree(cmd *exec.Cmd) {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
pid := cmd.Process.Pid
|
||||
if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
//go:build windows
|
||||
|
||||
package security
|
||||
|
||||
import "os/exec"
|
||||
|
||||
func prepareShellCmdSession(cmd *exec.Cmd) error {
|
||||
_ = cmd
|
||||
return nil
|
||||
}
|
||||
|
||||
func terminateCmdTree(cmd *exec.Cmd) {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// rateLimitEntry 记录某个 IP 的请求窗口信息
|
||||
type rateLimitEntry struct {
|
||||
count int
|
||||
windowAt time.Time
|
||||
}
|
||||
|
||||
// RateLimiter 基于 IP 的滑动窗口速率限制器
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*rateLimitEntry
|
||||
limit int // 窗口内允许的最大请求数
|
||||
window time.Duration // 窗口时长
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建速率限制器
|
||||
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
entries: make(map[string]*rateLimitEntry),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
// 后台定期清理过期条目,防止内存泄漏
|
||||
go rl.cleanup()
|
||||
return rl
|
||||
}
|
||||
|
||||
// cleanup 每分钟清理一次过期条目
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
now := time.Now()
|
||||
for ip, entry := range rl.entries {
|
||||
if now.Sub(entry.windowAt) > rl.window {
|
||||
delete(rl.entries, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// allow 检查指定 IP 是否允许通过
|
||||
func (rl *RateLimiter) allow(ip string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
entry, ok := rl.entries[ip]
|
||||
if !ok || now.Sub(entry.windowAt) > rl.window {
|
||||
rl.entries[ip] = &rateLimitEntry{count: 1, windowAt: now}
|
||||
return true
|
||||
}
|
||||
|
||||
entry.count++
|
||||
return entry.count <= rl.limit
|
||||
}
|
||||
|
||||
// RateLimitMiddleware 返回 Gin 中间件,对超限请求返回 429
|
||||
func RateLimitMiddleware(rl *RateLimiter) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
if !rl.allow(ip) {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded, please try again later",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user