mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-21 18:26:38 +02:00
Delete security directory
This commit is contained in:
@@ -1,132 +0,0 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Predefined errors for authentication operations.
|
||||
var (
|
||||
ErrInvalidPassword = errors.New("invalid password")
|
||||
)
|
||||
|
||||
// Session represents an authenticated user session.
|
||||
type Session struct {
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// AuthManager manages password-based authentication and session lifecycle.
|
||||
type AuthManager struct {
|
||||
password string
|
||||
sessionDuration time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
sessions map[string]Session
|
||||
}
|
||||
|
||||
// NewAuthManager creates a new AuthManager instance.
|
||||
func NewAuthManager(password string, sessionDurationHours int) (*AuthManager, error) {
|
||||
if strings.TrimSpace(password) == "" {
|
||||
return nil, errors.New("auth password must be configured")
|
||||
}
|
||||
|
||||
if sessionDurationHours <= 0 {
|
||||
sessionDurationHours = 12
|
||||
}
|
||||
|
||||
return &AuthManager{
|
||||
password: password,
|
||||
sessionDuration: time.Duration(sessionDurationHours) * time.Hour,
|
||||
sessions: make(map[string]Session),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Authenticate validates the password and creates a new session.
|
||||
func (a *AuthManager) Authenticate(password string) (string, time.Time, error) {
|
||||
if password != a.password {
|
||||
return "", time.Time{}, ErrInvalidPassword
|
||||
}
|
||||
|
||||
token := uuid.NewString()
|
||||
expiresAt := time.Now().Add(a.sessionDuration)
|
||||
|
||||
a.mu.Lock()
|
||||
a.sessions[token] = Session{
|
||||
Token: token,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
return token, expiresAt, nil
|
||||
}
|
||||
|
||||
// ValidateToken checks whether the provided token is still valid.
|
||||
func (a *AuthManager) ValidateToken(token string) (Session, bool) {
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return Session{}, false
|
||||
}
|
||||
|
||||
a.mu.RLock()
|
||||
session, ok := a.sessions[token]
|
||||
a.mu.RUnlock()
|
||||
if !ok {
|
||||
return Session{}, false
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
a.mu.Lock()
|
||||
delete(a.sessions, token)
|
||||
a.mu.Unlock()
|
||||
return Session{}, false
|
||||
}
|
||||
|
||||
return session, true
|
||||
}
|
||||
|
||||
// CheckPassword verifies whether the provided password matches the current password.
|
||||
func (a *AuthManager) CheckPassword(password string) bool {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
return password == a.password
|
||||
}
|
||||
|
||||
// RevokeToken invalidates the specified token.
|
||||
func (a *AuthManager) RevokeToken(token string) {
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
delete(a.sessions, token)
|
||||
a.mu.Unlock()
|
||||
}
|
||||
|
||||
// SessionDurationHours returns the configured session duration in hours.
|
||||
func (a *AuthManager) SessionDurationHours() int {
|
||||
return int(a.sessionDuration / time.Hour)
|
||||
}
|
||||
|
||||
// UpdateConfig updates the password and session duration, revoking existing sessions.
|
||||
func (a *AuthManager) UpdateConfig(password string, sessionDurationHours int) error {
|
||||
password = strings.TrimSpace(password)
|
||||
if password == "" {
|
||||
return errors.New("auth password must be configured")
|
||||
}
|
||||
|
||||
if sessionDurationHours <= 0 {
|
||||
sessionDurationHours = 12
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
a.password = password
|
||||
a.sessionDuration = time.Duration(sessionDurationHours) * time.Hour
|
||||
a.sessions = make(map[string]Session)
|
||||
return nil
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
ContextAuthTokenKey = "authToken"
|
||||
ContextSessionExpiry = "authSessionExpiry"
|
||||
)
|
||||
|
||||
// AuthMiddleware enforces authentication on protected routes.
|
||||
func AuthMiddleware(manager *AuthManager) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := extractTokenFromRequest(c)
|
||||
session, ok := manager.ValidateToken(token)
|
||||
if !ok {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "未授权访问,请先登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(ContextAuthTokenKey, session.Token)
|
||||
c.Set(ContextSessionExpiry, session.ExpiresAt)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func extractTokenFromRequest(c *gin.Context) string {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
if len(authHeader) > 7 && strings.EqualFold(authHeader[0:7], "Bearer ") {
|
||||
return strings.TrimSpace(authHeader[7:])
|
||||
}
|
||||
return strings.TrimSpace(authHeader)
|
||||
}
|
||||
|
||||
if token := c.Query("token"); token != "" {
|
||||
return strings.TrimSpace(token)
|
||||
}
|
||||
|
||||
if cookie, err := c.Cookie("auth_token"); err == nil {
|
||||
return strings.TrimSpace(cookie)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,268 +0,0 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// setupTestExecutor 创建测试用的执行器
|
||||
func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
|
||||
logger := zap.NewNop()
|
||||
mcpServer := mcp.NewServer(logger)
|
||||
|
||||
cfg := &config.SecurityConfig{
|
||||
Tools: []config.ToolConfig{},
|
||||
}
|
||||
|
||||
executor := NewExecutor(cfg, mcpServer, logger)
|
||||
return executor, mcpServer
|
||||
}
|
||||
|
||||
// setupTestStorage 创建测试用的存储
|
||||
func setupTestStorage(t *testing.T) *storage.FileResultStorage {
|
||||
tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405"))
|
||||
logger := zap.NewNop()
|
||||
|
||||
storage, err := storage.NewFileResultStorage(tmpDir, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("创建测试存储失败: %v", err)
|
||||
}
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
testStorage := setupTestStorage(t)
|
||||
executor.SetResultStorage(testStorage)
|
||||
|
||||
// 准备测试数据
|
||||
executionID := "test_exec_001"
|
||||
toolName := "nmap_scan"
|
||||
result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred"
|
||||
|
||||
// 保存测试结果
|
||||
err := testStorage.SaveResult(executionID, toolName, result)
|
||||
if err != nil {
|
||||
t.Fatalf("保存测试结果失败: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 测试1: 基本查询(第一页)
|
||||
args := map[string]interface{}{
|
||||
"execution_id": executionID,
|
||||
"page": float64(1),
|
||||
"limit": float64(2),
|
||||
}
|
||||
|
||||
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
|
||||
if err != nil {
|
||||
t.Fatalf("执行查询失败: %v", err)
|
||||
}
|
||||
|
||||
if toolResult.IsError {
|
||||
t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text)
|
||||
}
|
||||
|
||||
// 验证结果包含预期内容
|
||||
resultText := toolResult.Content[0].Text
|
||||
if !strings.Contains(resultText, executionID) {
|
||||
t.Errorf("结果中应该包含执行ID: %s", executionID)
|
||||
}
|
||||
|
||||
if !strings.Contains(resultText, "第 1/") {
|
||||
t.Errorf("结果中应该包含分页信息")
|
||||
}
|
||||
|
||||
// 测试2: 搜索功能
|
||||
args2 := map[string]interface{}{
|
||||
"execution_id": executionID,
|
||||
"search": "error",
|
||||
"page": float64(1),
|
||||
"limit": float64(10),
|
||||
}
|
||||
|
||||
toolResult2, err := executor.executeQueryExecutionResult(ctx, args2)
|
||||
if err != nil {
|
||||
t.Fatalf("执行搜索失败: %v", err)
|
||||
}
|
||||
|
||||
if toolResult2.IsError {
|
||||
t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text)
|
||||
}
|
||||
|
||||
resultText2 := toolResult2.Content[0].Text
|
||||
if !strings.Contains(resultText2, "error") {
|
||||
t.Errorf("搜索结果中应该包含关键词: error")
|
||||
}
|
||||
|
||||
// 测试3: 过滤功能
|
||||
args3 := map[string]interface{}{
|
||||
"execution_id": executionID,
|
||||
"filter": "Port",
|
||||
"page": float64(1),
|
||||
"limit": float64(10),
|
||||
}
|
||||
|
||||
toolResult3, err := executor.executeQueryExecutionResult(ctx, args3)
|
||||
if err != nil {
|
||||
t.Fatalf("执行过滤失败: %v", err)
|
||||
}
|
||||
|
||||
if toolResult3.IsError {
|
||||
t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text)
|
||||
}
|
||||
|
||||
resultText3 := toolResult3.Content[0].Text
|
||||
if !strings.Contains(resultText3, "Port") {
|
||||
t.Errorf("过滤结果中应该包含关键词: Port")
|
||||
}
|
||||
|
||||
// 测试4: 缺少必需参数
|
||||
args4 := map[string]interface{}{
|
||||
"page": float64(1),
|
||||
}
|
||||
|
||||
toolResult4, err := executor.executeQueryExecutionResult(ctx, args4)
|
||||
if err != nil {
|
||||
t.Fatalf("执行查询失败: %v", err)
|
||||
}
|
||||
|
||||
if !toolResult4.IsError {
|
||||
t.Fatal("缺少execution_id应该返回错误")
|
||||
}
|
||||
|
||||
// 测试5: 不存在的执行ID
|
||||
args5 := map[string]interface{}{
|
||||
"execution_id": "nonexistent_id",
|
||||
"page": float64(1),
|
||||
}
|
||||
|
||||
toolResult5, err := executor.executeQueryExecutionResult(ctx, args5)
|
||||
if err != nil {
|
||||
t.Fatalf("执行查询失败: %v", err)
|
||||
}
|
||||
|
||||
if !toolResult5.IsError {
|
||||
t.Fatal("不存在的执行ID应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"test": "value",
|
||||
}
|
||||
|
||||
// 测试未知的内部工具类型
|
||||
toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args)
|
||||
if err != nil {
|
||||
t.Fatalf("执行内部工具失败: %v", err)
|
||||
}
|
||||
|
||||
if !toolResult.IsError {
|
||||
t.Fatal("未知的工具类型应该返回错误")
|
||||
}
|
||||
|
||||
if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") {
|
||||
t.Errorf("错误消息应该包含'未知的内部工具类型'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
// 不设置存储,测试未初始化的情况
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"execution_id": "test_id",
|
||||
}
|
||||
|
||||
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
|
||||
if err != nil {
|
||||
t.Fatalf("执行查询失败: %v", err)
|
||||
}
|
||||
|
||||
if !toolResult.IsError {
|
||||
t.Fatal("未初始化的存储应该返回错误")
|
||||
}
|
||||
|
||||
if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") {
|
||||
t.Errorf("错误消息应该包含'结果存储未初始化'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginateLines(t *testing.T) {
|
||||
lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"}
|
||||
|
||||
// 测试第一页
|
||||
page := paginateLines(lines, 1, 2)
|
||||
if page.Page != 1 {
|
||||
t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page)
|
||||
}
|
||||
if page.Limit != 2 {
|
||||
t.Errorf("每页行数不匹配。期望: 2, 实际: %d", page.Limit)
|
||||
}
|
||||
if page.TotalLines != 5 {
|
||||
t.Errorf("总行数不匹配。期望: 5, 实际: %d", page.TotalLines)
|
||||
}
|
||||
if page.TotalPages != 3 {
|
||||
t.Errorf("总页数不匹配。期望: 3, 实际: %d", page.TotalPages)
|
||||
}
|
||||
if len(page.Lines) != 2 {
|
||||
t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines))
|
||||
}
|
||||
|
||||
// 测试第二页
|
||||
page2 := paginateLines(lines, 2, 2)
|
||||
if len(page2.Lines) != 2 {
|
||||
t.Errorf("第二页行数不匹配。期望: 2, 实际: %d", len(page2.Lines))
|
||||
}
|
||||
if page2.Lines[0] != "Line 3" {
|
||||
t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0])
|
||||
}
|
||||
|
||||
// 测试最后一页
|
||||
page3 := paginateLines(lines, 3, 2)
|
||||
if len(page3.Lines) != 1 {
|
||||
t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines))
|
||||
}
|
||||
|
||||
// 测试超出范围的页码(应该返回最后一页)
|
||||
page4 := paginateLines(lines, 4, 2)
|
||||
if page4.Page != 3 {
|
||||
t.Errorf("超出范围的页码应该被修正为最后一页。期望: 3, 实际: %d", page4.Page)
|
||||
}
|
||||
if len(page4.Lines) != 1 {
|
||||
t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines))
|
||||
}
|
||||
|
||||
// 测试无效页码(小于1)
|
||||
page0 := paginateLines(lines, 0, 2)
|
||||
if page0.Page != 1 {
|
||||
t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page)
|
||||
}
|
||||
|
||||
// 测试空列表
|
||||
emptyPage := paginateLines([]string{}, 1, 10)
|
||||
if emptyPage.TotalLines != 0 {
|
||||
t.Errorf("空列表的总行数应该为0。实际: %d", emptyPage.TotalLines)
|
||||
}
|
||||
if len(emptyPage.Lines) != 0 {
|
||||
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user