Add files via upload

This commit is contained in:
公明
2026-04-19 19:14:53 +08:00
committed by GitHub
parent 42fed78227
commit cad7611548
77 changed files with 34253 additions and 0 deletions
+132
View File
@@ -0,0 +1,132 @@
package security
import (
"errors"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
// Predefined errors for authentication operations.
var (
ErrInvalidPassword = errors.New("invalid password")
)
// Session represents an authenticated user session.
type Session struct {
Token string
ExpiresAt time.Time
}
// AuthManager manages password-based authentication and session lifecycle.
type AuthManager struct {
password string
sessionDuration time.Duration
mu sync.RWMutex
sessions map[string]Session
}
// NewAuthManager creates a new AuthManager instance.
func NewAuthManager(password string, sessionDurationHours int) (*AuthManager, error) {
if strings.TrimSpace(password) == "" {
return nil, errors.New("auth password must be configured")
}
if sessionDurationHours <= 0 {
sessionDurationHours = 12
}
return &AuthManager{
password: password,
sessionDuration: time.Duration(sessionDurationHours) * time.Hour,
sessions: make(map[string]Session),
}, nil
}
// Authenticate validates the password and creates a new session.
func (a *AuthManager) Authenticate(password string) (string, time.Time, error) {
if password != a.password {
return "", time.Time{}, ErrInvalidPassword
}
token := uuid.NewString()
expiresAt := time.Now().Add(a.sessionDuration)
a.mu.Lock()
a.sessions[token] = Session{
Token: token,
ExpiresAt: expiresAt,
}
a.mu.Unlock()
return token, expiresAt, nil
}
// ValidateToken checks whether the provided token is still valid.
func (a *AuthManager) ValidateToken(token string) (Session, bool) {
if strings.TrimSpace(token) == "" {
return Session{}, false
}
a.mu.RLock()
session, ok := a.sessions[token]
a.mu.RUnlock()
if !ok {
return Session{}, false
}
if time.Now().After(session.ExpiresAt) {
a.mu.Lock()
delete(a.sessions, token)
a.mu.Unlock()
return Session{}, false
}
return session, true
}
// CheckPassword verifies whether the provided password matches the current password.
func (a *AuthManager) CheckPassword(password string) bool {
a.mu.RLock()
defer a.mu.RUnlock()
return password == a.password
}
// RevokeToken invalidates the specified token.
func (a *AuthManager) RevokeToken(token string) {
if strings.TrimSpace(token) == "" {
return
}
a.mu.Lock()
delete(a.sessions, token)
a.mu.Unlock()
}
// SessionDurationHours returns the configured session duration in hours.
func (a *AuthManager) SessionDurationHours() int {
return int(a.sessionDuration / time.Hour)
}
// UpdateConfig updates the password and session duration, revoking existing sessions.
func (a *AuthManager) UpdateConfig(password string, sessionDurationHours int) error {
password = strings.TrimSpace(password)
if password == "" {
return errors.New("auth password must be configured")
}
if sessionDurationHours <= 0 {
sessionDurationHours = 12
}
a.mu.Lock()
defer a.mu.Unlock()
a.password = password
a.sessionDuration = time.Duration(sessionDurationHours) * time.Hour
a.sessions = make(map[string]Session)
return nil
}
+51
View File
@@ -0,0 +1,51 @@
package security
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
const (
ContextAuthTokenKey = "authToken"
ContextSessionExpiry = "authSessionExpiry"
)
// AuthMiddleware enforces authentication on protected routes.
func AuthMiddleware(manager *AuthManager) gin.HandlerFunc {
return func(c *gin.Context) {
token := extractTokenFromRequest(c)
session, ok := manager.ValidateToken(token)
if !ok {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "未授权访问,请先登录",
})
return
}
c.Set(ContextAuthTokenKey, session.Token)
c.Set(ContextSessionExpiry, session.ExpiresAt)
c.Next()
}
}
func extractTokenFromRequest(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
if len(authHeader) > 7 && strings.EqualFold(authHeader[0:7], "Bearer ") {
return strings.TrimSpace(authHeader[7:])
}
return strings.TrimSpace(authHeader)
}
if token := c.Query("token"); token != "" {
return strings.TrimSpace(token)
}
if cookie, err := c.Cookie("auth_token"); err == nil {
return strings.TrimSpace(cookie)
}
return ""
}
+1575
View File
File diff suppressed because it is too large Load Diff
+268
View File
@@ -0,0 +1,268 @@
package security
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/storage"
"go.uber.org/zap"
)
// setupTestExecutor 创建测试用的执行器
func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
cfg := &config.SecurityConfig{
Tools: []config.ToolConfig{},
}
executor := NewExecutor(cfg, mcpServer, logger)
return executor, mcpServer
}
// setupTestStorage 创建测试用的存储
func setupTestStorage(t *testing.T) *storage.FileResultStorage {
tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405"))
logger := zap.NewNop()
storage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
return storage
}
func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
executor, _ := setupTestExecutor(t)
testStorage := setupTestStorage(t)
executor.SetResultStorage(testStorage)
// 准备测试数据
executionID := "test_exec_001"
toolName := "nmap_scan"
result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred"
// 保存测试结果
err := testStorage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存测试结果失败: %v", err)
}
ctx := context.Background()
// 测试1: 基本查询(第一页)
args := map[string]interface{}{
"execution_id": executionID,
"page": float64(1),
"limit": float64(2),
}
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if toolResult.IsError {
t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text)
}
// 验证结果包含预期内容
resultText := toolResult.Content[0].Text
if !strings.Contains(resultText, executionID) {
t.Errorf("结果中应该包含执行ID: %s", executionID)
}
if !strings.Contains(resultText, "第 1/") {
t.Errorf("结果中应该包含分页信息")
}
// 测试2: 搜索功能
args2 := map[string]interface{}{
"execution_id": executionID,
"search": "error",
"page": float64(1),
"limit": float64(10),
}
toolResult2, err := executor.executeQueryExecutionResult(ctx, args2)
if err != nil {
t.Fatalf("执行搜索失败: %v", err)
}
if toolResult2.IsError {
t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text)
}
resultText2 := toolResult2.Content[0].Text
if !strings.Contains(resultText2, "error") {
t.Errorf("搜索结果中应该包含关键词: error")
}
// 测试3: 过滤功能
args3 := map[string]interface{}{
"execution_id": executionID,
"filter": "Port",
"page": float64(1),
"limit": float64(10),
}
toolResult3, err := executor.executeQueryExecutionResult(ctx, args3)
if err != nil {
t.Fatalf("执行过滤失败: %v", err)
}
if toolResult3.IsError {
t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text)
}
resultText3 := toolResult3.Content[0].Text
if !strings.Contains(resultText3, "Port") {
t.Errorf("过滤结果中应该包含关键词: Port")
}
// 测试4: 缺少必需参数
args4 := map[string]interface{}{
"page": float64(1),
}
toolResult4, err := executor.executeQueryExecutionResult(ctx, args4)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult4.IsError {
t.Fatal("缺少execution_id应该返回错误")
}
// 测试5: 不存在的执行ID
args5 := map[string]interface{}{
"execution_id": "nonexistent_id",
"page": float64(1),
}
toolResult5, err := executor.executeQueryExecutionResult(ctx, args5)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult5.IsError {
t.Fatal("不存在的执行ID应该返回错误")
}
}
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
executor, _ := setupTestExecutor(t)
ctx := context.Background()
args := map[string]interface{}{
"test": "value",
}
// 测试未知的内部工具类型
toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args)
if err != nil {
t.Fatalf("执行内部工具失败: %v", err)
}
if !toolResult.IsError {
t.Fatal("未知的工具类型应该返回错误")
}
if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") {
t.Errorf("错误消息应该包含'未知的内部工具类型'")
}
}
func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
executor, _ := setupTestExecutor(t)
// 不设置存储,测试未初始化的情况
ctx := context.Background()
args := map[string]interface{}{
"execution_id": "test_id",
}
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult.IsError {
t.Fatal("未初始化的存储应该返回错误")
}
if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") {
t.Errorf("错误消息应该包含'结果存储未初始化'")
}
}
func TestPaginateLines(t *testing.T) {
lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"}
// 测试第一页
page := paginateLines(lines, 1, 2)
if page.Page != 1 {
t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page)
}
if page.Limit != 2 {
t.Errorf("每页行数不匹配。期望: 2, 实际: %d", page.Limit)
}
if page.TotalLines != 5 {
t.Errorf("总行数不匹配。期望: 5, 实际: %d", page.TotalLines)
}
if page.TotalPages != 3 {
t.Errorf("总页数不匹配。期望: 3, 实际: %d", page.TotalPages)
}
if len(page.Lines) != 2 {
t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines))
}
// 测试第二页
page2 := paginateLines(lines, 2, 2)
if len(page2.Lines) != 2 {
t.Errorf("第二页行数不匹配。期望: 2, 实际: %d", len(page2.Lines))
}
if page2.Lines[0] != "Line 3" {
t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0])
}
// 测试最后一页
page3 := paginateLines(lines, 3, 2)
if len(page3.Lines) != 1 {
t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines))
}
// 测试超出范围的页码(应该返回最后一页)
page4 := paginateLines(lines, 4, 2)
if page4.Page != 3 {
t.Errorf("超出范围的页码应该被修正为最后一页。期望: 3, 实际: %d", page4.Page)
}
if len(page4.Lines) != 1 {
t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines))
}
// 测试无效页码(小于1
page0 := paginateLines(lines, 0, 2)
if page0.Page != 1 {
t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page)
}
// 测试空列表
emptyPage := paginateLines([]string{}, 1, 10)
if emptyPage.TotalLines != 0 {
t.Errorf("空列表的总行数应该为0。实际: %d", emptyPage.TotalLines)
}
if len(emptyPage.Lines) != 0 {
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
}
}