Delete internal directory

This commit is contained in:
公明
2026-04-19 01:14:50 +08:00
committed by GitHub
parent 5c444afe06
commit 5ef7618f44
98 changed files with 0 additions and 43993 deletions
File diff suppressed because it is too large Load Diff
-286
View File
@@ -1,286 +0,0 @@
package agent
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/storage"
"go.uber.org/zap"
)
// setupTestAgent 创建测试用的Agent
func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{
APIKey: "test-key",
BaseURL: "https://api.test.com/v1",
Model: "test-model",
}
agentCfg := &config.AgentConfig{
MaxIterations: 10,
LargeResultThreshold: 100, // 设置较小的阈值便于测试
ResultStorageDir: "",
}
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10)
// 创建测试存储
tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405"))
testStorage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
agent.SetResultStorage(testStorage)
return agent, testStorage
}
func TestAgent_FormatMinimalNotification(t *testing.T) {
agent, testStorage := setupTestAgent(t)
_ = testStorage // 避免未使用变量警告
executionID := "test_exec_001"
toolName := "nmap_scan"
size := 50000
lineCount := 1000
filePath := "tmp/test_exec_001.txt"
notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath)
// 验证通知包含必要信息
if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID: %s", executionID)
}
if !strings.Contains(notification, toolName) {
t.Errorf("通知中应该包含工具名称: %s", toolName)
}
if !strings.Contains(notification, "50000") {
t.Errorf("通知中应该包含大小信息")
}
if !strings.Contains(notification, "1000") {
t.Errorf("通知中应该包含行数信息")
}
if !strings.Contains(notification, "query_execution_result") {
t.Errorf("通知中应该包含查询工具的使用说明")
}
}
func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建模拟的MCP工具结果(大结果)
largeResult := &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: strings.Repeat("This is a test line with some content.\n", 1000), // 约50KB
},
},
IsError: false,
}
// 模拟MCP服务器返回大结果
// 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器
// 为了简化测试,我们直接测试结果处理逻辑
// 设置阈值
agent.mu.Lock()
agent.largeResultThreshold = 1000 // 设置较小的阈值
agent.mu.Unlock()
// 创建执行ID
executionID := "test_exec_large_001"
toolName := "test_tool"
// 格式化结果
var resultText strings.Builder
for _, content := range largeResult.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果并保存
agent.mu.RLock()
threshold := agent.largeResultThreshold
storage := agent.resultStorage
agent.mu.RUnlock()
if resultSize > threshold && storage != nil {
// 保存大结果
err := storage.SaveResult(executionID, toolName, resultStr)
if err != nil {
t.Fatalf("保存大结果失败: %v", err)
}
// 生成通知
lines := strings.Split(resultStr, "\n")
filePath := storage.GetResultPath(executionID)
notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
// 验证通知格式
if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID")
}
// 验证结果已保存
savedResult, err := storage.GetResult(executionID)
if err != nil {
t.Fatalf("获取保存的结果失败: %v", err)
}
if savedResult != resultStr {
t.Errorf("保存的结果与原始结果不匹配")
}
} else {
t.Fatal("大结果应该被检测到并保存")
}
}
func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建小结果
smallResult := &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "Small result content",
},
},
IsError: false,
}
// 设置较大的阈值
agent.mu.Lock()
agent.largeResultThreshold = 100000 // 100KB
agent.mu.Unlock()
// 格式化结果
var resultText strings.Builder
for _, content := range smallResult.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果
agent.mu.RLock()
threshold := agent.largeResultThreshold
storage := agent.resultStorage
agent.mu.RUnlock()
if resultSize > threshold && storage != nil {
t.Fatal("小结果不应该被保存")
}
// 小结果应该直接返回
if resultSize <= threshold {
// 这是预期的行为
if resultStr == "" {
t.Fatal("小结果应该直接返回,不应该为空")
}
}
}
func TestAgent_SetResultStorage(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建新的存储
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop())
if err != nil {
t.Fatalf("创建新存储失败: %v", err)
}
// 设置新存储
agent.SetResultStorage(newStorage)
// 验证存储已更新
agent.mu.RLock()
currentStorage := agent.resultStorage
agent.mu.RUnlock()
if currentStorage != newStorage {
t.Fatal("存储未正确更新")
}
// 清理
os.RemoveAll(tmpDir)
}
func TestAgent_NewAgent_DefaultValues(t *testing.T) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{
APIKey: "test-key",
BaseURL: "https://api.test.com/v1",
Model: "test-model",
}
// 测试默认配置
agent := NewAgent(openAICfg, nil, mcpServer, nil, logger, 0)
if agent.maxIterations != 30 {
t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations)
}
agent.mu.RLock()
threshold := agent.largeResultThreshold
agent.mu.RUnlock()
if threshold != 50*1024 {
t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold)
}
}
func TestAgent_NewAgent_CustomConfig(t *testing.T) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{
APIKey: "test-key",
BaseURL: "https://api.test.com/v1",
Model: "test-model",
}
agentCfg := &config.AgentConfig{
MaxIterations: 20,
LargeResultThreshold: 100 * 1024, // 100KB
ResultStorageDir: "custom_tmp",
}
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15)
if agent.maxIterations != 15 {
t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations)
}
agent.mu.RLock()
threshold := agent.largeResultThreshold
agent.mu.RUnlock()
if threshold != 100*1024 {
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
}
}
-491
View File
@@ -1,491 +0,0 @@
package agent
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/openai"
"github.com/pkoukk/tiktoken-go"
"go.uber.org/zap"
)
const (
// DefaultMinRecentMessage 压缩历史消息时保留的最近消息数量,确保最近的对话上下文不被压缩
DefaultMinRecentMessage = 5
// defaultChunkSize 压缩历史消息时每次处理的消息块大小,将旧消息分成多个块进行摘要
defaultChunkSize = 10
// defaultMaxImages 压缩时最多保留的图片数量,超过此数量的图片会被移除以节省上下文空间
defaultMaxImages = 3
// defaultSummaryTimeout 生成消息摘要时的超时时间
defaultSummaryTimeout = 10 * time.Minute
summaryPromptTemplate = `你是一名负责为安全代理执行上下文压缩的助手,任务是在保持所有关键渗透信息完整的前提下压缩扫描数据。
必须保留的关键信息:
- 已发现的漏洞与潜在攻击路径
- 扫描结果与工具输出(可压缩,但需保留核心发现)
- 获取到的访问凭证、令牌或认证细节
- 系统架构洞察与潜在薄弱点
- 当前评估进展
- 失败尝试与死路(避免重复劳动)
- 关于测试策略的所有决策记录
压缩指南:
- 保留精确技术细节(URL、路径、参数、Payload 等)
- 将冗长的工具输出压缩成概述,但保留关键发现
- 记录版本号与识别出的技术/组件信息
- 保留可能暗示漏洞的原始报错
- 将重复或相似发现整合成一条带有共性说明的结论
请牢记:另一位安全代理会依赖这份摘要继续测试,他必须在不损失任何作战上下文的情况下无缝接手。
需要压缩的对话片段:
%s
请给出技术精准且简明扼要的摘要,覆盖全部与安全评估相关的上下文。`
)
// MemoryCompressor 负责在调用LLM前压缩历史上下文,以避免Token爆炸。
type MemoryCompressor struct {
maxTotalTokens int
minRecentMessage int
maxImages int
chunkSize int
summaryModel string
timeout time.Duration
tokenCounter TokenCounter
completionClient CompletionClient
logger *zap.Logger
}
// MemoryCompressorConfig 用于初始化 MemoryCompressor。
type MemoryCompressorConfig struct {
MaxTotalTokens int
MinRecentMessage int
MaxImages int
ChunkSize int
SummaryModel string
Timeout time.Duration
TokenCounter TokenCounter
CompletionClient CompletionClient
Logger *zap.Logger
// 当 CompletionClient 为空时,可以通过 OpenAIConfig + HTTPClient 构造默认的客户端。
OpenAIConfig *config.OpenAIConfig
HTTPClient *http.Client
}
// NewMemoryCompressor 创建新的 MemoryCompressor。
func NewMemoryCompressor(cfg MemoryCompressorConfig) (*MemoryCompressor, error) {
if cfg.Logger == nil {
cfg.Logger = zap.NewNop()
}
// 如果没有显式配置 MaxTotalTokens,则后续逻辑会根据模型的最大上下文长度进行控制;
// 优先推荐在 config.yaml 的 openai.max_total_tokens 中统一配置。
if cfg.MinRecentMessage <= 0 {
cfg.MinRecentMessage = DefaultMinRecentMessage
}
if cfg.MaxImages <= 0 {
cfg.MaxImages = defaultMaxImages
}
if cfg.ChunkSize <= 0 {
cfg.ChunkSize = defaultChunkSize
}
if cfg.Timeout <= 0 {
cfg.Timeout = defaultSummaryTimeout
}
if cfg.SummaryModel == "" && cfg.OpenAIConfig != nil && cfg.OpenAIConfig.Model != "" {
cfg.SummaryModel = cfg.OpenAIConfig.Model
}
if cfg.SummaryModel == "" {
return nil, errors.New("summary model is required (either SummaryModel or OpenAIConfig.Model must be set)")
}
if cfg.TokenCounter == nil {
cfg.TokenCounter = NewTikTokenCounter()
}
if cfg.CompletionClient == nil {
if cfg.OpenAIConfig == nil {
return nil, errors.New("memory compressor requires either CompletionClient or OpenAIConfig")
}
if cfg.HTTPClient == nil {
cfg.HTTPClient = &http.Client{
Timeout: 5 * time.Minute,
}
}
cfg.CompletionClient = NewOpenAICompletionClient(cfg.OpenAIConfig, cfg.HTTPClient, cfg.Logger)
}
return &MemoryCompressor{
maxTotalTokens: cfg.MaxTotalTokens,
minRecentMessage: cfg.MinRecentMessage,
maxImages: cfg.MaxImages,
chunkSize: cfg.ChunkSize,
summaryModel: cfg.SummaryModel,
timeout: cfg.Timeout,
tokenCounter: cfg.TokenCounter,
completionClient: cfg.CompletionClient,
logger: cfg.Logger,
}, nil
}
// UpdateConfig 更新OpenAI配置(用于动态更新模型配置)
func (mc *MemoryCompressor) UpdateConfig(cfg *config.OpenAIConfig) {
if cfg == nil {
return
}
// 更新summaryModel字段
if cfg.Model != "" {
mc.summaryModel = cfg.Model
}
// 更新completionClient中的配置(如果是OpenAICompletionClient
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
openAIClient.UpdateConfig(cfg)
mc.logger.Info("MemoryCompressor配置已更新",
zap.String("model", cfg.Model),
)
}
}
// CompressHistory 根据 Token 限制压缩历史消息。reservedTokens 为预留给 tools 等非消息内容的 token 数,压缩时使用 (maxTotalTokens - reservedTokens) 作为消息上限。
func (mc *MemoryCompressor) CompressHistory(ctx context.Context, messages []ChatMessage, reservedTokens int) ([]ChatMessage, bool, error) {
if len(messages) == 0 {
return messages, false, nil
}
mc.handleImages(messages)
systemMsgs, regularMsgs := mc.splitMessages(messages)
if len(regularMsgs) <= mc.minRecentMessage {
return messages, false, nil
}
effectiveMax := mc.maxTotalTokens
if reservedTokens > 0 && reservedTokens < mc.maxTotalTokens {
effectiveMax = mc.maxTotalTokens - reservedTokens
}
totalTokens := mc.countTotalTokens(systemMsgs, regularMsgs)
if totalTokens <= int(float64(effectiveMax)*0.9) {
return messages, false, nil
}
recentStart := len(regularMsgs) - mc.minRecentMessage
recentStart = mc.adjustRecentStartForToolCalls(regularMsgs, recentStart)
oldMsgs := regularMsgs[:recentStart]
recentMsgs := regularMsgs[recentStart:]
mc.logger.Info("memory compression triggered",
zap.Int("total_tokens", totalTokens),
zap.Int("max_total_tokens", mc.maxTotalTokens),
zap.Int("reserved_tokens", reservedTokens),
zap.Int("effective_max", effectiveMax),
zap.Int("system_messages", len(systemMsgs)),
zap.Int("regular_messages", len(regularMsgs)),
zap.Int("old_messages", len(oldMsgs)),
zap.Int("recent_messages", len(recentMsgs)))
var compressed []ChatMessage
for i := 0; i < len(oldMsgs); i += mc.chunkSize {
end := i + mc.chunkSize
if end > len(oldMsgs) {
end = len(oldMsgs)
}
chunk := oldMsgs[i:end]
if len(chunk) == 0 {
continue
}
summary, err := mc.summarizeChunk(ctx, chunk)
if err != nil {
mc.logger.Warn("chunk summary failed, fallback to raw chunk",
zap.Error(err),
zap.Int("start", i),
zap.Int("end", end))
compressed = append(compressed, chunk...)
continue
}
compressed = append(compressed, summary)
}
finalMessages := make([]ChatMessage, 0, len(systemMsgs)+len(compressed)+len(recentMsgs))
finalMessages = append(finalMessages, systemMsgs...)
finalMessages = append(finalMessages, compressed...)
finalMessages = append(finalMessages, recentMsgs...)
return finalMessages, true, nil
}
func (mc *MemoryCompressor) handleImages(messages []ChatMessage) {
if mc.maxImages <= 0 {
return
}
count := 0
for i := len(messages) - 1; i >= 0; i-- {
content := messages[i].Content
if !strings.Contains(content, "[IMAGE]") {
continue
}
count++
if count > mc.maxImages {
messages[i].Content = "[Previously attached image removed to preserve context]"
}
}
}
func (mc *MemoryCompressor) splitMessages(messages []ChatMessage) (systemMsgs, regularMsgs []ChatMessage) {
for _, msg := range messages {
if strings.EqualFold(msg.Role, "system") {
systemMsgs = append(systemMsgs, msg)
} else {
regularMsgs = append(regularMsgs, msg)
}
}
return
}
func (mc *MemoryCompressor) countTotalTokens(systemMsgs, regularMsgs []ChatMessage) int {
total := 0
for _, msg := range systemMsgs {
total += mc.countTokens(msg.Content)
}
for _, msg := range regularMsgs {
total += mc.countTokens(msg.Content)
}
return total
}
// getModelName 获取当前使用的模型名称(优先从completionClient获取最新配置)
func (mc *MemoryCompressor) getModelName() string {
// 如果completionClient是OpenAICompletionClient,从它获取最新的模型名称
if openAIClient, ok := mc.completionClient.(*OpenAICompletionClient); ok {
if openAIClient.config != nil && openAIClient.config.Model != "" {
return openAIClient.config.Model
}
}
// 否则使用保存的summaryModel
return mc.summaryModel
}
func (mc *MemoryCompressor) countTokens(text string) int {
if mc.tokenCounter == nil {
return len(text) / 4
}
modelName := mc.getModelName()
count, err := mc.tokenCounter.Count(modelName, text)
if err != nil {
return len(text) / 4
}
return count
}
// CountTextTokens 对外暴露的文本 Token 计数,用于统计 tools 等非消息内容的 token(如 agent 侧序列化 tools 后计数)。
func (mc *MemoryCompressor) CountTextTokens(text string) int {
return mc.countTokens(text)
}
// totalTokensFor provides token statistics without mutating the message list.
func (mc *MemoryCompressor) totalTokensFor(messages []ChatMessage) (totalTokens int, systemCount int, regularCount int) {
if len(messages) == 0 {
return 0, 0, 0
}
systemMsgs, regularMsgs := mc.splitMessages(messages)
return mc.countTotalTokens(systemMsgs, regularMsgs), len(systemMsgs), len(regularMsgs)
}
func (mc *MemoryCompressor) summarizeChunk(ctx context.Context, chunk []ChatMessage) (ChatMessage, error) {
if len(chunk) == 0 {
return ChatMessage{}, errors.New("chunk is empty")
}
formatted := make([]string, 0, len(chunk))
for _, msg := range chunk {
formatted = append(formatted, fmt.Sprintf("%s: %s", msg.Role, mc.extractMessageText(msg)))
}
conversation := strings.Join(formatted, "\n")
prompt := fmt.Sprintf(summaryPromptTemplate, conversation)
// 使用动态获取的模型名称,而不是保存的summaryModel
modelName := mc.getModelName()
summary, err := mc.completionClient.Complete(ctx, modelName, prompt, mc.timeout)
if err != nil {
return ChatMessage{}, err
}
summary = strings.TrimSpace(summary)
if summary == "" {
return chunk[0], nil
}
return ChatMessage{
Role: "assistant",
Content: fmt.Sprintf("<context_summary message_count='%d'>%s</context_summary>", len(chunk), summary),
}, nil
}
func (mc *MemoryCompressor) extractMessageText(msg ChatMessage) string {
return msg.Content
}
func (mc *MemoryCompressor) adjustRecentStartForToolCalls(msgs []ChatMessage, recentStart int) int {
if recentStart <= 0 || recentStart >= len(msgs) {
return recentStart
}
adjusted := recentStart
for adjusted > 0 && strings.EqualFold(msgs[adjusted].Role, "tool") {
adjusted--
}
if adjusted != recentStart {
mc.logger.Debug("adjusted recent window to keep tool call context",
zap.Int("original_recent_start", recentStart),
zap.Int("adjusted_recent_start", adjusted),
)
}
return adjusted
}
// TokenCounter 用于计算文本Token数量。
type TokenCounter interface {
Count(model, text string) (int, error)
}
// TikTokenCounter 基于 tiktoken 的 Token 统计器。
type TikTokenCounter struct {
mu sync.RWMutex
cache map[string]*tiktoken.Tiktoken
fallbackEncoding *tiktoken.Tiktoken
}
// NewTikTokenCounter 创建新的 TikTokenCounter。
func NewTikTokenCounter() *TikTokenCounter {
return &TikTokenCounter{
cache: make(map[string]*tiktoken.Tiktoken),
}
}
// Count 实现 TokenCounter 接口。
func (tc *TikTokenCounter) Count(model, text string) (int, error) {
enc, err := tc.encodingForModel(model)
if err != nil {
return len(text) / 4, err
}
tokens := enc.Encode(text, nil, nil)
return len(tokens), nil
}
func (tc *TikTokenCounter) encodingForModel(model string) (*tiktoken.Tiktoken, error) {
tc.mu.RLock()
if enc, ok := tc.cache[model]; ok {
tc.mu.RUnlock()
return enc, nil
}
tc.mu.RUnlock()
tc.mu.Lock()
defer tc.mu.Unlock()
if enc, ok := tc.cache[model]; ok {
return enc, nil
}
enc, err := tiktoken.EncodingForModel(model)
if err != nil {
if tc.fallbackEncoding == nil {
tc.fallbackEncoding, err = tiktoken.GetEncoding("cl100k_base")
if err != nil {
return nil, err
}
}
tc.cache[model] = tc.fallbackEncoding
return tc.fallbackEncoding, nil
}
tc.cache[model] = enc
return enc, nil
}
// CompletionClient 对话压缩时使用的补全接口。
type CompletionClient interface {
Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error)
}
// OpenAICompletionClient 基于 OpenAI Chat Completion。
type OpenAICompletionClient struct {
config *config.OpenAIConfig
client *openai.Client
logger *zap.Logger
}
// NewOpenAICompletionClient 创建 OpenAICompletionClient。
func NewOpenAICompletionClient(cfg *config.OpenAIConfig, client *http.Client, logger *zap.Logger) *OpenAICompletionClient {
if logger == nil {
logger = zap.NewNop()
}
return &OpenAICompletionClient{
config: cfg,
client: openai.NewClient(cfg, client, logger),
logger: logger,
}
}
// UpdateConfig 更新底层配置。
func (c *OpenAICompletionClient) UpdateConfig(cfg *config.OpenAIConfig) {
c.config = cfg
if c.client != nil {
c.client.UpdateConfig(cfg)
}
}
// Complete 调用OpenAI获取摘要。
func (c *OpenAICompletionClient) Complete(ctx context.Context, model string, prompt string, timeout time.Duration) (string, error) {
if c.config == nil {
return "", errors.New("openai config is required")
}
if model == "" {
return "", errors.New("model name is required")
}
reqBody := OpenAIRequest{
Model: model,
Messages: []ChatMessage{
{Role: "user", Content: prompt},
},
}
requestCtx := ctx
var cancel context.CancelFunc
if timeout > 0 {
requestCtx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
var completion OpenAIResponse
if c.client == nil {
return "", errors.New("openai completion client not initialized")
}
if err := c.client.ChatCompletion(requestCtx, reqBody, &completion); err != nil {
if apiErr, ok := err.(*openai.APIError); ok {
return "", fmt.Errorf("openai completion failed, status: %d, body: %s", apiErr.StatusCode, apiErr.Body)
}
return "", err
}
if completion.Error != nil {
return "", errors.New(completion.Error.Message)
}
if len(completion.Choices) == 0 || completion.Choices[0].Message.Content == "" {
return "", errors.New("empty completion response")
}
return completion.Choices[0].Message.Content, nil
}
-449
View File
@@ -1,449 +0,0 @@
// Package agents 从 agents/ 目录加载 Markdown 代理定义(子代理 + 可选主代理 orchestrator.md / kind: orchestrator)。
package agents
import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"unicode"
"cyberstrike-ai/internal/config"
"gopkg.in/yaml.v3"
)
// OrchestratorMarkdownFilename 固定文件名:存在则视为 Deep 主代理定义,且不参与子代理列表。
const OrchestratorMarkdownFilename = "orchestrator.md"
// FrontMatter 对应 Markdown 文件头部字段(与文档示例一致)。
type FrontMatter struct {
Name string `yaml:"name"`
ID string `yaml:"id"`
Description string `yaml:"description"`
Tools interface{} `yaml:"tools"` // 字符串 "A, B" 或 []string
MaxIterations int `yaml:"max_iterations"`
BindRole string `yaml:"bind_role,omitempty"`
Kind string `yaml:"kind,omitempty"` // orchestrator = 主代理(亦可仅用文件名 orchestrator.md
}
// OrchestratorMarkdown 从 agents 目录解析出的主代理(Deep 协调者)定义。
type OrchestratorMarkdown struct {
Filename string
EinoName string // 写入 deep.Config.Name / 流式事件过滤
DisplayName string
Description string
Instruction string
}
// MarkdownDirLoad 一次扫描 agents 目录的结果(子代理不含主代理文件)。
type MarkdownDirLoad struct {
SubAgents []config.MultiAgentSubConfig
Orchestrator *OrchestratorMarkdown
FileEntries []FileAgent // 含主代理与所有子代理,供管理 API 列表
}
// IsOrchestratorMarkdown 判断该文件是否表示主代理:固定文件名 orchestrator.md,或 front matter kind: orchestrator。
func IsOrchestratorMarkdown(filename string, fm FrontMatter) bool {
base := filepath.Base(strings.TrimSpace(filename))
if strings.EqualFold(base, OrchestratorMarkdownFilename) {
return true
}
return strings.EqualFold(strings.TrimSpace(fm.Kind), "orchestrator")
}
// WantsMarkdownOrchestrator 保存前判断是否会把该文件作为主代理(用于唯一性校验)。
func WantsMarkdownOrchestrator(filename string, kindField string, raw string) bool {
if strings.EqualFold(strings.TrimSpace(kindField), "orchestrator") {
return true
}
base := filepath.Base(strings.TrimSpace(filename))
if strings.EqualFold(base, OrchestratorMarkdownFilename) {
return true
}
if strings.TrimSpace(raw) == "" {
return false
}
sub, err := ParseMarkdownSubAgent(filename, raw)
if err != nil {
return false
}
return strings.EqualFold(strings.TrimSpace(sub.Kind), "orchestrator")
}
// SplitFrontMatter 分离 YAML front matter 与正文(--- ... ---)。
func SplitFrontMatter(content string) (frontYAML string, body string, err error) {
s := strings.TrimSpace(content)
if !strings.HasPrefix(s, "---") {
return "", s, nil
}
rest := strings.TrimPrefix(s, "---")
rest = strings.TrimLeft(rest, "\r\n")
end := strings.Index(rest, "\n---")
if end < 0 {
return "", "", fmt.Errorf("agents: 缺少结束的 --- 分隔符")
}
fm := strings.TrimSpace(rest[:end])
body = strings.TrimSpace(rest[end+4:])
body = strings.TrimLeft(body, "\r\n")
return fm, body, nil
}
func parseToolsField(v interface{}) []string {
if v == nil {
return nil
}
switch t := v.(type) {
case string:
return splitToolList(t)
case []interface{}:
var out []string
for _, x := range t {
if s, ok := x.(string); ok && strings.TrimSpace(s) != "" {
out = append(out, strings.TrimSpace(s))
}
}
return out
case []string:
var out []string
for _, s := range t {
if strings.TrimSpace(s) != "" {
out = append(out, strings.TrimSpace(s))
}
}
return out
default:
return nil
}
}
func splitToolList(s string) []string {
s = strings.TrimSpace(s)
if s == "" {
return nil
}
parts := strings.FieldsFunc(s, func(r rune) bool {
return r == ',' || r == ';' || r == '|'
})
var out []string
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
}
// SlugID 从 name 生成可用的代理 id(小写、连字符)。
func SlugID(name string) string {
var b strings.Builder
name = strings.TrimSpace(strings.ToLower(name))
lastDash := false
for _, r := range name {
switch {
case unicode.IsLetter(r) && r < unicode.MaxASCII, unicode.IsDigit(r):
b.WriteRune(r)
lastDash = false
case r == ' ' || r == '_' || r == '/' || r == '.':
if !lastDash && b.Len() > 0 {
b.WriteByte('-')
lastDash = true
}
}
}
s := strings.Trim(b.String(), "-")
if s == "" {
return "agent"
}
return s
}
// sanitizeEinoAgentID 规范化 Deep 主代理在 Eino 中的 Name:小写 ASCII、数字、连字符,与默认 cyberstrike-deep 一致。
func sanitizeEinoAgentID(s string) string {
s = strings.TrimSpace(strings.ToLower(s))
var b strings.Builder
for _, r := range s {
switch {
case unicode.IsLetter(r) && r < unicode.MaxASCII, unicode.IsDigit(r):
b.WriteRune(r)
case r == '-':
b.WriteRune(r)
}
}
out := strings.Trim(b.String(), "-")
if out == "" {
return "cyberstrike-deep"
}
return out
}
func parseMarkdownAgentRaw(filename string, content string) (FrontMatter, string, error) {
var fm FrontMatter
fmStr, body, err := SplitFrontMatter(content)
if err != nil {
return fm, "", err
}
if strings.TrimSpace(fmStr) == "" {
return fm, "", fmt.Errorf("agents: %s 无 YAML front matter", filename)
}
if err := yaml.Unmarshal([]byte(fmStr), &fm); err != nil {
return fm, "", fmt.Errorf("agents: 解析 front matter: %w", err)
}
return fm, body, nil
}
func orchestratorFromParsed(filename string, fm FrontMatter, body string) (*OrchestratorMarkdown, error) {
display := strings.TrimSpace(fm.Name)
if display == "" {
display = "Orchestrator"
}
rawID := strings.TrimSpace(fm.ID)
if rawID == "" {
rawID = SlugID(display)
}
eino := sanitizeEinoAgentID(rawID)
return &OrchestratorMarkdown{
Filename: filepath.Base(strings.TrimSpace(filename)),
EinoName: eino,
DisplayName: display,
Description: strings.TrimSpace(fm.Description),
Instruction: strings.TrimSpace(body),
}, nil
}
func orchestratorConfigFromOrchestrator(o *OrchestratorMarkdown) config.MultiAgentSubConfig {
if o == nil {
return config.MultiAgentSubConfig{}
}
return config.MultiAgentSubConfig{
ID: o.EinoName,
Name: o.DisplayName,
Description: o.Description,
Instruction: o.Instruction,
Kind: "orchestrator",
}
}
func subAgentFromFrontMatter(filename string, fm FrontMatter, body string) (config.MultiAgentSubConfig, error) {
var out config.MultiAgentSubConfig
name := strings.TrimSpace(fm.Name)
if name == "" {
return out, fmt.Errorf("agents: %s 缺少 name 字段", filename)
}
id := strings.TrimSpace(fm.ID)
if id == "" {
id = SlugID(name)
}
out.ID = id
out.Name = name
out.Description = strings.TrimSpace(fm.Description)
out.Instruction = strings.TrimSpace(body)
out.RoleTools = parseToolsField(fm.Tools)
out.MaxIterations = fm.MaxIterations
out.BindRole = strings.TrimSpace(fm.BindRole)
out.Kind = strings.TrimSpace(fm.Kind)
return out, nil
}
func collectMarkdownBasenames(dir string) ([]string, error) {
if strings.TrimSpace(dir) == "" {
return nil, nil
}
st, err := os.Stat(dir)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
if !st.IsDir() {
return nil, fmt.Errorf("agents: 不是目录: %s", dir)
}
entries, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
var names []string
for _, e := range entries {
if e.IsDir() {
continue
}
n := e.Name()
if strings.HasPrefix(n, ".") {
continue
}
if !strings.EqualFold(filepath.Ext(n), ".md") {
continue
}
if strings.EqualFold(n, "README.md") {
continue
}
names = append(names, n)
}
sort.Strings(names)
return names, nil
}
// LoadMarkdownAgentsDir 扫描 agents 目录:拆出至多一个主代理与其余子代理。
func LoadMarkdownAgentsDir(dir string) (*MarkdownDirLoad, error) {
out := &MarkdownDirLoad{}
names, err := collectMarkdownBasenames(dir)
if err != nil {
return nil, err
}
for _, n := range names {
p := filepath.Join(dir, n)
b, err := os.ReadFile(p)
if err != nil {
return nil, err
}
fm, body, err := parseMarkdownAgentRaw(n, string(b))
if err != nil {
return nil, fmt.Errorf("%s: %w", n, err)
}
if IsOrchestratorMarkdown(n, fm) {
if out.Orchestrator != nil {
return nil, fmt.Errorf("agents: 仅能定义一个主代理(Deep 协调者),已有 %s,又与 %s 冲突", out.Orchestrator.Filename, n)
}
orch, err := orchestratorFromParsed(n, fm, body)
if err != nil {
return nil, fmt.Errorf("%s: %w", n, err)
}
out.Orchestrator = orch
out.FileEntries = append(out.FileEntries, FileAgent{
Filename: n,
Config: orchestratorConfigFromOrchestrator(orch),
IsOrchestrator: true,
})
continue
}
sub, err := subAgentFromFrontMatter(n, fm, body)
if err != nil {
return nil, fmt.Errorf("%s: %w", n, err)
}
out.SubAgents = append(out.SubAgents, sub)
out.FileEntries = append(out.FileEntries, FileAgent{Filename: n, Config: sub, IsOrchestrator: false})
}
return out, nil
}
// ParseMarkdownSubAgent 将单个 Markdown 文件解析为 MultiAgentSubConfig。
func ParseMarkdownSubAgent(filename string, content string) (config.MultiAgentSubConfig, error) {
fm, body, err := parseMarkdownAgentRaw(filename, content)
if err != nil {
return config.MultiAgentSubConfig{}, err
}
if IsOrchestratorMarkdown(filename, fm) {
orch, err := orchestratorFromParsed(filename, fm, body)
if err != nil {
return config.MultiAgentSubConfig{}, err
}
return orchestratorConfigFromOrchestrator(orch), nil
}
return subAgentFromFrontMatter(filename, fm, body)
}
// LoadMarkdownSubAgents 读取目录下所有子代理 .md(不含主代理 orchestrator.md / kind: orchestrator)。
func LoadMarkdownSubAgents(dir string) ([]config.MultiAgentSubConfig, error) {
load, err := LoadMarkdownAgentsDir(dir)
if err != nil {
return nil, err
}
return load.SubAgents, nil
}
// FileAgent 单个 Markdown 文件及其解析结果。
type FileAgent struct {
Filename string
Config config.MultiAgentSubConfig
IsOrchestrator bool
}
// LoadMarkdownAgentFiles 列出目录下全部 .md(含主代理),供管理 API 使用。
func LoadMarkdownAgentFiles(dir string) ([]FileAgent, error) {
load, err := LoadMarkdownAgentsDir(dir)
if err != nil {
return nil, err
}
return load.FileEntries, nil
}
// MergeYAMLAndMarkdown 合并 config.yaml 中的 sub_agents 与 Markdown 定义:同 id 时 Markdown 覆盖 YAML;仅存在于 Markdown 的条目追加在 YAML 顺序之后。
func MergeYAMLAndMarkdown(yamlSubs []config.MultiAgentSubConfig, mdSubs []config.MultiAgentSubConfig) []config.MultiAgentSubConfig {
mdByID := make(map[string]config.MultiAgentSubConfig)
for _, m := range mdSubs {
id := strings.TrimSpace(m.ID)
if id == "" {
continue
}
mdByID[id] = m
}
yamlIDSet := make(map[string]bool)
for _, y := range yamlSubs {
yamlIDSet[strings.TrimSpace(y.ID)] = true
}
out := make([]config.MultiAgentSubConfig, 0, len(yamlSubs)+len(mdSubs))
for _, y := range yamlSubs {
id := strings.TrimSpace(y.ID)
if id == "" {
continue
}
if m, ok := mdByID[id]; ok {
out = append(out, m)
} else {
out = append(out, y)
}
}
for _, m := range mdSubs {
id := strings.TrimSpace(m.ID)
if id == "" || yamlIDSet[id] {
continue
}
out = append(out, m)
}
return out
}
// EffectiveSubAgents 供多代理运行时使用。
func EffectiveSubAgents(yamlSubs []config.MultiAgentSubConfig, agentsDir string) ([]config.MultiAgentSubConfig, error) {
md, err := LoadMarkdownSubAgents(agentsDir)
if err != nil {
return nil, err
}
if len(md) == 0 {
return yamlSubs, nil
}
return MergeYAMLAndMarkdown(yamlSubs, md), nil
}
// BuildMarkdownFile 根据配置序列化为可写回磁盘的 Markdown。
func BuildMarkdownFile(sub config.MultiAgentSubConfig) ([]byte, error) {
fm := FrontMatter{
Name: sub.Name,
ID: sub.ID,
Description: sub.Description,
MaxIterations: sub.MaxIterations,
BindRole: sub.BindRole,
}
if k := strings.TrimSpace(sub.Kind); k != "" {
fm.Kind = k
}
if len(sub.RoleTools) > 0 {
fm.Tools = sub.RoleTools
}
head, err := yaml.Marshal(fm)
if err != nil {
return nil, err
}
var b strings.Builder
b.WriteString("---\n")
b.Write(head)
b.WriteString("---\n\n")
b.WriteString(strings.TrimSpace(sub.Instruction))
if !strings.HasSuffix(sub.Instruction, "\n") && sub.Instruction != "" {
b.WriteString("\n")
}
return []byte(b.String()), nil
}
@@ -1,66 +0,0 @@
package agents
import (
"os"
"path/filepath"
"testing"
)
func TestLoadMarkdownAgentsDir_OrchestratorExcludedFromSubs(t *testing.T) {
dir := t.TempDir()
orch := filepath.Join(dir, OrchestratorMarkdownFilename)
if err := os.WriteFile(orch, []byte(`---
id: cyberstrike-deep
name: Main
description: Test desc
---
Hello orchestrator
`), 0644); err != nil {
t.Fatal(err)
}
subPath := filepath.Join(dir, "worker.md")
if err := os.WriteFile(subPath, []byte(`---
id: worker
name: Worker
description: W
---
Do work
`), 0644); err != nil {
t.Fatal(err)
}
load, err := LoadMarkdownAgentsDir(dir)
if err != nil {
t.Fatal(err)
}
if load.Orchestrator == nil || load.Orchestrator.EinoName != "cyberstrike-deep" {
t.Fatalf("orchestrator: %+v", load.Orchestrator)
}
if len(load.SubAgents) != 1 || load.SubAgents[0].ID != "worker" {
t.Fatalf("subs: %+v", load.SubAgents)
}
if len(load.FileEntries) != 2 {
t.Fatalf("file entries: %d", len(load.FileEntries))
}
var orchFile *FileAgent
for i := range load.FileEntries {
if load.FileEntries[i].IsOrchestrator {
orchFile = &load.FileEntries[i]
break
}
}
if orchFile == nil || orchFile.Filename != OrchestratorMarkdownFilename {
t.Fatal("missing orchestrator file entry")
}
}
func TestLoadMarkdownAgentsDir_DuplicateOrchestrator(t *testing.T) {
dir := t.TempDir()
_ = os.WriteFile(filepath.Join(dir, OrchestratorMarkdownFilename), []byte("---\nname: A\n---\n\nx\n"), 0644)
_ = os.WriteFile(filepath.Join(dir, "b.md"), []byte("---\nname: B\nkind: orchestrator\n---\n\ny\n"), 0644)
_, err := LoadMarkdownAgentsDir(dir)
if err == nil {
t.Fatal("expected duplicate orchestrator error")
}
}
-1834
View File
File diff suppressed because it is too large Load Diff
-40
View File
@@ -1,40 +0,0 @@
package app
import (
"time"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/skills"
)
// skillStatsDBAdapter 将database.DB适配为skills.SkillStatsStorage接口
type skillStatsDBAdapter struct {
db *database.DB
}
// UpdateSkillStats 更新Skills统计信息
func (a *skillStatsDBAdapter) UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
return a.db.UpdateSkillStats(skillName, totalCalls, successCalls, failedCalls, lastCallTime)
}
// LoadSkillStats 加载所有Skills统计信息
func (a *skillStatsDBAdapter) LoadSkillStats() (map[string]*skills.SkillStats, error) {
dbStats, err := a.db.LoadSkillStats()
if err != nil {
return nil, err
}
// 转换为skills.SkillStats格式
result := make(map[string]*skills.SkillStats)
for name, stat := range dbStats {
result[name] = &skills.SkillStats{
SkillName: stat.SkillName,
TotalCalls: stat.TotalCalls,
SuccessCalls: stat.SuccessCalls,
FailedCalls: stat.FailedCalls,
LastCallTime: stat.LastCallTime,
}
}
return result, nil
}
-933
View File
@@ -1,933 +0,0 @@
package attackchain
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/openai"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Builder 攻击链构建器
type Builder struct {
db *database.DB
logger *zap.Logger
openAIClient *openai.Client
openAIConfig *config.OpenAIConfig
tokenCounter agent.TokenCounter
maxTokens int // 最大tokens限制,默认100000
}
// Node 攻击链节点(使用database包的类型)
type Node = database.AttackChainNode
// Edge 攻击链边(使用database包的类型)
type Edge = database.AttackChainEdge
// Chain 完整的攻击链
type Chain struct {
Nodes []Node `json:"nodes"`
Edges []Edge `json:"edges"`
}
// NewBuilder 创建新的攻击链构建器
func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *Builder {
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
}
httpClient := &http.Client{Timeout: 5 * time.Minute, Transport: transport}
// 优先使用配置文件中的统一 Token 上限(config.yaml -> openai.max_total_tokens
maxTokens := 0
if openAIConfig != nil && openAIConfig.MaxTotalTokens > 0 {
maxTokens = openAIConfig.MaxTotalTokens
} else if openAIConfig != nil {
// 如果未显式配置 max_total_tokens,则根据模型设置一个合理的默认值
model := strings.ToLower(openAIConfig.Model)
if strings.Contains(model, "gpt-4") {
maxTokens = 128000 // gpt-4通常支持128k
} else if strings.Contains(model, "gpt-3.5") {
maxTokens = 16000 // gpt-3.5-turbo通常支持16k
} else if strings.Contains(model, "deepseek") {
maxTokens = 131072 // deepseek-chat通常支持131k
} else {
maxTokens = 100000 // 兜底默认值
}
} else {
// 没有 OpenAI 配置时使用兜底值,避免为 0
maxTokens = 100000
}
return &Builder{
db: db,
logger: logger,
openAIClient: openai.NewClient(openAIConfig, httpClient, logger),
openAIConfig: openAIConfig,
tokenCounter: agent.NewTikTokenCounter(),
maxTokens: maxTokens,
}
}
// BuildChainFromConversation 从对话构建攻击链(简化版本:用户输入+最后一轮ReAct输入+大模型输出)
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
// 0. 首先检查是否有实际的工具执行记录
messages, err := b.db.GetMessages(conversationID)
if err != nil {
return nil, fmt.Errorf("获取对话消息失败: %w", err)
}
if len(messages) == 0 {
b.logger.Info("对话中没有数据", zap.String("conversationId", conversationID))
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
}
// 检查是否有实际的工具执行:assistant 的 mcp_execution_ids,或过程详情中的 tool_call/tool_result
//(多代理下若 MCP 未返回 execution_idIDs 可能为空,但工具已通过 Eino 执行并写入 process_details
hasToolExecutions := false
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "assistant") {
if len(messages[i].MCPExecutionIDs) > 0 {
hasToolExecutions = true
break
}
}
}
if !hasToolExecutions {
if pdOK, err := b.db.ConversationHasToolProcessDetails(conversationID); err != nil {
b.logger.Warn("查询过程详情判定工具执行失败", zap.Error(err))
} else if pdOK {
hasToolExecutions = true
}
}
// 检查任务是否被取消(通过检查最后一条assistant消息内容或process_details
taskCancelled := false
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "assistant") {
content := strings.ToLower(messages[i].Content)
if strings.Contains(content, "取消") || strings.Contains(content, "cancelled") {
taskCancelled = true
}
break
}
}
// 如果任务被取消且没有实际工具执行,返回空攻击链
if taskCancelled && !hasToolExecutions {
b.logger.Info("任务已取消且没有实际工具执行,返回空攻击链",
zap.String("conversationId", conversationID),
zap.Bool("taskCancelled", taskCancelled),
zap.Bool("hasToolExecutions", hasToolExecutions))
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
}
// 如果没有实际工具执行,也返回空攻击链(避免AI编造)
if !hasToolExecutions {
b.logger.Info("没有实际工具执行记录,返回空攻击链",
zap.String("conversationId", conversationID))
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
}
// 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出
reactInputJSON, modelOutput, err := b.db.GetReActData(conversationID)
if err != nil {
b.logger.Warn("获取保存的ReAct数据失败,将使用消息历史构建", zap.Error(err))
// 继续使用原来的逻辑
reactInputJSON = ""
modelOutput = ""
}
// var userInput string
var reactInputFinal string
var dataSource string // 记录数据来源
// 如果成功获取到保存的ReAct数据,直接使用
if reactInputJSON != "" && modelOutput != "" {
// 计算 ReAct 输入的哈希值,用于追踪
hash := sha256.Sum256([]byte(reactInputJSON))
reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识
// 统计消息数量
var messageCount int
var tempMessages []interface{}
if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil {
messageCount = len(tempMessages)
}
dataSource = "database_last_react_input"
b.logger.Info("使用保存的ReAct数据构建攻击链",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
zap.Int("reactInputSize", len(reactInputJSON)),
zap.Int("messageCount", messageCount),
zap.String("reactInputHash", reactInputHash),
zap.Int("modelOutputSize", len(modelOutput)))
// 从保存的ReAct输入(JSON格式)中提取用户输入
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
// 将JSON格式的messages转换为可读格式
reactInputFinal = b.formatReActInputFromJSON(reactInputJSON)
} else {
// 2. 如果没有保存的ReAct数据,从对话消息构建
dataSource = "messages_table"
b.logger.Info("从消息历史构建ReAct数据",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
zap.Int("messageCount", len(messages)))
// 提取用户输入(最后一条user消息)
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "user") {
// userInput = messages[i].Content
break
}
}
// 提取最后一轮ReAct的输入(历史消息+当前用户输入)
reactInputFinal = b.buildReActInput(messages)
// 提取大模型最后的输出(最后一条assistant消息)
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "assistant") {
modelOutput = messages[i].Content
break
}
}
}
// 多代理:保存的 last_react_input 可能仅为首轮用户消息,不含工具轨迹;补充最后一轮助手的过程详情(与单代理「最后一轮 ReAct」对齐)
hasMCPOnAssistant := false
var lastAssistantID string
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "assistant") {
lastAssistantID = messages[i].ID
if len(messages[i].MCPExecutionIDs) > 0 {
hasMCPOnAssistant = true
}
break
}
}
if lastAssistantID != "" {
pdHasTools, _ := b.db.ConversationHasToolProcessDetails(conversationID)
if pdHasTools && !(hasMCPOnAssistant && reactInputContainsToolTrace(reactInputJSON)) {
detailsMap, err := b.db.GetProcessDetailsByConversation(conversationID)
if err != nil {
b.logger.Warn("加载过程详情用于攻击链失败", zap.Error(err))
} else if dets := detailsMap[lastAssistantID]; len(dets) > 0 {
extra := b.formatProcessDetailsForAttackChain(dets)
if strings.TrimSpace(extra) != "" {
reactInputFinal = reactInputFinal + "\n\n## 执行过程与工具记录(含多代理编排与子任务)\n\n" + extra
b.logger.Info("攻击链输入已补充过程详情",
zap.String("conversationId", conversationID),
zap.String("messageId", lastAssistantID),
zap.Int("detailEvents", len(dets)))
}
}
}
}
// 3. 构建简化的prompt,一次性传递给大模型
prompt := b.buildSimplePrompt(reactInputFinal, modelOutput)
// fmt.Println(prompt)
// 6. 调用AI生成攻击链(一次性,不做任何处理)
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
if err != nil {
return nil, fmt.Errorf("AI生成失败: %w", err)
}
// 7. 解析JSON并生成节点/边ID(前端需要有效的ID)
chainData, err := b.parseChainJSON(chainJSON)
if err != nil {
// 如果解析失败,返回空链,让前端处理错误
b.logger.Warn("解析攻击链JSON失败", zap.Error(err), zap.String("raw_json", chainJSON))
return &Chain{
Nodes: []Node{},
Edges: []Edge{},
}, nil
}
b.logger.Info("攻击链构建完成",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
zap.Int("nodes", len(chainData.Nodes)),
zap.Int("edges", len(chainData.Edges)))
// 保存到数据库(供后续加载使用)
if err := b.saveChain(conversationID, chainData.Nodes, chainData.Edges); err != nil {
b.logger.Warn("保存攻击链到数据库失败", zap.Error(err))
// 即使保存失败,也返回数据给前端
}
// 直接返回,不做任何处理和校验
return chainData, nil
}
// reactInputContainsToolTrace 判断保存的 ReAct JSON 是否包含可解析的工具调用轨迹(单代理完整保存时为 true)。
func reactInputContainsToolTrace(reactInputJSON string) bool {
s := strings.TrimSpace(reactInputJSON)
if s == "" {
return false
}
return strings.Contains(s, "tool_calls") ||
strings.Contains(s, "tool_call_id") ||
strings.Contains(s, `"role":"tool"`) ||
strings.Contains(s, `"role": "tool"`)
}
// formatProcessDetailsForAttackChain 将最后一轮助手的过程详情格式化为攻击链分析的输入(覆盖多代理下 last_react_input 不完整的情况)。
func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessDetail) string {
if len(details) == 0 {
return ""
}
var sb strings.Builder
for _, d := range details {
// 目标:以主 agent(编排器)视角输出整轮迭代
// - 保留:编排器工具调用/结果、对子代理的 task 调度、子代理最终回复(不含推理)
// - 丢弃:thinking/planning/progress 等噪声、子代理的工具细节与推理过程
if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "planning" {
continue
}
// 解析 dataJSON string),用于识别 einoRole / toolName 等
var dataMap map[string]interface{}
if strings.TrimSpace(d.Data) != "" {
_ = json.Unmarshal([]byte(d.Data), &dataMap)
}
einoRole := ""
if v, ok := dataMap["einoRole"]; ok {
einoRole = strings.ToLower(strings.TrimSpace(fmt.Sprint(v)))
}
toolName := ""
if v, ok := dataMap["toolName"]; ok {
toolName = strings.TrimSpace(fmt.Sprint(v))
}
// 1) 编排器的工具调用/结果:保留(这是“主 agent 调了什么工具”)
if (d.EventType == "tool_call" || d.EventType == "tool_result" || d.EventType == "tool_calls_detected" || d.EventType == "iteration" || d.EventType == "eino_recovery") && einoRole == "orchestrator" {
sb.WriteString("[")
sb.WriteString(d.EventType)
sb.WriteString("] ")
sb.WriteString(strings.TrimSpace(d.Message))
sb.WriteString("\n")
if strings.TrimSpace(d.Data) != "" {
sb.WriteString(d.Data)
sb.WriteString("\n")
}
sb.WriteString("\n")
continue
}
// 2) 子代理调度:tool_call(toolName=="task") 代表编排器把子任务派发出去;保留(只需任务,不要子代理推理)
if d.EventType == "tool_call" && strings.EqualFold(toolName, "task") {
sb.WriteString("[dispatch_subagent_task] ")
sb.WriteString(strings.TrimSpace(d.Message))
sb.WriteString("\n")
if strings.TrimSpace(d.Data) != "" {
sb.WriteString(d.Data)
sb.WriteString("\n")
}
sb.WriteString("\n")
continue
}
// 3) 子代理最终回复:保留(只保留最终输出,不保留分析过程)
if d.EventType == "eino_agent_reply" && einoRole == "sub" {
sb.WriteString("[subagent_final_reply] ")
sb.WriteString(strings.TrimSpace(d.Message))
sb.WriteString("\n")
// data 里含 einoAgent 等元信息,保留有助于追踪“哪个子代理说的”
if strings.TrimSpace(d.Data) != "" {
sb.WriteString(d.Data)
sb.WriteString("\n")
}
sb.WriteString("\n")
continue
}
// 其他事件默认丢弃,避免把子代理工具细节/推理塞进 prompt,偏离“主 agent 一轮迭代”的视角。
}
return strings.TrimSpace(sb.String())
}
// buildReActInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
func (b *Builder) buildReActInput(messages []database.Message) string {
var builder strings.Builder
for _, msg := range messages {
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
}
return builder.String()
}
// extractUserInputFromReActInput 从保存的ReAct输入(JSON格式的messages数组)中提取最后一条用户输入
// func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string {
// // reactInputJSON是JSON格式的ChatMessage数组,需要解析
// var messages []map[string]interface{}
// if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
// b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
// return ""
// }
// // 从后往前查找最后一条user消息
// for i := len(messages) - 1; i >= 0; i-- {
// if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") {
// if content, ok := messages[i]["content"].(string); ok {
// return content
// }
// }
// }
// return ""
// }
// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
var messages []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
return reactInputJSON // 如果解析失败,返回原始JSON
}
var builder strings.Builder
for _, msg := range messages {
role, _ := msg["role"].(string)
content, _ := msg["content"].(string)
// 处理assistant消息:提取tool_calls信息
if role == "assistant" {
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 {
// 如果有文本内容,先显示
if content != "" {
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
}
// 详细显示每个工具调用
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls)))
for i, toolCall := range toolCalls {
if tc, ok := toolCall.(map[string]interface{}); ok {
toolCallID, _ := tc["id"].(string)
if funcData, ok := tc["function"].(map[string]interface{}); ok {
toolName, _ := funcData["name"].(string)
arguments, _ := funcData["arguments"].(string)
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID))
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName))
builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments))
}
}
}
builder.WriteString("\n")
continue
}
}
// 处理tool消息:显示tool_call_id和完整内容
if role == "tool" {
toolCallID, _ := msg["tool_call_id"].(string)
if toolCallID != "" {
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content))
} else {
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
}
continue
}
// 其他消息类型(system, user等)正常显示
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
}
return builder.String()
}
// buildSimplePrompt 构建简化的prompt
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据对话记录和工具执行结果,构建一个逻辑清晰、有教育意义的攻击链图,完整展现渗透测试的思维过程和执行路径。
## 核心目标
构建一个能够讲述完整攻击故事的攻击链让学习者能够:
1. 理解渗透测试的完整流程和思维逻辑(从目标识别到漏洞发现的每一步)
2. 学习如何从失败中获取线索并调整策略
3. 掌握工具使用的实际效果和局限性
4. 理解漏洞发现和利用的因果关系
**关键原则**:完整性优先。必须包含所有有意义的工具执行和关键步骤,不要为了控制节点数量而遗漏重要信息。
## 构建流程(按此顺序思考)
### 第一步:理解上下文
仔细分析ReAct输入中的工具调用序列和大模型输出,识别:
- 测试目标(IP、域名、URL等)
- 实际执行的工具和参数
- 工具返回的关键信息(成功结果、错误信息、超时等)
- AI的分析和决策过程
### 第二步:提取关键节点
从工具执行记录中提取有意义的节点,**确保不遗漏任何关键步骤**:
- **target节点**:每个独立的测试目标创建一个target节点
- **action节点**:每个有意义的工具执行创建一个action节点(包括提供线索的失败、成功的信息收集、漏洞验证等)
- **vulnerability节点**:每个真实确认的漏洞创建一个vulnerability节点
- **完整性检查**:对照ReAct输入中的工具调用序列,确保每个有意义的工具执行都被包含在攻击链中
### 第三步:构建逻辑关系(树状结构)
**重要:必须构建树状结构,而不是简单的线性链。**
按照因果关系连接节点,形成树状图(因为是单agent执行,所以可以不按照时间顺序):
- **分支结构**:一个节点可以有多个后续节点(例如:端口扫描发现多个端口后,可以同时进行多个不同的测试)
- **汇聚结构**:多个节点可以指向同一个节点(例如:多个不同的测试都发现了同一个漏洞)
- 识别哪些action是基于前面action的结果而执行的
- 识别哪些vulnerability是由哪些action发现的
- 识别失败节点如何为后续成功提供线索
- **避免线性链**:不要将所有节点连成一条线,应该根据实际的并行测试和分支探索构建树状结构
### 第四步:优化和精简
- **完整性检查**:确保所有有意义的工具执行都被包含,不要遗漏关键步骤
- **合并规则**:只合并真正相似或重复的action节点(如多次相同工具的相似调用)
- **删除规则**:只删除完全无价值的失败节点(完全无输出、纯系统错误、重复的相同失败)
- **重要提醒**:宁可保留更多节点,也不要遗漏关键步骤。攻击链必须完整展现渗透测试过程
- 确保攻击链逻辑连贯,能够讲述完整故事
## 节点类型详解
### target(目标节点)
- **用途**:标识测试目标
- **创建规则**:每个独立目标(不同IP/域名)创建一个target节点
- **多目标处理**:不同目标的节点不相互连接,各自形成独立的子图
- **metadata.target**:精确记录目标标识(IP地址、域名、URL等)
### action(行动节点)
- **用途**:记录工具执行和AI分析结果
- **标签规则**
* 15-25个汉字,动宾结构
* 成功节点:描述执行结果(如"扫描端口发现80/443/8080"、"目录扫描发现/admin路径"
* 失败节点:描述失败原因(如"尝试SQL注入(被WAF拦截)"、"端口扫描超时(目标不可达)")
- **ai_analysis要求**
* 成功节点:总结工具执行的关键发现,说明这些发现的意义
* 失败节点:必须说明失败原因、获得的线索、这些线索如何指引后续行动
* 不超过150字,要具体、有信息量
- **findings要求**
* 提取工具返回结果中的关键信息点
* 每个finding应该是独立的、有价值的信息片段
* 成功节点:列出关键发现(如["80端口开放", "443端口开放", "HTTP服务为Apache 2.4"]
* 失败节点:列出失败线索(如["WAF拦截", "返回403", "检测到Cloudflare"]
- **status标记**
* 成功节点:不设置或设为"success"
* 提供线索的失败节点:必须设为"failed_insight"
- **risk_score**:始终为0action节点不评估风险)
### vulnerability(漏洞节点)
- **用途**:记录真实确认的安全漏洞
- **创建规则**
* 必须是真实确认的漏洞,不是所有发现都是漏洞
* 需要明确的漏洞证据(如SQL注入返回数据库错误、XSS成功执行等)
- **risk_score规则**
* critical90-100):可导致系统完全沦陷(RCE、SQL注入导致数据泄露等)
* high(80-89):可导致敏感信息泄露或权限提升
* medium(60-79):存在安全风险但影响有限
* low40-59):轻微安全问题
- **metadata要求**
* vulnerability_type:漏洞类型(SQL注入、XSS、RCE等)
* description:详细描述漏洞位置、原理、影响
* severitycritical/high/medium/low
* location:精确的漏洞位置(URL、参数、文件路径等)
## 节点过滤和合并规则
### 必须保留的失败节点
以下失败情况必须创建节点,因为它们提供了有价值的线索:
- 工具返回明确的错误信息(权限错误、连接拒绝、认证失败等)
- 超时或连接失败(可能表明防火墙、网络隔离等)
- WAF/防火墙拦截(返回403、406等,表明存在防护机制)
- 工具未安装或配置错误(但执行了调用)
- 目标不可达(DNS解析失败、网络不通等)
### 应该删除的失败节点
以下情况不应创建节点:
- 完全无输出的工具调用
- 纯系统错误(与目标无关,如本地环境问题)
- 重复的相同失败(多次相同错误只保留第一次)
### 节点合并规则
以下情况应合并节点:
- 同一工具的多次相似调用(如多次nmap扫描不同端口范围,合并为一个"端口扫描"节点)
- 同一目标的多个相似探测(如多个目录扫描工具,合并为一个"目录扫描"节点)
### 节点数量控制
- **完整性优先**:必须包含所有有意义的工具执行和关键步骤,不要为了控制数量而删除重要节点
- **建议范围**:单目标通常8-15个节点,但如果实际执行步骤较多,可以适当增加(最多20个节点)
- **优先保留**:关键成功步骤、提供线索的失败、发现的漏洞、重要的信息收集步骤
- **可以合并**:同一工具的多次相似调用(如多次nmap扫描不同端口范围,合并为一个"端口扫描"节点)
- **可以删除**:完全无输出的工具调用、纯系统错误、重复的相同失败(多次相同错误只保留第一次)
- **重要原则**:宁可节点稍多,也不要遗漏关键步骤。攻击链必须能够完整展现渗透测试的完整过程
## 边的类型和权重
### 边的类型
- **leads_to**:表示"导致"或"引导到",用于action→action、target→action
* 例如:端口扫描 → 目录扫描(因为发现了80端口,所以进行目录扫描)
- **discovers**:表示"发现"**专门用于action→vulnerability**
* 例如:SQL注入测试 → SQL注入漏洞
* **重要**:所有action→vulnerability的边都必须使用discovers类型,即使多个action都指向同一个vulnerability,也应该统一使用discovers
- **enables**:表示"使能"或"促成"**仅用于vulnerability→vulnerability、action→action(当后续行动依赖前面结果时)**
* 例如:信息泄露漏洞 → 权限提升漏洞(通过信息泄露获得的信息促成了权限提升)
* **重要**enables不能用于action→vulnerabilityaction→vulnerability必须使用discovers
### 边的权重
- **权重1-2**:弱关联(如初步探测到进一步探测)
- **权重3-4**:中等关联(如发现端口到服务识别)
- **权重5-7**:强关联(如发现漏洞、关键信息泄露)
- **权重8-10**:极强关联(如漏洞利用成功、权限提升)
### DAG结构要求(有向无环图)
**关键:必须确保生成的是真正的DAG(有向无环图),不能有任何循环。**
- **节点编号规则**:节点id从"node_1"开始递增(node_1, node_2, node_3...
- **边的方向规则**:所有边的source节点id必须严格小于target节点idsource < target),这是确保无环的关键
* 例如:node_1 → node_2 ✓(正确)
* 例如:node_2 → node_1 ✗(错误,会形成环)
* 例如:node_3 → node_5 ✓(正确)
- **无环验证**:在输出JSON前,必须检查所有边,确保没有任何一条边的source >= target
- **无孤立节点**:确保每个节点至少有一条边连接(除了可能的根节点)
- **DAG结构特点**
* 一个节点可以有多个后续节点(分支),例如:node_2(端口扫描)可以同时连接到node_3、node_4、node_5等多个节点
* 多个节点可以汇聚到一个节点(汇聚),例如:node_3、node_4、node_5都指向node_6(漏洞节点)
* 避免将所有节点连成一条线,应该根据实际的并行测试和分支探索构建DAG结构
- **拓扑排序验证**:如果按照节点id从小到大排序,所有边都应该从左指向右(从上指向下),这样就能保证无环
## 攻击链逻辑连贯性要求
构建的攻击链应该能够回答以下问题:
1. **起点**:测试从哪里开始?(target节点)
2. **探索过程**:如何逐步收集信息?(action节点序列)
3. **失败与调整**:遇到障碍时如何调整策略?(failed_insight节点)
4. **关键发现**:发现了哪些重要信息?(action的findings
5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability
6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径)
## 最后一轮ReAct输入
%s
## 大模型输出
%s
## 输出格式
严格按照以下JSON格式输出,不要添加任何其他文字:
**重要:示例展示的是树状结构,注意node_2(端口扫描)同时连接到多个后续节点(node_3、node_4),形成分支结构。**
{
"nodes": [
{
"id": "node_1",
"type": "target",
"label": "测试目标: example.com",
"risk_score": 40,
"metadata": {
"target": "example.com"
}
},
{
"id": "node_2",
"type": "action",
"label": "扫描端口发现80/443/8080",
"risk_score": 0,
"metadata": {
"tool_name": "nmap",
"tool_intent": "端口扫描",
"ai_analysis": "使用nmap对目标进行端口扫描,发现80、443、8080端口开放。80端口运行HTTP服务,443端口运行HTTPS服务,8080端口可能为管理后台。这些开放端口为后续Web应用测试提供了入口。",
"findings": ["80端口开放", "443端口开放", "8080端口开放", "HTTP服务为Apache 2.4"]
}
},
{
"id": "node_3",
"type": "action",
"label": "目录扫描发现/admin后台",
"risk_score": 0,
"metadata": {
"tool_name": "dirsearch",
"tool_intent": "目录扫描",
"ai_analysis": "使用dirsearch对目标进行目录扫描,发现/admin目录存在且可访问。该目录可能为管理后台,是重要的测试目标。",
"findings": ["/admin目录存在", "返回200状态码", "疑似管理后台"]
}
},
{
"id": "node_4",
"type": "action",
"label": "识别Web服务为Apache 2.4",
"risk_score": 0,
"metadata": {
"tool_name": "whatweb",
"tool_intent": "Web服务识别",
"ai_analysis": "识别出目标运行Apache 2.4服务器,这为后续的漏洞测试提供了重要信息。",
"findings": ["Apache 2.4", "PHP版本信息"]
}
},
{
"id": "node_5",
"type": "action",
"label": "尝试SQL注入(被WAF拦截)",
"risk_score": 0,
"metadata": {
"tool_name": "sqlmap",
"tool_intent": "SQL注入检测",
"ai_analysis": "对/login.php进行SQL注入测试时被WAF拦截,返回403错误。错误信息显示检测到Cloudflare防护。这表明目标部署了WAF,需要调整测试策略。",
"findings": ["WAF拦截", "返回403", "检测到Cloudflare", "目标部署WAF"],
"status": "failed_insight"
}
},
{
"id": "node_6",
"type": "vulnerability",
"label": "SQL注入漏洞",
"risk_score": 85,
"metadata": {
"vulnerability_type": "SQL注入",
"description": "在/admin/login.php的username参数发现SQL注入漏洞,可通过注入payload绕过登录验证,直接获取管理员权限。漏洞返回数据库错误信息,确认存在注入点。",
"severity": "high",
"location": "/admin/login.php?username="
}
}
],
"edges": [
{
"source": "node_1",
"target": "node_2",
"type": "leads_to",
"weight": 3
},
{
"source": "node_2",
"target": "node_3",
"type": "leads_to",
"weight": 4
},
{
"source": "node_2",
"target": "node_4",
"type": "leads_to",
"weight": 3
},
{
"source": "node_3",
"target": "node_5",
"type": "leads_to",
"weight": 4
},
{
"source": "node_5",
"target": "node_6",
"type": "discovers",
"weight": 7
}
]
}
## 重要提醒
1. **严禁杜撰**:只使用ReAct输入中实际执行的工具和实际返回的结果。如无实际数据,返回空的nodes和edges数组。
2. **DAG结构必须**:必须构建真正的DAG(有向无环图),不能有任何循环。所有边的source节点id必须严格小于target节点idsource < target)。
3. **拓扑顺序**:节点应该按照逻辑顺序编号,target节点通常是node_1,后续的action节点按执行顺序递增,vulnerability节点在最后。
4. **完整性优先**:必须包含所有有意义的工具执行和关键步骤,不要为了控制节点数量而删除重要节点。攻击链必须能够完整展现从目标识别到漏洞发现的完整过程。
5. **逻辑连贯**:确保攻击链能够讲述一个完整、连贯的渗透测试故事,包括所有关键步骤和决策点。
6. **教育价值**:优先保留有教育意义的节点,帮助学习者理解渗透测试思维和完整流程。
7. **准确性**:所有节点信息必须基于实际数据,不要推测或假设。
8. **完整性检查**:确保每个节点都有必要的metadata字段,每条边都有正确的source和target,没有孤立节点,没有循环。
9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。
现在开始分析并构建攻击链:`, reactInput, modelOutput)
}
// saveChain 保存攻击链到数据库
func (b *Builder) saveChain(conversationID string, nodes []Node, edges []Edge) error {
// 先删除旧的攻击链数据
if err := b.db.DeleteAttackChain(conversationID); err != nil {
b.logger.Warn("删除旧攻击链失败", zap.Error(err))
}
for _, node := range nodes {
metadataJSON, _ := json.Marshal(node.Metadata)
if err := b.db.SaveAttackChainNode(conversationID, node.ID, node.Type, node.Label, "", string(metadataJSON), node.RiskScore); err != nil {
b.logger.Warn("保存攻击链节点失败", zap.String("nodeId", node.ID), zap.Error(err))
}
}
// 保存边
for _, edge := range edges {
if err := b.db.SaveAttackChainEdge(conversationID, edge.ID, edge.Source, edge.Target, edge.Type, edge.Weight); err != nil {
b.logger.Warn("保存攻击链边失败", zap.String("edgeId", edge.ID), zap.Error(err))
}
}
return nil
}
// LoadChainFromDatabase 从数据库加载攻击链
func (b *Builder) LoadChainFromDatabase(conversationID string) (*Chain, error) {
nodes, err := b.db.LoadAttackChainNodes(conversationID)
if err != nil {
return nil, fmt.Errorf("加载攻击链节点失败: %w", err)
}
edges, err := b.db.LoadAttackChainEdges(conversationID)
if err != nil {
return nil, fmt.Errorf("加载攻击链边失败: %w", err)
}
return &Chain{
Nodes: nodes,
Edges: edges,
}, nil
}
// callAIForChainGeneration 调用AI生成攻击链
func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (string, error) {
requestBody := map[string]interface{}{
"model": b.openAIConfig.Model,
"messages": []map[string]interface{}{
{
"role": "system",
"content": "你是一个专业的安全测试分析师,擅长构建攻击链图。请严格按照JSON格式返回攻击链数据。",
},
{
"role": "user",
"content": prompt,
},
},
"temperature": 0.3,
"max_tokens": 8000,
}
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if b.openAIClient == nil {
return "", fmt.Errorf("OpenAI客户端未初始化")
}
if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
var apiErr *openai.APIError
if errors.As(err, &apiErr) {
bodyStr := strings.ToLower(apiErr.Body)
if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") {
return "", fmt.Errorf("context length exceeded")
}
} else if strings.Contains(strings.ToLower(err.Error()), "context") || strings.Contains(strings.ToLower(err.Error()), "length") {
return "", fmt.Errorf("context length exceeded")
}
return "", fmt.Errorf("请求失败: %w", err)
}
if len(apiResponse.Choices) == 0 {
return "", fmt.Errorf("API未返回有效响应")
}
content := strings.TrimSpace(apiResponse.Choices[0].Message.Content)
// 尝试提取JSON(可能包含markdown代码块)
content = strings.TrimPrefix(content, "```json")
content = strings.TrimPrefix(content, "```")
content = strings.TrimSuffix(content, "```")
content = strings.TrimSpace(content)
return content, nil
}
// ChainJSON 攻击链JSON结构
type ChainJSON struct {
Nodes []struct {
ID string `json:"id"`
Type string `json:"type"`
Label string `json:"label"`
RiskScore int `json:"risk_score"`
Metadata map[string]interface{} `json:"metadata"`
} `json:"nodes"`
Edges []struct {
Source string `json:"source"`
Target string `json:"target"`
Type string `json:"type"`
Weight int `json:"weight"`
} `json:"edges"`
}
// parseChainJSON 解析攻击链JSON
func (b *Builder) parseChainJSON(chainJSON string) (*Chain, error) {
var chainData ChainJSON
if err := json.Unmarshal([]byte(chainJSON), &chainData); err != nil {
return nil, fmt.Errorf("解析JSON失败: %w", err)
}
// 创建节点ID映射(AI返回的ID -> 新的UUID
nodeIDMap := make(map[string]string)
// 转换为Chain结构
nodes := make([]Node, 0, len(chainData.Nodes))
for _, n := range chainData.Nodes {
// 生成新的UUID节点ID
newNodeID := fmt.Sprintf("node_%s", uuid.New().String())
nodeIDMap[n.ID] = newNodeID
node := Node{
ID: newNodeID,
Type: n.Type,
Label: n.Label,
RiskScore: n.RiskScore,
Metadata: n.Metadata,
}
if node.Metadata == nil {
node.Metadata = make(map[string]interface{})
}
nodes = append(nodes, node)
}
// 转换边
edges := make([]Edge, 0, len(chainData.Edges))
for _, e := range chainData.Edges {
sourceID, ok := nodeIDMap[e.Source]
if !ok {
continue
}
targetID, ok := nodeIDMap[e.Target]
if !ok {
continue
}
// 生成边的ID(前端需要)
edgeID := fmt.Sprintf("edge_%s", uuid.New().String())
edges = append(edges, Edge{
ID: edgeID,
Source: sourceID,
Target: targetID,
Type: e.Type,
Weight: e.Weight,
})
}
return &Chain{
Nodes: nodes,
Edges: edges,
}, nil
}
// 以下所有方法已不再使用,已删除以简化代码
-857
View File
@@ -1,857 +0,0 @@
package config
import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"gopkg.in/yaml.v3"
)
type Config struct {
Version string `yaml:"version,omitempty" json:"version,omitempty"` // 前端显示的版本号,如 v1.3.3
Server ServerConfig `yaml:"server"`
Log LogConfig `yaml:"log"`
MCP MCPConfig `yaml:"mcp"`
OpenAI OpenAIConfig `yaml:"openai"`
FOFA FofaConfig `yaml:"fofa,omitempty" json:"fofa,omitempty"`
Agent AgentConfig `yaml:"agent"`
Security SecurityConfig `yaml:"security"`
Database DatabaseConfig `yaml:"database"`
Auth AuthConfig `yaml:"auth"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
Robots RobotsConfig `yaml:"robots,omitempty" json:"robots,omitempty"` // 企业微信/钉钉/飞书等机器人配置
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录
AgentsDir string `yaml:"agents_dir,omitempty" json:"agents_dir,omitempty"` // 多代理子 Agent Markdown 定义目录(*.mdYAML front matter
MultiAgent MultiAgentConfig `yaml:"multi_agent,omitempty" json:"multi_agent,omitempty"`
}
// MultiAgentConfig 基于 CloudWeGo Eino DeepAgent 的多代理编排(与单 Agent /agent-loop 并存)。
type MultiAgentConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
DefaultMode string `yaml:"default_mode" json:"default_mode"` // single | multi,供前端默认展示
RobotUseMultiAgent bool `yaml:"robot_use_multi_agent" json:"robot_use_multi_agent"` // 为 true 时钉钉/飞书/企微机器人走 Eino 多代理
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // Deep 主代理最大推理轮次
SubAgentMaxIterations int `yaml:"sub_agent_max_iterations" json:"sub_agent_max_iterations"`
WithoutGeneralSubAgent bool `yaml:"without_general_sub_agent" json:"without_general_sub_agent"`
WithoutWriteTodos bool `yaml:"without_write_todos" json:"without_write_todos"`
OrchestratorInstruction string `yaml:"orchestrator_instruction" json:"orchestrator_instruction"`
SubAgents []MultiAgentSubConfig `yaml:"sub_agents" json:"sub_agents"`
}
// MultiAgentSubConfig 子代理(Eino ChatModelAgent),由 DeepAgent 通过 task 工具调度。
type MultiAgentSubConfig struct {
ID string `yaml:"id" json:"id"`
Name string `yaml:"name" json:"name"`
Description string `yaml:"description" json:"description"`
Instruction string `yaml:"instruction" json:"instruction"`
BindRole string `yaml:"bind_role,omitempty" json:"bind_role,omitempty"` // 可选:关联主配置 roles 中的角色名;未配 role_tools 时沿用该角色的 tools,并把 skills 写入指令提示
RoleTools []string `yaml:"role_tools" json:"role_tools"` // 与单 Agent 角色工具相同 key;空表示全部工具(bind_role 可补全 tools
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
Kind string `yaml:"kind,omitempty" json:"kind,omitempty"` // 仅 Markdownkind=orchestrator 表示 Deep 主代理(与 orchestrator.md 二选一约定)
}
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
type MultiAgentPublic struct {
Enabled bool `json:"enabled"`
DefaultMode string `json:"default_mode"`
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
SubAgentCount int `json:"sub_agent_count"`
}
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
type MultiAgentAPIUpdate struct {
Enabled bool `json:"enabled"`
DefaultMode string `json:"default_mode"`
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
}
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等)
type RobotsConfig struct {
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
}
// RobotWecomConfig 企业微信机器人配置
type RobotWecomConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
Token string `yaml:"token" json:"token"` // 回调 URL 校验 Token
EncodingAESKey string `yaml:"encoding_aes_key" json:"encoding_aes_key"` // EncodingAESKey
CorpID string `yaml:"corp_id" json:"corp_id"` // 企业 ID
Secret string `yaml:"secret" json:"secret"` // 应用 Secret
AgentID int64 `yaml:"agent_id" json:"agent_id"` // 应用 AgentId
}
// RobotDingtalkConfig 钉钉机器人配置
type RobotDingtalkConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
ClientID string `yaml:"client_id" json:"client_id"` // 应用 Key (AppKey)
ClientSecret string `yaml:"client_secret" json:"client_secret"` // 应用 Secret
}
// RobotLarkConfig 飞书机器人配置
type RobotLarkConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
AppID string `yaml:"app_id" json:"app_id"` // 应用 App ID
AppSecret string `yaml:"app_secret" json:"app_secret"` // 应用 App Secret
VerifyToken string `yaml:"verify_token" json:"verify_token"` // 事件订阅 Verification Token(可选)
}
type ServerConfig struct {
Host string `yaml:"host"`
Port int `yaml:"port"`
}
type LogConfig struct {
Level string `yaml:"level"`
Output string `yaml:"output"`
}
type MCPConfig struct {
Enabled bool `yaml:"enabled"`
Host string `yaml:"host"`
Port int `yaml:"port"`
AuthHeader string `yaml:"auth_header,omitempty"` // 鉴权 header 名,留空表示不鉴权
AuthHeaderValue string `yaml:"auth_header_value,omitempty"` // 鉴权 header 值,需与请求中该 header 一致
}
type OpenAIConfig struct {
Provider string `yaml:"provider,omitempty" json:"provider,omitempty"` // API 提供商: "openai"(默认) 或 "claude"claude 时自动桥接为 Anthropic Messages API
APIKey string `yaml:"api_key" json:"api_key"`
BaseURL string `yaml:"base_url" json:"base_url"`
Model string `yaml:"model" json:"model"`
MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"`
}
type FofaConfig struct {
// Email 为 FOFA 账号邮箱;APIKey 为 FOFA API Key(建议使用只读权限的 Key)
Email string `yaml:"email,omitempty" json:"email,omitempty"`
APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty"`
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://fofa.info/api/v1/search/all
}
type SecurityConfig struct {
Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具
ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式)
ToolDescriptionMode string `yaml:"tool_description_mode,omitempty"` // 工具描述模式: "short" | "full",默认 short
}
type DatabaseConfig struct {
Path string `yaml:"path"` // 会话数据库路径
KnowledgeDBPath string `yaml:"knowledge_db_path,omitempty"` // 知识库数据库路径(可选,为空则使用会话数据库)
}
type AgentConfig struct {
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
LargeResultThreshold int `yaml:"large_result_threshold" json:"large_result_threshold"` // 大结果阈值(字节),默认50KB
ResultStorageDir string `yaml:"result_storage_dir" json:"result_storage_dir"` // 结果存储目录,默认tmp
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
}
type AuthConfig struct {
Password string `yaml:"password" json:"password"`
SessionDurationHours int `yaml:"session_duration_hours" json:"session_duration_hours"`
GeneratedPassword string `yaml:"-" json:"-"`
GeneratedPasswordPersisted bool `yaml:"-" json:"-"`
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
}
// ExternalMCPConfig 外部MCP配置
type ExternalMCPConfig struct {
Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"`
}
// ExternalMCPServerConfig 外部MCP服务器配置
type ExternalMCPServerConfig struct {
// stdio模式配置
Command string `yaml:"command,omitempty" json:"command,omitempty"`
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式)
// HTTP模式配置
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "stdio" | "sse" | "http"(Streamable) | "simple_http"(自建/简单POST端点,如本机 http://127.0.0.1:8081/mcp)
URL string `yaml:"url,omitempty" json:"url,omitempty"`
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` // HTTP/SSE 请求头(如 x-api-key
// 通用配置
Description string `yaml:"description,omitempty" json:"description,omitempty"`
Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` // 超时时间(秒)
ExternalMCPEnable bool `yaml:"external_mcp_enable,omitempty" json:"external_mcp_enable,omitempty"` // 是否启用外部MCP
ToolEnabled map[string]bool `yaml:"tool_enabled,omitempty" json:"tool_enabled,omitempty"` // 每个工具的启用状态(工具名称 -> 是否启用)
// 向后兼容字段(已废弃,保留用于读取旧配置)
Enabled bool `yaml:"enabled,omitempty" json:"enabled,omitempty"` // 已废弃,使用 external_mcp_enable
Disabled bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` // 已废弃,使用 external_mcp_enable
}
type ToolConfig struct {
Name string `yaml:"name"`
Command string `yaml:"command"`
Args []string `yaml:"args,omitempty"` // 固定参数(可选)
ShortDescription string `yaml:"short_description,omitempty"` // 简短描述(用于工具列表,减少token消耗)
Description string `yaml:"description"` // 详细描述(用于工具文档)
Enabled bool `yaml:"enabled"`
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
AllowedExitCodes []int `yaml:"allowed_exit_codes,omitempty"` // 允许的退出码列表(某些工具在成功时也返回非零退出码)
}
// ParameterConfig 参数配置
type ParameterConfig struct {
Name string `yaml:"name"` // 参数名称
Type string `yaml:"type"` // 参数类型: string, int, bool, array
Description string `yaml:"description"` // 参数描述
Required bool `yaml:"required,omitempty"` // 是否必需
Default interface{} `yaml:"default,omitempty"` // 默认值
ItemType string `yaml:"item_type,omitempty"` // 当 type 为 array 时,数组元素类型,如 string, number, object
Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p"
Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始)
Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template"
Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}"
Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举)
}
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
if cfg.Auth.SessionDurationHours <= 0 {
cfg.Auth.SessionDurationHours = 12
}
if strings.TrimSpace(cfg.Auth.Password) == "" {
password, err := generateStrongPassword(24)
if err != nil {
return nil, fmt.Errorf("生成默认密码失败: %w", err)
}
cfg.Auth.Password = password
cfg.Auth.GeneratedPassword = password
if err := PersistAuthPassword(path, password); err != nil {
cfg.Auth.GeneratedPasswordPersisted = false
cfg.Auth.GeneratedPasswordPersistErr = err.Error()
} else {
cfg.Auth.GeneratedPasswordPersisted = true
}
}
// 如果配置了工具目录,从目录加载工具配置
if cfg.Security.ToolsDir != "" {
configDir := filepath.Dir(path)
toolsDir := cfg.Security.ToolsDir
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(toolsDir) {
toolsDir = filepath.Join(configDir, toolsDir)
}
tools, err := LoadToolsFromDir(toolsDir)
if err != nil {
return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err)
}
// 合并工具配置:目录中的工具优先,主配置中的工具作为补充
existingTools := make(map[string]bool)
for _, tool := range tools {
existingTools[tool.Name] = true
}
// 添加主配置中不存在于目录中的工具(向后兼容)
for _, tool := range cfg.Security.Tools {
if !existingTools[tool.Name] {
tools = append(tools, tool)
}
}
cfg.Security.Tools = tools
}
// 迁移外部MCP配置:将旧的 enabled/disabled 字段迁移到 external_mcp_enable
if cfg.ExternalMCP.Servers != nil {
for name, serverCfg := range cfg.ExternalMCP.Servers {
// 如果已经设置了 external_mcp_enable,跳过迁移
// 否则从 enabled/disabled 字段迁移
// 注意:由于 ExternalMCPEnable 是 bool 类型,零值为 false,所以需要检查是否真的设置了
// 这里我们通过检查旧的 enabled/disabled 字段来判断是否需要迁移
if serverCfg.Disabled {
// 旧配置使用 disabled,迁移到 external_mcp_enable
serverCfg.ExternalMCPEnable = false
} else if serverCfg.Enabled {
// 旧配置使用 enabled,迁移到 external_mcp_enable
serverCfg.ExternalMCPEnable = true
} else {
// 都没有设置,默认为启用
serverCfg.ExternalMCPEnable = true
}
cfg.ExternalMCP.Servers[name] = serverCfg
}
}
// 从角色目录加载角色配置
if cfg.RolesDir != "" {
configDir := filepath.Dir(path)
rolesDir := cfg.RolesDir
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
roles, err := LoadRolesFromDir(rolesDir)
if err != nil {
return nil, fmt.Errorf("从角色目录加载角色配置失败: %w", err)
}
cfg.Roles = roles
} else {
// 如果未配置 roles_dir,初始化为空 map
if cfg.Roles == nil {
cfg.Roles = make(map[string]RoleConfig)
}
}
return &cfg, nil
}
func generateStrongPassword(length int) (string, error) {
if length <= 0 {
length = 24
}
bytesLen := length
randomBytes := make([]byte, bytesLen)
if _, err := rand.Read(randomBytes); err != nil {
return "", err
}
password := base64.RawURLEncoding.EncodeToString(randomBytes)
if len(password) > length {
password = password[:length]
}
return password, nil
}
func PersistAuthPassword(path, password string) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
lines := strings.Split(string(data), "\n")
inAuthBlock := false
authIndent := -1
for i, line := range lines {
trimmed := strings.TrimSpace(line)
if !inAuthBlock {
if strings.HasPrefix(trimmed, "auth:") {
inAuthBlock = true
authIndent = len(line) - len(strings.TrimLeft(line, " "))
}
continue
}
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
continue
}
leadingSpaces := len(line) - len(strings.TrimLeft(line, " "))
if leadingSpaces <= authIndent {
// 离开 auth 块
inAuthBlock = false
authIndent = -1
// 继续寻找其它 auth 块(理论上没有)
if strings.HasPrefix(trimmed, "auth:") {
inAuthBlock = true
authIndent = leadingSpaces
}
continue
}
if strings.HasPrefix(strings.TrimSpace(line), "password:") {
prefix := line[:len(line)-len(strings.TrimLeft(line, " "))]
comment := ""
if idx := strings.Index(line, "#"); idx >= 0 {
comment = strings.TrimRight(line[idx:], " ")
}
newLine := fmt.Sprintf("%spassword: %s", prefix, password)
if comment != "" {
if !strings.HasPrefix(comment, " ") {
newLine += " "
}
newLine += comment
}
lines[i] = newLine
break
}
}
return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644)
}
func PrintGeneratedPasswordWarning(password string, persisted bool, persistErr string) {
if strings.TrimSpace(password) == "" {
return
}
if persisted {
fmt.Println("[CyberStrikeAI] ✅ 已为您自动生成并写入 Web 登录密码。")
} else {
if persistErr != "" {
fmt.Printf("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码: %s\n", persistErr)
} else {
fmt.Println("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码。")
}
fmt.Println("请手动将以下随机密码写入 config.yaml 的 auth.password")
}
fmt.Println("----------------------------------------------------------------")
fmt.Println("CyberStrikeAI Auto-Generated Web Password")
fmt.Printf("Password: %s\n", password)
fmt.Println("WARNING: Anyone with this password can fully control CyberStrikeAI.")
fmt.Println("Please store it securely and change it in config.yaml as soon as possible.")
fmt.Println("警告:持有此密码的人将拥有对 CyberStrikeAI 的完全控制权限。")
fmt.Println("请妥善保管,并尽快在 config.yaml 中修改 auth.password")
fmt.Println("----------------------------------------------------------------")
}
// generateRandomToken 生成用于 MCP 鉴权的随机字符串(64 位十六进制)
func generateRandomToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// persistMCPAuth 将 MCP 的 auth_header / auth_header_value 写回配置文件
func persistMCPAuth(path string, mcp *MCPConfig) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
lines := strings.Split(string(data), "\n")
inMcpBlock := false
mcpIndent := -1
for i, line := range lines {
trimmed := strings.TrimSpace(line)
if !inMcpBlock {
if strings.HasPrefix(trimmed, "mcp:") {
inMcpBlock = true
mcpIndent = len(line) - len(strings.TrimLeft(line, " "))
}
continue
}
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
continue
}
leadingSpaces := len(line) - len(strings.TrimLeft(line, " "))
if leadingSpaces <= mcpIndent {
inMcpBlock = false
mcpIndent = -1
if strings.HasPrefix(trimmed, "mcp:") {
inMcpBlock = true
mcpIndent = leadingSpaces
}
continue
}
prefix := line[:leadingSpaces]
rest := strings.TrimSpace(line[leadingSpaces:])
comment := ""
if idx := strings.Index(line, "#"); idx >= 0 {
comment = strings.TrimRight(line[idx:], " ")
}
withComment := ""
if comment != "" {
if !strings.HasPrefix(comment, " ") {
withComment = " "
}
withComment += comment
}
if strings.HasPrefix(rest, "auth_header_value:") {
lines[i] = fmt.Sprintf("%sauth_header_value: %q%s", prefix, mcp.AuthHeaderValue, withComment)
} else if strings.HasPrefix(rest, "auth_header:") {
lines[i] = fmt.Sprintf("%sauth_header: %q%s", prefix, mcp.AuthHeader, withComment)
}
}
return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644)
}
// EnsureMCPAuth 在 MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置
func EnsureMCPAuth(path string, cfg *Config) error {
if !cfg.MCP.Enabled || strings.TrimSpace(cfg.MCP.AuthHeaderValue) != "" {
return nil
}
token, err := generateRandomToken()
if err != nil {
return fmt.Errorf("生成 MCP 鉴权密钥失败: %w", err)
}
cfg.MCP.AuthHeaderValue = token
if strings.TrimSpace(cfg.MCP.AuthHeader) == "" {
cfg.MCP.AuthHeader = "X-MCP-Token"
}
return persistMCPAuth(path, &cfg.MCP)
}
// PrintMCPConfigJSON 向终端输出 MCP 配置的 JSON,可直接复制到 Cursor / Claude Code 的 mcp 配置中使用
func PrintMCPConfigJSON(mcp MCPConfig) {
if !mcp.Enabled {
return
}
hostForURL := strings.TrimSpace(mcp.Host)
if hostForURL == "" || hostForURL == "0.0.0.0" {
hostForURL = "localhost"
}
url := fmt.Sprintf("http://%s:%d/mcp", hostForURL, mcp.Port)
headers := map[string]string{}
if mcp.AuthHeader != "" {
headers[mcp.AuthHeader] = mcp.AuthHeaderValue
}
serverEntry := map[string]interface{}{
"url": url,
}
if len(headers) > 0 {
serverEntry["headers"] = headers
}
// Claude Code 需要 type: "http"
serverEntry["type"] = "http"
out := map[string]interface{}{
"mcpServers": map[string]interface{}{
"cyberstrike-ai": serverEntry,
},
}
b, _ := json.MarshalIndent(out, "", " ")
fmt.Println("[CyberStrikeAI] MCP 配置(可复制到 Cursor / Claude Code 使用):")
fmt.Println(" Cursor: 放入 ~/.cursor/mcp.json 的 mcpServers,或项目 .cursor/mcp.json")
fmt.Println(" Claude Code: 放入 .mcp.json 或 ~/.claude.json 的 mcpServers")
fmt.Println("----------------------------------------------------------------")
fmt.Println(string(b))
fmt.Println("----------------------------------------------------------------")
}
// LoadToolsFromDir 从目录加载所有工具配置文件
func LoadToolsFromDir(dir string) ([]ToolConfig, error) {
var tools []ToolConfig
// 检查目录是否存在
if _, err := os.Stat(dir); os.IsNotExist(err) {
return tools, nil // 目录不存在时返回空列表,不报错
}
// 读取目录中的所有 .yaml 和 .yml 文件
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("读取工具目录失败: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") {
continue
}
filePath := filepath.Join(dir, name)
tool, err := LoadToolFromFile(filePath)
if err != nil {
// 记录错误但继续加载其他文件
fmt.Printf("警告: 加载工具配置文件 %s 失败: %v\n", filePath, err)
continue
}
tools = append(tools, *tool)
}
return tools, nil
}
// LoadToolFromFile 从单个文件加载工具配置
func LoadToolFromFile(path string) (*ToolConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取文件失败: %w", err)
}
var tool ToolConfig
if err := yaml.Unmarshal(data, &tool); err != nil {
return nil, fmt.Errorf("解析工具配置失败: %w", err)
}
// 验证必需字段
if tool.Name == "" {
return nil, fmt.Errorf("工具名称不能为空")
}
if tool.Command == "" {
return nil, fmt.Errorf("工具命令不能为空")
}
return &tool, nil
}
// LoadRolesFromDir 从目录加载所有角色配置文件
func LoadRolesFromDir(dir string) (map[string]RoleConfig, error) {
roles := make(map[string]RoleConfig)
// 检查目录是否存在
if _, err := os.Stat(dir); os.IsNotExist(err) {
return roles, nil // 目录不存在时返回空map,不报错
}
// 读取目录中的所有 .yaml 和 .yml 文件
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("读取角色目录失败: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") {
continue
}
filePath := filepath.Join(dir, name)
role, err := LoadRoleFromFile(filePath)
if err != nil {
// 记录错误但继续加载其他文件
fmt.Printf("警告: 加载角色配置文件 %s 失败: %v\n", filePath, err)
continue
}
// 使用角色名称作为key
roleName := role.Name
if roleName == "" {
// 如果角色名称为空,使用文件名(去掉扩展名)作为名称
roleName = strings.TrimSuffix(strings.TrimSuffix(name, ".yaml"), ".yml")
role.Name = roleName
}
roles[roleName] = *role
}
return roles, nil
}
// LoadRoleFromFile 从单个文件加载角色配置
func LoadRoleFromFile(path string) (*RoleConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取文件失败: %w", err)
}
var role RoleConfig
if err := yaml.Unmarshal(data, &role); err != nil {
return nil, fmt.Errorf("解析角色配置失败: %w", err)
}
// 处理 icon 字段:如果包含 Unicode 转义格式(\U0001F3C6),转换为实际的 Unicode 字符
// Go 的 yaml 库可能不会自动解析 \U 转义序列,需要手动转换
if role.Icon != "" {
icon := role.Icon
// 去除可能的引号
icon = strings.Trim(icon, `"`)
// 检查是否是 Unicode 转义格式 \U0001F3C68位十六进制)或 \uXXXX(4位十六进制)
if len(icon) >= 3 && icon[0] == '\\' {
if icon[1] == 'U' && len(icon) >= 10 {
// \U0001F3C6 格式(8位十六进制)
if codePoint, err := strconv.ParseInt(icon[2:10], 16, 32); err == nil {
role.Icon = string(rune(codePoint))
}
} else if icon[1] == 'u' && len(icon) >= 6 {
// \uXXXX 格式(4位十六进制)
if codePoint, err := strconv.ParseInt(icon[2:6], 16, 32); err == nil {
role.Icon = string(rune(codePoint))
}
}
}
}
// 验证必需字段
if role.Name == "" {
// 如果名称为空,尝试从文件名获取
baseName := filepath.Base(path)
role.Name = strings.TrimSuffix(strings.TrimSuffix(baseName, ".yaml"), ".yml")
}
return &role, nil
}
func Default() *Config {
return &Config{
Server: ServerConfig{
Host: "0.0.0.0",
Port: 8080,
},
Log: LogConfig{
Level: "info",
Output: "stdout",
},
MCP: MCPConfig{
Enabled: true,
Host: "0.0.0.0",
Port: 8081,
},
OpenAI: OpenAIConfig{
BaseURL: "https://api.openai.com/v1",
Model: "gpt-4",
MaxTotalTokens: 120000,
},
Agent: AgentConfig{
MaxIterations: 30, // 默认最大迭代次数
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
},
Security: SecurityConfig{
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载
ToolsDir: "tools", // 默认工具目录
},
Database: DatabaseConfig{
Path: "data/conversations.db",
KnowledgeDBPath: "data/knowledge.db", // 默认知识库数据库路径
},
Auth: AuthConfig{
SessionDurationHours: 12,
},
Knowledge: KnowledgeConfig{
Enabled: true,
BasePath: "knowledge_base",
Embedding: EmbeddingConfig{
Provider: "openai",
Model: "text-embedding-3-small",
BaseURL: "https://api.openai.com/v1",
},
Retrieval: RetrievalConfig{
TopK: 5,
SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检
},
Indexing: IndexingConfig{
ChunkStrategy: "markdown_then_recursive",
RequestTimeoutSeconds: 120,
ChunkSize: 768, // 增加到 768,更好的上下文保持
ChunkOverlap: 50,
MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额
BatchSize: 64,
PreferSourceFile: false,
MaxRPM: 100, // 默认 100 RPM,避免 429 错误
RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM
MaxRetries: 3,
RetryDelayMs: 1000,
SubIndexes: nil,
},
},
}
}
// KnowledgeConfig 知识库配置
type KnowledgeConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用知识检索
BasePath string `yaml:"base_path" json:"base_path"` // 知识库路径
Embedding EmbeddingConfig `yaml:"embedding" json:"embedding"`
Retrieval RetrievalConfig `yaml:"retrieval" json:"retrieval"`
Indexing IndexingConfig `yaml:"indexing,omitempty" json:"indexing,omitempty"` // 索引构建配置
}
// IndexingConfig 索引构建配置(用于控制知识库索引构建时的行为)
type IndexingConfig struct {
// ChunkStrategy: "markdown_then_recursive"(默认,Eino Markdown 标题切分后再递归切)或 "recursive"(仅递归切分)
ChunkStrategy string `yaml:"chunk_strategy,omitempty" json:"chunk_strategy,omitempty"`
// RequestTimeoutSeconds 嵌入 HTTP 客户端超时(秒),0 表示使用默认 120
RequestTimeoutSeconds int `yaml:"request_timeout_seconds,omitempty" json:"request_timeout_seconds,omitempty"`
// 分块配置
ChunkSize int `yaml:"chunk_size,omitempty" json:"chunk_size,omitempty"` // 每个块的最大 token 数(估算),默认 512
ChunkOverlap int `yaml:"chunk_overlap,omitempty" json:"chunk_overlap,omitempty"` // 块之间的重叠 token 数,默认 50
MaxChunksPerItem int `yaml:"max_chunks_per_item,omitempty" json:"max_chunks_per_item,omitempty"` // 单个知识项的最大块数量,0 表示不限制
// PreferSourceFile 为 true 时优先用 Eino FileLoader 从 file_path 读原文再索引(与库内 content 不一致时以磁盘为准)
PreferSourceFile bool `yaml:"prefer_source_file,omitempty" json:"prefer_source_file,omitempty"`
// 速率限制配置(用于避免 API 速率限制)
RateLimitDelayMs int `yaml:"rate_limit_delay_ms,omitempty" json:"rate_limit_delay_ms,omitempty"` // 请求间隔时间(毫秒),0 表示不使用固定延迟
MaxRPM int `yaml:"max_rpm,omitempty" json:"max_rpm,omitempty"` // 每分钟最大请求数,0 表示不限制
// 重试配置(用于处理临时错误)
MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries,omitempty"` // 最大重试次数,默认 3
RetryDelayMs int `yaml:"retry_delay_ms,omitempty" json:"retry_delay_ms,omitempty"` // 重试间隔(毫秒),默认 1000
// BatchSize 嵌入批大小(SQLite 索引写入),0 表示默认 64
BatchSize int `yaml:"batch_size,omitempty" json:"batch_size,omitempty"`
// SubIndexes 传入 Eino indexer.WithSubIndexes(逻辑分区标记,随 Document 元数据传递)
SubIndexes []string `yaml:"sub_indexes,omitempty" json:"sub_indexes,omitempty"`
}
// EmbeddingConfig 嵌入配置
type EmbeddingConfig struct {
Provider string `yaml:"provider" json:"provider"` // 嵌入模型提供商
Model string `yaml:"model" json:"model"` // 模型名称
BaseURL string `yaml:"base_url" json:"base_url"` // API Base URL
APIKey string `yaml:"api_key" json:"api_key"` // API Key(从OpenAI配置继承)
}
// PostRetrieveConfig 检索后处理:固定对正文做规范化去重(最佳实践)、上下文预算截断;PrefetchTopK 用于多取候选再收敛到 top_k。
type PostRetrieveConfig struct {
// PrefetchTopK 向量检索阶段最多保留的候选数(余弦序),应 ≥ top_k,0 表示与 top_k 相同;上限见知识库包内常量。
PrefetchTopK int `yaml:"prefetch_top_k,omitempty" json:"prefetch_top_k,omitempty"`
// MaxContextChars 返回文档内容总 Unicode 字符数上限(整段 chunk,不截断半段);0 表示不限制。
MaxContextChars int `yaml:"max_context_chars,omitempty" json:"max_context_chars,omitempty"`
// MaxContextTokens 返回文档内容总 token 上限(tiktoken,按嵌入模型名映射,失败则 cl100k_base);0 表示不限制。
MaxContextTokens int `yaml:"max_context_tokens,omitempty" json:"max_context_tokens,omitempty"`
}
// RetrievalConfig 检索配置
type RetrievalConfig struct {
TopK int `yaml:"top_k" json:"top_k"` // 检索Top-K
SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 余弦相似度阈值
// SubIndexFilter 非空时仅保留 sub_indexes 含该标签(逗号分隔之一)的行;sub_indexes 为空的旧行仍返回。
SubIndexFilter string `yaml:"sub_index_filter,omitempty" json:"sub_index_filter,omitempty"`
// PostRetrieve 检索后处理(去重、预算截断);重排通过代码注入 [knowledge.DocumentReranker]。
PostRetrieve PostRetrieveConfig `yaml:"post_retrieve,omitempty" json:"post_retrieve,omitempty"`
}
// RolesConfig 角色配置(已废弃,使用 map[string]RoleConfig 替代)
// 保留此类型以兼容旧代码,但建议直接使用 map[string]RoleConfig
type RolesConfig struct {
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"`
}
// RoleConfig 单个角色配置
type RoleConfig struct {
Name string `yaml:"name" json:"name"` // 角色名称
Description string `yaml:"description" json:"description"` // 角色描述
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName"
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代)
Skills []string `yaml:"skills,omitempty" json:"skills,omitempty"` // 关联的skills列表(skill名称列表,在执行任务前会读取这些skills的内容)
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
}
-168
View File
@@ -1,168 +0,0 @@
package database
import (
"database/sql"
"encoding/json"
"fmt"
"go.uber.org/zap"
)
// AttackChainNode 攻击链节点
type AttackChainNode struct {
ID string `json:"id"`
Type string `json:"type"` // tool, vulnerability, target, exploit
Label string `json:"label"`
ToolExecutionID string `json:"tool_execution_id,omitempty"`
Metadata map[string]interface{} `json:"metadata"`
RiskScore int `json:"risk_score"`
}
// AttackChainEdge 攻击链边
type AttackChainEdge struct {
ID string `json:"id"`
Source string `json:"source"`
Target string `json:"target"`
Type string `json:"type"` // leads_to, exploits, enables, depends_on
Weight int `json:"weight"`
}
// SaveAttackChainNode 保存攻击链节点
func (db *DB) SaveAttackChainNode(conversationID, nodeID, nodeType, nodeName, toolExecutionID, metadata string, riskScore int) error {
var toolExecID sql.NullString
if toolExecutionID != "" {
toolExecID = sql.NullString{String: toolExecutionID, Valid: true}
}
var metadataJSON sql.NullString
if metadata != "" {
metadataJSON = sql.NullString{String: metadata, Valid: true}
}
query := `
INSERT OR REPLACE INTO attack_chain_nodes
(id, conversation_id, node_type, node_name, tool_execution_id, metadata, risk_score, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
`
_, err := db.Exec(query, nodeID, conversationID, nodeType, nodeName, toolExecID, metadataJSON, riskScore)
if err != nil {
db.logger.Error("保存攻击链节点失败", zap.Error(err), zap.String("nodeId", nodeID))
return err
}
return nil
}
// SaveAttackChainEdge 保存攻击链边
func (db *DB) SaveAttackChainEdge(conversationID, edgeID, sourceNodeID, targetNodeID, edgeType string, weight int) error {
query := `
INSERT OR REPLACE INTO attack_chain_edges
(id, conversation_id, source_node_id, target_node_id, edge_type, weight, created_at)
VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
`
_, err := db.Exec(query, edgeID, conversationID, sourceNodeID, targetNodeID, edgeType, weight)
if err != nil {
db.logger.Error("保存攻击链边失败", zap.Error(err), zap.String("edgeId", edgeID))
return err
}
return nil
}
// LoadAttackChainNodes 加载攻击链节点
func (db *DB) LoadAttackChainNodes(conversationID string) ([]AttackChainNode, error) {
query := `
SELECT id, node_type, node_name, tool_execution_id, metadata, risk_score
FROM attack_chain_nodes
WHERE conversation_id = ?
ORDER BY created_at ASC
`
rows, err := db.Query(query, conversationID)
if err != nil {
return nil, fmt.Errorf("查询攻击链节点失败: %w", err)
}
defer rows.Close()
var nodes []AttackChainNode
for rows.Next() {
var node AttackChainNode
var toolExecID sql.NullString
var metadataJSON sql.NullString
err := rows.Scan(&node.ID, &node.Type, &node.Label, &toolExecID, &metadataJSON, &node.RiskScore)
if err != nil {
db.logger.Warn("扫描攻击链节点失败", zap.Error(err))
continue
}
if toolExecID.Valid {
node.ToolExecutionID = toolExecID.String
}
if metadataJSON.Valid && metadataJSON.String != "" {
if err := json.Unmarshal([]byte(metadataJSON.String), &node.Metadata); err != nil {
db.logger.Warn("解析节点元数据失败", zap.Error(err))
node.Metadata = make(map[string]interface{})
}
} else {
node.Metadata = make(map[string]interface{})
}
nodes = append(nodes, node)
}
return nodes, nil
}
// LoadAttackChainEdges 加载攻击链边
func (db *DB) LoadAttackChainEdges(conversationID string) ([]AttackChainEdge, error) {
query := `
SELECT id, source_node_id, target_node_id, edge_type, weight
FROM attack_chain_edges
WHERE conversation_id = ?
ORDER BY created_at ASC
`
rows, err := db.Query(query, conversationID)
if err != nil {
return nil, fmt.Errorf("查询攻击链边失败: %w", err)
}
defer rows.Close()
var edges []AttackChainEdge
for rows.Next() {
var edge AttackChainEdge
err := rows.Scan(&edge.ID, &edge.Source, &edge.Target, &edge.Type, &edge.Weight)
if err != nil {
db.logger.Warn("扫描攻击链边失败", zap.Error(err))
continue
}
edges = append(edges, edge)
}
return edges, nil
}
// DeleteAttackChain 删除对话的攻击链数据
func (db *DB) DeleteAttackChain(conversationID string) error {
// 先删除边(因为有外键约束)
_, err := db.Exec("DELETE FROM attack_chain_edges WHERE conversation_id = ?", conversationID)
if err != nil {
db.logger.Warn("删除攻击链边失败", zap.Error(err))
}
// 再删除节点
_, err = db.Exec("DELETE FROM attack_chain_nodes WHERE conversation_id = ?", conversationID)
if err != nil {
db.logger.Error("删除攻击链节点失败", zap.Error(err), zap.String("conversationId", conversationID))
return err
}
return nil
}
-537
View File
@@ -1,537 +0,0 @@
package database
import (
"database/sql"
"fmt"
"strings"
"time"
"go.uber.org/zap"
)
// BatchTaskQueueRow 批量任务队列数据库行
type BatchTaskQueueRow struct {
ID string
Title sql.NullString
Role sql.NullString
AgentMode sql.NullString
ScheduleMode sql.NullString
CronExpr sql.NullString
NextRunAt sql.NullTime
ScheduleEnabled sql.NullInt64
LastScheduleTriggerAt sql.NullTime
LastScheduleError sql.NullString
LastRunError sql.NullString
Status string
CreatedAt time.Time
StartedAt sql.NullTime
CompletedAt sql.NullTime
CurrentIndex int
}
// BatchTaskRow 批量任务数据库行
type BatchTaskRow struct {
ID string
QueueID string
Message string
ConversationID sql.NullString
Status string
StartedAt sql.NullTime
CompletedAt sql.NullTime
Error sql.NullString
Result sql.NullString
}
// CreateBatchQueue 创建批量任务队列
func (db *DB) CreateBatchQueue(
queueID string,
title string,
role string,
agentMode string,
scheduleMode string,
cronExpr string,
nextRunAt *time.Time,
tasks []map[string]interface{},
) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开始事务失败: %w", err)
}
defer tx.Rollback()
now := time.Now()
var nextRunAtValue interface{}
if nextRunAt != nil {
nextRunAtValue = *nextRunAt
}
_, err = tx.Exec(
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, "pending", now, 0,
)
if err != nil {
return fmt.Errorf("创建批量任务队列失败: %w", err)
}
// 插入任务
for _, task := range tasks {
taskID, ok := task["id"].(string)
if !ok {
continue
}
message, ok := task["message"].(string)
if !ok {
continue
}
_, err = tx.Exec(
"INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)",
taskID, queueID, message, "pending",
)
if err != nil {
return fmt.Errorf("创建批量任务失败: %w", err)
}
}
return tx.Commit()
}
// GetBatchQueue 获取批量任务队列
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
var row BatchTaskQueueRow
var createdAt string
err := db.QueryRow(
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
queueID,
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询批量任务队列失败: %w", err)
}
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
if parseErr != nil {
// 尝试其他时间格式
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
if parseErr != nil {
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
parsedTime = time.Now()
}
}
row.CreatedAt = parsedTime
return &row, nil
}
// GetAllBatchQueues 获取所有批量任务队列
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
rows, err := db.Query(
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
)
if err != nil {
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
}
defer rows.Close()
var queues []*BatchTaskQueueRow
for rows.Next() {
var row BatchTaskQueueRow
var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
}
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
if parseErr != nil {
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
if parseErr != nil {
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
parsedTime = time.Now()
}
}
row.CreatedAt = parsedTime
queues = append(queues, &row)
}
return queues, nil
}
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
args := []interface{}{}
// 状态筛选
if status != "" && status != "all" {
query += " AND status = ?"
args = append(args, status)
}
// 关键字搜索(搜索队列ID和标题)
if keyword != "" {
query += " AND (id LIKE ? OR title LIKE ?)"
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
}
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
}
defer rows.Close()
var queues []*BatchTaskQueueRow
for rows.Next() {
var row BatchTaskQueueRow
var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
}
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
if parseErr != nil {
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
if parseErr != nil {
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
parsedTime = time.Now()
}
}
row.CreatedAt = parsedTime
queues = append(queues, &row)
}
return queues, nil
}
// CountBatchQueues 统计批量任务队列总数(支持筛选条件)
func (db *DB) CountBatchQueues(status, keyword string) (int, error) {
query := "SELECT COUNT(*) FROM batch_task_queues WHERE 1=1"
args := []interface{}{}
// 状态筛选
if status != "" && status != "all" {
query += " AND status = ?"
args = append(args, status)
}
// 关键字搜索(搜索队列ID和标题)
if keyword != "" {
query += " AND (id LIKE ? OR title LIKE ?)"
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
}
var count int
err := db.QueryRow(query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("统计批量任务队列总数失败: %w", err)
}
return count, nil
}
// GetBatchTasks 获取批量任务队列的所有任务
func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) {
rows, err := db.Query(
"SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY id",
queueID,
)
if err != nil {
return nil, fmt.Errorf("查询批量任务失败: %w", err)
}
defer rows.Close()
var tasks []*BatchTaskRow
for rows.Next() {
var task BatchTaskRow
if err := rows.Scan(
&task.ID, &task.QueueID, &task.Message, &task.ConversationID,
&task.Status, &task.StartedAt, &task.CompletedAt, &task.Error, &task.Result,
); err != nil {
return nil, fmt.Errorf("扫描批量任务失败: %w", err)
}
tasks = append(tasks, &task)
}
return tasks, nil
}
// UpdateBatchQueueStatus 更新批量任务队列状态
func (db *DB) UpdateBatchQueueStatus(queueID, status string) error {
var err error
now := time.Now()
if status == "running" {
_, err = db.Exec(
"UPDATE batch_task_queues SET status = ?, started_at = COALESCE(started_at, ?) WHERE id = ?",
status, now, queueID,
)
} else if status == "completed" || status == "cancelled" {
_, err = db.Exec(
"UPDATE batch_task_queues SET status = ?, completed_at = COALESCE(completed_at, ?) WHERE id = ?",
status, now, queueID,
)
} else {
_, err = db.Exec(
"UPDATE batch_task_queues SET status = ? WHERE id = ?",
status, queueID,
)
}
if err != nil {
return fmt.Errorf("更新批量任务队列状态失败: %w", err)
}
return nil
}
// UpdateBatchTaskStatus 更新批量任务状态
func (db *DB) UpdateBatchTaskStatus(queueID, taskID, status string, conversationID, result, errorMsg string) error {
var err error
now := time.Now()
// 构建更新语句
var updates []string
var args []interface{}
updates = append(updates, "status = ?")
args = append(args, status)
if conversationID != "" {
updates = append(updates, "conversation_id = ?")
args = append(args, conversationID)
}
if result != "" {
updates = append(updates, "result = ?")
args = append(args, result)
}
if errorMsg != "" {
updates = append(updates, "error = ?")
args = append(args, errorMsg)
}
if status == "running" {
updates = append(updates, "started_at = COALESCE(started_at, ?)")
args = append(args, now)
}
if status == "completed" || status == "failed" || status == "cancelled" {
updates = append(updates, "completed_at = COALESCE(completed_at, ?)")
args = append(args, now)
}
args = append(args, queueID, taskID)
// 构建SQL语句
sql := "UPDATE batch_tasks SET "
for i, update := range updates {
if i > 0 {
sql += ", "
}
sql += update
}
sql += " WHERE queue_id = ? AND id = ?"
_, err = db.Exec(sql, args...)
if err != nil {
return fmt.Errorf("更新批量任务状态失败: %w", err)
}
return nil
}
// UpdateBatchQueueCurrentIndex 更新批量任务队列的当前索引
func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) error {
_, err := db.Exec(
"UPDATE batch_task_queues SET current_index = ? WHERE id = ?",
currentIndex, queueID,
)
if err != nil {
return fmt.Errorf("更新批量任务队列当前索引失败: %w", err)
}
return nil
}
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色和代理模式
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error {
_, err := db.Exec(
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?",
title, role, agentMode, queueID,
)
if err != nil {
return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
}
return nil
}
// UpdateBatchQueueSchedule 更新批量任务队列调度相关信息
func (db *DB) UpdateBatchQueueSchedule(queueID, scheduleMode, cronExpr string, nextRunAt *time.Time) error {
var nextRunAtValue interface{}
if nextRunAt != nil {
nextRunAtValue = *nextRunAt
}
_, err := db.Exec(
"UPDATE batch_task_queues SET schedule_mode = ?, cron_expr = ?, next_run_at = ? WHERE id = ?",
scheduleMode, cronExpr, nextRunAtValue, queueID,
)
if err != nil {
return fmt.Errorf("更新批量任务调度配置失败: %w", err)
}
return nil
}
// UpdateBatchQueueScheduleEnabled 是否允许 Cron 自动触发(手工「开始执行」不受影响)
func (db *DB) UpdateBatchQueueScheduleEnabled(queueID string, enabled bool) error {
v := 0
if enabled {
v = 1
}
_, err := db.Exec(
"UPDATE batch_task_queues SET schedule_enabled = ? WHERE id = ?",
v, queueID,
)
if err != nil {
return fmt.Errorf("更新批量任务调度开关失败: %w", err)
}
return nil
}
// RecordBatchQueueScheduledTriggerStart 记录一次由调度触发的开始时间并清空调度层错误
func (db *DB) RecordBatchQueueScheduledTriggerStart(queueID string, at time.Time) error {
_, err := db.Exec(
"UPDATE batch_task_queues SET last_schedule_trigger_at = ?, last_schedule_error = NULL WHERE id = ?",
at, queueID,
)
if err != nil {
return fmt.Errorf("记录调度触发时间失败: %w", err)
}
return nil
}
// SetBatchQueueLastScheduleError 调度启动失败等原因(如状态不允许、重置失败)
func (db *DB) SetBatchQueueLastScheduleError(queueID, msg string) error {
_, err := db.Exec(
"UPDATE batch_task_queues SET last_schedule_error = ? WHERE id = ?",
msg, queueID,
)
if err != nil {
return fmt.Errorf("写入调度错误信息失败: %w", err)
}
return nil
}
// SetBatchQueueLastRunError 最近一轮执行中出现的子任务失败摘要(空串表示清空)
func (db *DB) SetBatchQueueLastRunError(queueID, msg string) error {
var v interface{}
if strings.TrimSpace(msg) == "" {
v = nil
} else {
v = msg
}
_, err := db.Exec(
"UPDATE batch_task_queues SET last_run_error = ? WHERE id = ?",
v, queueID,
)
if err != nil {
return fmt.Errorf("写入最近运行错误失败: %w", err)
}
return nil
}
// ResetBatchQueueForRerun 重置队列和任务状态用于下一轮调度执行
func (db *DB) ResetBatchQueueForRerun(queueID string) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开始事务失败: %w", err)
}
defer tx.Rollback()
_, err = tx.Exec(
"UPDATE batch_task_queues SET status = ?, current_index = 0, started_at = NULL, completed_at = NULL, last_run_error = NULL, last_schedule_error = NULL WHERE id = ?",
"pending", queueID,
)
if err != nil {
return fmt.Errorf("重置批量任务队列状态失败: %w", err)
}
_, err = tx.Exec(
"UPDATE batch_tasks SET status = ?, conversation_id = NULL, started_at = NULL, completed_at = NULL, error = NULL, result = NULL WHERE queue_id = ?",
"pending", queueID,
)
if err != nil {
return fmt.Errorf("重置批量任务状态失败: %w", err)
}
return tx.Commit()
}
// UpdateBatchTaskMessage 更新批量任务消息
func (db *DB) UpdateBatchTaskMessage(queueID, taskID, message string) error {
_, err := db.Exec(
"UPDATE batch_tasks SET message = ? WHERE queue_id = ? AND id = ?",
message, queueID, taskID,
)
if err != nil {
return fmt.Errorf("更新批量任务消息失败: %w", err)
}
return nil
}
// AddBatchTask 添加任务到批量任务队列
func (db *DB) AddBatchTask(queueID, taskID, message string) error {
_, err := db.Exec(
"INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)",
taskID, queueID, message, "pending",
)
if err != nil {
return fmt.Errorf("添加批量任务失败: %w", err)
}
return nil
}
// CancelPendingBatchTasks 批量取消队列中所有 pending 状态的任务(单条 SQL
func (db *DB) CancelPendingBatchTasks(queueID string, completedAt time.Time) error {
_, err := db.Exec(
"UPDATE batch_tasks SET status = ?, completed_at = ? WHERE queue_id = ? AND status = ?",
"cancelled", completedAt, queueID, "pending",
)
if err != nil {
return fmt.Errorf("批量取消 pending 任务失败: %w", err)
}
return nil
}
// DeleteBatchTask 删除批量任务
func (db *DB) DeleteBatchTask(queueID, taskID string) error {
_, err := db.Exec(
"DELETE FROM batch_tasks WHERE queue_id = ? AND id = ?",
queueID, taskID,
)
if err != nil {
return fmt.Errorf("删除批量任务失败: %w", err)
}
return nil
}
// DeleteBatchQueue 删除批量任务队列
func (db *DB) DeleteBatchQueue(queueID string) error {
tx, err := db.Begin()
if err != nil {
return fmt.Errorf("开始事务失败: %w", err)
}
defer tx.Rollback()
// 删除任务(外键会自动级联删除)
_, err = tx.Exec("DELETE FROM batch_tasks WHERE queue_id = ?", queueID)
if err != nil {
return fmt.Errorf("删除批量任务失败: %w", err)
}
// 删除队列
_, err = tx.Exec("DELETE FROM batch_task_queues WHERE id = ?", queueID)
if err != nil {
return fmt.Errorf("删除批量任务队列失败: %w", err)
}
return tx.Commit()
}
-758
View File
@@ -1,758 +0,0 @@
package database
import (
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Conversation 对话
type Conversation struct {
ID string `json:"id"`
Title string `json:"title"`
Pinned bool `json:"pinned"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
Messages []Message `json:"messages,omitempty"`
}
// Message 消息
type Message struct {
ID string `json:"id"`
ConversationID string `json:"conversationId"`
Role string `json:"role"`
Content string `json:"content"`
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
CreatedAt time.Time `json:"createdAt"`
}
// CreateConversation 创建新对话
func (db *DB) CreateConversation(title string) (*Conversation, error) {
return db.CreateConversationWithWebshell("", title)
}
// CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话)
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string) (*Conversation, error) {
id := uuid.New().String()
now := time.Now()
var err error
if webshellConnectionID != "" {
_, err = db.Exec(
"INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id) VALUES (?, ?, ?, ?, ?)",
id, title, now, now, webshellConnectionID,
)
} else {
_, err = db.Exec(
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
id, title, now, now,
)
}
if err != nil {
return nil, fmt.Errorf("创建对话失败: %w", err)
}
return &Conversation{
ID: id,
Title: title,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化)
func (db *DB) GetConversationByWebshellConnectionID(connectionID string) (*Conversation, error) {
if connectionID == "" {
return nil, fmt.Errorf("connectionID is empty")
}
var conv Conversation
var createdAt, updatedAt string
var pinned int
err := db.QueryRow(
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE webshell_connection_id = ? ORDER BY updated_at DESC LIMIT 1",
connectionID,
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("查询对话失败: %w", err)
}
conv.Pinned = pinned != 0
if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt); e == nil {
conv.CreatedAt = t
} else if t, e := time.Parse("2006-01-02 15:04:05", createdAt); e == nil {
conv.CreatedAt = t
} else {
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil {
conv.UpdatedAt = t
} else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil {
conv.UpdatedAt = t
} else {
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
messages, err := db.GetMessages(conv.ID)
if err != nil {
return nil, fmt.Errorf("加载消息失败: %w", err)
}
conv.Messages = messages
// 加载过程详情并附加到对应消息(与 GetConversation 一致,便于刷新后仍可查看执行过程)
processDetailsMap, err := db.GetProcessDetailsByConversation(conv.ID)
if err != nil {
db.logger.Warn("加载过程详情失败", zap.Error(err))
processDetailsMap = make(map[string][]ProcessDetail)
}
for i := range conv.Messages {
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
detailsJSON := make([]map[string]interface{}, len(details))
for j, detail := range details {
var data interface{}
if detail.Data != "" {
if err := json.Unmarshal([]byte(detail.Data), &data); err != nil {
db.logger.Warn("解析过程详情数据失败", zap.Error(err))
}
}
detailsJSON[j] = map[string]interface{}{
"id": detail.ID,
"messageId": detail.MessageID,
"conversationId": detail.ConversationID,
"eventType": detail.EventType,
"message": detail.Message,
"data": data,
"createdAt": detail.CreatedAt,
}
}
conv.Messages[i].ProcessDetails = detailsJSON
}
}
return &conv, nil
}
// WebShellConversationItem 用于侧边栏列表,不含消息
type WebShellConversationItem struct {
ID string `json:"id"`
Title string `json:"title"`
UpdatedAt time.Time `json:"updatedAt"`
}
// ListConversationsByWebshellConnectionID 列出该 WebShell 连接下的所有对话(按更新时间倒序),供侧边栏展示
func (db *DB) ListConversationsByWebshellConnectionID(connectionID string) ([]WebShellConversationItem, error) {
if connectionID == "" {
return nil, nil
}
rows, err := db.Query(
"SELECT id, title, updated_at FROM conversations WHERE webshell_connection_id = ? ORDER BY updated_at DESC",
connectionID,
)
if err != nil {
return nil, fmt.Errorf("查询对话列表失败: %w", err)
}
defer rows.Close()
var list []WebShellConversationItem
for rows.Next() {
var item WebShellConversationItem
var updatedAt string
if err := rows.Scan(&item.ID, &item.Title, &updatedAt); err != nil {
continue
}
if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil {
item.UpdatedAt = t
} else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil {
item.UpdatedAt = t
} else {
item.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
list = append(list, item)
}
return list, rows.Err()
}
// GetConversation 获取对话
func (db *DB) GetConversation(id string) (*Conversation, error) {
var conv Conversation
var createdAt, updatedAt string
var pinned int
err := db.QueryRow(
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?",
id,
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("对话不存在")
}
return nil, fmt.Errorf("查询对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
conv.Pinned = pinned != 0
// 加载消息
messages, err := db.GetMessages(id)
if err != nil {
return nil, fmt.Errorf("加载消息失败: %w", err)
}
conv.Messages = messages
// 加载过程详情(按消息ID分组)
processDetailsMap, err := db.GetProcessDetailsByConversation(id)
if err != nil {
db.logger.Warn("加载过程详情失败", zap.Error(err))
processDetailsMap = make(map[string][]ProcessDetail)
}
// 将过程详情附加到对应的消息上
for i := range conv.Messages {
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
// 将ProcessDetail转换为JSON格式,以便前端使用
detailsJSON := make([]map[string]interface{}, len(details))
for j, detail := range details {
var data interface{}
if detail.Data != "" {
if err := json.Unmarshal([]byte(detail.Data), &data); err != nil {
db.logger.Warn("解析过程详情数据失败", zap.Error(err))
}
}
detailsJSON[j] = map[string]interface{}{
"id": detail.ID,
"messageId": detail.MessageID,
"conversationId": detail.ConversationID,
"eventType": detail.EventType,
"message": detail.Message,
"data": data,
"createdAt": detail.CreatedAt,
}
}
conv.Messages[i].ProcessDetails = detailsJSON
}
}
return &conv, nil
}
// GetConversationLite 获取对话(轻量版):包含 messages,但不加载 process_details。
// 用于历史会话快速切换,避免一次性把大体量过程详情灌到前端导致卡顿。
func (db *DB) GetConversationLite(id string) (*Conversation, error) {
var conv Conversation
var createdAt, updatedAt string
var pinned int
err := db.QueryRow(
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?",
id,
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("对话不存在")
}
return nil, fmt.Errorf("查询对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
conv.Pinned = pinned != 0
// 加载消息(不加载 process_details
messages, err := db.GetMessages(id)
if err != nil {
return nil, fmt.Errorf("加载消息失败: %w", err)
}
conv.Messages = messages
return &conv, nil
}
// ListConversations 列出所有对话
func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) {
var rows *sql.Rows
var err error
if search != "" {
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
searchPattern := "%" + search + "%"
rows, err = db.Query(
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at
FROM conversations c
WHERE c.title LIKE ?
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
ORDER BY c.updated_at DESC
LIMIT ? OFFSET ?`,
searchPattern, searchPattern, limit, offset,
)
} else {
rows, err = db.Query(
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
limit, offset,
)
}
if err != nil {
return nil, fmt.Errorf("查询对话列表失败: %w", err)
}
defer rows.Close()
var conversations []*Conversation
for rows.Next() {
var conv Conversation
var createdAt, updatedAt string
var pinned int
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
conv.Pinned = pinned != 0
conversations = append(conversations, &conv)
}
return conversations, nil
}
// UpdateConversationTitle 更新对话标题
func (db *DB) UpdateConversationTitle(id, title string) error {
// 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间
_, err := db.Exec(
"UPDATE conversations SET title = ? WHERE id = ?",
title, id,
)
if err != nil {
return fmt.Errorf("更新对话标题失败: %w", err)
}
return nil
}
// UpdateConversationTime 更新对话时间
func (db *DB) UpdateConversationTime(id string) error {
_, err := db.Exec(
"UPDATE conversations SET updated_at = ? WHERE id = ?",
time.Now(), id,
)
if err != nil {
return fmt.Errorf("更新对话时间失败: %w", err)
}
return nil
}
// DeleteConversation 删除对话及其所有相关数据
// 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除:
// - messages(消息)
// - process_details(过程详情)
// - attack_chain_nodes(攻击链节点)
// - attack_chain_edges(攻击链边)
// - vulnerabilities(漏洞)
// - conversation_group_mappings(分组映射)
// 注意:knowledge_retrieval_logs 使用 ON DELETE SET NULL,记录会保留但 conversation_id 会被设为 NULL
func (db *DB) DeleteConversation(id string) error {
// 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除)
_, err := db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id)
if err != nil {
db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err))
// 不返回错误,继续删除对话
}
// 删除对话(外键CASCADE会自动删除其他相关数据)
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除对话失败: %w", err)
}
db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id))
return nil
}
// SaveReActData 保存最后一轮ReAct的输入和输出
func (db *DB) SaveReActData(conversationID, reactInput, reactOutput string) error {
_, err := db.Exec(
"UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?",
reactInput, reactOutput, time.Now(), conversationID,
)
if err != nil {
return fmt.Errorf("保存ReAct数据失败: %w", err)
}
return nil
}
// GetReActData 获取最后一轮ReAct的输入和输出
func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput string, err error) {
var input, output sql.NullString
err = db.QueryRow(
"SELECT last_react_input, last_react_output FROM conversations WHERE id = ?",
conversationID,
).Scan(&input, &output)
if err != nil {
if err == sql.ErrNoRows {
return "", "", fmt.Errorf("对话不存在")
}
return "", "", fmt.Errorf("获取ReAct数据失败: %w", err)
}
if input.Valid {
reactInput = input.String
}
if output.Valid {
reactOutput = output.String
}
return reactInput, reactOutput, nil
}
// ConversationHasToolProcessDetails 对话是否存在已落库的工具调用/结果(用于多代理等场景下 MCP execution id 未汇总时的攻击链判定)。
func (db *DB) ConversationHasToolProcessDetails(conversationID string) (bool, error) {
var n int
err := db.QueryRow(
`SELECT COUNT(*) FROM process_details WHERE conversation_id = ? AND event_type IN ('tool_call', 'tool_result')`,
conversationID,
).Scan(&n)
if err != nil {
return false, fmt.Errorf("查询过程详情失败: %w", err)
}
return n > 0, nil
}
// AddMessage 添加消息
func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) {
id := uuid.New().String()
var mcpIDsJSON string
if len(mcpExecutionIDs) > 0 {
jsonData, err := json.Marshal(mcpExecutionIDs)
if err != nil {
db.logger.Warn("序列化MCP执行ID失败", zap.Error(err))
} else {
mcpIDsJSON = string(jsonData)
}
}
_, err := db.Exec(
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)",
id, conversationID, role, content, mcpIDsJSON, time.Now(),
)
if err != nil {
return nil, fmt.Errorf("添加消息失败: %w", err)
}
// 更新对话时间
if err := db.UpdateConversationTime(conversationID); err != nil {
db.logger.Warn("更新对话时间失败", zap.Error(err))
}
message := &Message{
ID: id,
ConversationID: conversationID,
Role: role,
Content: content,
MCPExecutionIDs: mcpExecutionIDs,
CreatedAt: time.Now(),
}
return message, nil
}
// GetMessages 获取对话的所有消息
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
rows, err := db.Query(
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
conversationID,
)
if err != nil {
return nil, fmt.Errorf("查询消息失败: %w", err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
var msg Message
var mcpIDsJSON sql.NullString
var createdAt string
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil {
return nil, fmt.Errorf("扫描消息失败: %w", err)
}
// 尝试多种时间格式解析
var err error
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err != nil {
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err != nil {
msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
// 解析MCP执行ID
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
}
}
messages = append(messages, msg)
}
return messages, nil
}
// turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。
// 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。
func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) {
idx := -1
for i := range msgs {
if msgs[i].ID == anchorID {
idx = i
break
}
}
if idx < 0 {
return 0, 0, fmt.Errorf("message not found")
}
start = idx
for start > 0 && msgs[start].Role != "user" {
start--
}
if start < len(msgs) && msgs[start].Role != "user" {
start = 0
}
end = len(msgs)
for i := start + 1; i < len(msgs); i++ {
if msgs[i].Role == "user" {
end = i
break
}
}
return start, end, nil
}
// DeleteConversationTurn 删除锚点所在轮次的全部消息(用户提问 + 该轮助手回复等),并清空 last_react_*,避免与消息表不一致。
func (db *DB) DeleteConversationTurn(conversationID, anchorMessageID string) (deletedIDs []string, err error) {
msgs, err := db.GetMessages(conversationID)
if err != nil {
return nil, err
}
start, end, err := turnSliceRange(msgs, anchorMessageID)
if err != nil {
return nil, err
}
if start >= end {
return nil, fmt.Errorf("empty turn range")
}
deletedIDs = make([]string, 0, end-start)
for i := start; i < end; i++ {
deletedIDs = append(deletedIDs, msgs[i].ID)
}
tx, err := db.Begin()
if err != nil {
return nil, fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
ph := strings.Repeat("?,", len(deletedIDs))
ph = ph[:len(ph)-1]
args := make([]interface{}, 0, 1+len(deletedIDs))
args = append(args, conversationID)
for _, id := range deletedIDs {
args = append(args, id)
}
res, err := tx.Exec(
"DELETE FROM messages WHERE conversation_id = ? AND id IN ("+ph+")",
args...,
)
if err != nil {
return nil, fmt.Errorf("delete messages: %w", err)
}
n, err := res.RowsAffected()
if err != nil {
return nil, err
}
if int(n) != len(deletedIDs) {
return nil, fmt.Errorf("deleted count mismatch")
}
_, err = tx.Exec(
`UPDATE conversations SET last_react_input = NULL, last_react_output = NULL, updated_at = ? WHERE id = ?`,
time.Now(), conversationID,
)
if err != nil {
return nil, fmt.Errorf("clear react data: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit: %w", err)
}
db.logger.Info("conversation turn deleted",
zap.String("conversationId", conversationID),
zap.Strings("deletedMessageIds", deletedIDs),
zap.Int("count", len(deletedIDs)),
)
return deletedIDs, nil
}
// ProcessDetail 过程详情事件
type ProcessDetail struct {
ID string `json:"id"`
MessageID string `json:"messageId"`
ConversationID string `json:"conversationId"`
EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error
Message string `json:"message"`
Data string `json:"data"` // JSON格式的数据
CreatedAt time.Time `json:"createdAt"`
}
// AddProcessDetail 添加过程详情事件
func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message string, data interface{}) error {
id := uuid.New().String()
var dataJSON string
if data != nil {
jsonData, err := json.Marshal(data)
if err != nil {
db.logger.Warn("序列化过程详情数据失败", zap.Error(err))
} else {
dataJSON = string(jsonData)
}
}
_, err := db.Exec(
"INSERT INTO process_details (id, message_id, conversation_id, event_type, message, data, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, messageID, conversationID, eventType, message, dataJSON, time.Now(),
)
if err != nil {
return fmt.Errorf("添加过程详情失败: %w", err)
}
return nil
}
// GetProcessDetails 获取消息的过程详情
func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) {
rows, err := db.Query(
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC",
messageID,
)
if err != nil {
return nil, fmt.Errorf("查询过程详情失败: %w", err)
}
defer rows.Close()
var details []ProcessDetail
for rows.Next() {
var detail ProcessDetail
var createdAt string
if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil {
return nil, fmt.Errorf("扫描过程详情失败: %w", err)
}
// 尝试多种时间格式解析
var err error
detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err != nil {
detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err != nil {
detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
details = append(details, detail)
}
return details, nil
}
// GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组)
func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) {
rows, err := db.Query(
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE conversation_id = ? ORDER BY created_at ASC",
conversationID,
)
if err != nil {
return nil, fmt.Errorf("查询过程详情失败: %w", err)
}
defer rows.Close()
detailsMap := make(map[string][]ProcessDetail)
for rows.Next() {
var detail ProcessDetail
var createdAt string
if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil {
return nil, fmt.Errorf("扫描过程详情失败: %w", err)
}
// 尝试多种时间格式解析
var err error
detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err != nil {
detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err != nil {
detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
detailsMap[detail.MessageID] = append(detailsMap[detail.MessageID], detail)
}
return detailsMap, nil
}
@@ -1,39 +0,0 @@
package database
import (
"testing"
)
func TestTurnSliceRange(t *testing.T) {
mk := func(id, role string) Message {
return Message{ID: id, Role: role}
}
msgs := []Message{
mk("u1", "user"),
mk("a1", "assistant"),
mk("u2", "user"),
mk("a2", "assistant"),
}
cases := []struct {
anchor string
start int
end int
}{
{"u1", 0, 2},
{"a1", 0, 2},
{"u2", 2, 4},
{"a2", 2, 4},
}
for _, tc := range cases {
s, e, err := turnSliceRange(msgs, tc.anchor)
if err != nil {
t.Fatalf("anchor %s: %v", tc.anchor, err)
}
if s != tc.start || e != tc.end {
t.Fatalf("anchor %s: got [%d,%d) want [%d,%d)", tc.anchor, s, e, tc.start, tc.end)
}
}
if _, _, err := turnSliceRange(msgs, "nope"); err == nil {
t.Fatal("expected error for missing id")
}
}
-809
View File
@@ -1,809 +0,0 @@
package database
import (
"database/sql"
"fmt"
"strings"
_ "github.com/mattn/go-sqlite3"
"go.uber.org/zap"
)
// DB 数据库连接
type DB struct {
*sql.DB
logger *zap.Logger
}
// NewDB 创建数据库连接
func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
if err != nil {
return nil, fmt.Errorf("打开数据库失败: %w", err)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
}
database := &DB{
DB: db,
logger: logger,
}
// 初始化表
if err := database.initTables(); err != nil {
return nil, fmt.Errorf("初始化表失败: %w", err)
}
return database, nil
}
// initTables 初始化数据库表
func (db *DB) initTables() error {
// 创建对话表
createConversationsTable := `
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL,
last_react_input TEXT,
last_react_output TEXT
);`
// 创建消息表
createMessagesTable := `
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
mcp_execution_ids TEXT,
created_at DATETIME NOT NULL,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
);`
// 创建过程详情表
createProcessDetailsTable := `
CREATE TABLE IF NOT EXISTS process_details (
id TEXT PRIMARY KEY,
message_id TEXT NOT NULL,
conversation_id TEXT NOT NULL,
event_type TEXT NOT NULL,
message TEXT,
data TEXT,
created_at DATETIME NOT NULL,
FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
);`
// 创建工具执行记录表
createToolExecutionsTable := `
CREATE TABLE IF NOT EXISTS tool_executions (
id TEXT PRIMARY KEY,
tool_name TEXT NOT NULL,
arguments TEXT NOT NULL,
status TEXT NOT NULL,
result TEXT,
error TEXT,
start_time DATETIME NOT NULL,
end_time DATETIME,
duration_ms INTEGER,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);`
// 创建工具统计表
createToolStatsTable := `
CREATE TABLE IF NOT EXISTS tool_stats (
tool_name TEXT PRIMARY KEY,
total_calls INTEGER NOT NULL DEFAULT 0,
success_calls INTEGER NOT NULL DEFAULT 0,
failed_calls INTEGER NOT NULL DEFAULT 0,
last_call_time DATETIME,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);`
// 创建Skills统计表
createSkillStatsTable := `
CREATE TABLE IF NOT EXISTS skill_stats (
skill_name TEXT PRIMARY KEY,
total_calls INTEGER NOT NULL DEFAULT 0,
success_calls INTEGER NOT NULL DEFAULT 0,
failed_calls INTEGER NOT NULL DEFAULT 0,
last_call_time DATETIME,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);`
// 创建攻击链节点表
createAttackChainNodesTable := `
CREATE TABLE IF NOT EXISTS attack_chain_nodes (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
node_type TEXT NOT NULL,
node_name TEXT NOT NULL,
tool_execution_id TEXT,
metadata TEXT,
risk_score INTEGER DEFAULT 0,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE,
FOREIGN KEY (tool_execution_id) REFERENCES tool_executions(id) ON DELETE SET NULL
);`
// 创建攻击链边表
createAttackChainEdgesTable := `
CREATE TABLE IF NOT EXISTS attack_chain_edges (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
source_node_id TEXT NOT NULL,
target_node_id TEXT NOT NULL,
edge_type TEXT NOT NULL,
weight INTEGER DEFAULT 1,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE,
FOREIGN KEY (source_node_id) REFERENCES attack_chain_nodes(id) ON DELETE CASCADE,
FOREIGN KEY (target_node_id) REFERENCES attack_chain_nodes(id) ON DELETE CASCADE
);`
// 创建知识检索日志表(保留在会话数据库中,因为有外键关联)
createKnowledgeRetrievalLogsTable := `
CREATE TABLE IF NOT EXISTS knowledge_retrieval_logs (
id TEXT PRIMARY KEY,
conversation_id TEXT,
message_id TEXT,
query TEXT NOT NULL,
risk_type TEXT,
retrieved_items TEXT,
created_at DATETIME NOT NULL,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL,
FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE SET NULL
);`
// 创建对话分组表
createConversationGroupsTable := `
CREATE TABLE IF NOT EXISTS conversation_groups (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
icon TEXT,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL
);`
// 创建对话分组映射表
createConversationGroupMappingsTable := `
CREATE TABLE IF NOT EXISTS conversation_group_mappings (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
group_id TEXT NOT NULL,
created_at DATETIME NOT NULL,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE,
FOREIGN KEY (group_id) REFERENCES conversation_groups(id) ON DELETE CASCADE,
UNIQUE(conversation_id, group_id)
);`
// 创建漏洞表
createVulnerabilitiesTable := `
CREATE TABLE IF NOT EXISTS vulnerabilities (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
title TEXT NOT NULL,
description TEXT,
severity TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'open',
vulnerability_type TEXT,
target TEXT,
proof TEXT,
impact TEXT,
recommendation TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
);`
// 创建批量任务队列表
createBatchTaskQueuesTable := `
CREATE TABLE IF NOT EXISTS batch_task_queues (
id TEXT PRIMARY KEY,
title TEXT,
role TEXT,
agent_mode TEXT NOT NULL DEFAULT 'single',
schedule_mode TEXT NOT NULL DEFAULT 'manual',
cron_expr TEXT,
next_run_at DATETIME,
schedule_enabled INTEGER NOT NULL DEFAULT 1,
last_schedule_trigger_at DATETIME,
last_schedule_error TEXT,
last_run_error TEXT,
status TEXT NOT NULL,
created_at DATETIME NOT NULL,
started_at DATETIME,
completed_at DATETIME,
current_index INTEGER NOT NULL DEFAULT 0
);`
// 创建批量任务表
createBatchTasksTable := `
CREATE TABLE IF NOT EXISTS batch_tasks (
id TEXT PRIMARY KEY,
queue_id TEXT NOT NULL,
message TEXT NOT NULL,
conversation_id TEXT,
status TEXT NOT NULL,
started_at DATETIME,
completed_at DATETIME,
error TEXT,
result TEXT,
FOREIGN KEY (queue_id) REFERENCES batch_task_queues(id) ON DELETE CASCADE
);`
// 创建 WebShell 连接表
createWebshellConnectionsTable := `
CREATE TABLE IF NOT EXISTS webshell_connections (
id TEXT PRIMARY KEY,
url TEXT NOT NULL,
password TEXT NOT NULL DEFAULT '',
type TEXT NOT NULL DEFAULT 'php',
method TEXT NOT NULL DEFAULT 'post',
cmd_param TEXT NOT NULL DEFAULT '',
remark TEXT NOT NULL DEFAULT '',
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);`
// 创建 WebShell 连接扩展状态表(前端工作区/终端状态持久化)
createWebshellConnectionStatesTable := `
CREATE TABLE IF NOT EXISTS webshell_connection_states (
connection_id TEXT PRIMARY KEY,
state_json TEXT NOT NULL DEFAULT '{}',
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (connection_id) REFERENCES webshell_connections(id) ON DELETE CASCADE
);`
// 创建索引
createIndexes := `
CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at);
CREATE INDEX IF NOT EXISTS idx_process_details_message_id ON process_details(message_id);
CREATE INDEX IF NOT EXISTS idx_process_details_conversation_id ON process_details(conversation_id);
CREATE INDEX IF NOT EXISTS idx_tool_executions_tool_name ON tool_executions(tool_name);
CREATE INDEX IF NOT EXISTS idx_tool_executions_start_time ON tool_executions(start_time);
CREATE INDEX IF NOT EXISTS idx_tool_executions_status ON tool_executions(status);
CREATE INDEX IF NOT EXISTS idx_chain_nodes_conversation ON attack_chain_nodes(conversation_id);
CREATE INDEX IF NOT EXISTS idx_chain_edges_conversation ON attack_chain_edges(conversation_id);
CREATE INDEX IF NOT EXISTS idx_chain_edges_source ON attack_chain_edges(source_node_id);
CREATE INDEX IF NOT EXISTS idx_chain_edges_target ON attack_chain_edges(target_node_id);
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_conversation ON knowledge_retrieval_logs(conversation_id);
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_message ON knowledge_retrieval_logs(message_id);
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at);
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_conversation ON conversation_group_mappings(conversation_id);
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id);
CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at);
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title);
CREATE INDEX IF NOT EXISTS idx_webshell_connections_created_at ON webshell_connections(created_at);
CREATE INDEX IF NOT EXISTS idx_webshell_connection_states_updated_at ON webshell_connection_states(updated_at);
`
if _, err := db.Exec(createConversationsTable); err != nil {
return fmt.Errorf("创建conversations表失败: %w", err)
}
if _, err := db.Exec(createMessagesTable); err != nil {
return fmt.Errorf("创建messages表失败: %w", err)
}
if _, err := db.Exec(createProcessDetailsTable); err != nil {
return fmt.Errorf("创建process_details表失败: %w", err)
}
if _, err := db.Exec(createToolExecutionsTable); err != nil {
return fmt.Errorf("创建tool_executions表失败: %w", err)
}
if _, err := db.Exec(createToolStatsTable); err != nil {
return fmt.Errorf("创建tool_stats表失败: %w", err)
}
if _, err := db.Exec(createSkillStatsTable); err != nil {
return fmt.Errorf("创建skill_stats表失败: %w", err)
}
if _, err := db.Exec(createAttackChainNodesTable); err != nil {
return fmt.Errorf("创建attack_chain_nodes表失败: %w", err)
}
if _, err := db.Exec(createAttackChainEdgesTable); err != nil {
return fmt.Errorf("创建attack_chain_edges表失败: %w", err)
}
if _, err := db.Exec(createKnowledgeRetrievalLogsTable); err != nil {
return fmt.Errorf("创建knowledge_retrieval_logs表失败: %w", err)
}
if _, err := db.Exec(createConversationGroupsTable); err != nil {
return fmt.Errorf("创建conversation_groups表失败: %w", err)
}
if _, err := db.Exec(createConversationGroupMappingsTable); err != nil {
return fmt.Errorf("创建conversation_group_mappings表失败: %w", err)
}
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
return fmt.Errorf("创建vulnerabilities表失败: %w", err)
}
if _, err := db.Exec(createBatchTaskQueuesTable); err != nil {
return fmt.Errorf("创建batch_task_queues表失败: %w", err)
}
if _, err := db.Exec(createBatchTasksTable); err != nil {
return fmt.Errorf("创建batch_tasks表失败: %w", err)
}
if _, err := db.Exec(createWebshellConnectionsTable); err != nil {
return fmt.Errorf("创建webshell_connections表失败: %w", err)
}
if _, err := db.Exec(createWebshellConnectionStatesTable); err != nil {
return fmt.Errorf("创建webshell_connection_states表失败: %w", err)
}
// 为已有表添加新字段(如果不存在)- 必须在创建索引之前
if err := db.migrateConversationsTable(); err != nil {
db.logger.Warn("迁移conversations表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if err := db.migrateConversationGroupsTable(); err != nil {
db.logger.Warn("迁移conversation_groups表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if err := db.migrateConversationGroupMappingsTable(); err != nil {
db.logger.Warn("迁移conversation_group_mappings表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if err := db.migrateBatchTaskQueuesTable(); err != nil {
db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if _, err := db.Exec(createIndexes); err != nil {
return fmt.Errorf("创建索引失败: %w", err)
}
db.logger.Info("数据库表初始化完成")
return nil
}
// migrateConversationsTable 迁移conversations表,添加新字段
func (db *DB) migrateConversationsTable() error {
// 检查last_react_input字段是否存在
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_input'").Scan(&count)
if err != nil {
// 如果查询失败,尝试添加字段
if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); addErr != nil {
// 如果字段已存在,忽略错误(SQLite错误信息可能不同)
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加last_react_input字段失败", zap.Error(addErr))
}
}
} else if count == 0 {
// 字段不存在,添加它
if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); err != nil {
db.logger.Warn("添加last_react_input字段失败", zap.Error(err))
}
}
// 检查last_react_output字段是否存在
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_output'").Scan(&count)
if err != nil {
// 如果查询失败,尝试添加字段
if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); addErr != nil {
// 如果字段已存在,忽略错误
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加last_react_output字段失败", zap.Error(addErr))
}
}
} else if count == 0 {
// 字段不存在,添加它
if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); err != nil {
db.logger.Warn("添加last_react_output字段失败", zap.Error(err))
}
}
// 检查pinned字段是否存在
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='pinned'").Scan(&count)
if err != nil {
// 如果查询失败,尝试添加字段
if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN pinned INTEGER DEFAULT 0"); addErr != nil {
// 如果字段已存在,忽略错误
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加pinned字段失败", zap.Error(addErr))
}
}
} else if count == 0 {
// 字段不存在,添加它
if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN pinned INTEGER DEFAULT 0"); err != nil {
db.logger.Warn("添加pinned字段失败", zap.Error(err))
}
}
// 检查 webshell_connection_id 字段是否存在(WebShell AI 助手对话关联)
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='webshell_connection_id'").Scan(&count)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN webshell_connection_id TEXT"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加webshell_connection_id字段失败", zap.Error(addErr))
}
}
} else if count == 0 {
if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN webshell_connection_id TEXT"); err != nil {
db.logger.Warn("添加webshell_connection_id字段失败", zap.Error(err))
}
}
return nil
}
// migrateConversationGroupsTable 迁移conversation_groups表,添加新字段
func (db *DB) migrateConversationGroupsTable() error {
// 检查pinned字段是否存在
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversation_groups') WHERE name='pinned'").Scan(&count)
if err != nil {
// 如果查询失败,尝试添加字段
if _, addErr := db.Exec("ALTER TABLE conversation_groups ADD COLUMN pinned INTEGER DEFAULT 0"); addErr != nil {
// 如果字段已存在,忽略错误
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加pinned字段失败", zap.Error(addErr))
}
}
} else if count == 0 {
// 字段不存在,添加它
if _, err := db.Exec("ALTER TABLE conversation_groups ADD COLUMN pinned INTEGER DEFAULT 0"); err != nil {
db.logger.Warn("添加pinned字段失败", zap.Error(err))
}
}
return nil
}
// migrateConversationGroupMappingsTable 迁移conversation_group_mappings表,添加新字段
func (db *DB) migrateConversationGroupMappingsTable() error {
// 检查pinned字段是否存在
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversation_group_mappings') WHERE name='pinned'").Scan(&count)
if err != nil {
// 如果查询失败,尝试添加字段
if _, addErr := db.Exec("ALTER TABLE conversation_group_mappings ADD COLUMN pinned INTEGER DEFAULT 0"); addErr != nil {
// 如果字段已存在,忽略错误
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加pinned字段失败", zap.Error(addErr))
}
}
} else if count == 0 {
// 字段不存在,添加它
if _, err := db.Exec("ALTER TABLE conversation_group_mappings ADD COLUMN pinned INTEGER DEFAULT 0"); err != nil {
db.logger.Warn("添加pinned字段失败", zap.Error(err))
}
}
return nil
}
// migrateBatchTaskQueuesTable 迁移batch_task_queues表,补充新字段
func (db *DB) migrateBatchTaskQueuesTable() error {
// 检查title字段是否存在
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='title'").Scan(&count)
if err != nil {
// 如果查询失败,尝试添加字段
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); addErr != nil {
// 如果字段已存在,忽略错误
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加title字段失败", zap.Error(addErr))
}
}
} else if count == 0 {
// 字段不存在,添加它
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); err != nil {
db.logger.Warn("添加title字段失败", zap.Error(err))
}
}
// 检查role字段是否存在
var roleCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='role'").Scan(&roleCount)
if err != nil {
// 如果查询失败,尝试添加字段
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN role TEXT"); addErr != nil {
// 如果字段已存在,忽略错误
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加role字段失败", zap.Error(addErr))
}
}
} else if roleCount == 0 {
// 字段不存在,添加它
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN role TEXT"); err != nil {
db.logger.Warn("添加role字段失败", zap.Error(err))
}
}
// 检查agent_mode字段是否存在
var agentModeCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='agent_mode'").Scan(&agentModeCount)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加agent_mode字段失败", zap.Error(addErr))
}
}
} else if agentModeCount == 0 {
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); err != nil {
db.logger.Warn("添加agent_mode字段失败", zap.Error(err))
}
}
// 检查schedule_mode字段是否存在
var scheduleModeCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='schedule_mode'").Scan(&scheduleModeCount)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_mode TEXT NOT NULL DEFAULT 'manual'"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加schedule_mode字段失败", zap.Error(addErr))
}
}
} else if scheduleModeCount == 0 {
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_mode TEXT NOT NULL DEFAULT 'manual'"); err != nil {
db.logger.Warn("添加schedule_mode字段失败", zap.Error(err))
}
}
// 检查cron_expr字段是否存在
var cronExprCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='cron_expr'").Scan(&cronExprCount)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN cron_expr TEXT"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加cron_expr字段失败", zap.Error(addErr))
}
}
} else if cronExprCount == 0 {
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN cron_expr TEXT"); err != nil {
db.logger.Warn("添加cron_expr字段失败", zap.Error(err))
}
}
// 检查next_run_at字段是否存在
var nextRunAtCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='next_run_at'").Scan(&nextRunAtCount)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN next_run_at DATETIME"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加next_run_at字段失败", zap.Error(addErr))
}
}
} else if nextRunAtCount == 0 {
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN next_run_at DATETIME"); err != nil {
db.logger.Warn("添加next_run_at字段失败", zap.Error(err))
}
}
// schedule_enabled0=暂停 Cron 自动调度,1=允许(手工执行不受影响)
var scheduleEnCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='schedule_enabled'").Scan(&scheduleEnCount)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_enabled INTEGER NOT NULL DEFAULT 1"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加schedule_enabled字段失败", zap.Error(addErr))
}
}
} else if scheduleEnCount == 0 {
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_enabled INTEGER NOT NULL DEFAULT 1"); err != nil {
db.logger.Warn("添加schedule_enabled字段失败", zap.Error(err))
}
}
var lastTrigCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_schedule_trigger_at'").Scan(&lastTrigCount)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_trigger_at DATETIME"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加last_schedule_trigger_at字段失败", zap.Error(addErr))
}
}
} else if lastTrigCount == 0 {
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_trigger_at DATETIME"); err != nil {
db.logger.Warn("添加last_schedule_trigger_at字段失败", zap.Error(err))
}
}
var lastSchedErrCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_schedule_error'").Scan(&lastSchedErrCount)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_error TEXT"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加last_schedule_error字段失败", zap.Error(addErr))
}
}
} else if lastSchedErrCount == 0 {
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_error TEXT"); err != nil {
db.logger.Warn("添加last_schedule_error字段失败", zap.Error(err))
}
}
var lastRunErrCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_run_error'").Scan(&lastRunErrCount)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_run_error TEXT"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加last_run_error字段失败", zap.Error(addErr))
}
}
} else if lastRunErrCount == 0 {
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_run_error TEXT"); err != nil {
db.logger.Warn("添加last_run_error字段失败", zap.Error(err))
}
}
return nil
}
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
if err != nil {
return nil, fmt.Errorf("打开知识库数据库失败: %w", err)
}
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("连接知识库数据库失败: %w", err)
}
database := &DB{
DB: sqlDB,
logger: logger,
}
// 初始化知识库表
if err := database.initKnowledgeTables(); err != nil {
return nil, fmt.Errorf("初始化知识库表失败: %w", err)
}
return database, nil
}
// initKnowledgeTables 初始化知识库数据库表(只包含知识库相关的表)
func (db *DB) initKnowledgeTables() error {
// 创建知识库项表
createKnowledgeBaseItemsTable := `
CREATE TABLE IF NOT EXISTS knowledge_base_items (
id TEXT PRIMARY KEY,
category TEXT NOT NULL,
title TEXT NOT NULL,
file_path TEXT NOT NULL,
content TEXT,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL
);`
// 创建知识库向量表
createKnowledgeEmbeddingsTable := `
CREATE TABLE IF NOT EXISTS knowledge_embeddings (
id TEXT PRIMARY KEY,
item_id TEXT NOT NULL,
chunk_index INTEGER NOT NULL,
chunk_text TEXT NOT NULL,
embedding TEXT NOT NULL,
sub_indexes TEXT NOT NULL DEFAULT '',
embedding_model TEXT NOT NULL DEFAULT '',
embedding_dim INTEGER NOT NULL DEFAULT 0,
created_at DATETIME NOT NULL,
FOREIGN KEY (item_id) REFERENCES knowledge_base_items(id) ON DELETE CASCADE
);`
// 创建知识检索日志表(在独立知识库数据库中,不使用外键约束,因为conversations和messages表可能不在这个数据库中)
createKnowledgeRetrievalLogsTable := `
CREATE TABLE IF NOT EXISTS knowledge_retrieval_logs (
id TEXT PRIMARY KEY,
conversation_id TEXT,
message_id TEXT,
query TEXT NOT NULL,
risk_type TEXT,
retrieved_items TEXT,
created_at DATETIME NOT NULL
);`
// 创建索引
createIndexes := `
CREATE INDEX IF NOT EXISTS idx_knowledge_items_category ON knowledge_base_items(category);
CREATE INDEX IF NOT EXISTS idx_knowledge_embeddings_item_id ON knowledge_embeddings(item_id);
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_conversation ON knowledge_retrieval_logs(conversation_id);
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_message ON knowledge_retrieval_logs(message_id);
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at);
`
if _, err := db.Exec(createKnowledgeBaseItemsTable); err != nil {
return fmt.Errorf("创建knowledge_base_items表失败: %w", err)
}
if _, err := db.Exec(createKnowledgeEmbeddingsTable); err != nil {
return fmt.Errorf("创建knowledge_embeddings表失败: %w", err)
}
if _, err := db.Exec(createKnowledgeRetrievalLogsTable); err != nil {
return fmt.Errorf("创建knowledge_retrieval_logs表失败: %w", err)
}
if _, err := db.Exec(createIndexes); err != nil {
return fmt.Errorf("创建索引失败: %w", err)
}
if err := db.migrateKnowledgeEmbeddingsColumns(); err != nil {
return fmt.Errorf("迁移 knowledge_embeddings 列失败: %w", err)
}
db.logger.Info("知识库数据库表初始化完成")
return nil
}
// migrateKnowledgeEmbeddingsColumns 为已有库补充 sub_indexes、embedding_model、embedding_dim。
func (db *DB) migrateKnowledgeEmbeddingsColumns() error {
var n int
if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil {
return err
}
if n == 0 {
return nil
}
migrations := []struct {
col string
stmt string
}{
{"sub_indexes", `ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`},
{"embedding_model", `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`},
{"embedding_dim", `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`},
}
for _, m := range migrations {
var colCount int
q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?`
if err := db.QueryRow(q, m.col).Scan(&colCount); err != nil {
return err
}
if colCount > 0 {
continue
}
if _, err := db.Exec(m.stmt); err != nil {
return err
}
}
return nil
}
// Close 关闭数据库连接
func (db *DB) Close() error {
return db.DB.Close()
}
-449
View File
@@ -1,449 +0,0 @@
package database
import (
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
)
// ConversationGroup 对话分组
type ConversationGroup struct {
ID string `json:"id"`
Name string `json:"name"`
Icon string `json:"icon"`
Pinned bool `json:"pinned"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// GroupExistsByName 检查分组名称是否已存在
func (db *DB) GroupExistsByName(name string, excludeID string) (bool, error) {
var count int
var err error
if excludeID != "" {
err = db.QueryRow(
"SELECT COUNT(*) FROM conversation_groups WHERE name = ? AND id != ?",
name, excludeID,
).Scan(&count)
} else {
err = db.QueryRow(
"SELECT COUNT(*) FROM conversation_groups WHERE name = ?",
name,
).Scan(&count)
}
if err != nil {
return false, fmt.Errorf("检查分组名称失败: %w", err)
}
return count > 0, nil
}
// CreateGroup 创建分组
func (db *DB) CreateGroup(name, icon string) (*ConversationGroup, error) {
// 检查名称是否已存在
exists, err := db.GroupExistsByName(name, "")
if err != nil {
return nil, err
}
if exists {
return nil, fmt.Errorf("分组名称已存在")
}
id := uuid.New().String()
now := time.Now()
if icon == "" {
icon = "📁"
}
_, err = db.Exec(
"INSERT INTO conversation_groups (id, name, icon, pinned, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
id, name, icon, 0, now, now,
)
if err != nil {
return nil, fmt.Errorf("创建分组失败: %w", err)
}
return &ConversationGroup{
ID: id,
Name: name,
Icon: icon,
Pinned: false,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// ListGroups 列出所有分组
func (db *DB) ListGroups() ([]*ConversationGroup, error) {
rows, err := db.Query(
"SELECT id, name, icon, COALESCE(pinned, 0), created_at, updated_at FROM conversation_groups ORDER BY COALESCE(pinned, 0) DESC, created_at ASC",
)
if err != nil {
return nil, fmt.Errorf("查询分组列表失败: %w", err)
}
defer rows.Close()
var groups []*ConversationGroup
for rows.Next() {
var group ConversationGroup
var createdAt, updatedAt string
var pinned int
if err := rows.Scan(&group.ID, &group.Name, &group.Icon, &pinned, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描分组失败: %w", err)
}
group.Pinned = pinned != 0
// 尝试多种时间格式解析
var err1, err2 error
group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
group.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
group.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
groups = append(groups, &group)
}
return groups, nil
}
// GetGroup 获取分组
func (db *DB) GetGroup(id string) (*ConversationGroup, error) {
var group ConversationGroup
var createdAt, updatedAt string
var pinned int
err := db.QueryRow(
"SELECT id, name, icon, COALESCE(pinned, 0), created_at, updated_at FROM conversation_groups WHERE id = ?",
id,
).Scan(&group.ID, &group.Name, &group.Icon, &pinned, &createdAt, &updatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("分组不存在")
}
return nil, fmt.Errorf("查询分组失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
group.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
group.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
group.Pinned = pinned != 0
return &group, nil
}
// UpdateGroup 更新分组
func (db *DB) UpdateGroup(id, name, icon string) error {
// 检查名称是否已存在(排除当前分组)
exists, err := db.GroupExistsByName(name, id)
if err != nil {
return err
}
if exists {
return fmt.Errorf("分组名称已存在")
}
_, err = db.Exec(
"UPDATE conversation_groups SET name = ?, icon = ?, updated_at = ? WHERE id = ?",
name, icon, time.Now(), id,
)
if err != nil {
return fmt.Errorf("更新分组失败: %w", err)
}
return nil
}
// DeleteGroup 删除分组
func (db *DB) DeleteGroup(id string) error {
_, err := db.Exec("DELETE FROM conversation_groups WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除分组失败: %w", err)
}
return nil
}
// AddConversationToGroup 将对话添加到分组
// 注意:一个对话只能属于一个分组,所以在添加新分组之前,会先删除该对话的所有旧分组关联
func (db *DB) AddConversationToGroup(conversationID, groupID string) error {
// 先删除该对话的所有旧分组关联,确保一个对话只属于一个分组
_, err := db.Exec(
"DELETE FROM conversation_group_mappings WHERE conversation_id = ?",
conversationID,
)
if err != nil {
return fmt.Errorf("删除对话旧分组关联失败: %w", err)
}
// 然后插入新的分组关联
id := uuid.New().String()
_, err = db.Exec(
"INSERT INTO conversation_group_mappings (id, conversation_id, group_id, created_at) VALUES (?, ?, ?, ?)",
id, conversationID, groupID, time.Now(),
)
if err != nil {
return fmt.Errorf("添加对话到分组失败: %w", err)
}
return nil
}
// RemoveConversationFromGroup 从分组中移除对话
func (db *DB) RemoveConversationFromGroup(conversationID, groupID string) error {
_, err := db.Exec(
"DELETE FROM conversation_group_mappings WHERE conversation_id = ? AND group_id = ?",
conversationID, groupID,
)
if err != nil {
return fmt.Errorf("从分组中移除对话失败: %w", err)
}
return nil
}
// GetConversationsByGroup 获取分组中的所有对话
func (db *DB) GetConversationsByGroup(groupID string) ([]*Conversation, error) {
rows, err := db.Query(
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned
FROM conversations c
INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id
WHERE cgm.group_id = ?
ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC`,
groupID,
)
if err != nil {
return nil, fmt.Errorf("查询分组对话失败: %w", err)
}
defer rows.Close()
var conversations []*Conversation
for rows.Next() {
var conv Conversation
var createdAt, updatedAt string
var pinned int
var groupPinned int
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil {
return nil, fmt.Errorf("扫描对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
conv.Pinned = pinned != 0
conversations = append(conversations, &conv)
}
return conversations, nil
}
// SearchConversationsByGroup 搜索分组中的对话(按标题和消息内容模糊匹配)
func (db *DB) SearchConversationsByGroup(groupID string, searchQuery string) ([]*Conversation, error) {
// 构建SQL查询,支持按标题和消息内容搜索
// 使用 DISTINCT 避免因为一个对话有多条匹配消息而重复
query := `SELECT DISTINCT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned
FROM conversations c
INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id
WHERE cgm.group_id = ?`
args := []interface{}{groupID}
// 如果有搜索关键词,添加标题和消息内容搜索条件
if searchQuery != "" {
searchPattern := "%" + searchQuery + "%"
// 搜索标题或消息内容
// 使用 LEFT JOIN 连接消息表,这样即使没有消息的对话也能被搜索到(通过标题)
query += ` AND (
LOWER(c.title) LIKE LOWER(?)
OR EXISTS (
SELECT 1 FROM messages m
WHERE m.conversation_id = c.id
AND LOWER(m.content) LIKE LOWER(?)
)
)`
args = append(args, searchPattern, searchPattern)
}
query += " ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC"
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("搜索分组对话失败: %w", err)
}
defer rows.Close()
var conversations []*Conversation
for rows.Next() {
var conv Conversation
var createdAt, updatedAt string
var pinned int
var groupPinned int
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil {
return nil, fmt.Errorf("扫描对话失败: %w", err)
}
// 尝试多种时间格式解析
var err1, err2 error
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err1 != nil {
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err1 != nil {
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
if err2 != nil {
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
}
if err2 != nil {
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
conv.Pinned = pinned != 0
conversations = append(conversations, &conv)
}
return conversations, nil
}
// GetGroupByConversation 获取对话所属的分组
func (db *DB) GetGroupByConversation(conversationID string) (string, error) {
var groupID string
err := db.QueryRow(
"SELECT group_id FROM conversation_group_mappings WHERE conversation_id = ? LIMIT 1",
conversationID,
).Scan(&groupID)
if err != nil {
if err == sql.ErrNoRows {
return "", nil // 没有分组
}
return "", fmt.Errorf("查询对话分组失败: %w", err)
}
return groupID, nil
}
// UpdateConversationPinned 更新对话置顶状态
func (db *DB) UpdateConversationPinned(id string, pinned bool) error {
pinnedValue := 0
if pinned {
pinnedValue = 1
}
// 注意:不更新 updated_at,因为置顶操作不应该改变对话的更新时间
_, err := db.Exec(
"UPDATE conversations SET pinned = ? WHERE id = ?",
pinnedValue, id,
)
if err != nil {
return fmt.Errorf("更新对话置顶状态失败: %w", err)
}
return nil
}
// UpdateGroupPinned 更新分组置顶状态
func (db *DB) UpdateGroupPinned(id string, pinned bool) error {
pinnedValue := 0
if pinned {
pinnedValue = 1
}
_, err := db.Exec(
"UPDATE conversation_groups SET pinned = ?, updated_at = ? WHERE id = ?",
pinnedValue, time.Now(), id,
)
if err != nil {
return fmt.Errorf("更新分组置顶状态失败: %w", err)
}
return nil
}
// GroupMapping 分组映射关系
type GroupMapping struct {
ConversationID string `json:"conversationId"`
GroupID string `json:"groupId"`
}
// GetAllGroupMappings 批量获取所有分组映射(消除 N+1 查询)
func (db *DB) GetAllGroupMappings() ([]GroupMapping, error) {
rows, err := db.Query("SELECT conversation_id, group_id FROM conversation_group_mappings")
if err != nil {
return nil, fmt.Errorf("查询分组映射失败: %w", err)
}
defer rows.Close()
var mappings []GroupMapping
for rows.Next() {
var m GroupMapping
if err := rows.Scan(&m.ConversationID, &m.GroupID); err != nil {
return nil, fmt.Errorf("扫描分组映射失败: %w", err)
}
mappings = append(mappings, m)
}
if mappings == nil {
mappings = []GroupMapping{}
}
return mappings, nil
}
// UpdateConversationPinnedInGroup 更新对话在分组中的置顶状态
func (db *DB) UpdateConversationPinnedInGroup(conversationID, groupID string, pinned bool) error {
pinnedValue := 0
if pinned {
pinnedValue = 1
}
_, err := db.Exec(
"UPDATE conversation_group_mappings SET pinned = ? WHERE conversation_id = ? AND group_id = ?",
pinnedValue, conversationID, groupID,
)
if err != nil {
return fmt.Errorf("更新分组对话置顶状态失败: %w", err)
}
return nil
}
-537
View File
@@ -1,537 +0,0 @@
package database
import (
"database/sql"
"encoding/json"
"strings"
"time"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
// SaveToolExecution 保存工具执行记录
func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
argsJSON, err := json.Marshal(exec.Arguments)
if err != nil {
db.logger.Warn("序列化执行参数失败", zap.Error(err))
argsJSON = []byte("{}")
}
var resultJSON sql.NullString
if exec.Result != nil {
resultBytes, err := json.Marshal(exec.Result)
if err != nil {
db.logger.Warn("序列化执行结果失败", zap.Error(err))
} else {
resultJSON = sql.NullString{String: string(resultBytes), Valid: true}
}
}
var errorText sql.NullString
if exec.Error != "" {
errorText = sql.NullString{String: exec.Error, Valid: true}
}
var endTime sql.NullTime
if exec.EndTime != nil {
endTime = sql.NullTime{Time: *exec.EndTime, Valid: true}
}
var durationMs sql.NullInt64
if exec.Duration > 0 {
durationMs = sql.NullInt64{Int64: exec.Duration.Milliseconds(), Valid: true}
}
query := `
INSERT OR REPLACE INTO tool_executions
(id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err = db.Exec(query,
exec.ID,
exec.ToolName,
string(argsJSON),
exec.Status,
resultJSON,
errorText,
exec.StartTime,
endTime,
durationMs,
time.Now(),
)
if err != nil {
db.logger.Error("保存工具执行记录失败", zap.Error(err), zap.String("executionId", exec.ID))
return err
}
return nil
}
// CountToolExecutions 统计工具执行记录总数
func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
query := `SELECT COUNT(*) FROM tool_executions`
args := []interface{}{}
conditions := []string{}
if status != "" {
conditions = append(conditions, "status = ?")
args = append(args, status)
}
if toolName != "" {
// 支持部分匹配(模糊搜索),不区分大小写
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
args = append(args, "%"+strings.ToLower(toolName)+"%")
}
if len(conditions) > 0 {
query += ` WHERE ` + conditions[0]
for i := 1; i < len(conditions); i++ {
query += ` AND ` + conditions[i]
}
}
var count int
err := db.QueryRow(query, args...).Scan(&count)
if err != nil {
return 0, err
}
return count, nil
}
// LoadToolExecutions 加载所有工具执行记录(支持分页)
func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) {
return db.LoadToolExecutionsWithPagination(0, 1000, "", "")
}
// LoadToolExecutionsWithPagination 分页加载工具执行记录
// limit: 最大返回记录数,0 表示使用默认值 1000
// offset: 跳过的记录数,用于分页
// status: 状态筛选,空字符串表示不过滤
// toolName: 工具名称筛选,空字符串表示不过滤
func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status, toolName string) ([]*mcp.ToolExecution, error) {
if limit <= 0 {
limit = 1000 // 默认限制
}
if limit > 10000 {
limit = 10000 // 最大限制,防止一次性加载过多数据
}
query := `
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
FROM tool_executions
`
args := []interface{}{}
conditions := []string{}
if status != "" {
conditions = append(conditions, "status = ?")
args = append(args, status)
}
if toolName != "" {
// 支持部分匹配(模糊搜索),不区分大小写
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
args = append(args, "%"+strings.ToLower(toolName)+"%")
}
if len(conditions) > 0 {
query += ` WHERE ` + conditions[0]
for i := 1; i < len(conditions); i++ {
query += ` AND ` + conditions[i]
}
}
query += ` ORDER BY start_time DESC LIMIT ? OFFSET ?`
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var executions []*mcp.ToolExecution
for rows.Next() {
var exec mcp.ToolExecution
var argsJSON string
var resultJSON sql.NullString
var errorText sql.NullString
var endTime sql.NullTime
var durationMs sql.NullInt64
err := rows.Scan(
&exec.ID,
&exec.ToolName,
&argsJSON,
&exec.Status,
&resultJSON,
&errorText,
&exec.StartTime,
&endTime,
&durationMs,
)
if err != nil {
db.logger.Warn("加载执行记录失败", zap.Error(err))
continue
}
// 解析参数
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
db.logger.Warn("解析执行参数失败", zap.Error(err))
exec.Arguments = make(map[string]interface{})
}
// 解析结果
if resultJSON.Valid && resultJSON.String != "" {
var result mcp.ToolResult
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
db.logger.Warn("解析执行结果失败", zap.Error(err))
} else {
exec.Result = &result
}
}
// 设置错误
if errorText.Valid {
exec.Error = errorText.String
}
// 设置结束时间
if endTime.Valid {
exec.EndTime = &endTime.Time
}
// 设置持续时间
if durationMs.Valid {
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
}
executions = append(executions, &exec)
}
return executions, nil
}
// GetToolExecution 根据ID获取单条工具执行记录
func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
query := `
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
FROM tool_executions
WHERE id = ?
`
row := db.QueryRow(query, id)
var exec mcp.ToolExecution
var argsJSON string
var resultJSON sql.NullString
var errorText sql.NullString
var endTime sql.NullTime
var durationMs sql.NullInt64
err := row.Scan(
&exec.ID,
&exec.ToolName,
&argsJSON,
&exec.Status,
&resultJSON,
&errorText,
&exec.StartTime,
&endTime,
&durationMs,
)
if err != nil {
return nil, err
}
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
db.logger.Warn("解析执行参数失败", zap.Error(err))
exec.Arguments = make(map[string]interface{})
}
if resultJSON.Valid && resultJSON.String != "" {
var result mcp.ToolResult
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
db.logger.Warn("解析执行结果失败", zap.Error(err))
} else {
exec.Result = &result
}
}
if errorText.Valid {
exec.Error = errorText.String
}
if endTime.Valid {
exec.EndTime = &endTime.Time
}
if durationMs.Valid {
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
}
return &exec, nil
}
// DeleteToolExecution 删除工具执行记录
func (db *DB) DeleteToolExecution(id string) error {
query := `DELETE FROM tool_executions WHERE id = ?`
_, err := db.Exec(query, id)
if err != nil {
db.logger.Error("删除工具执行记录失败", zap.Error(err), zap.String("executionId", id))
return err
}
return nil
}
// DeleteToolExecutions 批量删除工具执行记录
func (db *DB) DeleteToolExecutions(ids []string) error {
if len(ids) == 0 {
return nil
}
// 构建 IN 查询的占位符
placeholders := make([]string, len(ids))
args := make([]interface{}, len(ids))
for i, id := range ids {
placeholders[i] = "?"
args[i] = id
}
query := `DELETE FROM tool_executions WHERE id IN (` + strings.Join(placeholders, ",") + `)`
_, err := db.Exec(query, args...)
if err != nil {
db.logger.Error("批量删除工具执行记录失败", zap.Error(err), zap.Int("count", len(ids)))
return err
}
return nil
}
// GetToolExecutionsByIds 根据ID列表获取工具执行记录(用于批量删除前获取统计信息)
func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error) {
if len(ids) == 0 {
return []*mcp.ToolExecution{}, nil
}
// 构建 IN 查询的占位符
placeholders := make([]string, len(ids))
args := make([]interface{}, len(ids))
for i, id := range ids {
placeholders[i] = "?"
args[i] = id
}
query := `
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
FROM tool_executions
WHERE id IN (` + strings.Join(placeholders, ",") + `)
`
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var executions []*mcp.ToolExecution
for rows.Next() {
var exec mcp.ToolExecution
var argsJSON string
var resultJSON sql.NullString
var errorText sql.NullString
var endTime sql.NullTime
var durationMs sql.NullInt64
err := rows.Scan(
&exec.ID,
&exec.ToolName,
&argsJSON,
&exec.Status,
&resultJSON,
&errorText,
&exec.StartTime,
&endTime,
&durationMs,
)
if err != nil {
db.logger.Warn("加载执行记录失败", zap.Error(err))
continue
}
// 解析参数
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
db.logger.Warn("解析执行参数失败", zap.Error(err))
exec.Arguments = make(map[string]interface{})
}
// 解析结果
if resultJSON.Valid && resultJSON.String != "" {
var result mcp.ToolResult
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
db.logger.Warn("解析执行结果失败", zap.Error(err))
} else {
exec.Result = &result
}
}
// 设置错误
if errorText.Valid {
exec.Error = errorText.String
}
// 设置结束时间
if endTime.Valid {
exec.EndTime = &endTime.Time
}
// 设置持续时间
if durationMs.Valid {
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
}
executions = append(executions, &exec)
}
return executions, nil
}
// SaveToolStats 保存工具统计信息
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
var lastCallTime sql.NullTime
if stats.LastCallTime != nil {
lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true}
}
query := `
INSERT OR REPLACE INTO tool_stats
(tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(query,
toolName,
stats.TotalCalls,
stats.SuccessCalls,
stats.FailedCalls,
lastCallTime,
time.Now(),
)
if err != nil {
db.logger.Error("保存工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
return err
}
return nil
}
// LoadToolStats 加载所有工具统计信息
func (db *DB) LoadToolStats() (map[string]*mcp.ToolStats, error) {
query := `
SELECT tool_name, total_calls, success_calls, failed_calls, last_call_time
FROM tool_stats
`
rows, err := db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
stats := make(map[string]*mcp.ToolStats)
for rows.Next() {
var stat mcp.ToolStats
var lastCallTime sql.NullTime
err := rows.Scan(
&stat.ToolName,
&stat.TotalCalls,
&stat.SuccessCalls,
&stat.FailedCalls,
&lastCallTime,
)
if err != nil {
db.logger.Warn("加载统计信息失败", zap.Error(err))
continue
}
if lastCallTime.Valid {
stat.LastCallTime = &lastCallTime.Time
}
stats[stat.ToolName] = &stat
}
return stats, nil
}
// UpdateToolStats 更新工具统计信息(累加模式)
func (db *DB) UpdateToolStats(toolName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
var lastCallTimeSQL sql.NullTime
if lastCallTime != nil {
lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true}
}
query := `
INSERT INTO tool_stats (tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(tool_name) DO UPDATE SET
total_calls = total_calls + ?,
success_calls = success_calls + ?,
failed_calls = failed_calls + ?,
last_call_time = COALESCE(?, last_call_time),
updated_at = ?
`
_, err := db.Exec(query,
toolName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
)
if err != nil {
db.logger.Error("更新工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
return err
}
return nil
}
// DecreaseToolStats 减少工具统计信息(用于删除执行记录时)
// 如果统计信息变为0,则删除该统计记录
func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error {
// 先更新统计信息
query := `
UPDATE tool_stats SET
total_calls = CASE WHEN total_calls - ? < 0 THEN 0 ELSE total_calls - ? END,
success_calls = CASE WHEN success_calls - ? < 0 THEN 0 ELSE success_calls - ? END,
failed_calls = CASE WHEN failed_calls - ? < 0 THEN 0 ELSE failed_calls - ? END,
updated_at = ?
WHERE tool_name = ?
`
_, err := db.Exec(query, totalCalls, totalCalls, successCalls, successCalls, failedCalls, failedCalls, time.Now(), toolName)
if err != nil {
db.logger.Error("减少工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
return err
}
// 检查更新后的 total_calls 是否为 0,如果是则删除该统计记录
checkQuery := `SELECT total_calls FROM tool_stats WHERE tool_name = ?`
var newTotalCalls int
err = db.QueryRow(checkQuery, toolName).Scan(&newTotalCalls)
if err != nil {
// 如果查询失败(记录不存在),直接返回
return nil
}
// 如果 total_calls 为 0,删除该统计记录
if newTotalCalls == 0 {
deleteQuery := `DELETE FROM tool_stats WHERE tool_name = ?`
_, err = db.Exec(deleteQuery, toolName)
if err != nil {
db.logger.Warn("删除零统计记录失败", zap.Error(err), zap.String("toolName", toolName))
// 不返回错误,因为主要操作(更新统计)已成功
} else {
db.logger.Info("已删除零统计记录", zap.String("toolName", toolName))
}
}
return nil
}
-142
View File
@@ -1,142 +0,0 @@
package database
import (
"database/sql"
"time"
"go.uber.org/zap"
)
// SkillStats Skills统计信息
type SkillStats struct {
SkillName string
TotalCalls int
SuccessCalls int
FailedCalls int
LastCallTime *time.Time
}
// SaveSkillStats 保存Skills统计信息
func (db *DB) SaveSkillStats(skillName string, stats *SkillStats) error {
var lastCallTime sql.NullTime
if stats.LastCallTime != nil {
lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true}
}
query := `
INSERT OR REPLACE INTO skill_stats
(skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(query,
skillName,
stats.TotalCalls,
stats.SuccessCalls,
stats.FailedCalls,
lastCallTime,
time.Now(),
)
if err != nil {
db.logger.Error("保存Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName))
return err
}
return nil
}
// LoadSkillStats 加载所有Skills统计信息
func (db *DB) LoadSkillStats() (map[string]*SkillStats, error) {
query := `
SELECT skill_name, total_calls, success_calls, failed_calls, last_call_time
FROM skill_stats
`
rows, err := db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
stats := make(map[string]*SkillStats)
for rows.Next() {
var stat SkillStats
var lastCallTime sql.NullTime
err := rows.Scan(
&stat.SkillName,
&stat.TotalCalls,
&stat.SuccessCalls,
&stat.FailedCalls,
&lastCallTime,
)
if err != nil {
db.logger.Warn("加载Skills统计信息失败", zap.Error(err))
continue
}
if lastCallTime.Valid {
stat.LastCallTime = &lastCallTime.Time
}
stats[stat.SkillName] = &stat
}
return stats, nil
}
// UpdateSkillStats 更新Skills统计信息(累加模式)
func (db *DB) UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
var lastCallTimeSQL sql.NullTime
if lastCallTime != nil {
lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true}
}
query := `
INSERT INTO skill_stats (skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(skill_name) DO UPDATE SET
total_calls = total_calls + ?,
success_calls = success_calls + ?,
failed_calls = failed_calls + ?,
last_call_time = COALESCE(?, last_call_time),
updated_at = ?
`
_, err := db.Exec(query,
skillName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
)
if err != nil {
db.logger.Error("更新Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName))
return err
}
return nil
}
// ClearSkillStats 清空所有Skills统计信息
func (db *DB) ClearSkillStats() error {
query := `DELETE FROM skill_stats`
_, err := db.Exec(query)
if err != nil {
db.logger.Error("清空Skills统计信息失败", zap.Error(err))
return err
}
db.logger.Info("已清空所有Skills统计信息")
return nil
}
// ClearSkillStatsByName 清空指定skill的统计信息
func (db *DB) ClearSkillStatsByName(skillName string) error {
query := `DELETE FROM skill_stats WHERE skill_name = ?`
_, err := db.Exec(query, skillName)
if err != nil {
db.logger.Error("清空指定skill统计信息失败", zap.Error(err), zap.String("skillName", skillName))
return err
}
db.logger.Info("已清空指定skill统计信息", zap.String("skillName", skillName))
return nil
}
-281
View File
@@ -1,281 +0,0 @@
package database
import (
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Vulnerability 漏洞
type Vulnerability struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"` // critical, high, medium, low, info
Status string `json:"status"` // open, confirmed, fixed, false_positive
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// CreateVulnerability 创建漏洞
func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
if vuln.ID == "" {
vuln.ID = uuid.New().String()
}
if vuln.Status == "" {
vuln.Status = "open"
}
now := time.Now()
if vuln.CreatedAt.IsZero() {
vuln.CreatedAt = now
}
vuln.UpdatedAt = now
query := `
INSERT INTO vulnerabilities (
id, conversation_id, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(
query,
vuln.ID, vuln.ConversationID, vuln.Title, vuln.Description,
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
vuln.Proof, vuln.Impact, vuln.Recommendation,
vuln.CreatedAt, vuln.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("创建漏洞失败: %w", err)
}
return vuln, nil
}
// GetVulnerability 获取漏洞
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
var vuln Vulnerability
query := `
SELECT id, conversation_id, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at
FROM vulnerabilities
WHERE id = ?
`
err := db.QueryRow(query, id).Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.CreatedAt, &vuln.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("漏洞不存在")
}
return nil, fmt.Errorf("获取漏洞失败: %w", err)
}
return &vuln, nil
}
// ListVulnerabilities 列出漏洞
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status string) ([]*Vulnerability, error) {
query := `
SELECT id, conversation_id, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at
FROM vulnerabilities
WHERE 1=1
`
args := []interface{}{}
if id != "" {
query += " AND id = ?"
args = append(args, id)
}
if conversationID != "" {
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
}
if status != "" {
query += " AND status = ?"
args = append(args, status)
}
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询漏洞列表失败: %w", err)
}
defer rows.Close()
var vulnerabilities []*Vulnerability
for rows.Next() {
var vuln Vulnerability
err := rows.Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.CreatedAt, &vuln.UpdatedAt,
)
if err != nil {
db.logger.Warn("扫描漏洞记录失败", zap.Error(err))
continue
}
vulnerabilities = append(vulnerabilities, &vuln)
}
return vulnerabilities, nil
}
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
func (db *DB) CountVulnerabilities(id, conversationID, severity, status string) (int, error) {
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
args := []interface{}{}
if id != "" {
query += " AND id = ?"
args = append(args, id)
}
if conversationID != "" {
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
}
if status != "" {
query += " AND status = ?"
args = append(args, status)
}
var count int
err := db.QueryRow(query, args...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("统计漏洞总数失败: %w", err)
}
return count, nil
}
// UpdateVulnerability 更新漏洞
func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
vuln.UpdatedAt = time.Now()
query := `
UPDATE vulnerabilities
SET title = ?, description = ?, severity = ?, status = ?,
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
recommendation = ?, updated_at = ?
WHERE id = ?
`
_, err := db.Exec(
query,
vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
vuln.Recommendation, vuln.UpdatedAt, id,
)
if err != nil {
return fmt.Errorf("更新漏洞失败: %w", err)
}
return nil
}
// DeleteVulnerability 删除漏洞
func (db *DB) DeleteVulnerability(id string) error {
_, err := db.Exec("DELETE FROM vulnerabilities WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除漏洞失败: %w", err)
}
return nil
}
// GetVulnerabilityStats 获取漏洞统计
func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface{}, error) {
stats := make(map[string]interface{})
// 总漏洞数
var totalCount int
query := "SELECT COUNT(*) FROM vulnerabilities"
args := []interface{}{}
if conversationID != "" {
query += " WHERE conversation_id = ?"
args = append(args, conversationID)
}
err := db.QueryRow(query, args...).Scan(&totalCount)
if err != nil {
return nil, fmt.Errorf("获取总漏洞数失败: %w", err)
}
stats["total"] = totalCount
// 按严重程度统计
severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities"
if conversationID != "" {
severityQuery += " WHERE conversation_id = ?"
}
severityQuery += " GROUP BY severity"
rows, err := db.Query(severityQuery, args...)
if err != nil {
return nil, fmt.Errorf("获取严重程度统计失败: %w", err)
}
defer rows.Close()
severityStats := make(map[string]int)
for rows.Next() {
var severity string
var count int
if err := rows.Scan(&severity, &count); err != nil {
continue
}
severityStats[severity] = count
}
stats["by_severity"] = severityStats
// 按状态统计
statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities"
if conversationID != "" {
statusQuery += " WHERE conversation_id = ?"
}
statusQuery += " GROUP BY status"
rows, err = db.Query(statusQuery, args...)
if err != nil {
return nil, fmt.Errorf("获取状态统计失败: %w", err)
}
defer rows.Close()
statusStats := make(map[string]int)
for rows.Next() {
var status string
var count int
if err := rows.Scan(&status, &count); err != nil {
continue
}
statusStats[status] = count
}
stats["by_status"] = statusStats
return stats, nil
}
-148
View File
@@ -1,148 +0,0 @@
package database
import (
"database/sql"
"time"
"go.uber.org/zap"
)
// WebShellConnection WebShell 连接配置
type WebShellConnection struct {
ID string `json:"id"`
URL string `json:"url"`
Password string `json:"password"`
Type string `json:"type"`
Method string `json:"method"`
CmdParam string `json:"cmdParam"`
Remark string `json:"remark"`
CreatedAt time.Time `json:"createdAt"`
}
// GetWebshellConnectionState 获取连接关联的持久化状态 JSON,不存在时返回 "{}"
func (db *DB) GetWebshellConnectionState(connectionID string) (string, error) {
var stateJSON string
err := db.QueryRow(`SELECT state_json FROM webshell_connection_states WHERE connection_id = ?`, connectionID).Scan(&stateJSON)
if err == sql.ErrNoRows {
return "{}", nil
}
if err != nil {
db.logger.Error("查询 WebShell 连接状态失败", zap.Error(err), zap.String("connectionID", connectionID))
return "", err
}
if stateJSON == "" {
stateJSON = "{}"
}
return stateJSON, nil
}
// UpsertWebshellConnectionState 保存连接关联的持久化状态 JSON
func (db *DB) UpsertWebshellConnectionState(connectionID, stateJSON string) error {
if stateJSON == "" {
stateJSON = "{}"
}
query := `
INSERT INTO webshell_connection_states (connection_id, state_json, updated_at)
VALUES (?, ?, ?)
ON CONFLICT(connection_id) DO UPDATE SET
state_json = excluded.state_json,
updated_at = excluded.updated_at
`
if _, err := db.Exec(query, connectionID, stateJSON, time.Now()); err != nil {
db.logger.Error("保存 WebShell 连接状态失败", zap.Error(err), zap.String("connectionID", connectionID))
return err
}
return nil
}
// ListWebshellConnections 列出所有 WebShell 连接,按创建时间倒序
func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) {
query := `
SELECT id, url, password, type, method, cmd_param, remark, created_at
FROM webshell_connections
ORDER BY created_at DESC
`
rows, err := db.Query(query)
if err != nil {
db.logger.Error("查询 WebShell 连接列表失败", zap.Error(err))
return nil, err
}
defer rows.Close()
var list []WebShellConnection
for rows.Next() {
var c WebShellConnection
err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt)
if err != nil {
db.logger.Warn("扫描 WebShell 连接行失败", zap.Error(err))
continue
}
list = append(list, c)
}
return list, rows.Err()
}
// GetWebshellConnection 根据 ID 获取一条连接
func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) {
query := `
SELECT id, url, password, type, method, cmd_param, remark, created_at
FROM webshell_connections WHERE id = ?
`
var c WebShellConnection
err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
db.logger.Error("查询 WebShell 连接失败", zap.Error(err), zap.String("id", id))
return nil, err
}
return &c, nil
}
// CreateWebshellConnection 创建 WebShell 连接
func (db *DB) CreateWebshellConnection(c *WebShellConnection) error {
query := `
INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.CreatedAt)
if err != nil {
db.logger.Error("创建 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID))
return err
}
return nil
}
// UpdateWebshellConnection 更新 WebShell 连接
func (db *DB) UpdateWebshellConnection(c *WebShellConnection) error {
query := `
UPDATE webshell_connections
SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ?
WHERE id = ?
`
result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.ID)
if err != nil {
db.logger.Error("更新 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID))
return err
}
affected, _ := result.RowsAffected()
if affected == 0 {
return sql.ErrNoRows
}
return nil
}
// DeleteWebshellConnection 删除 WebShell 连接
func (db *DB) DeleteWebshellConnection(id string) error {
result, err := db.Exec(`DELETE FROM webshell_connections WHERE id = ?`, id)
if err != nil {
db.logger.Error("删除 WebShell 连接失败", zap.Error(err), zap.String("id", id))
return err
}
affected, _ := result.RowsAffected()
if affected == 0 {
return sql.ErrNoRows
}
return nil
}
-21
View File
@@ -1,21 +0,0 @@
package einomcp
import "sync"
// ConversationHolder 在每次 DeepAgent 运行前写入会话 ID,供 MCP 工具桥接使用。
type ConversationHolder struct {
mu sync.RWMutex
id string
}
func (h *ConversationHolder) Set(id string) {
h.mu.Lock()
h.id = id
h.mu.Unlock()
}
func (h *ConversationHolder) Get() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.id
}
-186
View File
@@ -1,186 +0,0 @@
package einomcp
import (
"context"
"encoding/json"
"fmt"
"strings"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/security"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"github.com/eino-contrib/jsonschema"
)
// ExecutionRecorder 可选,在 MCP 工具成功返回且带有 execution id 时回调(用于汇总 mcpExecutionIds)。
type ExecutionRecorder func(executionID string)
// ToolErrorPrefix 用于把内部 MCP 执行结果中的 IsError 标记传递到多代理上层。
// Eino 工具通道目前只支持返回字符串,因此通过前缀标识,随后在多代理 runner 中解析为 success/isError。
const ToolErrorPrefix = "__CYBERSTRIKE_AI_TOOL_ERROR__\n"
// ToolsFromDefinitions 将单 Agent 使用的 OpenAI 风格工具定义转为 Eino InvokableTool,执行时走 Agent 的 MCP 路径。
func ToolsFromDefinitions(
ag *agent.Agent,
holder *ConversationHolder,
defs []agent.Tool,
rec ExecutionRecorder,
toolOutputChunk func(toolName, toolCallID, chunk string),
) ([]tool.BaseTool, error) {
out := make([]tool.BaseTool, 0, len(defs))
for _, d := range defs {
if d.Type != "function" || d.Function.Name == "" {
continue
}
info, err := toolInfoFromDefinition(d)
if err != nil {
return nil, fmt.Errorf("tool %q: %w", d.Function.Name, err)
}
out = append(out, &mcpBridgeTool{
info: info,
name: d.Function.Name,
agent: ag,
holder: holder,
record: rec,
chunk: toolOutputChunk,
})
}
return out, nil
}
func toolInfoFromDefinition(d agent.Tool) (*schema.ToolInfo, error) {
fn := d.Function
raw, err := json.Marshal(fn.Parameters)
if err != nil {
return nil, err
}
var js jsonschema.Schema
if len(raw) > 0 && string(raw) != "null" && string(raw) != "{}" {
if err := json.Unmarshal(raw, &js); err != nil {
return nil, err
}
}
if js.Type == "" {
js.Type = string(schema.Object)
}
if js.Properties == nil && js.Type == string(schema.Object) {
// 空参数对象
}
return &schema.ToolInfo{
Name: fn.Name,
Desc: fn.Description,
ParamsOneOf: schema.NewParamsOneOfByJSONSchema(&js),
}, nil
}
type mcpBridgeTool struct {
info *schema.ToolInfo
name string
agent *agent.Agent
holder *ConversationHolder
record ExecutionRecorder
chunk func(toolName, toolCallID, chunk string)
}
func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
_ = ctx
return m.info, nil
}
func (m *mcpBridgeTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
_ = opts
return runMCPToolInvocation(ctx, m.agent, m.holder, m.name, argumentsInJSON, m.record, m.chunk)
}
// runMCPToolInvocation 与 mcpBridgeTool.InvokableRun 共用。
func runMCPToolInvocation(
ctx context.Context,
ag *agent.Agent,
holder *ConversationHolder,
toolName string,
argumentsInJSON string,
record ExecutionRecorder,
chunk func(toolName, toolCallID, chunk string),
) (string, error) {
var args map[string]interface{}
if argumentsInJSON != "" && argumentsInJSON != "null" {
if err := json.Unmarshal([]byte(argumentsInJSON), &args); err != nil {
// Return soft error (nil error) so the eino graph continues and the LLM can self-correct,
// instead of a hard error that terminates the iteration loop.
return ToolErrorPrefix + fmt.Sprintf(
"Invalid tool arguments JSON: %s\n\nPlease ensure the arguments are a valid JSON object "+
"(double-quoted keys, matched braces, no trailing commas) and retry.\n\n"+
"(工具参数 JSON 解析失败:%s。请确保 arguments 是合法的 JSON 对象并重试。)",
err.Error(), err.Error()), nil
}
}
if args == nil {
args = map[string]interface{}{}
}
if chunk != nil {
toolCallID := compose.GetToolCallID(ctx)
if toolCallID != "" {
if existing, ok := ctx.Value(security.ToolOutputCallbackCtxKey).(security.ToolOutputCallback); ok && existing != nil {
ctx = context.WithValue(ctx, security.ToolOutputCallbackCtxKey, security.ToolOutputCallback(func(c string) {
existing(c)
if strings.TrimSpace(c) == "" {
return
}
chunk(toolName, toolCallID, c)
}))
} else {
ctx = context.WithValue(ctx, security.ToolOutputCallbackCtxKey, security.ToolOutputCallback(func(c string) {
if strings.TrimSpace(c) == "" {
return
}
chunk(toolName, toolCallID, c)
}))
}
}
}
res, err := ag.ExecuteMCPToolForConversation(ctx, holder.Get(), toolName, args)
if err != nil {
return "", err
}
if res == nil {
return "", nil
}
if res.ExecutionID != "" && record != nil {
record(res.ExecutionID)
}
if res.IsError {
return ToolErrorPrefix + res.Result, nil
}
return res.Result, nil
}
// UnknownToolReminderHandler 供 compose.ToolsNodeConfig.UnknownToolsHandler 使用:
// 模型请求了未注册的工具名时,返回一个「可恢复」的错误,让上层 runner 触发重试与纠错提示,
// 同时避免 UI 永远停留在“执行中”(runner 会在 recoverable 分支 flush 掉 pending 的 tool_call)。
// 不进行名称猜测或映射,避免误执行。
func UnknownToolReminderHandler() func(ctx context.Context, name, input string) (string, error) {
return func(ctx context.Context, name, input string) (string, error) {
_ = ctx
_ = input
requested := strings.TrimSpace(name)
// Return a recoverable error that still carries a friendly, bilingual hint.
// This will be caught by multiagent runner as "tool not found" and trigger a retry.
return "", fmt.Errorf("tool %q not found: %s", requested, unknownToolReminderText(requested))
}
}
func unknownToolReminderText(requested string) string {
if requested == "" {
requested = "(empty)"
}
return fmt.Sprintf(`The tool name %q is not registered for this agent.
Please retry using only names that appear in the tool definitions for this turn (exact match, case-sensitive). Do not invent or rename tools; adjust your plan and continue.
(工具 %q 未注册:请仅使用本回合上下文中给出的工具名称,须完全一致;请勿自行改写或猜测名称,并继续后续步骤。)`, requested, requested)
}
-16
View File
@@ -1,16 +0,0 @@
package einomcp
import (
"strings"
"testing"
)
func TestUnknownToolReminderText(t *testing.T) {
s := unknownToolReminderText("bad_tool")
if !strings.Contains(s, "bad_tool") {
t.Fatalf("expected requested name in message: %s", s)
}
if strings.Contains(s, "Tools currently available") {
t.Fatal("unified message must not list tool names")
}
}
File diff suppressed because it is too large Load Diff
-173
View File
@@ -1,173 +0,0 @@
package handler
import (
"context"
"net/http"
"sync"
"time"
"cyberstrike-ai/internal/attackchain"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AttackChainHandler 攻击链处理器
type AttackChainHandler struct {
db *database.DB
logger *zap.Logger
openAIConfig *config.OpenAIConfig
mu sync.RWMutex // 保护 openAIConfig 的并发访问
// 用于防止同一对话的并发生成
generatingLocks sync.Map // map[string]*sync.Mutex
}
// NewAttackChainHandler 创建新的攻击链处理器
func NewAttackChainHandler(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *AttackChainHandler {
return &AttackChainHandler{
db: db,
logger: logger,
openAIConfig: openAIConfig,
}
}
// UpdateConfig 更新OpenAI配置
func (h *AttackChainHandler) UpdateConfig(cfg *config.OpenAIConfig) {
h.mu.Lock()
defer h.mu.Unlock()
h.openAIConfig = cfg
h.logger.Info("AttackChainHandler配置已更新",
zap.String("base_url", cfg.BaseURL),
zap.String("model", cfg.Model),
)
}
// getOpenAIConfig 获取OpenAI配置(线程安全)
func (h *AttackChainHandler) getOpenAIConfig() *config.OpenAIConfig {
h.mu.RLock()
defer h.mu.RUnlock()
return h.openAIConfig
}
// GetAttackChain 获取攻击链(按需生成)
// GET /api/attack-chain/:conversationId
func (h *AttackChainHandler) GetAttackChain(c *gin.Context) {
conversationID := c.Param("conversationId")
if conversationID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
return
}
// 检查对话是否存在
_, err := h.db.GetConversation(conversationID)
if err != nil {
h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
// 先尝试从数据库加载(如果已生成过)
openAIConfig := h.getOpenAIConfig()
builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger)
chain, err := builder.LoadChainFromDatabase(conversationID)
if err == nil && len(chain.Nodes) > 0 {
// 如果已存在,直接返回
h.logger.Info("返回已存在的攻击链", zap.String("conversationId", conversationID))
c.JSON(http.StatusOK, chain)
return
}
// 如果不存在,则生成新的攻击链(按需生成)
// 使用锁机制防止同一对话的并发生成
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
lock := lockInterface.(*sync.Mutex)
// 尝试获取锁,如果正在生成则返回错误
acquired := lock.TryLock()
if !acquired {
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"})
return
}
defer lock.Unlock()
// 再次检查是否已生成(可能在等待锁的过程中已经生成完成)
chain, err = builder.LoadChainFromDatabase(conversationID)
if err == nil && len(chain.Nodes) > 0 {
h.logger.Info("返回已存在的攻击链(在锁等待期间已生成)", zap.String("conversationId", conversationID))
c.JSON(http.StatusOK, chain)
return
}
h.logger.Info("开始生成攻击链", zap.String("conversationId", conversationID))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
chain, err = builder.BuildChainFromConversation(ctx, conversationID)
if err != nil {
h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()})
return
}
// 生成完成后,从锁映射中删除(可选,保留也可以用于防止短时间内重复生成)
// h.generatingLocks.Delete(conversationID)
c.JSON(http.StatusOK, chain)
}
// RegenerateAttackChain 重新生成攻击链
// POST /api/attack-chain/:conversationId/regenerate
func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
conversationID := c.Param("conversationId")
if conversationID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "conversationId is required"})
return
}
// 检查对话是否存在
_, err := h.db.GetConversation(conversationID)
if err != nil {
h.logger.Warn("对话不存在", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
// 删除旧的攻击链
if err := h.db.DeleteAttackChain(conversationID); err != nil {
h.logger.Warn("删除旧攻击链失败", zap.Error(err))
}
// 使用锁机制防止并发生成
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
lock := lockInterface.(*sync.Mutex)
acquired := lock.TryLock()
if !acquired {
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
c.JSON(http.StatusConflict, gin.H{"error": "攻击链正在生成中,请稍后再试"})
return
}
defer lock.Unlock()
// 生成新的攻击链
h.logger.Info("重新生成攻击链", zap.String("conversationId", conversationID))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
openAIConfig := h.getOpenAIConfig()
builder := attackchain.NewBuilder(h.db, openAIConfig, h.logger)
chain, err := builder.BuildChainFromConversation(ctx, conversationID)
if err != nil {
h.logger.Error("生成攻击链失败", zap.String("conversationId", conversationID), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成攻击链失败: " + err.Error()})
return
}
c.JSON(http.StatusOK, chain)
}
-156
View File
@@ -1,156 +0,0 @@
package handler
import (
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AuthHandler handles authentication-related endpoints.
type AuthHandler struct {
manager *security.AuthManager
config *config.Config
configPath string
logger *zap.Logger
}
// NewAuthHandler creates a new AuthHandler.
func NewAuthHandler(manager *security.AuthManager, cfg *config.Config, configPath string, logger *zap.Logger) *AuthHandler {
return &AuthHandler{
manager: manager,
config: cfg,
configPath: configPath,
logger: logger,
}
}
type loginRequest struct {
Password string `json:"password" binding:"required"`
}
type changePasswordRequest struct {
OldPassword string `json:"oldPassword"`
NewPassword string `json:"newPassword"`
}
// Login verifies password and returns a session token.
func (h *AuthHandler) Login(c *gin.Context) {
var req loginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "密码不能为空"})
return
}
token, expiresAt, err := h.manager.Authenticate(req.Password)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
return
}
c.JSON(http.StatusOK, gin.H{
"token": token,
"expires_at": expiresAt.UTC().Format(time.RFC3339),
"session_duration_hr": h.manager.SessionDurationHours(),
})
}
// Logout revokes the current session token.
func (h *AuthHandler) Logout(c *gin.Context) {
token := c.GetString(security.ContextAuthTokenKey)
if token == "" {
authHeader := c.GetHeader("Authorization")
if len(authHeader) > 7 && strings.EqualFold(authHeader[:7], "Bearer ") {
token = strings.TrimSpace(authHeader[7:])
} else {
token = strings.TrimSpace(authHeader)
}
}
h.manager.RevokeToken(token)
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
}
// ChangePassword updates the login password.
func (h *AuthHandler) ChangePassword(c *gin.Context) {
var req changePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "参数无效"})
return
}
oldPassword := strings.TrimSpace(req.OldPassword)
newPassword := strings.TrimSpace(req.NewPassword)
if oldPassword == "" || newPassword == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码和新密码均不能为空"})
return
}
if len(newPassword) < 8 {
c.JSON(http.StatusBadRequest, gin.H{"error": "新密码长度至少需要 8 位"})
return
}
if oldPassword == newPassword {
c.JSON(http.StatusBadRequest, gin.H{"error": "新密码不能与旧密码相同"})
return
}
if !h.manager.CheckPassword(oldPassword) {
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
return
}
if err := config.PersistAuthPassword(h.configPath, newPassword); err != nil {
if h.logger != nil {
h.logger.Error("保存新密码失败", zap.Error(err))
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存新密码失败,请重试"})
return
}
if err := h.manager.UpdateConfig(newPassword, h.config.Auth.SessionDurationHours); err != nil {
if h.logger != nil {
h.logger.Error("更新认证配置失败", zap.Error(err))
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新认证配置失败"})
return
}
h.config.Auth.Password = newPassword
h.config.Auth.GeneratedPassword = ""
h.config.Auth.GeneratedPasswordPersisted = false
h.config.Auth.GeneratedPasswordPersistErr = ""
if h.logger != nil {
h.logger.Info("登录密码已更新,所有会话已失效")
}
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
}
// Validate returns the current session status.
func (h *AuthHandler) Validate(c *gin.Context) {
token := c.GetString(security.ContextAuthTokenKey)
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "会话无效"})
return
}
session, ok := h.manager.ValidateToken(token)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "会话已过期"})
return
}
c.JSON(http.StatusOK, gin.H{
"token": session.Token,
"expires_at": session.ExpiresAt.UTC().Format(time.RFC3339),
})
}
File diff suppressed because it is too large Load Diff
-813
View File
@@ -1,813 +0,0 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
// RegisterBatchTaskMCPTools 注册批量任务队列相关 MCP 工具(需传入已初始化 DB 的 AgentHandler
func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *zap.Logger) {
if mcpServer == nil || h == nil || logger == nil {
return
}
reg := func(tool mcp.Tool, fn func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error)) {
mcpServer.RegisterTool(tool, fn)
}
// --- list ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskList,
Description: "列出批量任务队列(精简摘要,省上下文)。含队列元数据、子任务 id/status/截断后的 message、各状态计数。完整子任务(含 result/error/conversationId/时间等)请用 batch_task_get(queue_id)。",
ShortDescription: "列出批量任务队列",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"status": map[string]interface{}{
"type": "string",
"description": "筛选状态:all(默认)、pending、running、paused、completed、cancelled",
"enum": []string{"all", "pending", "running", "paused", "completed", "cancelled"},
},
"keyword": map[string]interface{}{
"type": "string",
"description": "按队列 ID 或标题模糊搜索",
},
"page": map[string]interface{}{
"type": "integer",
"description": "页码,从 1 开始,默认 1",
},
"page_size": map[string]interface{}{
"type": "integer",
"description": "每页条数,默认 20,最大 100",
},
},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
status := mcpArgString(args, "status")
if status == "" {
status = "all"
}
keyword := mcpArgString(args, "keyword")
page := int(mcpArgFloat(args, "page"))
if page <= 0 {
page = 1
}
pageSize := int(mcpArgFloat(args, "page_size"))
if pageSize <= 0 {
pageSize = 20
}
if pageSize > 100 {
pageSize = 100
}
offset := (page - 1) * pageSize
if offset > 100000 {
offset = 100000
}
queues, total, err := h.batchTaskManager.ListQueues(pageSize, offset, status, keyword)
if err != nil {
return batchMCPTextResult(fmt.Sprintf("列出队列失败: %v", err), true), nil
}
totalPages := (total + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
}
slim := make([]batchTaskQueueMCPListItem, 0, len(queues))
for _, q := range queues {
if q == nil {
continue
}
slim = append(slim, toBatchTaskQueueMCPListItem(q))
}
payload := map[string]interface{}{
"queues": slim,
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": totalPages,
}
logger.Info("MCP batch_task_list", zap.String("status", status), zap.Int("total", total))
return batchMCPJSONResult(payload)
})
// --- get ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskGet,
Description: "根据 queue_id 获取单个批量任务队列详情(含子任务列表、Cron、调度开关与最近错误信息)。",
ShortDescription: "获取批量任务队列详情",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
},
"required": []string{"queue_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
queue, ok := h.batchTaskManager.GetBatchQueue(qid)
if !ok {
return batchMCPTextResult("队列不存在: "+qid, true), nil
}
return batchMCPJSONResult(queue)
})
// --- create ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskCreate,
Description: `创建新的批量任务队列。任务列表使用 tasks(字符串数组)或 tasks_text(多行,每行一条)。
agent_mode: single(默认)或 multi(需系统启用多代理)。schedule_mode: manual(默认)或 cron;为 cron 时必须提供 cron_expr(如 "0 */6 * * *")。
默认创建后不会立即执行。可通过 execute_now=true 在创建后立即启动;也可后续调用 batch_task_start 手工启动。Cron 队列若需按表达式自动触发下一轮,还需保持调度开关开启(可用 batch_task_schedule_enabled)。`,
ShortDescription: "创建批量任务队列(可选立即执行)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"title": map[string]interface{}{
"type": "string",
"description": "可选标题",
},
"role": map[string]interface{}{
"type": "string",
"description": "角色名称,空表示默认",
},
"tasks": map[string]interface{}{
"type": "array",
"description": "任务指令列表,每项一条",
"items": map[string]interface{}{"type": "string"},
},
"tasks_text": map[string]interface{}{
"type": "string",
"description": "多行文本,每行一条任务(与 tasks 二选一)",
},
"agent_mode": map[string]interface{}{
"type": "string",
"description": "single 或 multi",
"enum": []string{"single", "multi"},
},
"schedule_mode": map[string]interface{}{
"type": "string",
"description": "manual 或 cron",
"enum": []string{"manual", "cron"},
},
"cron_expr": map[string]interface{}{
"type": "string",
"description": "schedule_mode 为 cron 时必填。标准 5 段格式:分钟 小时 日 月 星期,例如 \"0 */6 * * *\"(每6小时)、\"30 2 * * 1-5\"(工作日凌晨2:30",
},
"execute_now": map[string]interface{}{
"type": "boolean",
"description": "是否创建后立即执行,默认 false",
},
},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
tasks, errMsg := batchMCPTasksFromArgs(args)
if errMsg != "" {
return batchMCPTextResult(errMsg, true), nil
}
title := mcpArgString(args, "title")
role := mcpArgString(args, "role")
agentMode := normalizeBatchQueueAgentMode(mcpArgString(args, "agent_mode"))
scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode"))
cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr"))
var nextRunAt *time.Time
if scheduleMode == "cron" {
if cronExpr == "" {
return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil
}
sch, err := h.batchCronParser.Parse(cronExpr)
if err != nil {
return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil
}
n := sch.Next(time.Now())
nextRunAt = &n
}
executeNow, ok := mcpArgBool(args, "execute_now")
if !ok {
executeNow = false
}
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, nextRunAt, tasks)
if createErr != nil {
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
}
started := false
if executeNow {
ok, err := h.startBatchQueueExecution(queue.ID, false)
if !ok {
return batchMCPTextResult("队列不存在: "+queue.ID, true), nil
}
if err != nil {
return batchMCPTextResult("创建成功但启动失败: "+err.Error(), true), nil
}
started = true
if refreshed, exists := h.batchTaskManager.GetBatchQueue(queue.ID); exists {
queue = refreshed
}
}
logger.Info("MCP batch_task_create", zap.String("queueId", queue.ID), zap.Int("taskCount", len(tasks)))
return batchMCPJSONResult(map[string]interface{}{
"queue_id": queue.ID,
"queue": queue,
"started": started,
"execute_now": executeNow,
"reminder": func() string {
if started {
return "队列已创建并立即启动。"
}
return "队列已创建,当前为 pending。需要开始执行时请调用 MCP 工具 batch_task_startqueue_id 同上)。Cron 自动调度需 schedule_enabled 为 true,可用 batch_task_schedule_enabled。"
}(),
})
})
// --- start ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskStart,
Description: `启动或继续执行批量任务队列(pending / paused)。
与 batch_task_create 配合使用:仅创建队列不会自动执行,需调用本工具才会开始跑子任务。`,
ShortDescription: "启动/继续批量任务队列(创建后需调用才会执行)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
},
"required": []string{"queue_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
ok, err := h.startBatchQueueExecution(qid, false)
if !ok {
return batchMCPTextResult("队列不存在: "+qid, true), nil
}
if err != nil {
return batchMCPTextResult("启动失败: "+err.Error(), true), nil
}
logger.Info("MCP batch_task_start", zap.String("queueId", qid))
return batchMCPTextResult("已提交启动,队列将开始执行。", false), nil
})
// --- rerun (reset + start for completed/cancelled queues) ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskRerun,
Description: "重跑已完成或已取消的批量任务队列。会重置所有子任务状态后重新执行一轮。",
ShortDescription: "重跑批量任务队列",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
},
"required": []string{"queue_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
queue, exists := h.batchTaskManager.GetBatchQueue(qid)
if !exists {
return batchMCPTextResult("队列不存在: "+qid, true), nil
}
if queue.Status != "completed" && queue.Status != "cancelled" {
return batchMCPTextResult("仅已完成或已取消的队列可以重跑,当前状态: "+queue.Status, true), nil
}
if !h.batchTaskManager.ResetQueueForRerun(qid) {
return batchMCPTextResult("重置队列失败", true), nil
}
ok, err := h.startBatchQueueExecution(qid, false)
if !ok {
return batchMCPTextResult("启动失败", true), nil
}
if err != nil {
return batchMCPTextResult("启动失败: "+err.Error(), true), nil
}
logger.Info("MCP batch_task_rerun", zap.String("queueId", qid))
return batchMCPTextResult("已重置并重新启动队列。", false), nil
})
// --- pause ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskPause,
Description: "暂停正在运行的批量任务队列(当前子任务会被取消)。",
ShortDescription: "暂停批量任务队列",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
},
"required": []string{"queue_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
if !h.batchTaskManager.PauseQueue(qid) {
return batchMCPTextResult("无法暂停:队列不存在或当前非 running 状态", true), nil
}
logger.Info("MCP batch_task_pause", zap.String("queueId", qid))
return batchMCPTextResult("队列已暂停。", false), nil
})
// --- delete queue ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskDelete,
Description: "删除批量任务队列及其子任务记录。",
ShortDescription: "删除批量任务队列",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
},
"required": []string{"queue_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
if !h.batchTaskManager.DeleteQueue(qid) {
return batchMCPTextResult("删除失败:队列不存在", true), nil
}
logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
return batchMCPTextResult("队列已删除。", false), nil
})
// --- update metadata (title/role/agentMode) ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskUpdateMetadata,
Description: "修改批量任务队列的标题、角色和代理模式。仅在队列非 running 状态下可修改。",
ShortDescription: "修改批量任务队列标题/角色/代理模式",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
"title": map[string]interface{}{
"type": "string",
"description": "新标题(空字符串清除标题)",
},
"role": map[string]interface{}{
"type": "string",
"description": "新角色名(空字符串使用默认角色)",
},
"agent_mode": map[string]interface{}{
"type": "string",
"description": "代理模式:single(单代理 ReAct)或 multi(多代理)",
"enum": []string{"single", "multi"},
},
},
"required": []string{"queue_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
title := mcpArgString(args, "title")
role := mcpArgString(args, "role")
agentMode := mcpArgString(args, "agent_mode")
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil {
return batchMCPTextResult(err.Error(), true), nil
}
updated, _ := h.batchTaskManager.GetBatchQueue(qid)
logger.Info("MCP batch_task_update_metadata", zap.String("queueId", qid))
return batchMCPJSONResult(updated)
})
// --- update schedule ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskUpdateSchedule,
Description: `修改批量任务队列的调度方式和 Cron 表达式。仅在队列非 running 状态下可修改。
schedule_mode 为 cron 时必须提供有效 cron_expr;为 manual 时会清除 Cron 配置。`,
ShortDescription: "修改批量任务调度配置(Cron 表达式)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
"schedule_mode": map[string]interface{}{
"type": "string",
"description": "manual 或 cron",
"enum": []string{"manual", "cron"},
},
"cron_expr": map[string]interface{}{
"type": "string",
"description": "Cron 表达式(schedule_mode 为 cron 时必填)。标准 5 段格式:分钟 小时 日 月 星期,如 \"0 */6 * * *\"(每6小时)、\"30 2 * * 1-5\"(工作日凌晨2:30",
},
},
"required": []string{"queue_id", "schedule_mode"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
queue, exists := h.batchTaskManager.GetBatchQueue(qid)
if !exists {
return batchMCPTextResult("队列不存在: "+qid, true), nil
}
if queue.Status == "running" {
return batchMCPTextResult("队列正在运行中,无法修改调度配置", true), nil
}
scheduleMode := normalizeBatchQueueScheduleMode(mcpArgString(args, "schedule_mode"))
cronExpr := strings.TrimSpace(mcpArgString(args, "cron_expr"))
var nextRunAt *time.Time
if scheduleMode == "cron" {
if cronExpr == "" {
return batchMCPTextResult("Cron 调度模式下 cron_expr 不能为空", true), nil
}
sch, err := h.batchCronParser.Parse(cronExpr)
if err != nil {
return batchMCPTextResult("无效的 Cron 表达式: "+err.Error(), true), nil
}
n := sch.Next(time.Now())
nextRunAt = &n
}
h.batchTaskManager.UpdateQueueSchedule(qid, scheduleMode, cronExpr, nextRunAt)
updated, _ := h.batchTaskManager.GetBatchQueue(qid)
logger.Info("MCP batch_task_update_schedule", zap.String("queueId", qid), zap.String("scheduleMode", scheduleMode), zap.String("cronExpr", cronExpr))
return batchMCPJSONResult(updated)
})
// --- schedule enabled ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskScheduleEnabled,
Description: `设置是否允许 Cron 自动触发该队列。关闭后仍保留 Cron 表达式,仅停止定时自动跑;可用手工「启动」执行。
仅对 schedule_mode 为 cron 的队列有意义。`,
ShortDescription: "开关批量任务 Cron 自动调度",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
"schedule_enabled": map[string]interface{}{
"type": "boolean",
"description": "true 允许定时触发,false 仅手工执行",
},
},
"required": []string{"queue_id", "schedule_enabled"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil
}
en, ok := mcpArgBool(args, "schedule_enabled")
if !ok {
return batchMCPTextResult("schedule_enabled 必须为布尔值", true), nil
}
if _, exists := h.batchTaskManager.GetBatchQueue(qid); !exists {
return batchMCPTextResult("队列不存在", true), nil
}
if !h.batchTaskManager.SetScheduleEnabled(qid, en) {
return batchMCPTextResult("更新失败", true), nil
}
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
logger.Info("MCP batch_task_schedule_enabled", zap.String("queueId", qid), zap.Bool("enabled", en))
return batchMCPJSONResult(queue)
})
// --- add task ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskAdd,
Description: "向处于 pending 状态的队列追加一条子任务。",
ShortDescription: "批量队列添加子任务",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
"message": map[string]interface{}{
"type": "string",
"description": "任务指令内容",
},
},
"required": []string{"queue_id", "message"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
msg := strings.TrimSpace(mcpArgString(args, "message"))
if qid == "" || msg == "" {
return batchMCPTextResult("queue_id 与 message 均不能为空", true), nil
}
task, err := h.batchTaskManager.AddTaskToQueue(qid, msg)
if err != nil {
return batchMCPTextResult(err.Error(), true), nil
}
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
logger.Info("MCP batch_task_add_task", zap.String("queueId", qid), zap.String("taskId", task.ID))
return batchMCPJSONResult(map[string]interface{}{"task": task, "queue": queue})
})
// --- update task ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskUpdate,
Description: "修改 pending 队列中仍为 pending 的子任务文案。",
ShortDescription: "更新批量子任务内容",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
"task_id": map[string]interface{}{
"type": "string",
"description": "子任务 ID",
},
"message": map[string]interface{}{
"type": "string",
"description": "新的任务指令",
},
},
"required": []string{"queue_id", "task_id", "message"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
tid := mcpArgString(args, "task_id")
msg := strings.TrimSpace(mcpArgString(args, "message"))
if qid == "" || tid == "" || msg == "" {
return batchMCPTextResult("queue_id、task_id、message 均不能为空", true), nil
}
if err := h.batchTaskManager.UpdateTaskMessage(qid, tid, msg); err != nil {
return batchMCPTextResult(err.Error(), true), nil
}
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
logger.Info("MCP batch_task_update_task", zap.String("queueId", qid), zap.String("taskId", tid))
return batchMCPJSONResult(queue)
})
// --- remove task ---
reg(mcp.Tool{
Name: builtin.ToolBatchTaskRemove,
Description: "从 pending 队列中删除仍为 pending 的子任务。",
ShortDescription: "删除批量子任务",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"queue_id": map[string]interface{}{
"type": "string",
"description": "队列 ID",
},
"task_id": map[string]interface{}{
"type": "string",
"description": "子任务 ID",
},
},
"required": []string{"queue_id", "task_id"},
},
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
qid := mcpArgString(args, "queue_id")
tid := mcpArgString(args, "task_id")
if qid == "" || tid == "" {
return batchMCPTextResult("queue_id 与 task_id 均不能为空", true), nil
}
if err := h.batchTaskManager.DeleteTask(qid, tid); err != nil {
return batchMCPTextResult(err.Error(), true), nil
}
queue, _ := h.batchTaskManager.GetBatchQueue(qid)
logger.Info("MCP batch_task_remove_task", zap.String("queueId", qid), zap.String("taskId", tid))
return batchMCPJSONResult(queue)
})
logger.Info("批量任务 MCP 工具已注册", zap.Int("count", 12))
}
// --- batch_task_list 精简结构(避免把每条子任务的 result 等大段文本塞进列表上下文) ---
const mcpBatchListTaskMessageMaxRunes = 160
// batchTaskMCPListSummary 列表中的子任务摘要(完整字段用 batch_task_get
type batchTaskMCPListSummary struct {
ID string `json:"id"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
}
// batchTaskQueueMCPListItem 列表中的队列摘要
type batchTaskQueueMCPListItem struct {
ID string `json:"id"`
Title string `json:"title,omitempty"`
Role string `json:"role,omitempty"`
AgentMode string `json:"agentMode"`
ScheduleMode string `json:"scheduleMode"`
CronExpr string `json:"cronExpr,omitempty"`
NextRunAt *time.Time `json:"nextRunAt,omitempty"`
ScheduleEnabled bool `json:"scheduleEnabled"`
LastScheduleTriggerAt *time.Time `json:"lastScheduleTriggerAt,omitempty"`
Status string `json:"status"`
CreatedAt time.Time `json:"createdAt"`
StartedAt *time.Time `json:"startedAt,omitempty"`
CompletedAt *time.Time `json:"completedAt,omitempty"`
CurrentIndex int `json:"currentIndex"`
TaskTotal int `json:"task_total"`
TaskCounts map[string]int `json:"task_counts"`
Tasks []batchTaskMCPListSummary `json:"tasks"`
}
func truncateStringRunes(s string, maxRunes int) string {
if maxRunes <= 0 {
return ""
}
n := 0
for i := range s {
if n == maxRunes {
out := strings.TrimSpace(s[:i])
if out == "" {
return "…"
}
return out + "…"
}
n++
}
return s
}
const mcpBatchListMaxTasksPerQueue = 200 // 列表中每个队列最多返回的子任务摘要数
func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem {
counts := map[string]int{
"pending": 0,
"running": 0,
"completed": 0,
"failed": 0,
"cancelled": 0,
}
tasks := make([]batchTaskMCPListSummary, 0, len(q.Tasks))
for _, t := range q.Tasks {
if t == nil {
continue
}
counts[t.Status]++
// 列表视图限制子任务摘要数量,完整列表通过 batch_task_get 查看
if len(tasks) < mcpBatchListMaxTasksPerQueue {
tasks = append(tasks, batchTaskMCPListSummary{
ID: t.ID,
Status: t.Status,
Message: truncateStringRunes(t.Message, mcpBatchListTaskMessageMaxRunes),
})
}
}
return batchTaskQueueMCPListItem{
ID: q.ID,
Title: q.Title,
Role: q.Role,
AgentMode: q.AgentMode,
ScheduleMode: q.ScheduleMode,
CronExpr: q.CronExpr,
NextRunAt: q.NextRunAt,
ScheduleEnabled: q.ScheduleEnabled,
LastScheduleTriggerAt: q.LastScheduleTriggerAt,
Status: q.Status,
CreatedAt: q.CreatedAt,
StartedAt: q.StartedAt,
CompletedAt: q.CompletedAt,
CurrentIndex: q.CurrentIndex,
TaskTotal: len(tasks),
TaskCounts: counts,
Tasks: tasks,
}
}
func batchMCPTextResult(text string, isErr bool) *mcp.ToolResult {
return &mcp.ToolResult{
Content: []mcp.Content{{Type: "text", Text: text}},
IsError: isErr,
}
}
func batchMCPJSONResult(v interface{}) (*mcp.ToolResult, error) {
b, err := json.MarshalIndent(v, "", " ")
if err != nil {
return batchMCPTextResult(fmt.Sprintf("JSON 编码失败: %v", err), true), nil
}
return &mcp.ToolResult{Content: []mcp.Content{{Type: "text", Text: string(b)}}}, nil
}
func batchMCPTasksFromArgs(args map[string]interface{}) ([]string, string) {
if raw, ok := args["tasks"]; ok && raw != nil {
switch t := raw.(type) {
case []interface{}:
out := make([]string, 0, len(t))
for _, x := range t {
if s, ok := x.(string); ok {
if tr := strings.TrimSpace(s); tr != "" {
out = append(out, tr)
}
}
}
if len(out) > 0 {
return out, ""
}
}
}
if txt := mcpArgString(args, "tasks_text"); txt != "" {
lines := strings.Split(txt, "\n")
out := make([]string, 0, len(lines))
for _, line := range lines {
if tr := strings.TrimSpace(line); tr != "" {
out = append(out, tr)
}
}
if len(out) > 0 {
return out, ""
}
}
return nil, "需要提供 tasks(字符串数组)或 tasks_text(多行文本,每行一条任务)"
}
func mcpArgString(args map[string]interface{}, key string) string {
v, ok := args[key]
if !ok || v == nil {
return ""
}
switch t := v.(type) {
case string:
return strings.TrimSpace(t)
case float64:
return strings.TrimSpace(strconv.FormatFloat(t, 'f', -1, 64))
case json.Number:
return strings.TrimSpace(t.String())
default:
return strings.TrimSpace(fmt.Sprint(t))
}
}
func mcpArgFloat(args map[string]interface{}, key string) float64 {
v, ok := args[key]
if !ok || v == nil {
return 0
}
switch t := v.(type) {
case float64:
return t
case int:
return float64(t)
case int64:
return float64(t)
case json.Number:
f, _ := t.Float64()
return f
case string:
f, _ := strconv.ParseFloat(strings.TrimSpace(t), 64)
return f
default:
return 0
}
}
func mcpArgBool(args map[string]interface{}, key string) (val bool, ok bool) {
v, exists := args[key]
if !exists {
return false, false
}
switch t := v.(type) {
case bool:
return t, true
case string:
s := strings.ToLower(strings.TrimSpace(t))
if s == "true" || s == "1" || s == "yes" {
return true, true
}
if s == "false" || s == "0" || s == "no" {
return false, true
}
case float64:
return t != 0, true
}
return false, false
}
-512
View File
@@ -1,512 +0,0 @@
package handler
import (
"crypto/rand"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"time"
"unicode/utf8"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
const (
chatUploadsRootDirName = "chat_uploads"
maxChatUploadEditBytes = 2 * 1024 * 1024 // 文本编辑上限
)
// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API
type ChatUploadsHandler struct {
logger *zap.Logger
}
// NewChatUploadsHandler 创建处理器
func NewChatUploadsHandler(logger *zap.Logger) *ChatUploadsHandler {
return &ChatUploadsHandler{logger: logger}
}
func (h *ChatUploadsHandler) absRoot() (string, error) {
cwd, err := os.Getwd()
if err != nil {
return "", err
}
return filepath.Abs(filepath.Join(cwd, chatUploadsRootDirName))
}
// resolveUnderChatUploads 校验 relativePath(使用 / 分隔)对应文件必须在 chat_uploads 根下
func (h *ChatUploadsHandler) resolveUnderChatUploads(relativePath string) (abs string, err error) {
root, err := h.absRoot()
if err != nil {
return "", err
}
rel := strings.TrimSpace(relativePath)
if rel == "" {
return "", fmt.Errorf("empty path")
}
rel = filepath.Clean(filepath.FromSlash(rel))
if rel == "." || strings.HasPrefix(rel, "..") {
return "", fmt.Errorf("invalid path")
}
full := filepath.Join(root, rel)
full, err = filepath.Abs(full)
if err != nil {
return "", err
}
rootAbs, _ := filepath.Abs(root)
if full != rootAbs && !strings.HasPrefix(full, rootAbs+string(filepath.Separator)) {
return "", fmt.Errorf("path escapes chat_uploads root")
}
return full, nil
}
// ChatUploadFileItem 列表项
type ChatUploadFileItem struct {
RelativePath string `json:"relativePath"`
AbsolutePath string `json:"absolutePath"` // 服务器上的绝对路径,便于在对话中引用(与附件落盘路径一致)
Name string `json:"name"`
Size int64 `json:"size"`
ModifiedUnix int64 `json:"modifiedUnix"`
Date string `json:"date"`
ConversationID string `json:"conversationId"`
// SubPath 为日期、会话目录之下的子路径(不含文件名),如 date/conv/a/b/file 则为 "a/b";无嵌套则为 ""。
SubPath string `json:"subPath"`
}
// List GET /api/chat-uploads
func (h *ChatUploadsHandler) List(c *gin.Context) {
conversationFilter := strings.TrimSpace(c.Query("conversation"))
root, err := h.absRoot()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 保证根目录存在,否则「按文件夹」浏览时无法 mkdir,且首次列表为空时界面无路径工具栏
if err := os.MkdirAll(root, 0755); err != nil {
h.logger.Warn("创建 chat_uploads 根目录失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var files []ChatUploadFileItem
var folders []string
err = filepath.WalkDir(root, func(path string, d os.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
rel, err := filepath.Rel(root, path)
if err != nil {
return err
}
if rel == "." {
return nil
}
relSlash := filepath.ToSlash(rel)
if d.IsDir() {
folders = append(folders, relSlash)
return nil
}
info, err := d.Info()
if err != nil {
return err
}
parts := strings.Split(relSlash, "/")
var dateStr, convID string
if len(parts) >= 2 {
dateStr = parts[0]
}
if len(parts) >= 3 {
convID = parts[1]
}
var subPath string
if len(parts) >= 4 {
subPath = strings.Join(parts[2:len(parts)-1], "/")
}
if conversationFilter != "" && convID != conversationFilter {
return nil
}
absPath, _ := filepath.Abs(path)
files = append(files, ChatUploadFileItem{
RelativePath: relSlash,
AbsolutePath: absPath,
Name: d.Name(),
Size: info.Size(),
ModifiedUnix: info.ModTime().Unix(),
Date: dateStr,
ConversationID: convID,
SubPath: subPath,
})
return nil
})
if err != nil {
h.logger.Warn("列举对话附件失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if conversationFilter != "" {
filteredFolders := make([]string, 0, len(folders))
for _, rel := range folders {
parts := strings.Split(rel, "/")
if len(parts) >= 2 && parts[1] == conversationFilter {
filteredFolders = append(filteredFolders, rel)
continue
}
if len(parts) == 1 {
prefix := rel + "/"
for _, f := range files {
if strings.HasPrefix(f.RelativePath, prefix) {
filteredFolders = append(filteredFolders, rel)
break
}
}
}
}
folders = filteredFolders
}
sort.Strings(folders)
sort.Slice(files, func(i, j int) bool {
return files[i].ModifiedUnix > files[j].ModifiedUnix
})
c.JSON(http.StatusOK, gin.H{"files": files, "folders": folders})
}
// Download GET /api/chat-uploads/download?path=...
func (h *ChatUploadsHandler) Download(c *gin.Context) {
p := c.Query("path")
abs, err := h.resolveUnderChatUploads(p)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
st, err := os.Stat(abs)
if err != nil || st.IsDir() {
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
return
}
c.FileAttachment(abs, filepath.Base(abs))
}
type chatUploadPathBody struct {
Path string `json:"path"`
}
// Delete DELETE /api/chat-uploads
func (h *ChatUploadsHandler) Delete(c *gin.Context) {
var body chatUploadPathBody
if err := c.ShouldBindJSON(&body); err != nil || strings.TrimSpace(body.Path) == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
abs, err := h.resolveUnderChatUploads(body.Path)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
st, err := os.Stat(abs)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if st.IsDir() {
if err := os.RemoveAll(abs); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else {
if err := os.Remove(abs); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
c.JSON(http.StatusOK, gin.H{"ok": true})
}
type chatUploadMkdirBody struct {
Parent string `json:"parent"`
Name string `json:"name"`
}
// Mkdir POST /api/chat-uploads/mkdir — 在 parent 目录下新建子目录(parent 为 chat_uploads 下相对路径,空表示根目录;name 为单段目录名)
func (h *ChatUploadsHandler) Mkdir(c *gin.Context) {
var body chatUploadMkdirBody
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
name := strings.TrimSpace(body.Name)
if name == "" || strings.ContainsAny(name, `/\`) || name == "." || name == ".." {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid name"})
return
}
if utf8.RuneCountInString(name) > 200 {
c.JSON(http.StatusBadRequest, gin.H{"error": "name too long"})
return
}
parent := strings.TrimSpace(body.Parent)
parent = filepath.ToSlash(filepath.Clean(filepath.FromSlash(parent)))
parent = strings.Trim(parent, "/")
if parent == "." {
parent = ""
}
root, err := h.absRoot()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if parent != "" {
absParent, err := h.resolveUnderChatUploads(parent)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
st, err := os.Stat(absParent)
if err != nil || !st.IsDir() {
c.JSON(http.StatusBadRequest, gin.H{"error": "parent not found"})
return
}
}
var rel string
if parent == "" {
rel = name
} else {
rel = parent + "/" + name
}
absNew, err := h.resolveUnderChatUploads(rel)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(absNew); err == nil {
c.JSON(http.StatusConflict, gin.H{"error": "already exists"})
return
}
if err := os.Mkdir(absNew, 0755); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
relOut, _ := filepath.Rel(root, absNew)
c.JSON(http.StatusOK, gin.H{"ok": true, "relativePath": filepath.ToSlash(relOut)})
}
type chatUploadRenameBody struct {
Path string `json:"path"`
NewName string `json:"newName"`
}
// Rename PUT /api/chat-uploads/rename
func (h *ChatUploadsHandler) Rename(c *gin.Context) {
var body chatUploadRenameBody
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
newName := strings.TrimSpace(body.NewName)
if newName == "" || strings.ContainsAny(newName, `/\`) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid newName"})
return
}
abs, err := h.resolveUnderChatUploads(body.Path)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
dir := filepath.Dir(abs)
newAbs := filepath.Join(dir, filepath.Base(newName))
root, _ := h.absRoot()
newAbs, _ = filepath.Abs(newAbs)
if newAbs != root && !strings.HasPrefix(newAbs, root+string(filepath.Separator)) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid target path"})
return
}
if err := os.Rename(abs, newAbs); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
newRel, _ := filepath.Rel(root, newAbs)
c.JSON(http.StatusOK, gin.H{"ok": true, "relativePath": filepath.ToSlash(newRel)})
}
type chatUploadContentBody struct {
Path string `json:"path"`
Content string `json:"content"`
}
// GetContent GET /api/chat-uploads/content?path=...
func (h *ChatUploadsHandler) GetContent(c *gin.Context) {
p := c.Query("path")
abs, err := h.resolveUnderChatUploads(p)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
st, err := os.Stat(abs)
if err != nil || st.IsDir() {
c.JSON(http.StatusNotFound, gin.H{"error": "file not found"})
return
}
if st.Size() > maxChatUploadEditBytes {
c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "file too large for editor"})
return
}
b, err := os.ReadFile(abs)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !utf8.Valid(b) {
c.JSON(http.StatusBadRequest, gin.H{"error": "binary file not editable in UI"})
return
}
c.JSON(http.StatusOK, gin.H{"content": string(b)})
}
// PutContent PUT /api/chat-uploads/content
func (h *ChatUploadsHandler) PutContent(c *gin.Context) {
var body chatUploadContentBody
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
if !utf8.ValidString(body.Content) {
c.JSON(http.StatusBadRequest, gin.H{"error": "content must be valid UTF-8"})
return
}
if len(body.Content) > maxChatUploadEditBytes {
c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "content too large"})
return
}
abs, err := h.resolveUnderChatUploads(body.Path)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := os.WriteFile(abs, []byte(body.Content), 0644); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
}
func chatUploadShortRand(n int) string {
const letters = "0123456789abcdef"
b := make([]byte, n)
_, _ = rand.Read(b)
for i := range b {
b[i] = letters[int(b[i])%len(letters)]
}
return string(b)
}
// Upload POST /api/chat-uploads multipart: fileconversationId 可选;relativeDir 可选(chat_uploads 下目录的相对路径,将文件直接上传至该目录)
func (h *ChatUploadsHandler) Upload(c *gin.Context) {
fh, err := c.FormFile("file")
if err != nil || fh == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing file"})
return
}
root, err := h.absRoot()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var targetDir string
targetRel := strings.TrimSpace(c.PostForm("relativeDir"))
if targetRel != "" {
absDir, err := h.resolveUnderChatUploads(targetRel)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
st, err := os.Stat(absDir)
if err != nil {
if os.IsNotExist(err) {
if err := os.MkdirAll(absDir, 0755); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} else if !st.IsDir() {
c.JSON(http.StatusBadRequest, gin.H{"error": "relativeDir is not a directory"})
return
}
targetDir = absDir
} else {
convID := strings.TrimSpace(c.PostForm("conversationId"))
convDir := convID
if convDir == "" {
convDir = "_manual"
} else {
convDir = strings.ReplaceAll(convDir, string(filepath.Separator), "_")
}
dateStr := time.Now().Format("2006-01-02")
targetDir = filepath.Join(root, dateStr, convDir)
if err := os.MkdirAll(targetDir, 0755); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
baseName := filepath.Base(fh.Filename)
if baseName == "" || baseName == "." {
baseName = "file"
}
baseName = strings.ReplaceAll(baseName, string(filepath.Separator), "_")
ext := filepath.Ext(baseName)
nameNoExt := strings.TrimSuffix(baseName, ext)
suffix := fmt.Sprintf("_%s_%s", time.Now().Format("150405"), chatUploadShortRand(6))
var unique string
if ext != "" {
unique = nameNoExt + suffix + ext
} else {
unique = baseName + suffix
}
fullPath := filepath.Join(targetDir, unique)
src, err := fh.Open()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
defer src.Close()
dst, err := os.Create(fullPath)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer dst.Close()
if _, err := io.Copy(dst, src); err != nil {
_ = os.Remove(fullPath)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rel, _ := filepath.Rel(root, fullPath)
absSaved, _ := filepath.Abs(fullPath)
c.JSON(http.StatusOK, gin.H{
"ok": true,
"relativePath": filepath.ToSlash(rel),
"absolutePath": absSaved,
"name": unique,
})
}
File diff suppressed because it is too large Load Diff
-233
View File
@@ -1,233 +0,0 @@
package handler
import (
"encoding/json"
"net/http"
"strconv"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ConversationHandler 对话处理器
type ConversationHandler struct {
db *database.DB
logger *zap.Logger
}
// NewConversationHandler 创建新的对话处理器
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
return &ConversationHandler{
db: db,
logger: logger,
}
}
// CreateConversationRequest 创建对话请求
type CreateConversationRequest struct {
Title string `json:"title"`
}
// CreateConversation 创建新对话
func (h *ConversationHandler) CreateConversation(c *gin.Context) {
var req CreateConversationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
title := req.Title
if title == "" {
title = "新对话"
}
conv, err := h.db.CreateConversation(title)
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, conv)
}
// ListConversations 列出对话
func (h *ConversationHandler) ListConversations(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "50")
offsetStr := c.DefaultQuery("offset", "0")
search := c.Query("search") // 获取搜索参数
limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr)
if limit <= 0 || limit > 100 {
limit = 50
}
conversations, err := h.db.ListConversations(limit, offset, search)
if err != nil {
h.logger.Error("获取对话列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, conversations)
}
// GetConversation 获取对话
func (h *ConversationHandler) GetConversation(c *gin.Context) {
id := c.Param("id")
// 默认轻量加载,只有用户需要展开详情时再按需拉取
// include_process_details=1/true 时返回全量 processDetails(兼容旧行为)
includeStr := c.DefaultQuery("include_process_details", "0")
include := includeStr == "1" || includeStr == "true" || includeStr == "yes"
var (
conv *database.Conversation
err error
)
if include {
conv, err = h.db.GetConversation(id)
} else {
conv, err = h.db.GetConversationLite(id)
}
if err != nil {
h.logger.Error("获取对话失败", zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
c.JSON(http.StatusOK, conv)
}
// GetMessageProcessDetails 获取指定消息的过程详情(按需加载)
func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
messageID := c.Param("id")
if messageID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "message id required"})
return
}
details, err := h.db.GetProcessDetails(messageID)
if err != nil {
h.logger.Error("获取过程详情失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
out := make([]map[string]interface{}, 0, len(details))
for _, d := range details {
var data interface{}
if d.Data != "" {
if err := json.Unmarshal([]byte(d.Data), &data); err != nil {
h.logger.Warn("解析过程详情数据失败", zap.Error(err))
}
}
out = append(out, map[string]interface{}{
"id": d.ID,
"messageId": d.MessageID,
"conversationId": d.ConversationID,
"eventType": d.EventType,
"message": d.Message,
"data": data,
"createdAt": d.CreatedAt,
})
}
c.JSON(http.StatusOK, gin.H{"processDetails": out})
}
// UpdateConversationRequest 更新对话请求
type UpdateConversationRequest struct {
Title string `json:"title"`
}
// UpdateConversation 更新对话
func (h *ConversationHandler) UpdateConversation(c *gin.Context) {
id := c.Param("id")
var req UpdateConversationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Title == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "标题不能为空"})
return
}
if err := h.db.UpdateConversationTitle(id, req.Title); err != nil {
h.logger.Error("更新对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 返回更新后的对话
conv, err := h.db.GetConversation(id)
if err != nil {
h.logger.Error("获取更新后的对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, conv)
}
// DeleteConversation 删除对话
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
id := c.Param("id")
if err := h.db.DeleteConversation(id); err != nil {
h.logger.Error("删除对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// DeleteTurnRequest 删除一轮对话(POST /api/conversations/:id/delete-turn
type DeleteTurnRequest struct {
MessageID string `json:"messageId"`
}
// DeleteConversationTurn 删除锚点消息所在轮次(从该轮 user 到下一轮 user 之前),并清空 last_react_*。
func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
conversationID := c.Param("id")
if conversationID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "conversation id required"})
return
}
var req DeleteTurnRequest
if err := c.ShouldBindJSON(&req); err != nil || req.MessageID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "messageId required"})
return
}
if _, err := h.db.GetConversation(conversationID); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
deletedIDs, err := h.db.DeleteConversationTurn(conversationID, req.MessageID)
if err != nil {
h.logger.Warn("删除对话轮次失败",
zap.String("conversationId", conversationID),
zap.String("messageId", req.MessageID),
zap.Error(err),
)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"deletedMessageIds": deletedIDs,
"message": "ok",
})
}
-542
View File
@@ -1,542 +0,0 @@
package handler
import (
"fmt"
"net/http"
"os"
"sync"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
// ExternalMCPHandler 外部MCP处理器
type ExternalMCPHandler struct {
manager *mcp.ExternalMCPManager
config *config.Config
configPath string
logger *zap.Logger
mu sync.RWMutex
}
// NewExternalMCPHandler 创建外部MCP处理器
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
return &ExternalMCPHandler{
manager: manager,
config: cfg,
configPath: configPath,
logger: logger,
}
}
// GetExternalMCPs 获取所有外部MCP配置
func (h *ExternalMCPHandler) GetExternalMCPs(c *gin.Context) {
h.mu.RLock()
defer h.mu.RUnlock()
configs := h.manager.GetConfigs()
// 获取所有外部MCP的工具数量
toolCounts := h.manager.GetToolCounts()
// 转换为响应格式
result := make(map[string]ExternalMCPResponse)
for name, cfg := range configs {
client, exists := h.manager.GetClient(name)
status := "disconnected"
if exists {
status = client.GetStatus()
} else if h.isEnabled(cfg) {
status = "disconnected"
} else {
status = "disabled"
}
toolCount := toolCounts[name]
errorMsg := ""
if status == "error" {
errorMsg = h.manager.GetError(name)
}
result[name] = ExternalMCPResponse{
Config: cfg,
Status: status,
ToolCount: toolCount,
Error: errorMsg,
}
}
c.JSON(http.StatusOK, gin.H{
"servers": result,
"stats": h.manager.GetStats(),
})
}
// GetExternalMCP 获取单个外部MCP配置
func (h *ExternalMCPHandler) GetExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.RLock()
defer h.mu.RUnlock()
configs := h.manager.GetConfigs()
cfg, exists := configs[name]
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "外部MCP配置不存在"})
return
}
client, clientExists := h.manager.GetClient(name)
status := "disconnected"
if clientExists {
status = client.GetStatus()
} else if h.isEnabled(cfg) {
status = "disconnected"
} else {
status = "disabled"
}
// 获取工具数量
toolCount := 0
if clientExists && client.IsConnected() {
if count, err := h.manager.GetToolCount(name); err == nil {
toolCount = count
}
}
// 获取错误信息
errorMsg := ""
if status == "error" {
errorMsg = h.manager.GetError(name)
}
c.JSON(http.StatusOK, ExternalMCPResponse{
Config: cfg,
Status: status,
ToolCount: toolCount,
Error: errorMsg,
})
}
// AddOrUpdateExternalMCP 添加或更新外部MCP配置
func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
var req AddOrUpdateExternalMCPRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
name := c.Param("name")
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "名称不能为空"})
return
}
// 验证配置
if err := h.validateConfig(req.Config); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.mu.Lock()
defer h.mu.Unlock()
// 添加或更新配置
if err := h.manager.AddOrUpdateConfig(name, req.Config); err != nil {
h.logger.Error("添加或更新外部MCP配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加或更新配置失败: " + err.Error()})
return
}
// 更新内存中的配置
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
// 如果用户提供了 disabled 或 enabled 字段,保留它们以保持向后兼容
// 同时将值迁移到 external_mcp_enable
cfg := req.Config
if req.Config.Disabled {
// 用户设置了 disabled: true
cfg.ExternalMCPEnable = false
cfg.Disabled = true
cfg.Enabled = false
} else if req.Config.Enabled {
// 用户设置了 enabled: true
cfg.ExternalMCPEnable = true
cfg.Enabled = true
cfg.Disabled = false
} else if !req.Config.ExternalMCPEnable {
// 用户没有设置任何字段,且 external_mcp_enable 为 false
// 检查现有配置是否有旧字段
if existingCfg, exists := h.config.ExternalMCP.Servers[name]; exists {
// 保留现有的旧字段
cfg.Enabled = existingCfg.Enabled
cfg.Disabled = existingCfg.Disabled
}
} else {
// 用户通过新字段启用了(external_mcp_enable: true),但没有设置旧字段
// 为了向后兼容,我们设置 enabled: true
// 这样即使原始配置中有 disabled: false,也会被转换为 enabled: true
cfg.Enabled = true
cfg.Disabled = false
}
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
}
// DeleteExternalMCP 删除外部MCP配置
func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 移除配置
if err := h.manager.RemoveConfig(name); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "配置不存在"})
return
}
// 从内存配置中删除
if h.config.ExternalMCP.Servers != nil {
delete(h.config.ExternalMCP.Servers, name)
}
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
}
// StartExternalMCP 启动外部MCP
func (h *ExternalMCPHandler) StartExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 更新配置为启用
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
cfg := h.config.ExternalMCP.Servers[name]
cfg.ExternalMCPEnable = true
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
// 启动客户端(立即创建客户端并设置状态为connecting,实际连接在后台进行)
h.logger.Info("开始启动外部MCP", zap.String("name", name))
if err := h.manager.StartClient(name); err != nil {
h.logger.Error("启动外部MCP失败", zap.String("name", name), zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{
"error": err.Error(),
"status": "error",
})
return
}
// 获取客户端状态(应该是connecting)
client, exists := h.manager.GetClient(name)
status := "connecting"
if exists {
status = client.GetStatus()
}
// 立即返回,不等待连接完成
// 客户端会在后台异步连接,用户可以通过状态查询接口查看连接状态
c.JSON(http.StatusOK, gin.H{
"message": "外部MCP启动请求已提交,正在后台连接中",
"status": status,
})
}
// StopExternalMCP 停止外部MCP
func (h *ExternalMCPHandler) StopExternalMCP(c *gin.Context) {
name := c.Param("name")
h.mu.Lock()
defer h.mu.Unlock()
// 停止客户端
if err := h.manager.StopClient(name); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 更新配置
if h.config.ExternalMCP.Servers == nil {
h.config.ExternalMCP.Servers = make(map[string]config.ExternalMCPServerConfig)
}
cfg := h.config.ExternalMCP.Servers[name]
cfg.ExternalMCPEnable = false
h.config.ExternalMCP.Servers[name] = cfg
// 保存到配置文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("外部MCP已停止", zap.String("name", name))
c.JSON(http.StatusOK, gin.H{"message": "外部MCP已停止"})
}
// GetExternalMCPStats 获取统计信息
func (h *ExternalMCPHandler) GetExternalMCPStats(c *gin.Context) {
stats := h.manager.GetStats()
c.JSON(http.StatusOK, stats)
}
// validateConfig 验证配置
func (h *ExternalMCPHandler) validateConfig(cfg config.ExternalMCPServerConfig) error {
transport := cfg.Transport
if transport == "" {
// 如果没有指定transport,根据是否有command或url判断
if cfg.Command != "" {
transport = "stdio"
} else if cfg.URL != "" {
transport = "http"
} else {
return fmt.Errorf("需要指定commandstdio模式)或urlhttp/sse模式)")
}
}
switch transport {
case "http":
if cfg.URL == "" {
return fmt.Errorf("HTTP模式需要URL")
}
case "stdio":
if cfg.Command == "" {
return fmt.Errorf("stdio模式需要command")
}
case "sse":
if cfg.URL == "" {
return fmt.Errorf("SSE模式需要URL")
}
default:
return fmt.Errorf("不支持的传输模式: %s,支持的模式: http, stdio, sse", transport)
}
return nil
}
// isEnabled 检查是否启用
func (h *ExternalMCPHandler) isEnabled(cfg config.ExternalMCPServerConfig) bool {
// 优先使用 ExternalMCPEnable 字段
// 如果没有设置,检查旧的 enabled/disabled 字段(向后兼容)
if cfg.ExternalMCPEnable {
return true
}
// 向后兼容:检查旧字段
if cfg.Disabled {
return false
}
if cfg.Enabled {
return true
}
// 都没有设置,默认为启用
return true
}
// saveConfig 保存配置到文件
func (h *ExternalMCPHandler) saveConfig() error {
// 读取现有配置文件并创建备份
data, err := os.ReadFile(h.configPath)
if err != nil {
return fmt.Errorf("读取配置文件失败: %w", err)
}
if err := os.WriteFile(h.configPath+".backup", data, 0644); err != nil {
h.logger.Warn("创建配置备份失败", zap.Error(err))
}
root, err := loadYAMLDocument(h.configPath)
if err != nil {
return fmt.Errorf("解析配置文件失败: %w", err)
}
// 在更新前,读取原始配置中的 enabled/disabled 字段,以便保持向后兼容
originalConfigs := make(map[string]map[string]bool)
externalMCPNode := findMapValue(root.Content[0], "external_mcp")
if externalMCPNode != nil && externalMCPNode.Kind == yaml.MappingNode {
serversNode := findMapValue(externalMCPNode, "servers")
if serversNode != nil && serversNode.Kind == yaml.MappingNode {
// 遍历现有的服务器配置,保存 enabled/disabled 字段
for i := 0; i < len(serversNode.Content); i += 2 {
if i+1 >= len(serversNode.Content) {
break
}
nameNode := serversNode.Content[i]
serverNode := serversNode.Content[i+1]
if nameNode.Kind == yaml.ScalarNode && serverNode.Kind == yaml.MappingNode {
serverName := nameNode.Value
originalConfigs[serverName] = make(map[string]bool)
// 检查是否有 enabled 字段
if enabledVal := findBoolInMap(serverNode, "enabled"); enabledVal != nil {
originalConfigs[serverName]["enabled"] = *enabledVal
}
// 检查是否有 disabled 字段
if disabledVal := findBoolInMap(serverNode, "disabled"); disabledVal != nil {
originalConfigs[serverName]["disabled"] = *disabledVal
}
}
}
}
}
// 更新外部MCP配置
updateExternalMCPConfig(root, h.config.ExternalMCP, originalConfigs)
if err := writeYAMLDocument(h.configPath, root); err != nil {
return fmt.Errorf("保存配置文件失败: %w", err)
}
h.logger.Info("配置已保存", zap.String("path", h.configPath))
return nil
}
// updateExternalMCPConfig 更新外部MCP配置
func updateExternalMCPConfig(doc *yaml.Node, cfg config.ExternalMCPConfig, originalConfigs map[string]map[string]bool) {
root := doc.Content[0]
externalMCPNode := ensureMap(root, "external_mcp")
serversNode := ensureMap(externalMCPNode, "servers")
// 清空现有服务器配置
serversNode.Content = nil
// 添加新的服务器配置
for name, serverCfg := range cfg.Servers {
// 添加服务器名称键
nameNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: name}
serverNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
serversNode.Content = append(serversNode.Content, nameNode, serverNode)
// 设置服务器配置字段
if serverCfg.Command != "" {
setStringInMap(serverNode, "command", serverCfg.Command)
}
if len(serverCfg.Args) > 0 {
setStringArrayInMap(serverNode, "args", serverCfg.Args)
}
// 保存 env 字段(环境变量)
if serverCfg.Env != nil && len(serverCfg.Env) > 0 {
envNode := ensureMap(serverNode, "env")
for envKey, envValue := range serverCfg.Env {
setStringInMap(envNode, envKey, envValue)
}
}
if serverCfg.Transport != "" {
setStringInMap(serverNode, "transport", serverCfg.Transport)
}
if serverCfg.URL != "" {
setStringInMap(serverNode, "url", serverCfg.URL)
}
// 保存 headers 字段(HTTP/SSE 请求头)
if serverCfg.Headers != nil && len(serverCfg.Headers) > 0 {
headersNode := ensureMap(serverNode, "headers")
for k, v := range serverCfg.Headers {
setStringInMap(headersNode, k, v)
}
}
if serverCfg.Description != "" {
setStringInMap(serverNode, "description", serverCfg.Description)
}
if serverCfg.Timeout > 0 {
setIntInMap(serverNode, "timeout", serverCfg.Timeout)
}
// 保存 external_mcp_enable 字段(新字段)
setBoolInMap(serverNode, "external_mcp_enable", serverCfg.ExternalMCPEnable)
// 保存 tool_enabled 字段(每个工具的启用状态)
if serverCfg.ToolEnabled != nil && len(serverCfg.ToolEnabled) > 0 {
toolEnabledNode := ensureMap(serverNode, "tool_enabled")
for toolName, enabled := range serverCfg.ToolEnabled {
setBoolInMap(toolEnabledNode, toolName, enabled)
}
}
// 保留旧的 enabled/disabled 字段以保持向后兼容
originalFields, hasOriginal := originalConfigs[name]
// 如果原始配置中有 enabled 字段,保留它
if hasOriginal {
if enabledVal, hasEnabled := originalFields["enabled"]; hasEnabled {
setBoolInMap(serverNode, "enabled", enabledVal)
}
// 如果原始配置中有 disabled 字段,保留它
// 注意:由于 omitemptydisabled: false 不会被保存,但 disabled: true 会被保存
if disabledVal, hasDisabled := originalFields["disabled"]; hasDisabled {
if disabledVal {
setBoolInMap(serverNode, "disabled", disabledVal)
} else {
// 如果原始配置中有 disabled: false,我们保存 enabled: true 来等效表示
// 因为 disabled: false 等价于 enabled: true
setBoolInMap(serverNode, "enabled", true)
}
}
}
// 如果用户在当前请求中明确设置了这些字段,也保存它们
if serverCfg.Enabled {
setBoolInMap(serverNode, "enabled", serverCfg.Enabled)
}
if serverCfg.Disabled {
setBoolInMap(serverNode, "disabled", serverCfg.Disabled)
} else if !hasOriginal && serverCfg.ExternalMCPEnable {
// 如果用户通过新字段启用了,且原始配置中没有旧字段,保存 enabled: true 以保持向后兼容
setBoolInMap(serverNode, "enabled", true)
}
}
}
// setStringArrayInMap 设置字符串数组
func setStringArrayInMap(mapNode *yaml.Node, key string, values []string) {
_, valueNode := ensureKeyValue(mapNode, key)
valueNode.Kind = yaml.SequenceNode
valueNode.Tag = "!!seq"
valueNode.Content = nil
for _, v := range values {
itemNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: v}
valueNode.Content = append(valueNode.Content, itemNode)
}
}
// AddOrUpdateExternalMCPRequest 添加或更新外部MCP请求
type AddOrUpdateExternalMCPRequest struct {
Config config.ExternalMCPServerConfig `json:"config"`
}
// ExternalMCPResponse 外部MCP响应
type ExternalMCPResponse struct {
Config config.ExternalMCPServerConfig `json:"config"`
Status string `json:"status"` // "connected", "disconnected", "disabled", "error", "connecting"
ToolCount int `json:"tool_count"` // 工具数量
Error string `json:"error,omitempty"` // 错误信息(仅在status为error时存在)
}
-518
View File
@@ -1,518 +0,0 @@
package handler
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
func setupTestRouter() (*gin.Engine, *ExternalMCPHandler, string) {
gin.SetMode(gin.TestMode)
router := gin.New()
// 创建临时配置文件
tmpFile, err := os.CreateTemp("", "test-config-*.yaml")
if err != nil {
panic(err)
}
tmpFile.WriteString("server:\n host: 0.0.0.0\n port: 8080\n")
tmpFile.Close()
configPath := tmpFile.Name()
logger := zap.NewNop()
manager := mcp.NewExternalMCPManager(logger)
cfg := &config.Config{
ExternalMCP: config.ExternalMCPConfig{
Servers: make(map[string]config.ExternalMCPServerConfig),
},
}
handler := NewExternalMCPHandler(manager, cfg, configPath, logger)
api := router.Group("/api")
api.GET("/external-mcp", handler.GetExternalMCPs)
api.GET("/external-mcp/stats", handler.GetExternalMCPStats)
api.GET("/external-mcp/:name", handler.GetExternalMCP)
api.PUT("/external-mcp/:name", handler.AddOrUpdateExternalMCP)
api.DELETE("/external-mcp/:name", handler.DeleteExternalMCP)
api.POST("/external-mcp/:name/start", handler.StartExternalMCP)
api.POST("/external-mcp/:name/stop", handler.StopExternalMCP)
return router, handler, configPath
}
func cleanupTestConfig(configPath string) {
os.Remove(configPath)
os.Remove(configPath + ".backup")
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_Stdio(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 测试添加stdio模式的配置
configJSON := `{
"command": "python3",
"args": ["/path/to/script.py", "--server", "http://example.com"],
"description": "Test stdio MCP",
"timeout": 300,
"enabled": true
}`
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-stdio", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已添加
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-stdio", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.Command != "python3" {
t.Errorf("期望command为python3,实际%s", response.Config.Command)
}
if len(response.Config.Args) != 3 {
t.Errorf("期望args长度为3,实际%d", len(response.Config.Args))
}
if response.Config.Description != "Test stdio MCP" {
t.Errorf("期望description为'Test stdio MCP',实际%s", response.Config.Description)
}
if response.Config.Timeout != 300 {
t.Errorf("期望timeout为300,实际%d", response.Config.Timeout)
}
if !response.Config.Enabled {
t.Error("期望enabled为true")
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_HTTP(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 测试添加HTTP模式的配置
configJSON := `{
"transport": "http",
"url": "http://127.0.0.1:8081/mcp",
"enabled": true
}`
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-http", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已添加
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-http", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.Transport != "http" {
t.Errorf("期望transport为http,实际%s", response.Config.Transport)
}
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
}
if !response.Config.Enabled {
t.Error("期望enabled为true")
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidConfig(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
testCases := []struct {
name string
configJSON string
expectedErr string
}{
{
name: "缺少command和url",
configJSON: `{"enabled": true}`,
expectedErr: "需要指定commandstdio模式)或urlhttp/sse模式)",
},
{
name: "stdio模式缺少command",
configJSON: `{"args": ["test"], "enabled": true}`,
expectedErr: "stdio模式需要command",
},
{
name: "http模式缺少url",
configJSON: `{"transport": "http", "enabled": true}`,
expectedErr: "HTTP模式需要URL",
},
{
name: "无效的transport",
configJSON: `{"transport": "invalid", "enabled": true}`,
expectedErr: "不支持的传输模式",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var configObj config.ExternalMCPServerConfig
if err := json.Unmarshal([]byte(tc.configJSON), &configObj); err != nil {
t.Fatalf("解析配置JSON失败: %v", err)
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-invalid", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
errorMsg := response["error"].(string)
// 对于stdio模式缺少command的情况,错误信息可能略有不同
if tc.name == "stdio模式缺少command" {
if !strings.Contains(errorMsg, "stdio") && !strings.Contains(errorMsg, "command") {
t.Errorf("期望错误信息包含'stdio'或'command',实际'%s'", errorMsg)
}
} else if !strings.Contains(errorMsg, tc.expectedErr) {
t.Errorf("期望错误信息包含'%s',实际'%s'", tc.expectedErr, errorMsg)
}
})
}
}
func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 先添加一个配置
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
}
handler.manager.AddOrUpdateConfig("test-delete", configObj)
// 删除配置
req := httptest.NewRequest("DELETE", "/api/external-mcp/test-delete", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已删除
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-delete", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusNotFound {
t.Errorf("期望状态码404,实际%d: %s", w2.Code, w2.Body.String())
}
}
func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
router, handler, _ := setupTestRouter()
// 添加多个配置
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
})
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: false,
})
req := httptest.NewRequest("GET", "/api/external-mcp", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
servers := response["servers"].(map[string]interface{})
if len(servers) != 2 {
t.Errorf("期望2个服务器,实际%d", len(servers))
}
if _, ok := servers["test1"]; !ok {
t.Error("期望包含test1")
}
if _, ok := servers["test2"]; !ok {
t.Error("期望包含test2")
}
stats := response["stats"].(map[string]interface{})
if int(stats["total"].(float64)) != 2 {
t.Errorf("期望总数为2,实际%d", int(stats["total"].(float64)))
}
}
func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
router, handler, _ := setupTestRouter()
// 添加配置
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
})
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: true,
})
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: false,
Disabled: true,
})
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
var stats map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if int(stats["total"].(float64)) != 3 {
t.Errorf("期望总数为3,实际%d", int(stats["total"].(float64)))
}
if int(stats["enabled"].(float64)) != 2 {
t.Errorf("期望启用数为2,实际%d", int(stats["enabled"].(float64)))
}
if int(stats["disabled"].(float64)) != 1 {
t.Errorf("期望停用数为1,实际%d", int(stats["disabled"].(float64)))
}
}
func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 添加一个禁用的配置
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: false,
Disabled: true,
})
// 测试启动(可能会失败,因为没有真实的服务器)
req := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/start", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 启动可能会失败,但应该返回合理的状态码
if w.Code != http.StatusOK {
// 如果启动失败,应该是400或500
if w.Code != http.StatusBadRequest && w.Code != http.StatusInternalServerError {
t.Errorf("期望状态码200/400/500,实际%d: %s", w.Code, w.Body.String())
}
}
// 测试停止
req2 := httptest.NewRequest("POST", "/api/external-mcp/test-start-stop/stop", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Errorf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
}
func TestExternalMCPHandler_GetExternalMCP_NotFound(t *testing.T) {
router, _, _ := setupTestRouter()
req := httptest.NewRequest("GET", "/api/external-mcp/nonexistent", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("期望状态码404,实际%d: %s", w.Code, w.Body.String())
}
}
func TestExternalMCPHandler_DeleteExternalMCP_NotFound(t *testing.T) {
router, _, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
req := httptest.NewRequest("DELETE", "/api/external-mcp/nonexistent", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 删除不存在的配置可能返回200(幂等操作)或404,都是合理的
if w.Code != http.StatusNotFound && w.Code != http.StatusOK {
t.Errorf("期望状态码404或200,实际%d: %s", w.Code, w.Body.String())
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
router, _, _ := setupTestRouter()
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: configObj,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 空名称应该返回404或400
if w.Code != http.StatusNotFound && w.Code != http.StatusBadRequest {
t.Errorf("期望状态码404或400,实际%d: %s", w.Code, w.Body.String())
}
}
func TestExternalMCPHandler_AddOrUpdateExternalMCP_InvalidJSON(t *testing.T) {
router, _, _ := setupTestRouter()
// 发送无效的JSON
body := []byte(`{"config": invalid json}`)
req := httptest.NewRequest("PUT", "/api/external-mcp/test", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("期望状态码400,实际%d: %s", w.Code, w.Body.String())
}
}
func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
router, handler, configPath := setupTestRouter()
defer cleanupTestConfig(configPath)
// 先添加配置
config1 := config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
}
handler.manager.AddOrUpdateConfig("test-update", config1)
// 更新配置
config2 := config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: true,
}
reqBody := AddOrUpdateExternalMCPRequest{
Config: config2,
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/external-mcp/test-update", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w.Code, w.Body.String())
}
// 验证配置已更新
req2 := httptest.NewRequest("GET", "/api/external-mcp/test-update", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusOK {
t.Fatalf("期望状态码200,实际%d: %s", w2.Code, w2.Body.String())
}
var response ExternalMCPResponse
if err := json.Unmarshal(w2.Body.Bytes(), &response); err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if response.Config.URL != "http://127.0.0.1:8081/mcp" {
t.Errorf("期望url为'http://127.0.0.1:8081/mcp',实际%s", response.Config.URL)
}
if response.Config.Command != "" {
t.Errorf("期望command为空,实际%s", response.Config.Command)
}
}
-467
View File
@@ -1,467 +0,0 @@
package handler
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"time"
"cyberstrike-ai/internal/config"
openaiClient "cyberstrike-ai/internal/openai"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type FofaHandler struct {
cfg *config.Config
logger *zap.Logger
client *http.Client
openAIClient *openaiClient.Client
}
func NewFofaHandler(cfg *config.Config, logger *zap.Logger) *FofaHandler {
// LLM 请求通常比 FOFA 查询更慢一点,单独给一个更宽松的超时。
llmHTTPClient := &http.Client{Timeout: 2 * time.Minute}
var llmCfg *config.OpenAIConfig
if cfg != nil {
llmCfg = &cfg.OpenAI
}
return &FofaHandler{
cfg: cfg,
logger: logger,
client: &http.Client{Timeout: 30 * time.Second},
openAIClient: openaiClient.NewClient(llmCfg, llmHTTPClient, logger),
}
}
type fofaSearchRequest struct {
Query string `json:"query" binding:"required"`
Size int `json:"size,omitempty"`
Page int `json:"page,omitempty"`
Fields string `json:"fields,omitempty"`
Full bool `json:"full,omitempty"`
}
type fofaParseRequest struct {
Text string `json:"text" binding:"required"`
}
type fofaParseResponse struct {
Query string `json:"query"`
Explanation string `json:"explanation,omitempty"`
Warnings []string `json:"warnings,omitempty"`
}
type fofaAPIResponse struct {
Error bool `json:"error"`
ErrMsg string `json:"errmsg"`
Size int `json:"size"`
Page int `json:"page"`
Total int `json:"total"`
Mode string `json:"mode"`
Query string `json:"query"`
Results [][]interface{} `json:"results"`
}
type fofaSearchResponse struct {
Query string `json:"query"`
Size int `json:"size"`
Page int `json:"page"`
Total int `json:"total"`
Fields []string `json:"fields"`
ResultsCount int `json:"results_count"`
Results []map[string]interface{} `json:"results"`
}
func (h *FofaHandler) resolveCredentials() (email, apiKey string) {
// 优先环境变量(便于容器部署),其次配置文件
email = strings.TrimSpace(os.Getenv("FOFA_EMAIL"))
apiKey = strings.TrimSpace(os.Getenv("FOFA_API_KEY"))
if email != "" && apiKey != "" {
return email, apiKey
}
if h.cfg != nil {
if email == "" {
email = strings.TrimSpace(h.cfg.FOFA.Email)
}
if apiKey == "" {
apiKey = strings.TrimSpace(h.cfg.FOFA.APIKey)
}
}
return email, apiKey
}
func (h *FofaHandler) resolveBaseURL() string {
if h.cfg != nil {
if v := strings.TrimSpace(h.cfg.FOFA.BaseURL); v != "" {
return v
}
}
return "https://fofa.info/api/v1/search/all"
}
// ParseNaturalLanguage 将自然语言解析为 FOFA 查询语法(仅生成,不执行查询)
func (h *FofaHandler) ParseNaturalLanguage(c *gin.Context) {
var req fofaParseRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
req.Text = strings.TrimSpace(req.Text)
if req.Text == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "text 不能为空"})
return
}
if h.cfg == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "系统配置未初始化"})
return
}
if strings.TrimSpace(h.cfg.OpenAI.APIKey) == "" || strings.TrimSpace(h.cfg.OpenAI.Model) == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "未配置 AI 模型:请在系统设置中填写 openai.api_key 与 openai.model(支持 OpenAI 兼容 API,如 DeepSeek",
"need": []string{"openai.api_key", "openai.model"},
})
return
}
if h.openAIClient == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "AI 客户端未初始化"})
return
}
systemPrompt := strings.TrimSpace(`
你是“FOFA 查询语法生成器”。任务:把用户输入的自然语言搜索意图,转换成 FOFA 查询语法。
输出要求(非常重要):
1) 只输出 JSON(不要 markdown、不要代码块、不要额外解释文本)
2) JSON 结构必须是:
{
"query": "stringFOFA查询语法(可直接粘贴到 FOFA 或本系统查询框)",
"explanation": "string,可选,解释你如何映射字段/逻辑",
"warnings": ["string"...] 可选,列出歧义/风险/需要人工确认的点
}
3) 如果用户输入本身已经是 FOFA 查询语法(或非常接近 FOFA 语法的表达式),应当“原样返回”为 query:
- 不要擅自改写字段名、操作符、括号结构
- 不要改写任何字符串值(尤其是地理位置类值),不要做缩写/同义词替换/翻译/音译
查询语法要点(来自 FOFA 语法参考):
- 逻辑连接符:&&(与)、||(或),必要时用 () 包住子表达式以确认优先级(括号优先级最高)
- 当同一层级同时出现 && 与 ||(混用)时,用 () 明确优先级(避免歧义)
- 比较/匹配:
- = 匹配;当字段="" 时,可查询“不存在该字段”或“值为空”的情况
- == 完全匹配;当字段=="" 时,可查询“字段存在且值为空”的情况
- != 不匹配;当字段!="" 时,可查询“值不为空”的情况
- *= 模糊匹配;可使用 * 或 ? 进行搜索
- 直接输入关键词(不带字段)会在标题、HTML内容、HTTP头、URL字段中搜索;但当意图明确时优先用字段表达(更可控、更准确)
字段示例速查(来自用户提供的案例,可直接套用/拼接):
- 高级搜索操作符示例:
- title="beijing" (= 匹配)
- title=="" (== 完全匹配,字段存在且值为空)
- title="" (= 匹配,可能表示字段不存在或值为空)
- title!="" (!= 不匹配,可用于值不为空)
- title*="*Home*" (*= 模糊匹配,用 * 或 ?)
- (app="Apache" || app="Nginx") && country="CN" (混用 && / || 时用括号)
- 基础类(General):
- ip="1.1.1.1"
- ip="220.181.111.1/24"
- ip="2600:9000:202a:2600:18:4ab7:f600:93a1"
- port="6379"
- domain="qq.com"
- host=".fofa.info"
- os="centos"
- server="Microsoft-IIS/10"
- asn="19551"
- org="LLC Baxet"
- is_domain=true / is_domain=false
- is_ipv6=true / is_ipv6=false
- 标记类(Special Label):
- app="Microsoft-Exchange"
- fid="sSXXGNUO2FefBTcCLIT/2Q=="
- product="NGINX"
- product="Roundcube-Webmail" && product.version="1.6.10"
- category="服务"
- type="service" / type="subdomain"
- cloud_name="Aliyundun"
- is_cloud=true / is_cloud=false
- is_fraud=true / is_fraud=false
- is_honeypot=true / is_honeypot=false
- 协议类(type=service):
- protocol="quic"
- banner="users"
- banner_hash="7330105010150477363"
- banner_fid="zRpqmn0FXQRjZpH8MjMX55zpMy9SgsW8"
- base_protocol="udp" / base_protocol="tcp"
- 网站类(type=subdomain):
- title="beijing"
- header="elastic"
- header_hash="1258854265"
- body="网络空间测绘"
- body_hash="-2090962452"
- js_name="js/jquery.js"
- js_md5="82ac3f14327a8b7ba49baa208d4eaa15"
- cname="customers.spektrix.com"
- cname_domain="siteforce.com"
- icon_hash="-247388890"
- status_code="402"
- icp="京ICP证030173号"
- sdk_hash="Are3qNnP2Eqn7q5kAoUO3l+w3mgVIytO"
- 地理位置(Location):
- country="CN" 或 country="中国"
- region="Zhejiang" 或 region="浙江"(仅支持中国地区中文)
- city="Hangzhou"
- 证书类(Certificate):
- cert="baidu"
- cert.subject="Oracle Corporation"
- cert.issuer="DigiCert"
- cert.subject.org="Oracle Corporation"
- cert.subject.cn="baidu.com"
- cert.issuer.org="cPanel, Inc."
- cert.issuer.cn="Synology Inc. CA"
- cert.domain="huawei.com"
- cert.is_equal=true / cert.is_equal=false
- cert.is_valid=true / cert.is_valid=false
- cert.is_match=true / cert.is_match=false
- cert.is_expired=true / cert.is_expired=false
- jarm="2ad2ad0002ad2ad22c2ad2ad2ad2ad2eac92ec34bcc0cf7520e97547f83e81"
- tls.version="TLS 1.3"
- tls.ja3s="15af977ce25de452b96affa2addb1036"
- cert.sn="356078156165546797850343536942784588840297"
- cert.not_after.after="2025-03-01" / cert.not_after.before="2025-03-01"
- cert.not_before.after="2025-03-01" / cert.not_before.before="2025-03-01"
- 时间类(Last update time):
- after="2023-01-01"
- before="2023-12-01"
- after="2023-01-01" && before="2023-12-01"
- 独立IP语法(需配合 ip_filter / ip_exclude):
- ip_filter(banner="SSH-2.0-OpenSSH_6.7p2") && ip_filter(icon_hash="-1057022626")
- ip_filter(banner="SSH-2.0-OpenSSH_6.7p2" && asn="3462") && ip_exclude(title="EdgeOS")
- port_size="6" / port_size_gt="6" / port_size_lt="12"
- ip_ports="80,161"
- ip_country="CN"
- ip_region="Zhejiang"
- ip_city="Hangzhou"
- ip_after="2021-03-18"
- ip_before="2019-09-09"
生成约束与注意事项:
- 字符串值一律用英文双引号包裹,例如 title="登录"、country="CN"
- 字符串值保持字面一致:不要缩写(例如 city="beijing" 不要变成 city="BJ"),不要用别名(例如 Beijing/Peking),不要擅自翻译/音译/改写大小写
- 地理位置字段(country/region/city)更倾向于“按用户给定值输出”;不确定合法取值时,不要猜测,把备选写进 warnings
- 不要捏造不存在的 FOFA 字段;不确定时把不确定点写进 warnings,并输出一个保守的 query
- 当用户描述里有“多个与/或条件”,优先加 () 明确优先级,例如:(app="Apache" || app="Nginx") && country="CN"
- 当用户缺少关键条件导致范围过大或歧义(如地点/协议/端口/服务类型未说明),允许 query 为空字符串,并在 warnings 里明确需要补充的信息
`)
userPrompt := fmt.Sprintf("自然语言意图:%s", req.Text)
requestBody := map[string]interface{}{
"model": h.cfg.OpenAI.Model,
"messages": []map[string]interface{}{
{"role": "system", "content": systemPrompt},
{"role": "user", "content": userPrompt},
},
"temperature": 0.1,
"max_tokens": 1200,
}
// OpenAI 返回结构:只需要 choices[0].message.content
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
ctx, cancel := context.WithTimeout(c.Request.Context(), 90*time.Second)
defer cancel()
if err := h.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
var apiErr *openaiClient.APIError
if errors.As(err, &apiErr) {
h.logger.Warn("FOFA自然语言解析:LLM返回错误", zap.Int("status", apiErr.StatusCode))
c.JSON(http.StatusBadGateway, gin.H{"error": "AI 解析失败(上游返回非 200),请检查模型配置或稍后重试"})
return
}
c.JSON(http.StatusBadGateway, gin.H{"error": "AI 解析失败: " + err.Error()})
return
}
if len(apiResponse.Choices) == 0 {
c.JSON(http.StatusBadGateway, gin.H{"error": "AI 未返回有效结果"})
return
}
content := strings.TrimSpace(apiResponse.Choices[0].Message.Content)
// 兼容模型偶尔返回 ```json ... ``` 的情况
content = strings.TrimPrefix(content, "```json")
content = strings.TrimPrefix(content, "```")
content = strings.TrimSuffix(content, "```")
content = strings.TrimSpace(content)
var parsed fofaParseResponse
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
// 直接回传一部分原文,方便排查,但避免太大
snippet := content
if len(snippet) > 1200 {
snippet = snippet[:1200]
}
c.JSON(http.StatusBadGateway, gin.H{
"error": "AI 返回内容无法解析为 JSON,请稍后重试或换个描述方式",
"snippet": snippet,
})
return
}
parsed.Query = strings.TrimSpace(parsed.Query)
if parsed.Query == "" {
// query 允许为空(表示需求不明确),但前端需要明确提示
if len(parsed.Warnings) == 0 {
parsed.Warnings = []string{"需求信息不足,未能生成可用的 FOFA 查询语法,请补充关键条件(如国家/端口/产品/域名等)。"}
}
}
c.JSON(http.StatusOK, parsed)
}
// Search FOFA 查询(后端代理,避免前端暴露 key)
func (h *FofaHandler) Search(c *gin.Context) {
var req fofaSearchRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
req.Query = strings.TrimSpace(req.Query)
if req.Query == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "query 不能为空"})
return
}
if req.Size <= 0 {
req.Size = 100
}
if req.Page <= 0 {
req.Page = 1
}
// FOFA 接口 size 上限和账户权限相关,这里只做一个合理的保护
if req.Size > 10000 {
req.Size = 10000
}
if req.Fields == "" {
req.Fields = "host,ip,port,domain,title,protocol,country,province,city,server"
}
email, apiKey := h.resolveCredentials()
if email == "" || apiKey == "" {
c.JSON(http.StatusBadRequest, gin.H{
"error": "FOFA 未配置:请在系统设置中填写 FOFA Email/API Key,或设置环境变量 FOFA_EMAIL/FOFA_API_KEY",
"need": []string{"fofa.email", "fofa.api_key"},
"env_key": []string{"FOFA_EMAIL", "FOFA_API_KEY"},
})
return
}
baseURL := h.resolveBaseURL()
qb64 := base64.StdEncoding.EncodeToString([]byte(req.Query))
u, err := url.Parse(baseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "FOFA base_url 无效: " + err.Error()})
return
}
params := u.Query()
params.Set("email", email)
params.Set("key", apiKey)
params.Set("qbase64", qb64)
params.Set("size", fmt.Sprintf("%d", req.Size))
params.Set("page", fmt.Sprintf("%d", req.Page))
params.Set("fields", strings.TrimSpace(req.Fields))
if req.Full {
params.Set("full", "true")
} else {
// 明确传 false,便于排查
params.Set("full", "false")
}
u.RawQuery = params.Encode()
httpReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, u.String(), nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败: " + err.Error()})
return
}
resp, err := h.client.Do(httpReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "请求 FOFA 失败: " + err.Error()})
return
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("FOFA 返回非 2xx: %d", resp.StatusCode)})
return
}
var apiResp fofaAPIResponse
if err := json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "解析 FOFA 响应失败: " + err.Error()})
return
}
if apiResp.Error {
msg := strings.TrimSpace(apiResp.ErrMsg)
if msg == "" {
msg = "FOFA 返回错误"
}
c.JSON(http.StatusBadGateway, gin.H{"error": msg})
return
}
fields := splitAndCleanCSV(req.Fields)
results := make([]map[string]interface{}, 0, len(apiResp.Results))
for _, row := range apiResp.Results {
item := make(map[string]interface{}, len(fields))
for i, f := range fields {
if i < len(row) {
item[f] = row[i]
} else {
item[f] = nil
}
}
results = append(results, item)
}
c.JSON(http.StatusOK, fofaSearchResponse{
Query: req.Query,
Size: apiResp.Size,
Page: apiResp.Page,
Total: apiResp.Total,
Fields: fields,
ResultsCount: len(results),
Results: results,
})
}
func splitAndCleanCSV(s string) []string {
parts := strings.Split(s, ",")
out := make([]string, 0, len(parts))
seen := make(map[string]struct{}, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
if _, ok := seen[v]; ok {
continue
}
seen[v] = struct{}{}
out = append(out, v)
}
return out
}
-320
View File
@@ -1,320 +0,0 @@
package handler
import (
"net/http"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// GroupHandler 分组处理器
type GroupHandler struct {
db *database.DB
logger *zap.Logger
}
// NewGroupHandler 创建新的分组处理器
func NewGroupHandler(db *database.DB, logger *zap.Logger) *GroupHandler {
return &GroupHandler{
db: db,
logger: logger,
}
}
// CreateGroupRequest 创建分组请求
type CreateGroupRequest struct {
Name string `json:"name"`
Icon string `json:"icon"`
}
// CreateGroup 创建分组
func (h *GroupHandler) CreateGroup(c *gin.Context) {
var req CreateGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称不能为空"})
return
}
group, err := h.db.CreateGroup(req.Name, req.Icon)
if err != nil {
h.logger.Error("创建分组失败", zap.Error(err))
// 如果是名称重复错误,返回400状态码
if err.Error() == "分组名称已存在" {
c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称已存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, group)
}
// ListGroups 列出所有分组
func (h *GroupHandler) ListGroups(c *gin.Context) {
groups, err := h.db.ListGroups()
if err != nil {
h.logger.Error("获取分组列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, groups)
}
// GetGroup 获取分组
func (h *GroupHandler) GetGroup(c *gin.Context) {
id := c.Param("id")
group, err := h.db.GetGroup(id)
if err != nil {
h.logger.Error("获取分组失败", zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "分组不存在"})
return
}
c.JSON(http.StatusOK, group)
}
// UpdateGroupRequest 更新分组请求
type UpdateGroupRequest struct {
Name string `json:"name"`
Icon string `json:"icon"`
}
// UpdateGroup 更新分组
func (h *GroupHandler) UpdateGroup(c *gin.Context) {
id := c.Param("id")
var req UpdateGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称不能为空"})
return
}
if err := h.db.UpdateGroup(id, req.Name, req.Icon); err != nil {
h.logger.Error("更新分组失败", zap.Error(err))
// 如果是名称重复错误,返回400状态码
if err.Error() == "分组名称已存在" {
c.JSON(http.StatusBadRequest, gin.H{"error": "分组名称已存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
group, err := h.db.GetGroup(id)
if err != nil {
h.logger.Error("获取更新后的分组失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, group)
}
// DeleteGroup 删除分组
func (h *GroupHandler) DeleteGroup(c *gin.Context) {
id := c.Param("id")
if err := h.db.DeleteGroup(id); err != nil {
h.logger.Error("删除分组失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// AddConversationToGroupRequest 添加对话到分组请求
type AddConversationToGroupRequest struct {
ConversationID string `json:"conversationId"`
GroupID string `json:"groupId"`
}
// AddConversationToGroup 将对话添加到分组
func (h *GroupHandler) AddConversationToGroup(c *gin.Context) {
var req AddConversationToGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.db.AddConversationToGroup(req.ConversationID, req.GroupID); err != nil {
h.logger.Error("添加对话到分组失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "添加成功"})
}
// RemoveConversationFromGroup 从分组中移除对话
func (h *GroupHandler) RemoveConversationFromGroup(c *gin.Context) {
conversationID := c.Param("conversationId")
groupID := c.Param("id")
if err := h.db.RemoveConversationFromGroup(conversationID, groupID); err != nil {
h.logger.Error("从分组中移除对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "移除成功"})
}
// GroupConversation 分组对话响应结构
type GroupConversation struct {
ID string `json:"id"`
Title string `json:"title"`
Pinned bool `json:"pinned"`
GroupPinned bool `json:"groupPinned"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// GetGroupConversations 获取分组中的所有对话
func (h *GroupHandler) GetGroupConversations(c *gin.Context) {
groupID := c.Param("id")
searchQuery := c.Query("search") // 获取搜索参数
var conversations []*database.Conversation
var err error
// 如果有搜索关键词,使用搜索方法;否则使用普通方法
if searchQuery != "" {
conversations, err = h.db.SearchConversationsByGroup(groupID, searchQuery)
} else {
conversations, err = h.db.GetConversationsByGroup(groupID)
}
if err != nil {
h.logger.Error("获取分组对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 获取每个对话在分组中的置顶状态
groupConvs := make([]GroupConversation, 0, len(conversations))
for _, conv := range conversations {
// 查询分组内置顶状态
var groupPinned int
err := h.db.QueryRow(
"SELECT COALESCE(pinned, 0) FROM conversation_group_mappings WHERE conversation_id = ? AND group_id = ?",
conv.ID, groupID,
).Scan(&groupPinned)
if err != nil {
h.logger.Warn("查询分组内置顶状态失败", zap.String("conversationId", conv.ID), zap.Error(err))
groupPinned = 0
}
groupConvs = append(groupConvs, GroupConversation{
ID: conv.ID,
Title: conv.Title,
Pinned: conv.Pinned,
GroupPinned: groupPinned != 0,
CreatedAt: conv.CreatedAt,
UpdatedAt: conv.UpdatedAt,
})
}
c.JSON(http.StatusOK, groupConvs)
}
// GetAllMappings 批量获取所有分组映射(消除前端 N+1 请求)
func (h *GroupHandler) GetAllMappings(c *gin.Context) {
mappings, err := h.db.GetAllGroupMappings()
if err != nil {
h.logger.Error("获取分组映射失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, mappings)
}
// UpdateConversationPinnedRequest 更新对话置顶状态请求
type UpdateConversationPinnedRequest struct {
Pinned bool `json:"pinned"`
}
// UpdateConversationPinned 更新对话置顶状态
func (h *GroupHandler) UpdateConversationPinned(c *gin.Context) {
conversationID := c.Param("id")
var req UpdateConversationPinnedRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.db.UpdateConversationPinned(conversationID, req.Pinned); err != nil {
h.logger.Error("更新对话置顶状态失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "更新成功"})
}
// UpdateGroupPinnedRequest 更新分组置顶状态请求
type UpdateGroupPinnedRequest struct {
Pinned bool `json:"pinned"`
}
// UpdateGroupPinned 更新分组置顶状态
func (h *GroupHandler) UpdateGroupPinned(c *gin.Context) {
groupID := c.Param("id")
var req UpdateGroupPinnedRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.db.UpdateGroupPinned(groupID, req.Pinned); err != nil {
h.logger.Error("更新分组置顶状态失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "更新成功"})
}
// UpdateConversationPinnedInGroupRequest 更新分组对话置顶状态请求
type UpdateConversationPinnedInGroupRequest struct {
Pinned bool `json:"pinned"`
}
// UpdateConversationPinnedInGroup 更新对话在分组中的置顶状态
func (h *GroupHandler) UpdateConversationPinnedInGroup(c *gin.Context) {
groupID := c.Param("id")
conversationID := c.Param("conversationId")
var req UpdateConversationPinnedInGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := h.db.UpdateConversationPinnedInGroup(conversationID, groupID, req.Pinned); err != nil {
h.logger.Error("更新分组对话置顶状态失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "更新成功"})
}
-517
View File
@@ -1,517 +0,0 @@
package handler
import (
"context"
"fmt"
"net/http"
"time"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/knowledge"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// KnowledgeHandler 知识库处理器
type KnowledgeHandler struct {
manager *knowledge.Manager
retriever *knowledge.Retriever
indexer *knowledge.Indexer
db *database.DB
logger *zap.Logger
}
// NewKnowledgeHandler 创建新的知识库处理器
func NewKnowledgeHandler(
manager *knowledge.Manager,
retriever *knowledge.Retriever,
indexer *knowledge.Indexer,
db *database.DB,
logger *zap.Logger,
) *KnowledgeHandler {
return &KnowledgeHandler{
manager: manager,
retriever: retriever,
indexer: indexer,
db: db,
logger: logger,
}
}
// GetCategories 获取所有分类
func (h *KnowledgeHandler) GetCategories(c *gin.Context) {
categories, err := h.manager.GetCategories()
if err != nil {
h.logger.Error("获取分类失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"categories": categories})
}
// GetItems 获取知识项列表(支持按分类分页和关键字搜索,默认不返回完整内容)
func (h *KnowledgeHandler) GetItems(c *gin.Context) {
category := c.Query("category")
searchKeyword := c.Query("search") // 搜索关键字
// 如果提供了搜索关键字,执行关键字搜索(在所有数据中搜索)
if searchKeyword != "" {
items, err := h.manager.SearchItemsByKeyword(searchKeyword, category)
if err != nil {
h.logger.Error("搜索知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 按分类分组结果
groupedByCategory := make(map[string][]*knowledge.KnowledgeItemSummary)
for _, item := range items {
cat := item.Category
if cat == "" {
cat = "未分类"
}
groupedByCategory[cat] = append(groupedByCategory[cat], item)
}
// 转换为 CategoryWithItems 格式
categoriesWithItems := make([]*knowledge.CategoryWithItems, 0, len(groupedByCategory))
for cat, catItems := range groupedByCategory {
categoriesWithItems = append(categoriesWithItems, &knowledge.CategoryWithItems{
Category: cat,
ItemCount: len(catItems),
Items: catItems,
})
}
// 按分类名称排序
for i := 0; i < len(categoriesWithItems)-1; i++ {
for j := i + 1; j < len(categoriesWithItems); j++ {
if categoriesWithItems[i].Category > categoriesWithItems[j].Category {
categoriesWithItems[i], categoriesWithItems[j] = categoriesWithItems[j], categoriesWithItems[i]
}
}
}
c.JSON(http.StatusOK, gin.H{
"categories": categoriesWithItems,
"total": len(categoriesWithItems),
"search": searchKeyword,
"is_search": true,
})
return
}
// 分页模式:categoryPage=true 表示按分类分页,否则按项分页(向后兼容)
categoryPageMode := c.Query("categoryPage") != "false" // 默认使用分类分页
// 分页参数
limit := 50 // 默认每页 50 条(分类分页时为分类数,项分页时为项数)
offset := 0
if limitStr := c.Query("limit"); limitStr != "" {
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 && parsed <= 500 {
limit = parsed
}
}
if offsetStr := c.Query("offset"); offsetStr != "" {
if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 {
offset = parsed
}
}
// 如果指定了 category 参数,且使用分类分页模式,则只返回该分类
if category != "" && categoryPageMode {
// 单分类模式:返回该分类的所有知识项(不分页)
items, total, err := h.manager.GetItemsSummary(category, 0, 0)
if err != nil {
h.logger.Error("获取知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 包装成分类结构
categoriesWithItems := []*knowledge.CategoryWithItems{
{
Category: category,
ItemCount: total,
Items: items,
},
}
c.JSON(http.StatusOK, gin.H{
"categories": categoriesWithItems,
"total": 1, // 只有一个分类
"limit": limit,
"offset": offset,
})
return
}
if categoryPageMode {
// 按分类分页模式(默认)
// limit 表示每页分类数,推荐 5-10 个分类
if limit <= 0 || limit > 100 {
limit = 10 // 默认每页 10 个分类
}
categoriesWithItems, totalCategories, err := h.manager.GetCategoriesWithItems(limit, offset)
if err != nil {
h.logger.Error("获取分类知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"categories": categoriesWithItems,
"total": totalCategories,
"limit": limit,
"offset": offset,
})
return
}
// 按项分页模式(向后兼容)
// 是否包含完整内容(默认 false,只返回摘要)
includeContent := c.Query("includeContent") == "true"
if includeContent {
// 返回完整内容(向后兼容)
items, err := h.manager.GetItemsWithOptions(category, limit, offset, true)
if err != nil {
h.logger.Error("获取知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 获取总数
total, err := h.manager.GetItemsCount(category)
if err != nil {
h.logger.Warn("获取知识项总数失败", zap.Error(err))
total = len(items)
}
c.JSON(http.StatusOK, gin.H{
"items": items,
"total": total,
"limit": limit,
"offset": offset,
})
} else {
// 返回摘要(不包含完整内容,推荐方式)
items, total, err := h.manager.GetItemsSummary(category, limit, offset)
if err != nil {
h.logger.Error("获取知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"items": items,
"total": total,
"limit": limit,
"offset": offset,
})
}
}
// GetItem 获取单个知识项
func (h *KnowledgeHandler) GetItem(c *gin.Context) {
id := c.Param("id")
item, err := h.manager.GetItem(id)
if err != nil {
h.logger.Error("获取知识项失败", zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, item)
}
// CreateItem 创建知识项
func (h *KnowledgeHandler) CreateItem(c *gin.Context) {
var req struct {
Category string `json:"category" binding:"required"`
Title string `json:"title" binding:"required"`
Content string `json:"content" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
item, err := h.manager.CreateItem(req.Category, req.Title, req.Content)
if err != nil {
h.logger.Error("创建知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 异步索引
go func() {
ctx := context.Background()
if err := h.indexer.IndexItem(ctx, item.ID); err != nil {
h.logger.Warn("索引知识项失败", zap.String("itemId", item.ID), zap.Error(err))
}
}()
c.JSON(http.StatusOK, item)
}
// UpdateItem 更新知识项
func (h *KnowledgeHandler) UpdateItem(c *gin.Context) {
id := c.Param("id")
var req struct {
Category string `json:"category" binding:"required"`
Title string `json:"title" binding:"required"`
Content string `json:"content" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
item, err := h.manager.UpdateItem(id, req.Category, req.Title, req.Content)
if err != nil {
h.logger.Error("更新知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 异步重新索引
go func() {
ctx := context.Background()
if err := h.indexer.IndexItem(ctx, item.ID); err != nil {
h.logger.Warn("重新索引知识项失败", zap.String("itemId", item.ID), zap.Error(err))
}
}()
c.JSON(http.StatusOK, item)
}
// DeleteItem 删除知识项
func (h *KnowledgeHandler) DeleteItem(c *gin.Context) {
id := c.Param("id")
if err := h.manager.DeleteItem(id); err != nil {
h.logger.Error("删除知识项失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// RebuildIndex 重建索引
func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
// 异步重建索引
go func() {
ctx := context.Background()
if err := h.indexer.RebuildIndex(ctx); err != nil {
h.logger.Error("重建索引失败", zap.Error(err))
}
}()
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
}
// ScanKnowledgeBase 扫描知识库
func (h *KnowledgeHandler) ScanKnowledgeBase(c *gin.Context) {
itemsToIndex, err := h.manager.ScanKnowledgeBase()
if err != nil {
h.logger.Error("扫描知识库失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if len(itemsToIndex) == 0 {
c.JSON(http.StatusOK, gin.H{"message": "扫描完成,没有需要索引的新项或更新项"})
return
}
// 异步索引新添加或更新的项(增量索引)
go func() {
ctx := context.Background()
h.logger.Info("开始增量索引", zap.Int("count", len(itemsToIndex)))
failedCount := 0
consecutiveFailures := 0
var firstFailureItemID string
var firstFailureError error
for i, itemID := range itemsToIndex {
if err := h.indexer.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
// 只在第一个失败时记录详细日志
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
h.logger.Warn("索引知识项失败",
zap.String("itemId", itemID),
zap.Int("totalItems", len(itemsToIndex)),
zap.Error(err),
)
}
// 如果连续失败 2 次,立即停止增量索引
if consecutiveFailures >= 2 {
h.logger.Error("连续索引失败次数过多,立即停止增量索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemsToIndex)),
zap.Int("processedItems", i+1),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
break
}
continue
}
// 成功时重置连续失败计数
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
// 减少进度日志频率
if (i+1)%10 == 0 || i+1 == len(itemsToIndex) {
h.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemsToIndex)), zap.Int("failed", failedCount))
}
}
h.logger.Info("增量索引完成", zap.Int("totalItems", len(itemsToIndex)), zap.Int("failedCount", failedCount))
}()
c.JSON(http.StatusOK, gin.H{
"message": fmt.Sprintf("扫描完成,开始索引 %d 个新添加或更新的知识项", len(itemsToIndex)),
"items_to_index": len(itemsToIndex),
})
}
// GetRetrievalLogs 获取检索日志
func (h *KnowledgeHandler) GetRetrievalLogs(c *gin.Context) {
conversationID := c.Query("conversationId")
messageID := c.Query("messageId")
limit := 50 // 默认 50 条
if limitStr := c.Query("limit"); limitStr != "" {
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
limit = parsed
}
}
logs, err := h.manager.GetRetrievalLogs(conversationID, messageID, limit)
if err != nil {
h.logger.Error("获取检索日志失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"logs": logs})
}
// DeleteRetrievalLog 删除检索日志
func (h *KnowledgeHandler) DeleteRetrievalLog(c *gin.Context) {
id := c.Param("id")
if err := h.manager.DeleteRetrievalLog(id); err != nil {
h.logger.Error("删除检索日志失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// GetIndexStatus 获取索引状态
func (h *KnowledgeHandler) GetIndexStatus(c *gin.Context) {
status, err := h.manager.GetIndexStatus()
if err != nil {
h.logger.Error("获取索引状态失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 获取索引器的错误信息
if h.indexer != nil {
lastError, lastErrorTime := h.indexer.GetLastError()
if lastError != "" {
// 如果错误是最近发生的(5 分钟内),则返回错误信息
if time.Since(lastErrorTime) < 5*time.Minute {
status["last_error"] = lastError
status["last_error_time"] = lastErrorTime.Format(time.RFC3339)
}
}
// 获取重建索引状态
isRebuilding, totalItems, current, failed, lastItemID, lastChunks, startTime := h.indexer.GetRebuildStatus()
if isRebuilding {
status["is_rebuilding"] = true
status["rebuild_total"] = totalItems
status["rebuild_current"] = current
status["rebuild_failed"] = failed
status["rebuild_start_time"] = startTime.Format(time.RFC3339)
if lastItemID != "" {
status["rebuild_last_item_id"] = lastItemID
}
if lastChunks > 0 {
status["rebuild_last_chunks"] = lastChunks
}
// 重建中时,is_complete 为 false
status["is_complete"] = false
// 计算重建进度百分比
if totalItems > 0 {
status["progress_percent"] = float64(current) / float64(totalItems) * 100
}
}
}
c.JSON(http.StatusOK, status)
}
// Search 搜索知识库(用于 API 调用,Agent 内部使用 Retriever
func (h *KnowledgeHandler) Search(c *gin.Context) {
var req knowledge.SearchRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Retriever.Search 经 Eino VectorEinoRetriever,与 MCP 工具链一致。
results, err := h.retriever.Search(c.Request.Context(), &req)
if err != nil {
h.logger.Error("搜索知识库失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"results": results})
}
// GetStats 获取知识库统计信息
func (h *KnowledgeHandler) GetStats(c *gin.Context) {
totalCategories, totalItems, err := h.manager.GetStats()
if err != nil {
h.logger.Error("获取知识库统计信息失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"enabled": true,
"total_categories": totalCategories,
"total_items": totalItems,
})
}
// 辅助函数:解析整数
func parseInt(s string) (int, error) {
var result int
_, err := fmt.Sscanf(s, "%d", &result)
return result, err
}
-299
View File
@@ -1,299 +0,0 @@
package handler
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"cyberstrike-ai/internal/agents"
"cyberstrike-ai/internal/config"
"github.com/gin-gonic/gin"
)
var markdownAgentFilenameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*\.md$`)
// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。
type MarkdownAgentsHandler struct {
dir string
}
// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。
func NewMarkdownAgentsHandler(dir string) *MarkdownAgentsHandler {
return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)}
}
func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) {
filename = strings.TrimSpace(filename)
if filename == "" || !markdownAgentFilenameRe.MatchString(filename) {
return "", fmt.Errorf("非法文件名")
}
clean := filepath.Clean(filename)
if clean != filename || strings.Contains(clean, "..") {
return "", fmt.Errorf("非法文件名")
}
return filepath.Join(h.dir, clean), nil
}
// existingOtherOrchestrator 若目录中已有别的主代理文件,返回其文件名;writingBasename 为当前正在写入的文件名时视为同一文件不冲突。
func existingOtherOrchestrator(dir, writingBasename string) (other string, err error) {
load, err := agents.LoadMarkdownAgentsDir(dir)
if err != nil {
return "", err
}
if load.Orchestrator == nil {
return "", nil
}
if strings.EqualFold(load.Orchestrator.Filename, writingBasename) {
return "", nil
}
return load.Orchestrator.Filename, nil
}
// ListMarkdownAgents GET /api/multi-agent/markdown-agents
func (h *MarkdownAgentsHandler) ListMarkdownAgents(c *gin.Context) {
if h.dir == "" {
c.JSON(http.StatusOK, gin.H{"agents": []any{}, "dir": "", "error": "未配置 agents 目录"})
return
}
files, err := agents.LoadMarkdownAgentFiles(h.dir)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
out := make([]gin.H, 0, len(files))
for _, fa := range files {
sub := fa.Config
out = append(out, gin.H{
"filename": fa.Filename,
"id": sub.ID,
"name": sub.Name,
"description": sub.Description,
"is_orchestrator": fa.IsOrchestrator,
"kind": sub.Kind,
})
}
c.JSON(http.StatusOK, gin.H{"agents": out, "dir": h.dir})
}
// GetMarkdownAgent GET /api/multi-agent/markdown-agents/:filename
func (h *MarkdownAgentsHandler) GetMarkdownAgent(c *gin.Context) {
filename := c.Param("filename")
path, err := h.safeJoin(filename)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
b, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sub, err := agents.ParseMarkdownSubAgent(filename, string(b))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
isOrch := agents.IsOrchestratorMarkdown(filename, agents.FrontMatter{Kind: sub.Kind})
c.JSON(http.StatusOK, gin.H{
"filename": filename,
"raw": string(b),
"id": sub.ID,
"name": sub.Name,
"description": sub.Description,
"tools": sub.RoleTools,
"instruction": sub.Instruction,
"bind_role": sub.BindRole,
"max_iterations": sub.MaxIterations,
"kind": sub.Kind,
"is_orchestrator": isOrch,
})
}
type markdownAgentBody struct {
Filename string `json:"filename"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Tools []string `json:"tools"`
Instruction string `json:"instruction"`
BindRole string `json:"bind_role"`
MaxIterations int `json:"max_iterations"`
Kind string `json:"kind"`
Raw string `json:"raw"`
}
// CreateMarkdownAgent POST /api/multi-agent/markdown-agents
func (h *MarkdownAgentsHandler) CreateMarkdownAgent(c *gin.Context) {
if h.dir == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "未配置 agents 目录"})
return
}
var body markdownAgentBody
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
filename := strings.TrimSpace(body.Filename)
if filename == "" {
if strings.EqualFold(strings.TrimSpace(body.Kind), "orchestrator") {
filename = agents.OrchestratorMarkdownFilename
} else {
base := agents.SlugID(body.Name)
if base == "" {
base = "agent"
}
filename = base + ".md"
}
}
path, err := h.safeJoin(filename)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(path); err == nil {
c.JSON(http.StatusConflict, gin.H{"error": "文件已存在"})
return
}
sub := config.MultiAgentSubConfig{
ID: strings.TrimSpace(body.ID),
Name: strings.TrimSpace(body.Name),
Description: strings.TrimSpace(body.Description),
Instruction: strings.TrimSpace(body.Instruction),
RoleTools: body.Tools,
BindRole: strings.TrimSpace(body.BindRole),
MaxIterations: body.MaxIterations,
Kind: strings.TrimSpace(body.Kind),
}
if strings.EqualFold(filepath.Base(path), agents.OrchestratorMarkdownFilename) && sub.Kind == "" {
sub.Kind = "orchestrator"
}
if sub.ID == "" {
sub.ID = agents.SlugID(sub.Name)
}
if sub.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name 必填"})
return
}
var out []byte
if strings.TrimSpace(body.Raw) != "" {
out = []byte(body.Raw)
} else {
out, err = agents.BuildMarkdownFile(sub)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if want := agents.WantsMarkdownOrchestrator(filepath.Base(path), body.Kind, string(out)); want {
other, oerr := existingOtherOrchestrator(h.dir, filepath.Base(path))
if oerr != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": oerr.Error()})
return
}
if other != "" {
c.JSON(http.StatusConflict, gin.H{"error": fmt.Sprintf("已存在主代理定义:%s,请先删除或取消其主代理标记", other)})
return
}
}
if err := os.MkdirAll(h.dir, 0755); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := os.WriteFile(path, out, 0644); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"})
}
// UpdateMarkdownAgent PUT /api/multi-agent/markdown-agents/:filename
func (h *MarkdownAgentsHandler) UpdateMarkdownAgent(c *gin.Context) {
filename := c.Param("filename")
path, err := h.safeJoin(filename)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var body markdownAgentBody
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
sub := config.MultiAgentSubConfig{
ID: strings.TrimSpace(body.ID),
Name: strings.TrimSpace(body.Name),
Description: strings.TrimSpace(body.Description),
Instruction: strings.TrimSpace(body.Instruction),
RoleTools: body.Tools,
BindRole: strings.TrimSpace(body.BindRole),
MaxIterations: body.MaxIterations,
Kind: strings.TrimSpace(body.Kind),
}
if strings.EqualFold(filename, agents.OrchestratorMarkdownFilename) && sub.Kind == "" {
sub.Kind = "orchestrator"
}
if sub.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name 必填"})
return
}
if sub.ID == "" {
sub.ID = agents.SlugID(sub.Name)
}
var out []byte
if strings.TrimSpace(body.Raw) != "" {
out = []byte(body.Raw)
} else {
out, err = agents.BuildMarkdownFile(sub)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
if want := agents.WantsMarkdownOrchestrator(filename, body.Kind, string(out)); want {
other, oerr := existingOtherOrchestrator(h.dir, filename)
if oerr != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": oerr.Error()})
return
}
if other != "" {
c.JSON(http.StatusConflict, gin.H{"error": fmt.Sprintf("已存在主代理定义:%s,请先删除或取消其主代理标记", other)})
return
}
}
if err := os.WriteFile(path, out, 0644); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "已保存"})
}
// DeleteMarkdownAgent DELETE /api/multi-agent/markdown-agents/:filename
func (h *MarkdownAgentsHandler) DeleteMarkdownAgent(c *gin.Context) {
filename := c.Param("filename")
path, err := h.safeJoin(filename)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "已删除"})
}
-420
View File
@@ -1,420 +0,0 @@
package handler
import (
"net/http"
"strconv"
"strings"
"time"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// MonitorHandler 监控处理器
type MonitorHandler struct {
mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager
executor *security.Executor
db *database.DB
logger *zap.Logger
}
// NewMonitorHandler 创建新的监控处理器
func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, db *database.DB, logger *zap.Logger) *MonitorHandler {
return &MonitorHandler{
mcpServer: mcpServer,
externalMCPMgr: nil, // 将在创建后设置
executor: executor,
db: db,
logger: logger,
}
}
// SetExternalMCPManager 设置外部MCP管理器
func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
h.externalMCPMgr = mgr
}
// MonitorResponse 监控响应
type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"`
Stats map[string]*mcp.ToolStats `json:"stats"`
Timestamp time.Time `json:"timestamp"`
Total int `json:"total,omitempty"`
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
TotalPages int `json:"total_pages,omitempty"`
}
// Monitor 获取监控信息
func (h *MonitorHandler) Monitor(c *gin.Context) {
// 解析分页参数
page := 1
pageSize := 20
if pageStr := c.Query("page"); pageStr != "" {
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
page = p
}
}
if pageSizeStr := c.Query("page_size"); pageSizeStr != "" {
if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 && ps <= 100 {
pageSize = ps
}
}
// 解析状态筛选参数
status := c.Query("status")
// 解析工具筛选参数
toolName := c.Query("tool")
executions, total := h.loadExecutionsWithPagination(page, pageSize, status, toolName)
stats := h.loadStats()
totalPages := (total + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
}
c.JSON(http.StatusOK, MonitorResponse{
Executions: executions,
Stats: stats,
Timestamp: time.Now(),
Total: total,
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
})
}
func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "")
return executions
}
func (h *MonitorHandler) loadExecutionsWithPagination(page, pageSize int, status, toolName string) ([]*mcp.ToolExecution, int) {
if h.db == nil {
allExecutions := h.mcpServer.GetAllExecutions()
// 如果指定了状态筛选或工具筛选,先进行筛选
if status != "" || toolName != "" {
filtered := make([]*mcp.ToolExecution, 0)
for _, exec := range allExecutions {
matchStatus := status == "" || exec.Status == status
// 支持部分匹配(模糊搜索)
matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName))
if matchStatus && matchTool {
filtered = append(filtered, exec)
}
}
allExecutions = filtered
}
total := len(allExecutions)
offset := (page - 1) * pageSize
end := offset + pageSize
if end > total {
end = total
}
if offset >= total {
return []*mcp.ToolExecution{}, total
}
return allExecutions[offset:end], total
}
offset := (page - 1) * pageSize
executions, err := h.db.LoadToolExecutionsWithPagination(offset, pageSize, status, toolName)
if err != nil {
h.logger.Warn("从数据库加载执行记录失败,回退到内存数据", zap.Error(err))
allExecutions := h.mcpServer.GetAllExecutions()
// 如果指定了状态筛选或工具筛选,先进行筛选
if status != "" || toolName != "" {
filtered := make([]*mcp.ToolExecution, 0)
for _, exec := range allExecutions {
matchStatus := status == "" || exec.Status == status
// 支持部分匹配(模糊搜索)
matchTool := toolName == "" || strings.Contains(strings.ToLower(exec.ToolName), strings.ToLower(toolName))
if matchStatus && matchTool {
filtered = append(filtered, exec)
}
}
allExecutions = filtered
}
total := len(allExecutions)
offset := (page - 1) * pageSize
end := offset + pageSize
if end > total {
end = total
}
if offset >= total {
return []*mcp.ToolExecution{}, total
}
return allExecutions[offset:end], total
}
// 获取总数(考虑状态筛选和工具筛选)
total, err := h.db.CountToolExecutions(status, toolName)
if err != nil {
h.logger.Warn("获取执行记录总数失败", zap.Error(err))
// 回退:使用已加载的记录数估算
total = offset + len(executions)
if len(executions) == pageSize {
total = offset + len(executions) + 1
}
}
return executions, total
}
func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
// 合并内部MCP服务器和外部MCP管理器的统计信息
stats := make(map[string]*mcp.ToolStats)
// 加载内部MCP服务器的统计信息
if h.db == nil {
internalStats := h.mcpServer.GetStats()
for k, v := range internalStats {
stats[k] = v
}
} else {
dbStats, err := h.db.LoadToolStats()
if err != nil {
h.logger.Warn("从数据库加载统计信息失败,回退到内存数据", zap.Error(err))
internalStats := h.mcpServer.GetStats()
for k, v := range internalStats {
stats[k] = v
}
} else {
for k, v := range dbStats {
stats[k] = v
}
}
}
// 合并外部MCP管理器的统计信息
if h.externalMCPMgr != nil {
externalStats := h.externalMCPMgr.GetToolStats()
for k, v := range externalStats {
// 如果已存在,合并统计信息
if existing, exists := stats[k]; exists {
existing.TotalCalls += v.TotalCalls
existing.SuccessCalls += v.SuccessCalls
existing.FailedCalls += v.FailedCalls
// 使用最新的调用时间
if v.LastCallTime != nil && (existing.LastCallTime == nil || v.LastCallTime.After(*existing.LastCallTime)) {
existing.LastCallTime = v.LastCallTime
}
} else {
stats[k] = v
}
}
}
return stats
}
// GetExecution 获取特定执行记录
func (h *MonitorHandler) GetExecution(c *gin.Context) {
id := c.Param("id")
// 先从内部MCP服务器查找
exec, exists := h.mcpServer.GetExecution(id)
if exists {
c.JSON(http.StatusOK, exec)
return
}
// 如果找不到,尝试从外部MCP管理器查找
if h.externalMCPMgr != nil {
exec, exists = h.externalMCPMgr.GetExecution(id)
if exists {
c.JSON(http.StatusOK, exec)
return
}
}
// 如果都找不到,尝试从数据库查找(如果使用数据库存储)
if h.db != nil {
exec, err := h.db.GetToolExecution(id)
if err == nil && exec != nil {
c.JSON(http.StatusOK, exec)
return
}
}
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
}
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
var req struct {
IDs []string `json:"ids"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
result := make(map[string]string, len(req.IDs))
for _, id := range req.IDs {
// 先从内部MCP服务器查找
if exec, exists := h.mcpServer.GetExecution(id); exists {
result[id] = exec.ToolName
continue
}
// 再从外部MCP管理器查找
if h.externalMCPMgr != nil {
if exec, exists := h.externalMCPMgr.GetExecution(id); exists {
result[id] = exec.ToolName
continue
}
}
// 最后从数据库查找
if h.db != nil {
if exec, err := h.db.GetToolExecution(id); err == nil && exec != nil {
result[id] = exec.ToolName
}
}
}
c.JSON(http.StatusOK, result)
}
// GetStats 获取统计信息
func (h *MonitorHandler) GetStats(c *gin.Context) {
stats := h.loadStats()
c.JSON(http.StatusOK, stats)
}
// DeleteExecution 删除执行记录
func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
id := c.Param("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID不能为空"})
return
}
// 如果使用数据库,先获取执行记录信息,然后删除并更新统计
if h.db != nil {
// 先获取执行记录信息(用于更新统计)
exec, err := h.db.GetToolExecution(id)
if err != nil {
// 如果找不到记录,可能已经被删除,直接返回成功
h.logger.Warn("执行记录不存在,可能已被删除", zap.String("executionId", id), zap.Error(err))
c.JSON(http.StatusOK, gin.H{"message": "执行记录不存在或已被删除"})
return
}
// 删除执行记录
err = h.db.DeleteToolExecution(id)
if err != nil {
h.logger.Error("删除执行记录失败", zap.Error(err), zap.String("executionId", id))
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除执行记录失败: " + err.Error()})
return
}
// 更新统计信息(减少相应的计数)
totalCalls := 1
successCalls := 0
failedCalls := 0
if exec.Status == "failed" {
failedCalls = 1
} else if exec.Status == "completed" {
successCalls = 1
}
if exec.ToolName != "" {
if err := h.db.DecreaseToolStats(exec.ToolName, totalCalls, successCalls, failedCalls); err != nil {
h.logger.Warn("更新统计信息失败", zap.Error(err), zap.String("toolName", exec.ToolName))
// 不返回错误,因为记录已经删除成功
}
}
h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"})
return
}
// 如果不使用数据库,尝试从内存中删除(内部MCP服务器)
// 注意:内存中的记录可能已经被清理,所以这里只记录日志
h.logger.Info("尝试删除内存中的执行记录", zap.String("executionId", id))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
}
// DeleteExecutions 批量删除执行记录
func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
var request struct {
IDs []string `json:"ids"`
}
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
if len(request.IDs) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID列表不能为空"})
return
}
// 如果使用数据库,先获取执行记录信息,然后删除并更新统计
if h.db != nil {
// 先获取执行记录信息(用于更新统计)
executions, err := h.db.GetToolExecutionsByIds(request.IDs)
if err != nil {
h.logger.Error("获取执行记录失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取执行记录失败: " + err.Error()})
return
}
// 按工具名称分组统计需要减少的数量
toolStats := make(map[string]struct {
totalCalls int
successCalls int
failedCalls int
})
for _, exec := range executions {
if exec.ToolName == "" {
continue
}
stats := toolStats[exec.ToolName]
stats.totalCalls++
if exec.Status == "failed" {
stats.failedCalls++
} else if exec.Status == "completed" {
stats.successCalls++
}
toolStats[exec.ToolName] = stats
}
// 批量删除执行记录
err = h.db.DeleteToolExecutions(request.IDs)
if err != nil {
h.logger.Error("批量删除执行记录失败", zap.Error(err), zap.Int("count", len(request.IDs)))
c.JSON(http.StatusInternalServerError, gin.H{"error": "批量删除执行记录失败: " + err.Error()})
return
}
// 更新统计信息(减少相应的计数)
for toolName, stats := range toolStats {
if err := h.db.DecreaseToolStats(toolName, stats.totalCalls, stats.successCalls, stats.failedCalls); err != nil {
h.logger.Warn("更新统计信息失败", zap.Error(err), zap.String("toolName", toolName))
// 不返回错误,因为记录已经删除成功
}
}
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
return
}
// 如果不使用数据库,尝试从内存中删除(内部MCP服务器)
// 注意:内存中的记录可能已经被清理,所以这里只记录日志
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
}
-316
View File
@@ -1,316 +0,0 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/multiagent"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// MultiAgentLoopStream Eino DeepAgent 流式对话(需 config.multi_agent.enabled)。
func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
if h.config == nil || !h.config.MultiAgent.Enabled {
ev := StreamEvent{Type: "error", Message: "多代理未启用,请在设置或 config.yaml 中开启 multi_agent.enabled"}
b, _ := json.Marshal(ev)
fmt.Fprintf(c.Writer, "data: %s\n\n", b)
done := StreamEvent{Type: "done", Message: ""}
db, _ := json.Marshal(done)
fmt.Fprintf(c.Writer, "data: %s\n\n", db)
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
return
}
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
event := StreamEvent{Type: "error", Message: "请求参数错误: " + err.Error()}
b, _ := json.Marshal(event)
fmt.Fprintf(c.Writer, "data: %s\n\n", b)
c.Writer.Flush()
return
}
c.Header("X-Accel-Buffering", "no")
// 用于在 sendEvent 中判断是否为用户主动停止导致的取消。
// 注意:baseCtx 会在后面创建;该变量用于闭包提前捕获引用。
var baseCtx context.Context
clientDisconnected := false
// 与 sseKeepalive 共用:禁止并发写 ResponseWriter,否则会破坏 chunked 编码(ERR_INVALID_CHUNKED_ENCODING)。
var sseWriteMu sync.Mutex
sendEvent := func(eventType, message string, data interface{}) {
if clientDisconnected {
return
}
// 用户主动停止时,Eino 可能仍会并发上报 eventType=="error"。
// 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。
if eventType == "error" && baseCtx != nil && errors.Is(context.Cause(baseCtx), ErrTaskCancelled) {
return
}
select {
case <-c.Request.Context().Done():
clientDisconnected = true
return
default:
}
ev := StreamEvent{Type: eventType, Message: message, Data: data}
b, _ := json.Marshal(ev)
sseWriteMu.Lock()
_, err := fmt.Fprintf(c.Writer, "data: %s\n\n", b)
if err != nil {
sseWriteMu.Unlock()
clientDisconnected = true
return
}
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
c.Writer.Flush()
}
sseWriteMu.Unlock()
}
h.logger.Info("收到 Eino DeepAgent 流式请求",
zap.String("conversationId", req.ConversationID),
)
prep, err := h.prepareMultiAgentSession(&req)
if err != nil {
sendEvent("error", err.Error(), nil)
sendEvent("done", "", nil)
return
}
if prep.CreatedNew {
sendEvent("conversation", "会话已创建", map[string]interface{}{
"conversationId": prep.ConversationID,
})
}
conversationID := prep.ConversationID
assistantMessageID := prep.AssistantMessageID
if prep.UserMessageID != "" {
sendEvent("message_saved", "", map[string]interface{}{
"conversationId": conversationID,
"userMessageId": prep.UserMessageID,
})
}
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent)
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
defer timeoutCancel()
defer cancelWithCause(nil)
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
var errorMsg string
if errors.Is(err, ErrTaskAlreadyRunning) {
errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。"
sendEvent("error", errorMsg, map[string]interface{}{
"conversationId": conversationID,
"errorType": "task_already_running",
})
} else {
errorMsg = "❌ 无法启动任务: " + err.Error()
sendEvent("error", errorMsg, nil)
}
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errorMsg, assistantMessageID)
}
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
return
}
taskStatus := "completed"
defer h.tasks.FinishTask(conversationID, taskStatus)
sendEvent("progress", "正在启动 Eino DeepAgent...", map[string]interface{}{
"conversationId": conversationID,
})
stopKeepalive := make(chan struct{})
go sseKeepalive(c, stopKeepalive, &sseWriteMu)
defer close(stopKeepalive)
result, runErr := multiagent.RunDeepAgent(
taskCtx,
h.config,
&h.config.MultiAgent,
h.agent,
h.logger,
conversationID,
prep.FinalMessage,
prep.History,
prep.RoleTools,
progressCallback,
h.agentsMarkdownDir,
)
if runErr != nil {
cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled"
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
cancelMsg := "任务已被用户取消,后续操作已停止。"
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", cancelMsg, assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
}
sendEvent("cancelled", cancelMsg, map[string]interface{}{
"conversationId": conversationID,
"messageId": assistantMessageID,
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
return
}
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
taskStatus = "failed"
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
errMsg := "执行失败: " + runErr.Error()
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
}
sendEvent("error", errMsg, map[string]interface{}{
"conversationId": conversationID,
"messageId": assistantMessageID,
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
return
}
if assistantMessageID != "" {
mcpIDsJSON := ""
if len(result.MCPExecutionIDs) > 0 {
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
mcpIDsJSON = string(jsonData)
}
_, _ = h.db.Exec(
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
result.Response,
mcpIDsJSON,
assistantMessageID,
)
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
}
}
sendEvent("response", result.Response, map[string]interface{}{
"mcpExecutionIds": result.MCPExecutionIDs,
"conversationId": conversationID,
"messageId": assistantMessageID,
"agentMode": "eino_deep",
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
}
// MultiAgentLoop Eino DeepAgent 非流式对话(与 POST /api/agent-loop 对齐,需 multi_agent.enabled)。
func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
if h.config == nil || !h.config.MultiAgent.Enabled {
c.JSON(http.StatusNotFound, gin.H{"error": "多代理未启用,请在 config.yaml 中设置 multi_agent.enabled: true"})
return
}
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID))
prep, err := h.prepareMultiAgentSession(&req)
if err != nil {
status, msg := multiAgentHTTPErrorStatus(err)
c.JSON(status, gin.H{"error": msg})
return
}
result, runErr := multiagent.RunDeepAgent(
c.Request.Context(),
h.config,
&h.config.MultiAgent,
h.agent,
h.logger,
prep.ConversationID,
prep.FinalMessage,
prep.History,
prep.RoleTools,
nil,
h.agentsMarkdownDir,
)
if runErr != nil {
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
errMsg := "执行失败: " + runErr.Error()
if prep.AssistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, prep.AssistantMessageID)
}
c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg})
return
}
if prep.AssistantMessageID != "" {
mcpIDsJSON := ""
if len(result.MCPExecutionIDs) > 0 {
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
mcpIDsJSON = string(jsonData)
}
_, _ = h.db.Exec(
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
result.Response,
mcpIDsJSON,
prep.AssistantMessageID,
)
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
}
}
c.JSON(http.StatusOK, ChatResponse{
Response: result.Response,
MCPExecutionIDs: result.MCPExecutionIDs,
ConversationID: prep.ConversationID,
Time: time.Now(),
})
}
func multiAgentHTTPErrorStatus(err error) (int, string) {
msg := err.Error()
switch {
case strings.Contains(msg, "对话不存在"):
return http.StatusNotFound, msg
case strings.Contains(msg, "未找到该 WebShell"):
return http.StatusBadRequest, msg
case strings.Contains(msg, "附件最多"):
return http.StatusBadRequest, msg
case strings.Contains(msg, "保存用户消息失败"), strings.Contains(msg, "创建对话失败"):
return http.StatusInternalServerError, msg
case strings.Contains(msg, "保存上传文件失败"):
return http.StatusInternalServerError, msg
default:
return http.StatusBadRequest, msg
}
}
-140
View File
@@ -1,140 +0,0 @@
package handler
import (
"fmt"
"strings"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
// multiAgentPrepared 多代理请求在调用 Eino 前的会话与消息准备结果。
type multiAgentPrepared struct {
ConversationID string
CreatedNew bool
History []agent.ChatMessage
FinalMessage string
RoleTools []string
AssistantMessageID string
UserMessageID string
}
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPrepared, error) {
if len(req.Attachments) > maxAttachments {
return nil, fmt.Errorf("附件最多 %d 个", maxAttachments)
}
conversationID := strings.TrimSpace(req.ConversationID)
createdNew := false
if conversationID == "" {
title := safeTruncateString(req.Message, 50)
var conv *database.Conversation
var err error
if strings.TrimSpace(req.WebShellConnectionID) != "" {
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title)
} else {
conv, err = h.db.CreateConversation(title)
}
if err != nil {
return nil, fmt.Errorf("创建对话失败: %w", err)
}
conversationID = conv.ID
createdNew = true
} else {
if _, err := h.db.GetConversation(conversationID); err != nil {
return nil, fmt.Errorf("对话不存在")
}
}
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil {
agentHistoryMessages = []agent.ChatMessage{}
} else {
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
for _, msg := range historyMessages {
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
}
}
}
finalMessage := req.Message
var roleTools []string
if req.WebShellConnectionID != "" {
conn, errConn := h.db.GetWebshellConnection(strings.TrimSpace(req.WebShellConnectionID))
if errConn != nil || conn == nil {
h.logger.Warn("WebShell AI 助手:未找到连接", zap.String("id", req.WebShellConnectionID), zap.Error(errConn))
return nil, fmt.Errorf("未找到该 WebShell 连接")
}
remark := conn.Remark
if remark == "" {
remark = conn.URL
}
finalMessage = fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base、list_skills、read_skill。请根据用户输入决定下一步:若仅为问候、闲聊或简单问题,直接简短回复即可,不必调用工具;当用户明确需要执行命令、列目录、读写文件、记录漏洞或检索知识库/查看 Skills 等操作时再调用上述工具。\n\n用户请求:%s",
conn.ID, remark, conn.ID, req.Message)
roleTools = []string{
builtin.ToolWebshellExec,
builtin.ToolWebshellFileList,
builtin.ToolWebshellFileRead,
builtin.ToolWebshellFileWrite,
builtin.ToolRecordVulnerability,
builtin.ToolListKnowledgeRiskTypes,
builtin.ToolSearchKnowledgeBase,
builtin.ToolListSkills,
builtin.ToolReadSkill,
}
} else if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + req.Message
}
roleTools = role.Tools
}
}
var savedPaths []string
if len(req.Attachments) > 0 {
var aerr error
savedPaths, aerr = saveAttachmentsToDateAndConversationDir(req.Attachments, conversationID, h.logger)
if aerr != nil {
return nil, fmt.Errorf("保存上传文件失败: %w", aerr)
}
}
finalMessage = appendAttachmentsToMessage(finalMessage, req.Attachments, savedPaths)
userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths)
userMsgRow, uerr := h.db.AddMessage(conversationID, "user", userContent, nil)
if uerr != nil {
h.logger.Error("保存用户消息失败", zap.Error(uerr))
return nil, fmt.Errorf("保存用户消息失败: %w", uerr)
}
userMessageID := ""
if userMsgRow != nil {
userMessageID = userMsgRow.ID
}
assistantMsg, aerr := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
var assistantMessageID string
if aerr != nil {
h.logger.Warn("创建助手消息占位失败", zap.Error(aerr))
} else if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
return &multiAgentPrepared{
ConversationID: conversationID,
CreatedNew: createdNew,
History: agentHistoryMessages,
FinalMessage: finalMessage,
RoleTools: roleTools,
AssistantMessageID: assistantMessageID,
UserMessageID: userMessageID,
}, nil
}
File diff suppressed because it is too large Load Diff
-139
View File
@@ -1,139 +0,0 @@
package handler
// apiDocI18n 为 OpenAPI 文档提供 x-i18n-* 扩展键,供前端 apiDocs 国际化使用。
// 前端通过 apiDocs.tags.* / apiDocs.summary.* / apiDocs.response.* 翻译。
var apiDocI18nTagToKey = map[string]string{
"认证": "auth", "对话管理": "conversationManagement", "对话交互": "conversationInteraction",
"批量任务": "batchTasks", "对话分组": "conversationGroups", "漏洞管理": "vulnerabilityManagement",
"角色管理": "roleManagement", "Skills管理": "skillsManagement", "监控": "monitoring",
"配置管理": "configManagement", "外部MCP管理": "externalMCPManagement", "攻击链": "attackChain",
"知识库": "knowledgeBase", "MCP": "mcp",
}
var apiDocI18nSummaryToKey = map[string]string{
"用户登录": "login", "用户登出": "logout", "修改密码": "changePassword", "验证Token": "validateToken",
"创建对话": "createConversation", "列出对话": "listConversations", "查看对话详情": "getConversationDetail",
"更新对话": "updateConversation", "删除对话": "deleteConversation", "获取对话结果": "getConversationResult",
"发送消息并获取AI回复(非流式)": "sendMessageNonStream", "发送消息并获取AI回复(流式)": "sendMessageStream",
"取消任务": "cancelTask", "列出运行中的任务": "listRunningTasks", "列出已完成的任务": "listCompletedTasks",
"创建批量任务队列": "createBatchQueue", "列出批量任务队列": "listBatchQueues", "获取批量任务队列": "getBatchQueue",
"删除批量任务队列": "deleteBatchQueue", "启动批量任务队列": "startBatchQueue", "暂停批量任务队列": "pauseBatchQueue",
"添加任务到队列": "addTaskToQueue", "SQL注入扫描": "sqlInjectionScan", "端口扫描": "portScan",
"更新批量任务": "updateBatchTask", "删除批量任务": "deleteBatchTask",
"创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup",
"删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup",
"从分组移除对话": "removeConversationFromGroup",
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
"获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability",
"列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole",
"获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill",
"获取Skill统计": "getSkillStats", "清空Skill统计": "clearSkillStats", "获取Skill": "getSkill",
"更新Skill": "updateSkill", "删除Skill": "deleteSkill", "获取绑定角色": "getBoundRoles",
"获取监控信息": "getMonitorInfo", "获取执行记录": "getExecutionRecords", "删除执行记录": "deleteExecutionRecord",
"批量删除执行记录": "batchDeleteExecutionRecords", "获取统计信息": "getStats",
"获取配置": "getConfig", "更新配置": "updateConfig", "获取工具配置": "getToolConfig", "应用配置": "applyConfig",
"列出外部MCP": "listExternalMCP", "获取外部MCP统计": "getExternalMCPStats", "获取外部MCP": "getExternalMCP",
"添加或更新外部MCP": "addOrUpdateExternalMCP", "stdio模式配置": "stdioModeConfig", "SSE模式配置": "sseModeConfig",
"删除外部MCP": "deleteExternalMCP", "启动外部MCP": "startExternalMCP", "停止外部MCP": "stopExternalMCP",
"获取攻击链": "getAttackChain", "重新生成攻击链": "regenerateAttackChain",
"设置对话置顶": "pinConversation", "设置分组置顶": "pinGroup", "设置分组中对话的置顶": "pinGroupConversation",
"获取分类": "getCategories", "列出知识项": "listKnowledgeItems", "创建知识项": "createKnowledgeItem",
"获取知识项": "getKnowledgeItem", "更新知识项": "updateKnowledgeItem", "删除知识项": "deleteKnowledgeItem",
"获取索引状态": "getIndexStatus", "重建索引": "rebuildIndex", "扫描知识库": "scanKnowledgeBase",
"搜索知识库": "searchKnowledgeBase", "基础搜索": "basicSearch", "按风险类型搜索": "searchByRiskType",
"获取检索日志": "getRetrievalLogs", "删除检索日志": "deleteRetrievalLog",
"MCP端点": "mcpEndpoint", "列出所有工具": "listAllTools", "调用工具": "invokeTool", "初始化连接": "initConnection",
"成功响应": "successResponse", "错误响应": "errorResponse",
}
var apiDocI18nResponseDescToKey = map[string]string{
"获取成功": "getSuccess", "未授权": "unauthorized", "未授权,需要有效的Token": "unauthorizedToken",
"创建成功": "createSuccess", "请求参数错误": "badRequest", "对话不存在": "conversationNotFound",
"对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty",
"请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound",
"请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig",
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
"登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess",
"密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid",
"对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess",
"删除成功": "deleteSuccess", "队列不存在": "queueNotFound", "启动成功": "startSuccess",
"暂停成功": "pauseSuccess", "添加成功": "addSuccess",
"任务不存在": "taskNotFound", "对话或分组不存在": "conversationOrGroupNotFound",
"取消请求已提交": "cancelSubmitted", "未找到正在执行的任务": "noRunningTask",
"消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events": "streamResponse",
}
// enrichSpecWithI18nKeys 在 spec 的每个 operation 上写入 x-i18n-tags、x-i18n-summary
// 在每个 response 上写入 x-i18n-description,供前端按 key 做国际化。
func enrichSpecWithI18nKeys(spec map[string]interface{}) {
paths, _ := spec["paths"].(map[string]interface{})
if paths == nil {
return
}
for _, pathItem := range paths {
pm, _ := pathItem.(map[string]interface{})
if pm == nil {
continue
}
for _, method := range []string{"get", "post", "put", "delete", "patch"} {
opVal, ok := pm[method]
if !ok {
continue
}
op, _ := opVal.(map[string]interface{})
if op == nil {
continue
}
// x-i18n-tags: 与 tags 一一对应的 i18n 键数组(spec 中 tags 为 []string
switch tags := op["tags"].(type) {
case []string:
if len(tags) > 0 {
keys := make([]string, 0, len(tags))
for _, s := range tags {
if k := apiDocI18nTagToKey[s]; k != "" {
keys = append(keys, k)
} else {
keys = append(keys, s)
}
}
op["x-i18n-tags"] = keys
}
case []interface{}:
if len(tags) > 0 {
keys := make([]interface{}, 0, len(tags))
for _, t := range tags {
if s, ok := t.(string); ok {
if k := apiDocI18nTagToKey[s]; k != "" {
keys = append(keys, k)
} else {
keys = append(keys, s)
}
}
}
if len(keys) > 0 {
op["x-i18n-tags"] = keys
}
}
}
// x-i18n-summary
if summary, _ := op["summary"].(string); summary != "" {
if k := apiDocI18nSummaryToKey[summary]; k != "" {
op["x-i18n-summary"] = k
}
}
// responses -> 每个 status -> x-i18n-description
if respMap, _ := op["responses"].(map[string]interface{}); respMap != nil {
for _, rv := range respMap {
if r, _ := rv.(map[string]interface{}); r != nil {
if desc, _ := r["description"].(string); desc != "" {
if k := apiDocI18nResponseDescToKey[desc]; k != "" {
r["x-i18n-description"] = k
}
}
}
}
}
}
}
}
-907
View File
@@ -1,907 +0,0 @@
package handler
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"sort"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
const (
robotCmdHelp = "帮助"
robotCmdList = "列表"
robotCmdListAlt = "对话列表"
robotCmdSwitch = "切换"
robotCmdContinue = "继续"
robotCmdNew = "新对话"
robotCmdClear = "清空"
robotCmdCurrent = "当前"
robotCmdStop = "停止"
robotCmdRoles = "角色"
robotCmdRolesList = "角色列表"
robotCmdSwitchRole = "切换角色"
robotCmdDelete = "删除"
robotCmdVersion = "版本"
)
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理
type RobotHandler struct {
config *config.Config
db *database.DB
agentHandler *AgentHandler
logger *zap.Logger
mu sync.RWMutex
sessions map[string]string // key: "platform_userID", value: conversationID
sessionRoles map[string]string // key: "platform_userID", value: roleName(默认"默认"
cancelMu sync.Mutex // 保护 runningCancels
runningCancels map[string]context.CancelFunc // key: "platform_userID", 用于停止命令中断任务
}
// NewRobotHandler 创建机器人处理器
func NewRobotHandler(cfg *config.Config, db *database.DB, agentHandler *AgentHandler, logger *zap.Logger) *RobotHandler {
return &RobotHandler{
config: cfg,
db: db,
agentHandler: agentHandler,
logger: logger,
sessions: make(map[string]string),
sessionRoles: make(map[string]string),
runningCancels: make(map[string]context.CancelFunc),
}
}
// sessionKey 生成会话 key
func (h *RobotHandler) sessionKey(platform, userID string) string {
return platform + "_" + userID
}
// getOrCreateConversation 获取或创建当前会话,title 用于新对话的标题(取用户首条消息前50字)
func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (convID string, isNew bool) {
h.mu.RLock()
convID = h.sessions[h.sessionKey(platform, userID)]
h.mu.RUnlock()
if convID != "" {
return convID, false
}
t := strings.TrimSpace(title)
if t == "" {
t = "新对话 " + time.Now().Format("01-02 15:04")
} else {
t = safeTruncateString(t, 50)
}
conv, err := h.db.CreateConversation(t)
if err != nil {
h.logger.Warn("创建机器人会话失败", zap.Error(err))
return "", false
}
convID = conv.ID
h.mu.Lock()
h.sessions[h.sessionKey(platform, userID)] = convID
h.mu.Unlock()
return convID, true
}
// setConversation 切换当前会话
func (h *RobotHandler) setConversation(platform, userID, convID string) {
h.mu.Lock()
h.sessions[h.sessionKey(platform, userID)] = convID
h.mu.Unlock()
}
// getRole 获取当前用户使用的角色,未设置时返回"默认"
func (h *RobotHandler) getRole(platform, userID string) string {
h.mu.RLock()
role := h.sessionRoles[h.sessionKey(platform, userID)]
h.mu.RUnlock()
if role == "" {
return "默认"
}
return role
}
// setRole 设置当前用户使用的角色
func (h *RobotHandler) setRole(platform, userID, roleName string) {
h.mu.Lock()
h.sessionRoles[h.sessionKey(platform, userID)] = roleName
h.mu.Unlock()
}
// clearConversation 清空当前会话(切换到新对话)
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
title := "新对话 " + time.Now().Format("01-02 15:04")
conv, err := h.db.CreateConversation(title)
if err != nil {
h.logger.Warn("创建新对话失败", zap.Error(err))
return ""
}
h.setConversation(platform, userID, conv.ID)
return conv.ID
}
// HandleMessage 处理用户输入,返回回复文本(供各平台 webhook 调用)
func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply string) {
text = strings.TrimSpace(text)
if text == "" {
return "请输入内容或发送「帮助」/ help 查看命令。"
}
// 先尝试作为命令处理(支持中英文)
if cmdReply, ok := h.handleRobotCommand(platform, userID, text); ok {
return cmdReply
}
// 普通消息:走 Agent
convID, _ := h.getOrCreateConversation(platform, userID, text)
if convID == "" {
return "无法创建或获取对话,请稍后再试。"
}
// 若对话标题为「新对话 xx:xx」格式(由「新对话」命令创建),将标题更新为首条消息内容,与 Web 端体验一致
if conv, err := h.db.GetConversation(convID); err == nil && strings.HasPrefix(conv.Title, "新对话 ") {
newTitle := safeTruncateString(text, 50)
if newTitle != "" {
_ = h.db.UpdateConversationTitle(convID, newTitle)
}
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
sk := h.sessionKey(platform, userID)
h.cancelMu.Lock()
h.runningCancels[sk] = cancel
h.cancelMu.Unlock()
defer func() {
cancel()
h.cancelMu.Lock()
delete(h.runningCancels, sk)
h.cancelMu.Unlock()
}()
role := h.getRole(platform, userID)
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, convID, text, role)
if err != nil {
h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err))
if errors.Is(err, context.Canceled) {
return "任务已取消。"
}
return "处理失败: " + err.Error()
}
if newConvID != convID {
h.setConversation(platform, userID, newConvID)
}
return resp
}
func (h *RobotHandler) cmdHelp() string {
return "**【CyberStrikeAI 机器人命令】**\n\n" +
"- `帮助` `help` — 显示本帮助 | Show this help\n" +
"- `列表` `list` — 列出所有对话标题与 ID | List conversations\n" +
"- `切换 <ID>` `switch <ID>` — 指定对话继续 | Switch to conversation\n" +
"- `新对话` `new` — 开启新对话 | Start new conversation\n" +
"- `清空` `clear` — 清空当前上下文 | Clear context\n" +
"- `当前` `current` — 显示当前对话 ID 与标题 | Show current conversation\n" +
"- `停止` `stop` — 中断当前任务 | Stop running task\n" +
"- `角色` `roles` — 列出所有可用角色 | List roles\n" +
"- `角色 <名>` `role <name>` — 切换当前角色 | Switch role\n" +
"- `删除 <ID>` `delete <ID>` — 删除指定对话 | Delete conversation\n" +
"- `版本` `version` — 显示当前版本号 | Show version\n\n" +
"---\n" +
"除以上命令外,直接输入内容将发送给 AI 进行渗透测试/安全分析。\n" +
"Otherwise, send any text for AI penetration testing / security analysis."
}
func (h *RobotHandler) cmdList() string {
convs, err := h.db.ListConversations(50, 0, "")
if err != nil {
return "获取对话列表失败: " + err.Error()
}
if len(convs) == 0 {
return "暂无对话。发送任意内容将自动创建新对话。"
}
var b strings.Builder
b.WriteString("【对话列表】\n")
for i, c := range convs {
if i >= 20 {
b.WriteString("… 仅显示前 20 条\n")
break
}
b.WriteString(fmt.Sprintf("· %s\n ID: %s\n", c.Title, c.ID))
}
return strings.TrimSuffix(b.String(), "\n")
}
func (h *RobotHandler) cmdSwitch(platform, userID, convID string) string {
if convID == "" {
return "请指定对话 ID,例如:切换 xxx-xxx-xxx"
}
conv, err := h.db.GetConversation(convID)
if err != nil {
return "对话不存在或 ID 错误。"
}
h.setConversation(platform, userID, conv.ID)
return fmt.Sprintf("已切换到对话:「%s」\nID: %s", conv.Title, conv.ID)
}
func (h *RobotHandler) cmdNew(platform, userID string) string {
newID := h.clearConversation(platform, userID)
if newID == "" {
return "创建新对话失败,请重试。"
}
return "已开启新对话,可直接发送内容。"
}
func (h *RobotHandler) cmdClear(platform, userID string) string {
return h.cmdNew(platform, userID)
}
func (h *RobotHandler) cmdStop(platform, userID string) string {
sk := h.sessionKey(platform, userID)
h.cancelMu.Lock()
cancel, ok := h.runningCancels[sk]
if ok {
delete(h.runningCancels, sk)
cancel()
}
h.cancelMu.Unlock()
if !ok {
return "当前没有正在执行的任务。"
}
return "已停止当前任务。"
}
func (h *RobotHandler) cmdCurrent(platform, userID string) string {
h.mu.RLock()
convID := h.sessions[h.sessionKey(platform, userID)]
h.mu.RUnlock()
if convID == "" {
return "当前没有进行中的对话。发送任意内容将创建新对话。"
}
conv, err := h.db.GetConversation(convID)
if err != nil {
return "当前对话 ID: " + convID + "(获取标题失败)"
}
role := h.getRole(platform, userID)
return fmt.Sprintf("当前对话:「%s」\nID: %s\n当前角色: %s", conv.Title, conv.ID, role)
}
func (h *RobotHandler) cmdRoles() string {
if h.config.Roles == nil || len(h.config.Roles) == 0 {
return "暂无可用角色。"
}
names := make([]string, 0, len(h.config.Roles))
for name, role := range h.config.Roles {
if role.Enabled {
names = append(names, name)
}
}
if len(names) == 0 {
return "暂无可用角色。"
}
sort.Slice(names, func(i, j int) bool {
if names[i] == "默认" {
return true
}
if names[j] == "默认" {
return false
}
return names[i] < names[j]
})
var b strings.Builder
b.WriteString("【角色列表】\n")
for _, name := range names {
role := h.config.Roles[name]
desc := role.Description
if desc == "" {
desc = "无描述"
}
b.WriteString(fmt.Sprintf("· %s — %s\n", name, desc))
}
return strings.TrimSuffix(b.String(), "\n")
}
func (h *RobotHandler) cmdSwitchRole(platform, userID, roleName string) string {
if roleName == "" {
return "请指定角色名称,例如:角色 渗透测试"
}
if h.config.Roles == nil {
return "暂无可用角色。"
}
role, exists := h.config.Roles[roleName]
if !exists {
return fmt.Sprintf("角色「%s」不存在。发送「角色」查看可用角色。", roleName)
}
if !role.Enabled {
return fmt.Sprintf("角色「%s」已禁用。", roleName)
}
h.setRole(platform, userID, roleName)
return fmt.Sprintf("已切换到角色:「%s」\n%s", roleName, role.Description)
}
func (h *RobotHandler) cmdDelete(platform, userID, convID string) string {
if convID == "" {
return "请指定对话 ID,例如:删除 xxx-xxx-xxx"
}
sk := h.sessionKey(platform, userID)
h.mu.RLock()
currentConvID := h.sessions[sk]
h.mu.RUnlock()
if convID == currentConvID {
// 删除当前对话时,先清空会话绑定
h.mu.Lock()
delete(h.sessions, sk)
h.mu.Unlock()
}
if err := h.db.DeleteConversation(convID); err != nil {
return "删除失败: " + err.Error()
}
return fmt.Sprintf("已删除对话 ID: %s", convID)
}
func (h *RobotHandler) cmdVersion() string {
v := h.config.Version
if v == "" {
v = "未知"
}
return "CyberStrikeAI " + v
}
// handleRobotCommand 处理机器人内置命令;若匹配到命令返回 (回复内容, true),否则返回 ("", false)
func (h *RobotHandler) handleRobotCommand(platform, userID, text string) (string, bool) {
switch {
case text == robotCmdHelp || text == "help" || text == "" || text == "?":
return h.cmdHelp(), true
case text == robotCmdList || text == robotCmdListAlt || text == "list":
return h.cmdList(), true
case strings.HasPrefix(text, robotCmdSwitch+" ") || strings.HasPrefix(text, robotCmdContinue+" ") || strings.HasPrefix(text, "switch ") || strings.HasPrefix(text, "continue "):
var id string
switch {
case strings.HasPrefix(text, robotCmdSwitch+" "):
id = strings.TrimSpace(text[len(robotCmdSwitch)+1:])
case strings.HasPrefix(text, robotCmdContinue+" "):
id = strings.TrimSpace(text[len(robotCmdContinue)+1:])
case strings.HasPrefix(text, "switch "):
id = strings.TrimSpace(text[7:])
default:
id = strings.TrimSpace(text[9:])
}
return h.cmdSwitch(platform, userID, id), true
case text == robotCmdNew || text == "new":
return h.cmdNew(platform, userID), true
case text == robotCmdClear || text == "clear":
return h.cmdClear(platform, userID), true
case text == robotCmdCurrent || text == "current":
return h.cmdCurrent(platform, userID), true
case text == robotCmdStop || text == "stop":
return h.cmdStop(platform, userID), true
case text == robotCmdRoles || text == robotCmdRolesList || text == "roles":
return h.cmdRoles(), true
case strings.HasPrefix(text, robotCmdRoles+" ") || strings.HasPrefix(text, robotCmdSwitchRole+" ") || strings.HasPrefix(text, "role "):
var roleName string
switch {
case strings.HasPrefix(text, robotCmdRoles+" "):
roleName = strings.TrimSpace(text[len(robotCmdRoles)+1:])
case strings.HasPrefix(text, robotCmdSwitchRole+" "):
roleName = strings.TrimSpace(text[len(robotCmdSwitchRole)+1:])
default:
roleName = strings.TrimSpace(text[5:])
}
return h.cmdSwitchRole(platform, userID, roleName), true
case strings.HasPrefix(text, robotCmdDelete+" ") || strings.HasPrefix(text, "delete "):
var convID string
if strings.HasPrefix(text, robotCmdDelete+" ") {
convID = strings.TrimSpace(text[len(robotCmdDelete)+1:])
} else {
convID = strings.TrimSpace(text[7:])
}
return h.cmdDelete(platform, userID, convID), true
case text == robotCmdVersion || text == "version":
return h.cmdVersion(), true
default:
return "", false
}
}
// —————— 企业微信 ——————
// wecomXML 企业微信回调 XML(明文模式下的简化结构;加密模式需先解密再解析)
type wecomXML struct {
ToUserName string `xml:"ToUserName"`
FromUserName string `xml:"FromUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType string `xml:"MsgType"`
Content string `xml:"Content"`
MsgID string `xml:"MsgId"`
AgentID int64 `xml:"AgentID"`
Encrypt string `xml:"Encrypt"` // 加密模式下消息在此
}
// wecomReplyXML 被动回复 XML(仅用于兼容,当前使用手动构造 XML)
type wecomReplyXML struct {
XMLName xml.Name `xml:"xml"`
ToUserName string `xml:"ToUserName"`
FromUserName string `xml:"FromUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType string `xml:"MsgType"`
Content string `xml:"Content"`
}
// HandleWecomGET 企业微信 URL 校验(GET
func (h *RobotHandler) HandleWecomGET(c *gin.Context) {
if !h.config.Robots.Wecom.Enabled {
c.String(http.StatusNotFound, "")
return
}
// Gin 的 Query() 会自动 URL 解码,拿到的就是正确的 base64 字符串
echostr := c.Query("echostr")
msgSignature := c.Query("msg_signature")
timestamp := c.Query("timestamp")
nonce := c.Query("nonce")
// 验证签名:将 token、timestamp、nonce、echostr 四个参数排序后拼接计算 SHA1
signature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, echostr)
if signature != msgSignature {
h.logger.Warn("企业微信 URL 验证签名失败", zap.String("expected", msgSignature), zap.String("got", signature))
c.String(http.StatusBadRequest, "invalid signature")
return
}
if echostr == "" {
c.String(http.StatusBadRequest, "missing echostr")
return
}
// 如果配置了 EncodingAESKey,说明是加密模式,需要解密 echostr
if h.config.Robots.Wecom.EncodingAESKey != "" {
decrypted, err := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, echostr)
if err != nil {
h.logger.Warn("企业微信 echostr 解密失败", zap.Error(err))
c.String(http.StatusBadRequest, "decrypt failed")
return
}
c.String(http.StatusOK, string(decrypted))
return
}
// 明文模式直接返回 echostr
c.String(http.StatusOK, echostr)
}
// signWecomRequest 生成企业微信请求签名
// 企业微信签名算法:将 token、timestamp、nonce、echostr 四个值排序后拼接成字符串,再计算 SHA1
func (h *RobotHandler) signWecomRequest(token, timestamp, nonce, echostr string) string {
strs := []string{token, timestamp, nonce, echostr}
sort.Strings(strs)
s := strings.Join(strs, "")
hash := sha1.Sum([]byte(s))
return fmt.Sprintf("%x", hash)
}
// wecomDecrypt 企业微信消息解密(AES-256-CBC,PKCS7,明文格式:16字节随机+4字节长度+消息+corpID)
func wecomDecrypt(encodingAESKey, encryptedB64 string) ([]byte, error) {
key, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
if err != nil {
return nil, err
}
if len(key) != 32 {
return nil, fmt.Errorf("encoding_aes_key 解码后应为 32 字节")
}
ciphertext, err := base64.StdEncoding.DecodeString(encryptedB64)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
iv := key[:16]
mode := cipher.NewCBCDecrypter(block, iv)
if len(ciphertext)%aes.BlockSize != 0 {
return nil, fmt.Errorf("密文长度不是块大小的倍数")
}
plain := make([]byte, len(ciphertext))
mode.CryptBlocks(plain, ciphertext)
// 去除 PKCS7 填充
n := int(plain[len(plain)-1])
if n < 1 || n > 32 {
return nil, fmt.Errorf("无效的 PKCS7 填充")
}
plain = plain[:len(plain)-n]
// 企业微信格式:16 字节随机 + 4 字节长度(大端) + 消息 + corpID
if len(plain) < 20 {
return nil, fmt.Errorf("明文过短")
}
msgLen := binary.BigEndian.Uint32(plain[16:20])
if int(20+msgLen) > len(plain) {
return nil, fmt.Errorf("消息长度越界")
}
return plain[20 : 20+msgLen], nil
}
// wecomEncrypt 企业微信消息加密(AES-256-CBC,PKCS7,明文格式:16字节随机+4字节长度+消息+corpID)
func wecomEncrypt(encodingAESKey, message, corpID string) (string, error) {
key, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
if err != nil {
return "", err
}
if len(key) != 32 {
return "", fmt.Errorf("encoding_aes_key 解码后应为 32 字节")
}
// 构造明文:16 字节随机 + 4 字节长度 (大端) + 消息 + corpID
random := make([]byte, 16)
if _, err := rand.Read(random); err != nil {
// 降级方案:使用时间戳生成随机数
for i := range random {
random[i] = byte(time.Now().UnixNano() % 256)
}
}
msgLen := len(message)
msgBytes := []byte(message)
corpBytes := []byte(corpID)
plain := make([]byte, 16+4+msgLen+len(corpBytes))
copy(plain[:16], random)
binary.BigEndian.PutUint32(plain[16:20], uint32(msgLen))
copy(plain[20:20+msgLen], msgBytes)
copy(plain[20+msgLen:], corpBytes)
// PKCS7 填充
padding := aes.BlockSize - len(plain)%aes.BlockSize
pad := bytes.Repeat([]byte{byte(padding)}, padding)
plain = append(plain, pad...)
// AES-256-CBC 加密
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
iv := key[:16]
ciphertext := make([]byte, len(plain))
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(ciphertext, plain)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// HandleWecomPOST 企业微信消息回调(POST),支持明文与加密模式
func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
if !h.config.Robots.Wecom.Enabled {
h.logger.Debug("企业微信机器人未启用,跳过请求")
c.String(http.StatusOK, "")
return
}
// 从 URL 获取签名参数(加密模式回复时需要用到)
timestamp := c.Query("timestamp")
nonce := c.Query("nonce")
msgSignature := c.Query("msg_signature")
// 先读取请求体,后续解析/签名验证都会用到
bodyRaw, err := io.ReadAll(c.Request.Body)
if err != nil {
h.logger.Warn("企业微信 POST 读取请求体失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
h.logger.Debug("企业微信 POST 收到请求", zap.String("body", string(bodyRaw)))
// 验证请求签名防止伪造。企业微信签名算法同 URL 验证,使用 token、timestamp、nonce、 Encrypt 四个字段
// 若配置了 Token 则必须校验签名,避免未授权请求触发 Agent(防止平台被接管)
token := h.config.Robots.Wecom.Token
if token != "" {
if msgSignature == "" {
h.logger.Warn("企业微信 POST 缺少签名,已拒绝(需配置 token 并确保回调携带 msg_signature")
c.String(http.StatusOK, "")
return
}
var tmp wecomXML
if err := xml.Unmarshal(bodyRaw, &tmp); err != nil {
h.logger.Warn("企业微信 POST 签名验证前解析 XML 失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
expected := h.signWecomRequest(token, timestamp, nonce, tmp.Encrypt)
if expected != msgSignature {
h.logger.Warn("企业微信 POST 签名验证失败", zap.String("expected", expected), zap.String("got", msgSignature))
c.String(http.StatusOK, "")
return
}
}
var body wecomXML
if err := xml.Unmarshal(bodyRaw, &body); err != nil {
h.logger.Warn("企业微信 POST 解析 XML 失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
h.logger.Debug("企业微信 XML 解析成功", zap.String("ToUserName", body.ToUserName), zap.String("FromUserName", body.FromUserName), zap.String("MsgType", body.MsgType), zap.String("Content", body.Content), zap.String("Encrypt", body.Encrypt))
// 保存企业 ID(用于明文模式回复)
enterpriseID := body.ToUserName
// 加密模式:先解密再解析内层 XML
if body.Encrypt != "" && h.config.Robots.Wecom.EncodingAESKey != "" {
h.logger.Debug("企业微信进入加密模式解密流程")
decrypted, err := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, body.Encrypt)
if err != nil {
h.logger.Warn("企业微信消息解密失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
h.logger.Debug("企业微信解密成功", zap.String("decrypted", string(decrypted)))
if err := xml.Unmarshal(decrypted, &body); err != nil {
h.logger.Warn("企业微信解密后 XML 解析失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
h.logger.Debug("企业微信内层 XML 解析成功", zap.String("FromUserName", body.FromUserName), zap.String("Content", body.Content))
}
userID := body.FromUserName
text := strings.TrimSpace(body.Content)
// 限制回复内容长度(企业微信限制 2048 字节)
maxReplyLen := 2000
limitReply := func(s string) string {
if len(s) > maxReplyLen {
return s[:maxReplyLen] + "\n\n(内容过长,已截断)"
}
return s
}
if body.MsgType != "text" {
h.logger.Debug("企业微信收到非文本消息", zap.String("MsgType", body.MsgType))
h.sendWecomReply(c, userID, enterpriseID, limitReply("暂仅支持文本消息,请发送文字。"), timestamp, nonce)
return
}
// 文本消息:先判断是否为内置命令(如 帮助/列表/新对话 等),这类命令处理很快,可以直接走被动回复,避免依赖主动发送 API。
if cmdReply, ok := h.handleRobotCommand("wecom", userID, text); ok {
h.logger.Debug("企业微信收到命令消息,走被动回复", zap.String("userID", userID), zap.String("text", text))
h.sendWecomReply(c, userID, enterpriseID, limitReply(cmdReply), timestamp, nonce)
return
}
h.logger.Debug("企业微信开始处理消息(异步 AI)", zap.String("userID", userID), zap.String("text", text))
// 企业微信被动回复有 5 秒超时限制,而 AI 调用通常超过该时长。
// 这里采用推荐做法:立即返回 success(或空串),然后通过主动发送接口推送完整回复。
c.String(http.StatusOK, "success")
// 异步处理消息并通过企业微信主动消息接口发送结果
go func() {
reply := h.HandleMessage("wecom", userID, text)
reply = limitReply(reply)
h.logger.Debug("企业微信消息处理完成", zap.String("userID", userID), zap.String("reply", reply))
// 调用企业微信 API 主动发送消息
h.sendWecomMessageViaAPI(userID, enterpriseID, reply)
}()
}
// sendWecomReply 发送企业微信回复(加密模式自动加密)
// 参数:toUser=用户 ID, fromUser=企业 ID(明文模式)/CorpID(加密模式), content=回复内容,timestamp/nonce=请求参数
func (h *RobotHandler) sendWecomReply(c *gin.Context, toUser, fromUser, content, timestamp, nonce string) {
// 加密模式:判断 EncodingAESKey 是否配置
if h.config.Robots.Wecom.EncodingAESKey != "" {
// 加密模式使用 CorpID 进行加密
corpID := h.config.Robots.Wecom.CorpID
if corpID == "" {
h.logger.Warn("企业微信加密模式缺少 CorpID 配置")
c.String(http.StatusOK, "")
return
}
// 构造完整的明文 XML 回复(格式严格按企业微信文档要求)
plainResp := fmt.Sprintf(`<xml>
<ToUserName><![CDATA[%s]]></ToUserName>
<FromUserName><![CDATA[%s]]></FromUserName>
<CreateTime>%d</CreateTime>
<MsgType><![CDATA[text]]></MsgType>
<Content><![CDATA[%s]]></Content>
</xml>`, toUser, fromUser, time.Now().Unix(), content)
encrypted, err := wecomEncrypt(h.config.Robots.Wecom.EncodingAESKey, plainResp, corpID)
if err != nil {
h.logger.Warn("企业微信回复加密失败", zap.Error(err))
c.String(http.StatusOK, "")
return
}
// 使用请求中的 timestamp/nonce 生成签名(企业微信要求回复时使用与请求相同的 timestamp 和 nonce
msgSignature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, encrypted)
h.logger.Debug("企业微信发送加密回复",
zap.String("Encrypt", encrypted[:50]+"..."),
zap.String("MsgSignature", msgSignature),
zap.String("TimeStamp", timestamp),
zap.String("Nonce", nonce))
// 加密模式仅返回 4 个核心字段(企业微信官方要求)
xmlResp := fmt.Sprintf(`<xml><Encrypt><![CDATA[%s]]></Encrypt><MsgSignature><![CDATA[%s]]></MsgSignature><TimeStamp><![CDATA[%s]]></TimeStamp><Nonce><![CDATA[%s]]></Nonce></xml>`, encrypted, msgSignature, timestamp, nonce)
// also log the final response body so we can cross-check with the
// network traffic or developer console
h.logger.Debug("企业微信加密回复包", zap.String("xml", xmlResp))
// for additional confidence, decrypt the payload ourselves and log it
if dec, err2 := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, encrypted); err2 == nil {
h.logger.Debug("企业微信加密回复解密检查", zap.String("plain", string(dec)))
} else {
h.logger.Warn("企业微信加密回复解密检查失败", zap.Error(err2))
}
// 使用 c.Writer.Write 直接写入响应,避免 c.String 的转义问题
c.Writer.WriteHeader(http.StatusOK)
// use text/xml as that's what WeCom examples show
c.Writer.Header().Set("Content-Type", "text/xml; charset=utf-8")
_, _ = c.Writer.Write([]byte(xmlResp))
h.logger.Debug("企业微信加密回复已发送")
return
}
// 明文模式
h.logger.Debug("企业微信发送明文回复", zap.String("ToUserName", toUser), zap.String("FromUserName", fromUser), zap.String("Content", content[:50]+"..."))
// 手动构造 XML 响应(使用 CDATA 包裹所有字段,并包含 AgentID)
xmlResp := fmt.Sprintf(`<xml>
<ToUserName><![CDATA[%s]]></ToUserName>
<FromUserName><![CDATA[%s]]></FromUserName>
<CreateTime>%d</CreateTime>
<MsgType><![CDATA[text]]></MsgType>
<Content><![CDATA[%s]]></Content>
</xml>`, toUser, fromUser, time.Now().Unix(), content)
// log the exact plaintext response for debugging
h.logger.Debug("企业微信明文回复包", zap.String("xml", xmlResp))
// use text/xml as recommended by WeCom docs
c.Header("Content-Type", "text/xml; charset=utf-8")
c.String(http.StatusOK, xmlResp)
h.logger.Debug("企业微信明文回复已发送")
}
// —————— 测试接口(需登录,用于验证机器人逻辑,无需钉钉/飞书客户端) ——————
// RobotTestRequest 模拟机器人消息请求
type RobotTestRequest struct {
Platform string `json:"platform"` // 如 "dingtalk"、"lark"、"wecom"
UserID string `json:"user_id"`
Text string `json:"text"`
}
// HandleRobotTest 供本地验证:POST JSON { "platform", "user_id", "text" },返回 { "reply": "..." }
func (h *RobotHandler) HandleRobotTest(c *gin.Context) {
var req RobotTestRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体需为 JSON,包含 platform、user_id、text"})
return
}
platform := strings.TrimSpace(req.Platform)
if platform == "" {
platform = "test"
}
userID := strings.TrimSpace(req.UserID)
if userID == "" {
userID = "test_user"
}
reply := h.HandleMessage(platform, userID, req.Text)
c.JSON(http.StatusOK, gin.H{"reply": reply})
}
// sendWecomMessageViaAPI 通过企业微信 API 主动发送消息(用于异步处理后的结果发送)
func (h *RobotHandler) sendWecomMessageViaAPI(toUser, toParty, content string) {
if !h.config.Robots.Wecom.Enabled {
return
}
secret := h.config.Robots.Wecom.Secret
corpID := h.config.Robots.Wecom.CorpID
agentID := h.config.Robots.Wecom.AgentID
if secret == "" || corpID == "" {
h.logger.Warn("企业微信主动 API 缺少 secret 或 corpID 配置")
return
}
// 第 1 步:获取 access_token
tokenURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s", corpID, secret)
resp, err := http.Get(tokenURL)
if err != nil {
h.logger.Warn("企业微信获取 token 失败", zap.Error(err))
return
}
defer resp.Body.Close()
var tokenResp struct {
AccessToken string `json:"access_token"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
h.logger.Warn("企业微信 token 响应解析失败", zap.Error(err))
return
}
if tokenResp.ErrCode != 0 {
h.logger.Warn("企业微信 token 获取错误", zap.String("errmsg", tokenResp.ErrMsg), zap.Int("errcode", tokenResp.ErrCode))
return
}
// 第 2 步:构造发送消息请求
msgReq := map[string]interface{}{
"touser": toUser,
"msgtype": "text",
"agentid": agentID,
"text": map[string]interface{}{
"content": content,
},
}
msgBody, err := json.Marshal(msgReq)
if err != nil {
h.logger.Warn("企业微信消息序列化失败", zap.Error(err))
return
}
// 第 3 步:发送消息
sendURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=%s", tokenResp.AccessToken)
msgResp, err := http.Post(sendURL, "application/json", bytes.NewReader(msgBody))
if err != nil {
h.logger.Warn("企业微信主动发送消息失败", zap.Error(err))
return
}
defer msgResp.Body.Close()
var sendResp struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
InvalidUser string `json:"invaliduser"`
MsgID string `json:"msgid"`
}
if err := json.NewDecoder(msgResp.Body).Decode(&sendResp); err != nil {
h.logger.Warn("企业微信发送响应解析失败", zap.Error(err))
return
}
if sendResp.ErrCode == 0 {
h.logger.Debug("企业微信主动发送消息成功", zap.String("msgid", sendResp.MsgID))
} else {
h.logger.Warn("企业微信主动发送消息失败", zap.String("errmsg", sendResp.ErrMsg), zap.Int("errcode", sendResp.ErrCode), zap.String("invaliduser", sendResp.InvalidUser))
}
}
// —————— 钉钉 ——————
// HandleDingtalkPOST 钉钉事件回调(流式接入等);当前为占位,返回 200
func (h *RobotHandler) HandleDingtalkPOST(c *gin.Context) {
if !h.config.Robots.Dingtalk.Enabled {
c.JSON(http.StatusOK, gin.H{})
return
}
// 钉钉流式/事件回调格式需按官方文档解析并异步回复,此处仅返回 200
c.JSON(http.StatusOK, gin.H{"message": "ok"})
}
// —————— 飞书 ——————
// HandleLarkPOST 飞书事件回调;当前为占位,返回 200;验证时需返回 challenge
func (h *RobotHandler) HandleLarkPOST(c *gin.Context) {
if !h.config.Robots.Lark.Enabled {
c.JSON(http.StatusOK, gin.H{})
return
}
var body struct {
Challenge string `json:"challenge"`
}
if err := c.ShouldBindJSON(&body); err == nil && body.Challenge != "" {
c.JSON(http.StatusOK, gin.H{"challenge": body.Challenge})
return
}
c.JSON(http.StatusOK, gin.H{})
}
-487
View File
@@ -1,487 +0,0 @@
package handler
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"cyberstrike-ai/internal/config"
"gopkg.in/yaml.v3"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// RoleHandler 角色处理器
type RoleHandler struct {
config *config.Config
configPath string
logger *zap.Logger
skillsManager SkillsManager // Skills管理器接口(可选)
}
// SkillsManager Skills管理器接口
type SkillsManager interface {
ListSkills() ([]string, error)
}
// NewRoleHandler 创建新的角色处理器
func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *RoleHandler {
return &RoleHandler{
config: cfg,
configPath: configPath,
logger: logger,
}
}
// SetSkillsManager 设置Skills管理器
func (h *RoleHandler) SetSkillsManager(manager SkillsManager) {
h.skillsManager = manager
}
// GetSkills 获取所有可用的skills列表
func (h *RoleHandler) GetSkills(c *gin.Context) {
if h.skillsManager == nil {
c.JSON(http.StatusOK, gin.H{
"skills": []string{},
})
return
}
skills, err := h.skillsManager.ListSkills()
if err != nil {
h.logger.Warn("获取skills列表失败", zap.Error(err))
c.JSON(http.StatusOK, gin.H{
"skills": []string{},
})
return
}
c.JSON(http.StatusOK, gin.H{
"skills": skills,
})
}
// GetRoles 获取所有角色
func (h *RoleHandler) GetRoles(c *gin.Context) {
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
roles := make([]config.RoleConfig, 0, len(h.config.Roles))
for key, role := range h.config.Roles {
// 确保角色的key与name一致
if role.Name == "" {
role.Name = key
}
roles = append(roles, role)
}
c.JSON(http.StatusOK, gin.H{
"roles": roles,
})
}
// GetRole 获取单个角色
func (h *RoleHandler) GetRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
if h.config.Roles == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
role, exists := h.config.Roles[roleName]
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
// 确保角色的name与key一致
if role.Name == "" {
role.Name = roleName
}
c.JSON(http.StatusOK, gin.H{
"role": role,
})
}
// UpdateRole 更新角色
func (h *RoleHandler) UpdateRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
var req config.RoleConfig
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
// 确保角色名称与请求中的name一致
if req.Name == "" {
req.Name = roleName
}
// 初始化Roles map
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
// 删除所有与角色name相同但key不同的旧角色(避免重复)
// 使用角色name作为key,确保唯一性
finalKey := req.Name
keysToDelete := make([]string, 0)
for key := range h.config.Roles {
// 如果key与最终的key不同,但name相同,则标记为删除
if key != finalKey {
role := h.config.Roles[key]
// 确保角色的name字段正确设置
if role.Name == "" {
role.Name = key
}
if role.Name == req.Name {
keysToDelete = append(keysToDelete, key)
}
}
}
// 删除旧的角色
for _, key := range keysToDelete {
delete(h.config.Roles, key)
h.logger.Info("删除重复的角色", zap.String("oldKey", key), zap.String("name", req.Name))
}
// 如果当前更新的key与最终key不同,也需要删除旧的
if roleName != finalKey {
delete(h.config.Roles, roleName)
}
// 如果角色名称改变,需要删除旧文件
if roleName != finalKey {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 删除旧的角色文件
oldSafeFileName := sanitizeFileName(roleName)
oldRoleFileYaml := filepath.Join(rolesDir, oldSafeFileName+".yaml")
oldRoleFileYml := filepath.Join(rolesDir, oldSafeFileName+".yml")
if _, err := os.Stat(oldRoleFileYaml); err == nil {
if err := os.Remove(oldRoleFileYaml); err != nil {
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYaml), zap.Error(err))
}
}
if _, err := os.Stat(oldRoleFileYml); err == nil {
if err := os.Remove(oldRoleFileYml); err != nil {
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYml), zap.Error(err))
}
}
}
// 使用角色name作为key来保存(确保唯一性)
h.config.Roles[finalKey] = req
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
c.JSON(http.StatusOK, gin.H{
"message": "角色已更新",
"role": req,
})
}
// CreateRole 创建新角色
func (h *RoleHandler) CreateRole(c *gin.Context) {
var req config.RoleConfig
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
if req.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
// 初始化Roles map
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
// 检查角色是否已存在
if _, exists := h.config.Roles[req.Name]; exists {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色已存在"})
return
}
// 创建角色(默认启用)
if !req.Enabled {
req.Enabled = true
}
h.config.Roles[req.Name] = req
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("创建角色", zap.String("roleName", req.Name))
c.JSON(http.StatusOK, gin.H{
"message": "角色已创建",
"role": req,
})
}
// DeleteRole 删除角色
func (h *RoleHandler) DeleteRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
if h.config.Roles == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
if _, exists := h.config.Roles[roleName]; !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
// 不允许删除"默认"角色
if roleName == "默认" {
c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除默认角色"})
return
}
delete(h.config.Roles, roleName)
// 删除对应的角色文件
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 尝试删除角色文件(.yaml 和 .yml)
safeFileName := sanitizeFileName(roleName)
roleFileYaml := filepath.Join(rolesDir, safeFileName+".yaml")
roleFileYml := filepath.Join(rolesDir, safeFileName+".yml")
// 删除 .yaml 文件(如果存在)
if _, err := os.Stat(roleFileYaml); err == nil {
if err := os.Remove(roleFileYaml); err != nil {
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYaml), zap.Error(err))
} else {
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYaml))
}
}
// 删除 .yml 文件(如果存在)
if _, err := os.Stat(roleFileYml); err == nil {
if err := os.Remove(roleFileYml); err != nil {
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYml), zap.Error(err))
} else {
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYml))
}
}
h.logger.Info("删除角色", zap.String("roleName", roleName))
c.JSON(http.StatusOK, gin.H{
"message": "角色已删除",
})
}
// saveConfig 保存配置到目录中的文件
func (h *RoleHandler) saveConfig() error {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 确保目录存在
if err := os.MkdirAll(rolesDir, 0755); err != nil {
return fmt.Errorf("创建角色目录失败: %w", err)
}
// 保存每个角色到独立的文件
if h.config.Roles != nil {
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
safeFileName := sanitizeFileName(role.Name)
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
// 将角色配置序列化为YAML
roleData, err := yaml.Marshal(&role)
if err != nil {
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
continue
}
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
roleDataStr := string(roleData)
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
// 使用负向前瞻确保后面没有引号,或者直接匹配没有引号的情况
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
roleData = []byte(roleDataStr)
}
// 写入文件
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
continue
}
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
}
}
return nil
}
// sanitizeFileName 将角色名称转换为安全的文件名
func sanitizeFileName(name string) string {
// 替换可能不安全的字符
replacer := map[rune]string{
'/': "_",
'\\': "_",
':': "_",
'*': "_",
'?': "_",
'"': "_",
'<': "_",
'>': "_",
'|': "_",
' ': "_",
}
var result []rune
for _, r := range name {
if replacement, ok := replacer[r]; ok {
result = append(result, []rune(replacement)...)
} else {
result = append(result, r)
}
}
fileName := string(result)
// 如果文件名为空,使用默认名称
if fileName == "" {
fileName = "role"
}
return fileName
}
// updateRolesConfig 更新角色配置
func updateRolesConfig(doc *yaml.Node, cfg config.RolesConfig) {
root := doc.Content[0]
rolesNode := ensureMap(root, "roles")
// 清空现有角色
if rolesNode.Kind == yaml.MappingNode {
rolesNode.Content = nil
}
// 添加新角色(使用name作为key,确保唯一性)
if cfg.Roles != nil {
// 先建立一个以name为key的map,去重(保留最后一个)
rolesByName := make(map[string]config.RoleConfig)
for roleKey, role := range cfg.Roles {
// 确保角色的name字段正确设置
if role.Name == "" {
role.Name = roleKey
}
// 使用name作为最终key,如果有多个key对应相同的name,只保留最后一个
rolesByName[role.Name] = role
}
// 将去重后的角色写入YAML
for roleName, role := range rolesByName {
roleNode := ensureMap(rolesNode, roleName)
setStringInMap(roleNode, "name", role.Name)
setStringInMap(roleNode, "description", role.Description)
setStringInMap(roleNode, "user_prompt", role.UserPrompt)
if role.Icon != "" {
setStringInMap(roleNode, "icon", role.Icon)
}
setBoolInMap(roleNode, "enabled", role.Enabled)
// 添加工具列表(优先使用tools字段)
if len(role.Tools) > 0 {
toolsNode := ensureArray(roleNode, "tools")
toolsNode.Content = nil
for _, toolKey := range role.Tools {
toolNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: toolKey}
toolsNode.Content = append(toolsNode.Content, toolNode)
}
} else if len(role.MCPs) > 0 {
// 向后兼容:如果没有tools但有mcps,保存mcps
mcpsNode := ensureArray(roleNode, "mcps")
mcpsNode.Content = nil
for _, mcpName := range role.MCPs {
mcpNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: mcpName}
mcpsNode.Content = append(mcpsNode.Content, mcpNode)
}
}
}
}
}
// ensureArray 确保数组中存在指定key的数组节点
func ensureArray(parent *yaml.Node, key string) *yaml.Node {
_, valueNode := ensureKeyValue(parent, key)
if valueNode.Kind != yaml.SequenceNode {
valueNode.Kind = yaml.SequenceNode
valueNode.Tag = "!!seq"
valueNode.Content = nil
}
return valueNode
}
-781
View File
@@ -1,781 +0,0 @@
package handler
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/skills"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
)
// SkillsHandler Skills处理器
type SkillsHandler struct {
manager *skills.Manager
config *config.Config
configPath string
logger *zap.Logger
db *database.DB // 数据库连接(用于获取调用统计)
}
// NewSkillsHandler 创建新的Skills处理器
func NewSkillsHandler(manager *skills.Manager, cfg *config.Config, configPath string, logger *zap.Logger) *SkillsHandler {
return &SkillsHandler{
manager: manager,
config: cfg,
configPath: configPath,
logger: logger,
}
}
// SetDB 设置数据库连接(用于获取调用统计)
func (h *SkillsHandler) SetDB(db *database.DB) {
h.db = db
}
// GetSkills 获取所有skills列表(支持分页和搜索)
func (h *SkillsHandler) GetSkills(c *gin.Context) {
skillList, err := h.manager.ListSkills()
if err != nil {
h.logger.Error("获取skills列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 搜索参数
searchKeyword := strings.TrimSpace(c.Query("search"))
// 先加载所有skills的详细信息用于搜索过滤
allSkillsInfo := make([]map[string]interface{}, 0, len(skillList))
for _, skillName := range skillList {
skill, err := h.manager.LoadSkill(skillName)
if err != nil {
h.logger.Warn("加载skill失败", zap.String("skill", skillName), zap.Error(err))
continue
}
// 获取文件信息
skillPath := skill.Path
skillFile := filepath.Join(skillPath, "SKILL.md")
// 尝试其他可能的文件名
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
alternatives := []string{
filepath.Join(skillPath, "skill.md"),
filepath.Join(skillPath, "README.md"),
filepath.Join(skillPath, "readme.md"),
}
for _, alt := range alternatives {
if _, err := os.Stat(alt); err == nil {
skillFile = alt
break
}
}
}
fileInfo, _ := os.Stat(skillFile)
var fileSize int64
var modTime string
if fileInfo != nil {
fileSize = fileInfo.Size()
modTime = fileInfo.ModTime().Format("2006-01-02 15:04:05")
}
skillInfo := map[string]interface{}{
"name": skill.Name,
"description": skill.Description,
"path": skill.Path,
"file_size": fileSize,
"mod_time": modTime,
}
allSkillsInfo = append(allSkillsInfo, skillInfo)
}
// 如果有搜索关键词,进行过滤
filteredSkillsInfo := allSkillsInfo
if searchKeyword != "" {
keywordLower := strings.ToLower(searchKeyword)
filteredSkillsInfo = make([]map[string]interface{}, 0)
for _, skillInfo := range allSkillsInfo {
name := strings.ToLower(fmt.Sprintf("%v", skillInfo["name"]))
description := strings.ToLower(fmt.Sprintf("%v", skillInfo["description"]))
path := strings.ToLower(fmt.Sprintf("%v", skillInfo["path"]))
if strings.Contains(name, keywordLower) ||
strings.Contains(description, keywordLower) ||
strings.Contains(path, keywordLower) {
filteredSkillsInfo = append(filteredSkillsInfo, skillInfo)
}
}
}
// 分页参数
limit := 20 // 默认每页20条
offset := 0
if limitStr := c.Query("limit"); limitStr != "" {
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
// 允许更大的limit用于搜索场景,但设置一个合理的上限(10000)
if parsed <= 10000 {
limit = parsed
} else {
limit = 10000
}
}
}
if offsetStr := c.Query("offset"); offsetStr != "" {
if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 {
offset = parsed
}
}
// 计算分页范围
total := len(filteredSkillsInfo)
start := offset
end := offset + limit
if start > total {
start = total
}
if end > total {
end = total
}
// 获取当前页的skill列表
var paginatedSkillsInfo []map[string]interface{}
if start < end {
paginatedSkillsInfo = filteredSkillsInfo[start:end]
} else {
paginatedSkillsInfo = []map[string]interface{}{}
}
c.JSON(http.StatusOK, gin.H{
"skills": paginatedSkillsInfo,
"total": total,
"limit": limit,
"offset": offset,
})
}
// GetSkill 获取单个skill的详细信息
func (h *SkillsHandler) GetSkill(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
skill, err := h.manager.LoadSkill(skillName)
if err != nil {
h.logger.Warn("加载skill失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在: " + err.Error()})
return
}
// 获取文件信息
skillPath := skill.Path
skillFile := filepath.Join(skillPath, "SKILL.md")
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
alternatives := []string{
filepath.Join(skillPath, "skill.md"),
filepath.Join(skillPath, "README.md"),
filepath.Join(skillPath, "readme.md"),
}
for _, alt := range alternatives {
if _, err := os.Stat(alt); err == nil {
skillFile = alt
break
}
}
}
fileInfo, _ := os.Stat(skillFile)
var fileSize int64
var modTime string
if fileInfo != nil {
fileSize = fileInfo.Size()
modTime = fileInfo.ModTime().Format("2006-01-02 15:04:05")
}
c.JSON(http.StatusOK, gin.H{
"skill": map[string]interface{}{
"name": skill.Name,
"description": skill.Description,
"content": skill.Content,
"path": skill.Path,
"file_size": fileSize,
"mod_time": modTime,
},
})
}
// GetSkillBoundRoles 获取绑定指定skill的角色列表
func (h *SkillsHandler) GetSkillBoundRoles(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
boundRoles := h.getRolesBoundToSkill(skillName)
c.JSON(http.StatusOK, gin.H{
"skill": skillName,
"bound_roles": boundRoles,
"bound_count": len(boundRoles),
})
}
// getRolesBoundToSkill 获取绑定指定skill的角色列表(不修改配置)
func (h *SkillsHandler) getRolesBoundToSkill(skillName string) []string {
if h.config.Roles == nil {
return []string{}
}
boundRoles := make([]string, 0)
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 检查角色的Skills列表中是否包含该skill
if len(role.Skills) > 0 {
for _, skill := range role.Skills {
if skill == skillName {
boundRoles = append(boundRoles, roleName)
break
}
}
}
}
return boundRoles
}
// CreateSkill 创建新skill
func (h *SkillsHandler) CreateSkill(c *gin.Context) {
var req struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Content string `json:"content" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
// 验证skill名称(只允许字母、数字、连字符和下划线)
if !isValidSkillName(req.Name) {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称只能包含字母、数字、连字符和下划线"})
return
}
// 获取skills目录
skillsDir := h.config.SkillsDir
if skillsDir == "" {
skillsDir = "skills"
}
configDir := filepath.Dir(h.configPath)
if !filepath.IsAbs(skillsDir) {
skillsDir = filepath.Join(configDir, skillsDir)
}
// 创建skill目录
skillDir := filepath.Join(skillsDir, req.Name)
if err := os.MkdirAll(skillDir, 0755); err != nil {
h.logger.Error("创建skill目录失败", zap.String("skill", req.Name), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建skill目录失败: " + err.Error()})
return
}
// 检查是否已存在
skillFile := filepath.Join(skillDir, "SKILL.md")
if _, err := os.Stat(skillFile); err == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill已存在"})
return
}
// 构建SKILL.md内容
var content strings.Builder
content.WriteString("---\n")
content.WriteString(fmt.Sprintf("name: %s\n", req.Name))
if req.Description != "" {
// 如果描述包含特殊字符,需要加引号
desc := req.Description
if strings.Contains(desc, ":") || strings.Contains(desc, "\n") {
desc = fmt.Sprintf(`"%s"`, strings.ReplaceAll(desc, `"`, `\"`))
}
content.WriteString(fmt.Sprintf("description: %s\n", desc))
}
content.WriteString("version: 1.0.0\n")
content.WriteString("---\n\n")
content.WriteString(req.Content)
// 写入文件
if err := os.WriteFile(skillFile, []byte(content.String()), 0644); err != nil {
h.logger.Error("创建skill文件失败", zap.String("skill", req.Name), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建skill文件失败: " + err.Error()})
return
}
h.manager.InvalidateSkill(req.Name)
h.logger.Info("创建skill成功", zap.String("skill", req.Name))
c.JSON(http.StatusOK, gin.H{
"message": "skill已创建",
"skill": map[string]interface{}{
"name": req.Name,
"path": skillDir,
},
})
}
// UpdateSkill 更新skill
func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
var req struct {
Description string `json:"description"`
Content string `json:"content" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
// 获取skills目录
skillsDir := h.config.SkillsDir
if skillsDir == "" {
skillsDir = "skills"
}
configDir := filepath.Dir(h.configPath)
if !filepath.IsAbs(skillsDir) {
skillsDir = filepath.Join(configDir, skillsDir)
}
// 查找skill文件
skillDir := filepath.Join(skillsDir, skillName)
skillFile := filepath.Join(skillDir, "SKILL.md")
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
alternatives := []string{
filepath.Join(skillDir, "skill.md"),
filepath.Join(skillDir, "README.md"),
filepath.Join(skillDir, "readme.md"),
}
found := false
for _, alt := range alternatives {
if _, err := os.Stat(alt); err == nil {
skillFile = alt
found = true
break
}
}
if !found {
c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在"})
return
}
}
// 读取现有文件以保留front matter中的name
existingContent, err := os.ReadFile(skillFile)
if err != nil {
h.logger.Error("读取skill文件失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "读取skill文件失败: " + err.Error()})
return
}
// 解析现有内容,提取name
existingName := skillName
contentStr := string(existingContent)
if strings.HasPrefix(contentStr, "---") {
parts := strings.SplitN(contentStr, "---", 3)
if len(parts) >= 2 {
frontMatter := parts[1]
lines := strings.Split(frontMatter, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "name:") {
name := strings.TrimSpace(strings.TrimPrefix(line, "name:"))
name = strings.Trim(name, `"'`)
if name != "" {
existingName = name
}
break
}
}
}
}
// 构建新的SKILL.md内容
var newContent strings.Builder
newContent.WriteString("---\n")
newContent.WriteString(fmt.Sprintf("name: %s\n", existingName))
if req.Description != "" {
// 如果描述包含特殊字符,需要加引号
desc := req.Description
if strings.Contains(desc, ":") || strings.Contains(desc, "\n") {
desc = fmt.Sprintf(`"%s"`, strings.ReplaceAll(desc, `"`, `\"`))
}
newContent.WriteString(fmt.Sprintf("description: %s\n", desc))
}
newContent.WriteString("version: 1.0.0\n")
newContent.WriteString("---\n\n")
newContent.WriteString(req.Content)
// 写入文件(统一使用SKILL.md)
targetFile := filepath.Join(skillDir, "SKILL.md")
if err := os.WriteFile(targetFile, []byte(newContent.String()), 0644); err != nil {
h.logger.Error("更新skill文件失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新skill文件失败: " + err.Error()})
return
}
// 如果原文件不是SKILL.md,删除旧文件
if skillFile != targetFile {
os.Remove(skillFile)
}
h.manager.InvalidateSkill(skillName)
h.logger.Info("更新skill成功", zap.String("skill", skillName))
c.JSON(http.StatusOK, gin.H{
"message": "skill已更新",
})
}
// DeleteSkill 删除skill
func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
// 检查是否有角色绑定了该skill,如果有则自动移除绑定
affectedRoles := h.removeSkillFromRoles(skillName)
if len(affectedRoles) > 0 {
h.logger.Info("从角色中移除skill绑定",
zap.String("skill", skillName),
zap.Strings("roles", affectedRoles))
}
// 获取skills目录
skillsDir := h.config.SkillsDir
if skillsDir == "" {
skillsDir = "skills"
}
configDir := filepath.Dir(h.configPath)
if !filepath.IsAbs(skillsDir) {
skillsDir = filepath.Join(configDir, skillsDir)
}
// 删除skill目录
skillDir := filepath.Join(skillsDir, skillName)
if err := os.RemoveAll(skillDir); err != nil {
h.logger.Error("删除skill失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除skill失败: " + err.Error()})
return
}
h.manager.InvalidateSkill(skillName)
responseMsg := "skill已删除"
if len(affectedRoles) > 0 {
responseMsg = fmt.Sprintf("skill已删除,已自动从 %d 个角色中移除绑定: %s",
len(affectedRoles), strings.Join(affectedRoles, ", "))
}
h.logger.Info("删除skill成功", zap.String("skill", skillName))
c.JSON(http.StatusOK, gin.H{
"message": responseMsg,
"affected_roles": affectedRoles,
})
}
// GetSkillStats 获取skills调用统计信息
func (h *SkillsHandler) GetSkillStats(c *gin.Context) {
skillList, err := h.manager.ListSkills()
if err != nil {
h.logger.Error("获取skills列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 获取skills目录
skillsDir := h.config.SkillsDir
if skillsDir == "" {
skillsDir = "skills"
}
configDir := filepath.Dir(h.configPath)
if !filepath.IsAbs(skillsDir) {
skillsDir = filepath.Join(configDir, skillsDir)
}
// 从数据库加载调用统计
var skillStatsMap map[string]*database.SkillStats
if h.db != nil {
dbStats, err := h.db.LoadSkillStats()
if err != nil {
h.logger.Warn("从数据库加载Skills统计信息失败", zap.Error(err))
skillStatsMap = make(map[string]*database.SkillStats)
} else {
skillStatsMap = dbStats
}
} else {
skillStatsMap = make(map[string]*database.SkillStats)
}
// 构建统计信息(包含所有skills,即使没有调用记录)
statsList := make([]map[string]interface{}, 0, len(skillList))
totalCalls := 0
totalSuccess := 0
totalFailed := 0
for _, skillName := range skillList {
stat, exists := skillStatsMap[skillName]
if !exists {
stat = &database.SkillStats{
SkillName: skillName,
TotalCalls: 0,
SuccessCalls: 0,
FailedCalls: 0,
}
}
totalCalls += stat.TotalCalls
totalSuccess += stat.SuccessCalls
totalFailed += stat.FailedCalls
lastCallTimeStr := ""
if stat.LastCallTime != nil {
lastCallTimeStr = stat.LastCallTime.Format("2006-01-02 15:04:05")
}
statsList = append(statsList, map[string]interface{}{
"skill_name": stat.SkillName,
"total_calls": stat.TotalCalls,
"success_calls": stat.SuccessCalls,
"failed_calls": stat.FailedCalls,
"last_call_time": lastCallTimeStr,
})
}
c.JSON(http.StatusOK, gin.H{
"total_skills": len(skillList),
"total_calls": totalCalls,
"total_success": totalSuccess,
"total_failed": totalFailed,
"skills_dir": skillsDir,
"stats": statsList,
})
}
// ClearSkillStats 清空所有Skills统计信息
func (h *SkillsHandler) ClearSkillStats(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"})
return
}
if err := h.db.ClearSkillStats(); err != nil {
h.logger.Error("清空Skills统计信息失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()})
return
}
h.logger.Info("已清空所有Skills统计信息")
c.JSON(http.StatusOK, gin.H{
"message": "已清空所有Skills统计信息",
})
}
// ClearSkillStatsByName 清空指定skill的统计信息
func (h *SkillsHandler) ClearSkillStatsByName(c *gin.Context) {
skillName := c.Param("name")
if skillName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
return
}
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"})
return
}
if err := h.db.ClearSkillStatsByName(skillName); err != nil {
h.logger.Error("清空指定skill统计信息失败", zap.String("skill", skillName), zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()})
return
}
h.logger.Info("已清空指定skill统计信息", zap.String("skill", skillName))
c.JSON(http.StatusOK, gin.H{
"message": fmt.Sprintf("已清空skill '%s' 的统计信息", skillName),
})
}
// removeSkillFromRoles 从所有角色中移除指定的skill绑定
// 返回受影响角色名称列表
func (h *SkillsHandler) removeSkillFromRoles(skillName string) []string {
if h.config.Roles == nil {
return []string{}
}
affectedRoles := make([]string, 0)
rolesToUpdate := make(map[string]config.RoleConfig)
// 遍历所有角色,查找并移除skill绑定
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 检查角色的Skills列表中是否包含要删除的skill
if len(role.Skills) > 0 {
updated := false
newSkills := make([]string, 0, len(role.Skills))
for _, skill := range role.Skills {
if skill != skillName {
newSkills = append(newSkills, skill)
} else {
updated = true
}
}
if updated {
role.Skills = newSkills
rolesToUpdate[roleName] = role
affectedRoles = append(affectedRoles, roleName)
}
}
}
// 如果有角色需要更新,保存到文件
if len(rolesToUpdate) > 0 {
// 更新内存中的配置
for roleName, role := range rolesToUpdate {
h.config.Roles[roleName] = role
}
// 保存更新后的角色配置到文件
if err := h.saveRolesConfig(); err != nil {
h.logger.Error("保存角色配置失败", zap.Error(err))
}
}
return affectedRoles
}
// saveRolesConfig 保存角色配置到文件(从SkillsHandler调用)
func (h *SkillsHandler) saveRolesConfig() error {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 确保目录存在
if err := os.MkdirAll(rolesDir, 0755); err != nil {
return fmt.Errorf("创建角色目录失败: %w", err)
}
// 保存每个角色到独立的文件
if h.config.Roles != nil {
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
safeFileName := sanitizeRoleFileName(role.Name)
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
// 将角色配置序列化为YAML
roleData, err := yaml.Marshal(&role)
if err != nil {
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
continue
}
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
roleDataStr := string(roleData)
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
roleData = []byte(roleDataStr)
}
// 写入文件
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
continue
}
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
}
}
return nil
}
// sanitizeRoleFileName 将角色名称转换为安全的文件名
func sanitizeRoleFileName(name string) string {
// 替换可能不安全的字符
replacer := map[rune]string{
'/': "_",
'\\': "_",
':': "_",
'*': "_",
'?': "_",
'"': "_",
'<': "_",
'>': "_",
'|': "_",
' ': "_",
}
var result []rune
for _, r := range name {
if replacement, ok := replacer[r]; ok {
result = append(result, []rune(replacement)...)
} else {
result = append(result, r)
}
}
fileName := string(result)
// 如果文件名为空,使用默认名称
if fileName == "" {
fileName = "role"
}
return fileName
}
// isValidSkillName 验证skill名称是否有效
func isValidSkillName(name string) bool {
if name == "" || len(name) > 100 {
return false
}
// 只允许字母、数字、连字符和下划线
for _, r := range name {
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' || r == '_') {
return false
}
}
return true
}
-58
View File
@@ -1,58 +0,0 @@
package handler
import (
"fmt"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// sseInterval is how often we write on long SSE streams. Shorter intervals help NATs and
// some proxies that treat connections as idle; 10s is a reasonable balance with traffic.
const sseKeepaliveInterval = 10 * time.Second
// sseKeepalive sends periodic SSE traffic so proxies (e.g. nginx proxy_read_timeout), NATs,
// and load balancers do not close long-running streams. Some intermediaries ignore comment-only
// lines, so we send both a comment and a minimal data frame (type heartbeat) per tick.
//
// writeMu must be the same mutex used by sendEvent for this request: concurrent writes to
// http.ResponseWriter break chunked transfer encoding (browser: net::ERR_INVALID_CHUNKED_ENCODING).
func sseKeepalive(c *gin.Context, stop <-chan struct{}, writeMu *sync.Mutex) {
if writeMu == nil {
return
}
ticker := time.NewTicker(sseKeepaliveInterval)
defer ticker.Stop()
for {
select {
case <-stop:
return
case <-c.Request.Context().Done():
return
case <-ticker.C:
select {
case <-stop:
return
case <-c.Request.Context().Done():
return
default:
}
writeMu.Lock()
if _, err := fmt.Fprintf(c.Writer, ": keepalive\n\n"); err != nil {
writeMu.Unlock()
return
}
// data: frame so strict proxies still see downstream bytes (comments alone may not reset timers)
if _, err := fmt.Fprintf(c.Writer, `data: {"type":"heartbeat"}`+"\n\n"); err != nil {
writeMu.Unlock()
return
}
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
writeMu.Unlock()
}
}
}
-276
View File
@@ -1,276 +0,0 @@
package handler
import (
"context"
"errors"
"sync"
"time"
)
// ErrTaskCancelled 用户取消任务的错误
var ErrTaskCancelled = errors.New("agent task cancelled by user")
// ErrTaskAlreadyRunning 会话已有任务正在执行
var ErrTaskAlreadyRunning = errors.New("agent task already running for conversation")
// AgentTask 描述正在运行的Agent任务
type AgentTask struct {
ConversationID string `json:"conversationId"`
Message string `json:"message,omitempty"`
StartedAt time.Time `json:"startedAt"`
Status string `json:"status"`
CancellingAt time.Time `json:"-"` // 进入 cancelling 状态的时间,用于清理长时间卡住的任务
cancel func(error)
}
// CompletedTask 已完成的任务(用于历史记录)
type CompletedTask struct {
ConversationID string `json:"conversationId"`
Message string `json:"message,omitempty"`
StartedAt time.Time `json:"startedAt"`
CompletedAt time.Time `json:"completedAt"`
Status string `json:"status"`
}
// AgentTaskManager 管理正在运行的Agent任务
type AgentTaskManager struct {
mu sync.RWMutex
tasks map[string]*AgentTask
completedTasks []*CompletedTask // 最近完成的任务历史
maxHistorySize int // 最大历史记录数
historyRetention time.Duration // 历史记录保留时间
}
const (
// cancellingStuckThreshold 处于「取消中」超过此时长则强制从运行列表移除。正常取消会在当前步骤内返回,
// 超过则视为卡住,尽快释放会话。常见做法多为 30–60s 内释放。
cancellingStuckThreshold = 45 * time.Second
// cancellingStuckThresholdLegacy 未记录 CancellingAt 时用 StartedAt 判断的兜底时长
cancellingStuckThresholdLegacy = 2 * time.Minute
cleanupInterval = 15 * time.Second // 与上面阈值配合,最长约 60s 内移除
)
// NewAgentTaskManager 创建任务管理器
func NewAgentTaskManager() *AgentTaskManager {
m := &AgentTaskManager{
tasks: make(map[string]*AgentTask),
completedTasks: make([]*CompletedTask, 0),
maxHistorySize: 50, // 最多保留50条历史记录
historyRetention: 24 * time.Hour, // 保留24小时
}
go m.runStuckCancellingCleanup()
return m
}
// runStuckCancellingCleanup 定期将长时间处于「取消中」的任务强制结束,避免卡住无法发新消息
func (m *AgentTaskManager) runStuckCancellingCleanup() {
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
for range ticker.C {
m.cleanupStuckCancelling()
}
}
func (m *AgentTaskManager) cleanupStuckCancelling() {
m.mu.Lock()
var toFinish []string
now := time.Now()
for id, task := range m.tasks {
if task.Status != "cancelling" {
continue
}
var elapsed time.Duration
if !task.CancellingAt.IsZero() {
elapsed = now.Sub(task.CancellingAt)
if elapsed < cancellingStuckThreshold {
continue
}
} else {
elapsed = now.Sub(task.StartedAt)
if elapsed < cancellingStuckThresholdLegacy {
continue
}
}
toFinish = append(toFinish, id)
}
m.mu.Unlock()
for _, id := range toFinish {
m.FinishTask(id, "cancelled")
}
}
// StartTask 注册并开始一个新的任务
func (m *AgentTaskManager) StartTask(conversationID, message string, cancel context.CancelCauseFunc) (*AgentTask, error) {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.tasks[conversationID]; exists {
return nil, ErrTaskAlreadyRunning
}
task := &AgentTask{
ConversationID: conversationID,
Message: message,
StartedAt: time.Now(),
Status: "running",
cancel: func(err error) {
if cancel != nil {
cancel(err)
}
},
}
m.tasks[conversationID] = task
return task, nil
}
// CancelTask 取消指定会话的任务。若任务已在取消中,仍返回 (true, nil) 以便接口幂等、前端不报错。
func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool, error) {
m.mu.Lock()
task, exists := m.tasks[conversationID]
if !exists {
m.mu.Unlock()
return false, nil
}
// 如果已经处于取消流程,视为成功(幂等),避免前端重复点击报「未找到任务」
if task.Status == "cancelling" {
m.mu.Unlock()
return true, nil
}
task.Status = "cancelling"
task.CancellingAt = time.Now()
cancel := task.cancel
m.mu.Unlock()
if cause == nil {
cause = ErrTaskCancelled
}
if cancel != nil {
cancel(cause)
}
return true, nil
}
// UpdateTaskStatus 更新任务状态但不删除任务(用于在发送事件前更新状态)
func (m *AgentTaskManager) UpdateTaskStatus(conversationID string, status string) {
m.mu.Lock()
defer m.mu.Unlock()
task, exists := m.tasks[conversationID]
if !exists {
return
}
if status != "" {
task.Status = status
}
}
// FinishTask 完成任务并从管理器中移除
func (m *AgentTaskManager) FinishTask(conversationID string, finalStatus string) {
m.mu.Lock()
defer m.mu.Unlock()
task, exists := m.tasks[conversationID]
if !exists {
return
}
if finalStatus != "" {
task.Status = finalStatus
}
// 保存到历史记录
completedTask := &CompletedTask{
ConversationID: task.ConversationID,
Message: task.Message,
StartedAt: task.StartedAt,
CompletedAt: time.Now(),
Status: finalStatus,
}
// 添加到历史记录
m.completedTasks = append(m.completedTasks, completedTask)
// 清理过期和过多的历史记录
m.cleanupHistory()
// 从运行任务中移除
delete(m.tasks, conversationID)
}
// cleanupHistory 清理过期的历史记录
func (m *AgentTaskManager) cleanupHistory() {
now := time.Now()
cutoffTime := now.Add(-m.historyRetention)
// 过滤掉过期的记录
validTasks := make([]*CompletedTask, 0, len(m.completedTasks))
for _, task := range m.completedTasks {
if task.CompletedAt.After(cutoffTime) {
validTasks = append(validTasks, task)
}
}
// 如果仍然超过最大数量,只保留最新的
if len(validTasks) > m.maxHistorySize {
// 按完成时间排序,保留最新的
// 由于是追加的,最新的在最后,所以直接取最后N个
start := len(validTasks) - m.maxHistorySize
validTasks = validTasks[start:]
}
m.completedTasks = validTasks
}
// GetActiveTasks 返回所有正在运行的任务
func (m *AgentTaskManager) GetActiveTasks() []*AgentTask {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]*AgentTask, 0, len(m.tasks))
for _, task := range m.tasks {
result = append(result, &AgentTask{
ConversationID: task.ConversationID,
Message: task.Message,
StartedAt: task.StartedAt,
Status: task.Status,
})
}
return result
}
// GetCompletedTasks 返回最近完成的任务历史
func (m *AgentTaskManager) GetCompletedTasks() []*CompletedTask {
m.mu.RLock()
defer m.mu.RUnlock()
// 清理过期记录(只读锁,不影响其他操作)
// 注意:这里不能直接调用cleanupHistory,因为需要写锁
// 所以返回时过滤过期记录
now := time.Now()
cutoffTime := now.Add(-m.historyRetention)
result := make([]*CompletedTask, 0, len(m.completedTasks))
for _, task := range m.completedTasks {
if task.CompletedAt.After(cutoffTime) {
result = append(result, task)
}
}
// 按完成时间倒序排序(最新的在前)
// 由于是追加的,最新的在最后,需要反转
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
result[i], result[j] = result[j], result[i]
}
// 限制返回数量
if len(result) > m.maxHistorySize {
result = result[:m.maxHistorySize]
}
return result
}
-257
View File
@@ -1,257 +0,0 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
const (
terminalMaxCommandLen = 4096
terminalMaxOutputLen = 256 * 1024 // 256KB
terminalTimeout = 30 * time.Minute
)
// TerminalHandler 处理系统设置中的终端命令执行
type TerminalHandler struct {
logger *zap.Logger
}
// maskTerminalCommand 对可能包含敏感信息的终端命令做脱敏,避免在日志中直接记录密码等内容
func maskTerminalCommand(cmd string) string {
trimmed := strings.TrimSpace(cmd)
lower := strings.ToLower(trimmed)
if strings.Contains(lower, "sudo") || strings.Contains(lower, "password") {
return "[masked sensitive terminal command]"
}
if len(trimmed) > 256 {
return trimmed[:256] + "..."
}
return trimmed
}
// NewTerminalHandler 创建终端处理器
func NewTerminalHandler(logger *zap.Logger) *TerminalHandler {
return &TerminalHandler{logger: logger}
}
// RunCommandRequest 执行命令请求
type RunCommandRequest struct {
Command string `json:"command"`
Shell string `json:"shell,omitempty"`
Cwd string `json:"cwd,omitempty"`
}
// RunCommandResponse 执行命令响应
type RunCommandResponse struct {
Stdout string `json:"stdout"`
Stderr string `json:"stderr"`
ExitCode int `json:"exit_code"`
Error string `json:"error,omitempty"`
}
// RunCommand 执行终端命令(需登录)
func (h *TerminalHandler) RunCommand(c *gin.Context) {
var req RunCommandRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体无效,需要 command 字段"})
return
}
cmdStr := strings.TrimSpace(req.Command)
if cmdStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "command 不能为空"})
return
}
if len(cmdStr) > terminalMaxCommandLen {
c.JSON(http.StatusBadRequest, gin.H{"error": "命令过长"})
return
}
shell := req.Shell
if shell == "" {
if runtime.GOOS == "windows" {
shell = "cmd"
} else {
shell = "sh"
}
}
ctx, cancel := context.WithTimeout(c.Request.Context(), terminalTimeout)
defer cancel()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr)
} else {
cmd = exec.CommandContext(ctx, shell, "-c", cmdStr)
// 无 TTY 时设置 COLUMNS/TERM,使 ping 等工具的 usage 排版与真实终端一致
cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color")
}
if req.Cwd != "" {
absCwd, err := filepath.Abs(req.Cwd)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录无效"})
return
}
cur, _ := os.Getwd()
curAbs, _ := filepath.Abs(cur)
rel, err := filepath.Rel(curAbs, absCwd)
if err != nil || strings.HasPrefix(rel, "..") || rel == ".." {
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录必须在当前进程目录下"})
return
}
cmd.Dir = absCwd
}
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
stdoutBytes := stdout.Bytes()
stderrBytes := stderr.Bytes()
// 限制输出长度,防止内存占用过大(复制后截断,避免修改原 buffer)
truncSuffix := []byte("\n...(输出已截断)\n")
if len(stdoutBytes) > terminalMaxOutputLen {
tmp := make([]byte, terminalMaxOutputLen+len(truncSuffix))
n := copy(tmp, stdoutBytes[:terminalMaxOutputLen])
copy(tmp[n:], truncSuffix)
stdoutBytes = tmp
}
if len(stderrBytes) > terminalMaxOutputLen {
tmp := make([]byte, terminalMaxOutputLen+len(truncSuffix))
n := copy(tmp, stderrBytes[:terminalMaxOutputLen])
copy(tmp[n:], truncSuffix)
stderrBytes = tmp
}
exitCode := 0
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else {
exitCode = -1
}
if ctx.Err() == context.DeadlineExceeded {
so := strings.ReplaceAll(string(stdoutBytes), "\r\n", "\n")
so = strings.ReplaceAll(so, "\r", "\n")
se := strings.ReplaceAll(string(stderrBytes), "\r\n", "\n")
se = strings.ReplaceAll(se, "\r", "\n")
resp := RunCommandResponse{
Stdout: so,
Stderr: se,
ExitCode: -1,
Error: "命令执行超时(" + terminalTimeout.String() + "",
}
c.JSON(http.StatusOK, resp)
return
}
h.logger.Debug("终端命令执行异常", zap.String("command", maskTerminalCommand(cmdStr)), zap.Error(err))
}
// 统一为 \n,避免前端因 \r 出现错位/对角线排版
stdoutStr := strings.ReplaceAll(string(stdoutBytes), "\r\n", "\n")
stdoutStr = strings.ReplaceAll(stdoutStr, "\r", "\n")
stderrStr := strings.ReplaceAll(string(stderrBytes), "\r\n", "\n")
stderrStr = strings.ReplaceAll(stderrStr, "\r", "\n")
resp := RunCommandResponse{
Stdout: stdoutStr,
Stderr: stderrStr,
ExitCode: exitCode,
}
if err != nil && exitCode != 0 {
resp.Error = err.Error()
}
c.JSON(http.StatusOK, resp)
}
// streamEvent SSE 事件
type streamEvent struct {
T string `json:"t"` // "out" | "err" | "exit"
D string `json:"d,omitempty"`
C int `json:"c"` // exit code(不用 omitempty,否则 0 不序列化导致前端显示 [exit undefined]
}
// RunCommandStream 流式执行命令,输出实时推送到前端(SSE)
func (h *TerminalHandler) RunCommandStream(c *gin.Context) {
var req RunCommandRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体无效,需要 command 字段"})
return
}
cmdStr := strings.TrimSpace(req.Command)
if cmdStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "command 不能为空"})
return
}
if len(cmdStr) > terminalMaxCommandLen {
c.JSON(http.StatusBadRequest, gin.H{"error": "命令过长"})
return
}
shell := req.Shell
if shell == "" {
if runtime.GOOS == "windows" {
shell = "cmd"
} else {
shell = "sh"
}
}
ctx, cancel := context.WithTimeout(c.Request.Context(), terminalTimeout)
defer cancel()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.CommandContext(ctx, "cmd", "/c", cmdStr)
} else {
cmd = exec.CommandContext(ctx, shell, "-c", cmdStr)
cmd.Env = append(os.Environ(), "COLUMNS=256", "LINES=40", "TERM=xterm-256color")
}
if req.Cwd != "" {
absCwd, err := filepath.Abs(req.Cwd)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录无效"})
return
}
cur, _ := os.Getwd()
curAbs, _ := filepath.Abs(cur)
rel, err := filepath.Rel(curAbs, absCwd)
if err != nil || strings.HasPrefix(rel, "..") || rel == ".." {
c.JSON(http.StatusBadRequest, gin.H{"error": "工作目录必须在当前进程目录下"})
return
}
cmd.Dir = absCwd
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
cancel()
return
}
sendEvent := func(ev streamEvent) {
body, _ := json.Marshal(ev)
c.SSEvent("", string(body))
flusher.Flush()
}
runCommandStreamImpl(cmd, sendEvent, ctx)
}
-46
View File
@@ -1,46 +0,0 @@
//go:build !windows
package handler
import (
"bufio"
"context"
"os/exec"
"strings"
"github.com/creack/pty"
)
const ptyCols = 256
const ptyRows = 40
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
if err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return
}
defer ptmx.Close()
normalize := func(s string) string {
s = strings.ReplaceAll(s, "\r\n", "\n")
return strings.ReplaceAll(s, "\r", "\n")
}
sc := bufio.NewScanner(ptmx)
for sc.Scan() {
sendEvent(streamEvent{T: "out", D: normalize(sc.Text())})
}
exitCode := 0
if err := cmd.Wait(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else {
exitCode = -1
}
}
if ctx.Err() == context.DeadlineExceeded {
exitCode = -1
}
sendEvent(streamEvent{T: "exit", C: exitCode})
}
@@ -1,65 +0,0 @@
//go:build windows
package handler
import (
"bufio"
"context"
"os/exec"
"strings"
"sync"
)
// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return
}
if err := cmd.Start(); err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return
}
normalize := func(s string) string {
s = strings.ReplaceAll(s, "\r\n", "\n")
return strings.ReplaceAll(s, "\r", "\n")
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
sc := bufio.NewScanner(stdoutPipe)
for sc.Scan() {
sendEvent(streamEvent{T: "out", D: normalize(sc.Text())})
}
}()
go func() {
defer wg.Done()
sc := bufio.NewScanner(stderrPipe)
for sc.Scan() {
sendEvent(streamEvent{T: "err", D: normalize(sc.Text())})
}
}()
wg.Wait()
exitCode := 0
if err := cmd.Wait(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else {
exitCode = -1
}
}
if ctx.Err() == context.DeadlineExceeded {
exitCode = -1
}
sendEvent(streamEvent{T: "exit", C: exitCode})
}
-112
View File
@@ -1,112 +0,0 @@
//go:build !windows
package handler
import (
"encoding/json"
"net/http"
"os"
"os/exec"
"time"
"github.com/creack/pty"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
// terminalResize is sent by the frontend when the xterm.js terminal is resized.
type terminalResize struct {
Type string `json:"type"`
Cols uint16 `json:"cols"`
Rows uint16 `json:"rows"`
}
// wsUpgrader 仅用于系统设置中的终端 WebSocket,会复用已有的登录保护(JWT 中间件在上层路由组)
var wsUpgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
// 由于已在 Gin 路由层做了认证,这里放宽 Origin,方便在同一域名下通过 HTTPS/WSS 访问
return true
},
}
// RunCommandWS 提供真正交互式 Shell:基于 WebSocket + PTY 的长会话
// 前端建立 WebSocket 连接后,所有键盘输入都会透传到 Shell,Shell 的输出也会实时写回前端。
func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
conn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return
}
defer conn.Close()
// 启动交互式 Shell,这里优先使用 bash,找不到则退回 sh
shell := "bash"
if _, err := exec.LookPath(shell); err != nil {
shell = "sh"
}
cmd := exec.Command(shell)
cmd.Env = append(os.Environ(),
"COLUMNS=80",
"LINES=24",
"TERM=xterm-256color",
)
// Use 80x24 as a safe default; the frontend will send the actual size immediately after connecting.
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: 80, Rows: 24})
if err != nil {
return
}
defer ptmx.Close()
// Shell -> WebSocket:将 PTY 输出实时发给前端
doneChan := make(chan struct{})
go func() {
buf := make([]byte, 4096)
for {
n, err := ptmx.Read(buf)
if n > 0 {
_ = conn.WriteMessage(websocket.BinaryMessage, buf[:n])
}
if err != nil {
break
}
}
close(doneChan)
}()
// WebSocket -> Shell:将前端输入写入 PTY(包括 sudo 密码、Ctrl+C 等)
conn.SetReadLimit(64 * 1024)
_ = conn.SetReadDeadline(time.Now().Add(terminalTimeout))
conn.SetPongHandler(func(string) error {
_ = conn.SetReadDeadline(time.Now().Add(terminalTimeout))
return nil
})
for {
msgType, data, err := conn.ReadMessage()
if err != nil {
_ = cmd.Process.Kill()
break
}
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
continue
}
if len(data) == 0 {
continue
}
// Check if this is a resize message (JSON with type:"resize")
if msgType == websocket.TextMessage && len(data) > 0 && data[0] == '{' {
var resize terminalResize
if json.Unmarshal(data, &resize) == nil && resize.Type == "resize" && resize.Cols > 0 && resize.Rows > 0 {
_ = pty.Setsize(ptmx, &pty.Winsize{Cols: resize.Cols, Rows: resize.Rows})
continue
}
}
if _, err := ptmx.Write(data); err != nil {
_ = cmd.Process.Kill()
break
}
}
<-doneChan
}
-263
View File
@@ -1,263 +0,0 @@
package handler
import (
"net/http"
"strconv"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// VulnerabilityHandler 漏洞处理器
type VulnerabilityHandler struct {
db *database.DB
logger *zap.Logger
}
// NewVulnerabilityHandler 创建新的漏洞处理器
func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *VulnerabilityHandler {
return &VulnerabilityHandler{
db: db,
logger: logger,
}
}
// CreateVulnerabilityRequest 创建漏洞请求
type CreateVulnerabilityRequest struct {
ConversationID string `json:"conversation_id" binding:"required"`
Title string `json:"title" binding:"required"`
Description string `json:"description"`
Severity string `json:"severity" binding:"required"`
Status string `json:"status"`
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
}
// CreateVulnerability 创建漏洞
func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
var req CreateVulnerabilityRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
vuln := &database.Vulnerability{
ConversationID: req.ConversationID,
Title: req.Title,
Description: req.Description,
Severity: req.Severity,
Status: req.Status,
Type: req.Type,
Target: req.Target,
Proof: req.Proof,
Impact: req.Impact,
Recommendation: req.Recommendation,
}
created, err := h.db.CreateVulnerability(vuln)
if err != nil {
h.logger.Error("创建漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, created)
}
// GetVulnerability 获取漏洞
func (h *VulnerabilityHandler) GetVulnerability(c *gin.Context) {
id := c.Param("id")
vuln, err := h.db.GetVulnerability(id)
if err != nil {
h.logger.Error("获取漏洞失败", zap.Error(err))
c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"})
return
}
c.JSON(http.StatusOK, vuln)
}
// ListVulnerabilitiesResponse 漏洞列表响应
type ListVulnerabilitiesResponse struct {
Vulnerabilities []*database.Vulnerability `json:"vulnerabilities"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
}
// ListVulnerabilities 列出漏洞
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "20")
offsetStr := c.DefaultQuery("offset", "0")
pageStr := c.Query("page")
id := c.Query("id")
conversationID := c.Query("conversation_id")
severity := c.Query("severity")
status := c.Query("status")
limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr)
page := 1
// 如果提供了page参数,优先使用page计算offset
if pageStr != "" {
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
page = p
offset = (page - 1) * limit
}
}
if limit <= 0 || limit > 100 {
limit = 20
}
if offset < 0 {
offset = 0
}
// 获取总数
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status)
if err != nil {
h.logger.Error("获取漏洞总数失败", zap.Error(err))
// 继续执行,使用0作为总数
total = 0
}
// 获取漏洞列表
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status)
if err != nil {
h.logger.Error("获取漏洞列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 计算总页数
totalPages := (total + limit - 1) / limit
if totalPages == 0 {
totalPages = 1
}
// 如果使用offset计算page,需要重新计算
if pageStr == "" {
page = (offset / limit) + 1
}
response := ListVulnerabilitiesResponse{
Vulnerabilities: vulnerabilities,
Total: total,
Page: page,
PageSize: limit,
TotalPages: totalPages,
}
c.JSON(http.StatusOK, response)
}
// UpdateVulnerabilityRequest 更新漏洞请求
type UpdateVulnerabilityRequest struct {
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"`
Status string `json:"status"`
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
}
// UpdateVulnerability 更新漏洞
func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
id := c.Param("id")
var req UpdateVulnerabilityRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 获取现有漏洞
existing, err := h.db.GetVulnerability(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "漏洞不存在"})
return
}
// 更新字段
if req.Title != "" {
existing.Title = req.Title
}
if req.Description != "" {
existing.Description = req.Description
}
if req.Severity != "" {
existing.Severity = req.Severity
}
if req.Status != "" {
existing.Status = req.Status
}
if req.Type != "" {
existing.Type = req.Type
}
if req.Target != "" {
existing.Target = req.Target
}
if req.Proof != "" {
existing.Proof = req.Proof
}
if req.Impact != "" {
existing.Impact = req.Impact
}
if req.Recommendation != "" {
existing.Recommendation = req.Recommendation
}
if err := h.db.UpdateVulnerability(id, existing); err != nil {
h.logger.Error("更新漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 返回更新后的漏洞
updated, err := h.db.GetVulnerability(id)
if err != nil {
h.logger.Error("获取更新后的漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, updated)
}
// DeleteVulnerability 删除漏洞
func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
id := c.Param("id")
if err := h.db.DeleteVulnerability(id); err != nil {
h.logger.Error("删除漏洞失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// GetVulnerabilityStats 获取漏洞统计
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
conversationID := c.Query("conversation_id")
stats, err := h.db.GetVulnerabilityStats(conversationID)
if err != nil {
h.logger.Error("获取漏洞统计失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, stats)
}
-706
View File
@@ -1,706 +0,0 @@
package handler
import (
"bytes"
"database/sql"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
// WebShellHandler 代理执行 WebShell 命令(类似冰蝎/蚁剑),避免前端跨域并统一构建请求
type WebShellHandler struct {
logger *zap.Logger
client *http.Client
db *database.DB
}
// NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用)
func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler {
return &WebShellHandler{
logger: logger,
client: &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{DisableKeepAlives: false},
},
db: db,
}
}
// CreateConnectionRequest 创建连接请求
type CreateConnectionRequest struct {
URL string `json:"url" binding:"required"`
Password string `json:"password"`
Type string `json:"type"`
Method string `json:"method"`
CmdParam string `json:"cmd_param"`
Remark string `json:"remark"`
}
// UpdateConnectionRequest 更新连接请求
type UpdateConnectionRequest struct {
URL string `json:"url" binding:"required"`
Password string `json:"password"`
Type string `json:"type"`
Method string `json:"method"`
CmdParam string `json:"cmd_param"`
Remark string `json:"remark"`
}
// ListConnections 列出所有 WebShell 连接(GET /api/webshell/connections
func (h *WebShellHandler) ListConnections(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
list, err := h.db.ListWebshellConnections()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if list == nil {
list = []database.WebShellConnection{}
}
c.JSON(http.StatusOK, list)
}
// CreateConnection 创建 WebShell 连接(POST /api/webshell/connections
func (h *WebShellHandler) CreateConnection(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
var req CreateConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.URL = strings.TrimSpace(req.URL)
if req.URL == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"})
return
}
if _, err := url.Parse(req.URL); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
return
}
method := strings.ToLower(strings.TrimSpace(req.Method))
if method != "get" && method != "post" {
method = "post"
}
shellType := strings.ToLower(strings.TrimSpace(req.Type))
if shellType == "" {
shellType = "php"
}
conn := &database.WebShellConnection{
ID: "ws_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:12],
URL: req.URL,
Password: strings.TrimSpace(req.Password),
Type: shellType,
Method: method,
CmdParam: strings.TrimSpace(req.CmdParam),
Remark: strings.TrimSpace(req.Remark),
CreatedAt: time.Now(),
}
if err := h.db.CreateWebshellConnection(conn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, conn)
}
// UpdateConnection 更新 WebShell 连接(PUT /api/webshell/connections/:id
func (h *WebShellHandler) UpdateConnection(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
var req UpdateConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.URL = strings.TrimSpace(req.URL)
if req.URL == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"})
return
}
if _, err := url.Parse(req.URL); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
return
}
method := strings.ToLower(strings.TrimSpace(req.Method))
if method != "get" && method != "post" {
method = "post"
}
shellType := strings.ToLower(strings.TrimSpace(req.Type))
if shellType == "" {
shellType = "php"
}
conn := &database.WebShellConnection{
ID: id,
URL: req.URL,
Password: strings.TrimSpace(req.Password),
Type: shellType,
Method: method,
CmdParam: strings.TrimSpace(req.CmdParam),
Remark: strings.TrimSpace(req.Remark),
}
if err := h.db.UpdateWebshellConnection(conn); err != nil {
if err == sql.ErrNoRows {
c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
updated, _ := h.db.GetWebshellConnection(id)
if updated != nil {
c.JSON(http.StatusOK, updated)
} else {
c.JSON(http.StatusOK, conn)
}
}
// DeleteConnection 删除 WebShell 连接(DELETE /api/webshell/connections/:id
func (h *WebShellHandler) DeleteConnection(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
if err := h.db.DeleteWebshellConnection(id); err != nil {
if err == sql.ErrNoRows {
c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
}
// GetConnectionState 获取 WebShell 连接关联的前端持久化状态(GET /api/webshell/connections/:id/state
func (h *WebShellHandler) GetConnectionState(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
conn, err := h.db.GetWebshellConnection(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if conn == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"})
return
}
stateJSON, err := h.db.GetWebshellConnectionState(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var state interface{}
if err := json.Unmarshal([]byte(stateJSON), &state); err != nil {
state = map[string]interface{}{}
}
c.JSON(http.StatusOK, gin.H{"state": state})
}
// SaveConnectionState 保存 WebShell 连接关联的前端持久化状态(PUT /api/webshell/connections/:id/state
func (h *WebShellHandler) SaveConnectionState(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
conn, err := h.db.GetWebshellConnection(id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if conn == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "connection not found"})
return
}
var req struct {
State json.RawMessage `json:"state"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
raw := req.State
if len(raw) == 0 {
raw = json.RawMessage(`{}`)
}
if len(raw) > 2*1024*1024 {
c.JSON(http.StatusBadRequest, gin.H{"error": "state payload too large (max 2MB)"})
return
}
var anyJSON interface{}
if err := json.Unmarshal(raw, &anyJSON); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "state must be valid json"})
return
}
if err := h.db.UpsertWebshellConnectionState(id, string(raw)); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
}
// GetAIHistory 获取指定 WebShell 连接的 AI 助手对话历史(GET /api/webshell/connections/:id/ai-history
func (h *WebShellHandler) GetAIHistory(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
conv, err := h.db.GetConversationByWebshellConnectionID(id)
if err != nil {
h.logger.Warn("获取 WebShell AI 对话失败", zap.String("connectionId", id), zap.Error(err))
c.JSON(http.StatusOK, gin.H{"conversationId": nil, "messages": []database.Message{}})
return
}
if conv == nil {
c.JSON(http.StatusOK, gin.H{"conversationId": nil, "messages": []database.Message{}})
return
}
c.JSON(http.StatusOK, gin.H{"conversationId": conv.ID, "messages": conv.Messages})
}
// ListAIConversations 列出该 WebShell 连接下的所有 AI 对话(供侧边栏)
func (h *WebShellHandler) ListAIConversations(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "database not available"})
return
}
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "id is required"})
return
}
list, err := h.db.ListConversationsByWebshellConnectionID(id)
if err != nil {
h.logger.Warn("列出 WebShell AI 对话失败", zap.String("connectionId", id), zap.Error(err))
c.JSON(http.StatusOK, []database.WebShellConversationItem{})
return
}
if list == nil {
list = []database.WebShellConversationItem{}
}
c.JSON(http.StatusOK, list)
}
// ExecRequest 执行命令请求(前端传入连接信息 + 命令)
type ExecRequest struct {
URL string `json:"url" binding:"required"`
Password string `json:"password"`
Type string `json:"type"` // php, asp, aspx, jsp, custom
Method string `json:"method"` // GET 或 POST,空则默认 POST
CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd
Command string `json:"command" binding:"required"`
}
// ExecResponse 执行命令响应
type ExecResponse struct {
OK bool `json:"ok"`
Output string `json:"output"`
Error string `json:"error,omitempty"`
HTTPCode int `json:"http_code,omitempty"`
}
// FileOpRequest 文件操作请求
type FileOpRequest struct {
URL string `json:"url" binding:"required"`
Password string `json:"password"`
Type string `json:"type"`
Method string `json:"method"` // GET 或 POST,空则默认 POST
CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd
Action string `json:"action" binding:"required"` // list, read, delete, write, mkdir, rename, upload, upload_chunk
Path string `json:"path"`
TargetPath string `json:"target_path"` // rename 时目标路径
Content string `json:"content"` // write/upload 时使用
ChunkIndex int `json:"chunk_index"` // upload_chunk 时,0 表示首块
}
// FileOpResponse 文件操作响应
type FileOpResponse struct {
OK bool `json:"ok"`
Output string `json:"output"`
Error string `json:"error,omitempty"`
}
func (h *WebShellHandler) Exec(c *gin.Context) {
var req ExecRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.URL = strings.TrimSpace(req.URL)
req.Command = strings.TrimSpace(req.Command)
if req.URL == "" || req.Command == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "url and command are required"})
return
}
parsed, err := url.Parse(req.URL)
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url: only http(s) allowed"})
return
}
useGET := strings.ToUpper(strings.TrimSpace(req.Method)) == "GET"
cmdParam := strings.TrimSpace(req.CmdParam)
if cmdParam == "" {
cmdParam = "cmd"
}
var httpReq *http.Request
if useGET {
targetURL := h.buildExecURL(req.URL, req.Type, req.Password, cmdParam, req.Command)
httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil)
} else {
body := h.buildExecBody(req.Type, req.Password, cmdParam, req.Command)
httpReq, err = http.NewRequest(http.MethodPost, req.URL, bytes.NewReader(body))
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
if err != nil {
h.logger.Warn("webshell exec NewRequest", zap.Error(err))
c.JSON(http.StatusInternalServerError, ExecResponse{OK: false, Error: err.Error()})
return
}
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
resp, err := h.client.Do(httpReq)
if err != nil {
h.logger.Warn("webshell exec Do", zap.String("url", req.URL), zap.Error(err))
c.JSON(http.StatusOK, ExecResponse{OK: false, Error: err.Error()})
return
}
defer resp.Body.Close()
out, _ := io.ReadAll(resp.Body)
output := string(out)
httpCode := resp.StatusCode
c.JSON(http.StatusOK, ExecResponse{
OK: resp.StatusCode == http.StatusOK,
Output: output,
HTTPCode: httpCode,
})
}
// buildExecBody 按常见 WebShell 约定构建 POST 体(多数使用 pass + cmd,可配置命令参数名)
func (h *WebShellHandler) buildExecBody(shellType, password, cmdParam, command string) []byte {
form := h.execParams(shellType, password, cmdParam, command)
return []byte(form.Encode())
}
// buildExecURL 构建 GET 请求的完整 URLbaseURL + ?pass=xxx&cmd=yyycmd 可配置)
func (h *WebShellHandler) buildExecURL(baseURL, shellType, password, cmdParam, command string) string {
form := h.execParams(shellType, password, cmdParam, command)
if parsed, err := url.Parse(baseURL); err == nil {
parsed.RawQuery = form.Encode()
return parsed.String()
}
return baseURL + "?" + form.Encode()
}
func (h *WebShellHandler) execParams(shellType, password, cmdParam, command string) url.Values {
shellType = strings.ToLower(strings.TrimSpace(shellType))
if shellType == "" {
shellType = "php"
}
if strings.TrimSpace(cmdParam) == "" {
cmdParam = "cmd"
}
form := url.Values{}
form.Set("pass", password)
form.Set(cmdParam, command)
return form
}
func (h *WebShellHandler) FileOp(c *gin.Context) {
var req FileOpRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.URL = strings.TrimSpace(req.URL)
req.Action = strings.ToLower(strings.TrimSpace(req.Action))
if req.URL == "" || req.Action == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "url and action are required"})
return
}
parsed, err := url.Parse(req.URL)
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url: only http(s) allowed"})
return
}
// 通过执行系统命令实现文件操作(与通用一句话兼容)
var command string
shellType := strings.ToLower(strings.TrimSpace(req.Type))
switch req.Action {
case "list":
path := strings.TrimSpace(req.Path)
if path == "" {
path = "."
}
if shellType == "asp" || shellType == "aspx" {
command = "dir " + h.escapePath(path)
} else {
command = "ls -la " + h.escapePath(path)
}
case "read":
if shellType == "asp" || shellType == "aspx" {
command = "type " + h.escapePath(strings.TrimSpace(req.Path))
} else {
command = "cat " + h.escapePath(strings.TrimSpace(req.Path))
}
case "delete":
if shellType == "asp" || shellType == "aspx" {
command = "del " + h.escapePath(strings.TrimSpace(req.Path))
} else {
command = "rm -f " + h.escapePath(strings.TrimSpace(req.Path))
}
case "write":
path := h.escapePath(strings.TrimSpace(req.Path))
command = "echo " + h.escapeForEcho(req.Content) + " > " + path
case "mkdir":
path := strings.TrimSpace(req.Path)
if path == "" {
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for mkdir"})
return
}
if shellType == "asp" || shellType == "aspx" {
command = "md " + h.escapePath(path)
} else {
command = "mkdir -p " + h.escapePath(path)
}
case "rename":
oldPath := strings.TrimSpace(req.Path)
newPath := strings.TrimSpace(req.TargetPath)
if oldPath == "" || newPath == "" {
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path and target_path are required for rename"})
return
}
if shellType == "asp" || shellType == "aspx" {
command = "move /y " + h.escapePath(oldPath) + " " + h.escapePath(newPath)
} else {
command = "mv " + h.escapePath(oldPath) + " " + h.escapePath(newPath)
}
case "upload":
path := strings.TrimSpace(req.Path)
if path == "" {
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for upload"})
return
}
if len(req.Content) > 512*1024 {
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "upload content too large (max 512KB base64)"})
return
}
// base64 仅含 A-Za-z0-9+/=,用单引号包裹安全
command = "echo " + "'" + req.Content + "'" + " | base64 -d > " + h.escapePath(path)
case "upload_chunk":
path := strings.TrimSpace(req.Path)
if path == "" {
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for upload_chunk"})
return
}
redir := ">>"
if req.ChunkIndex == 0 {
redir = ">"
}
command = "echo " + "'" + req.Content + "'" + " | base64 -d " + redir + " " + h.escapePath(path)
default:
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "unsupported action: " + req.Action})
return
}
useGET := strings.ToUpper(strings.TrimSpace(req.Method)) == "GET"
cmdParam := strings.TrimSpace(req.CmdParam)
if cmdParam == "" {
cmdParam = "cmd"
}
var httpReq *http.Request
if useGET {
targetURL := h.buildExecURL(req.URL, req.Type, req.Password, cmdParam, command)
httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil)
} else {
body := h.buildExecBody(req.Type, req.Password, cmdParam, command)
httpReq, err = http.NewRequest(http.MethodPost, req.URL, bytes.NewReader(body))
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
if err != nil {
c.JSON(http.StatusInternalServerError, FileOpResponse{OK: false, Error: err.Error()})
return
}
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
resp, err := h.client.Do(httpReq)
if err != nil {
c.JSON(http.StatusOK, FileOpResponse{OK: false, Error: err.Error()})
return
}
defer resp.Body.Close()
out, _ := io.ReadAll(resp.Body)
output := string(out)
c.JSON(http.StatusOK, FileOpResponse{
OK: resp.StatusCode == http.StatusOK,
Output: output,
})
}
func (h *WebShellHandler) escapePath(p string) string {
if p == "" {
return "."
}
// 简单转义空格与敏感字符,避免命令注入
return "'" + strings.ReplaceAll(p, "'", "'\\''") + "'"
}
func (h *WebShellHandler) escapeForEcho(s string) string {
// 仅用于 write:base64 写入更安全,这里简单用单引号包裹
return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
}
// ExecWithConnection 在指定 WebShell 连接上执行命令(供 MCP/Agent 等非 HTTP 调用)
func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection, command string) (output string, ok bool, errMsg string) {
if conn == nil {
return "", false, "connection is nil"
}
command = strings.TrimSpace(command)
if command == "" {
return "", false, "command is required"
}
useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET"
cmdParam := strings.TrimSpace(conn.CmdParam)
if cmdParam == "" {
cmdParam = "cmd"
}
var httpReq *http.Request
var err error
if useGET {
targetURL := h.buildExecURL(conn.URL, conn.Type, conn.Password, cmdParam, command)
httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil)
} else {
body := h.buildExecBody(conn.Type, conn.Password, cmdParam, command)
httpReq, err = http.NewRequest(http.MethodPost, conn.URL, bytes.NewReader(body))
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
if err != nil {
return "", false, err.Error()
}
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
resp, err := h.client.Do(httpReq)
if err != nil {
return "", false, err.Error()
}
defer resp.Body.Close()
out, _ := io.ReadAll(resp.Body)
return string(out), resp.StatusCode == http.StatusOK, ""
}
// FileOpWithConnection 在指定 WebShell 连接上执行文件操作(供 MCP/Agent 调用),支持 list / read / write
func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection, action, path, content, targetPath string) (output string, ok bool, errMsg string) {
if conn == nil {
return "", false, "connection is nil"
}
action = strings.ToLower(strings.TrimSpace(action))
shellType := strings.ToLower(strings.TrimSpace(conn.Type))
if shellType == "" {
shellType = "php"
}
var command string
switch action {
case "list":
if path == "" {
path = "."
}
if shellType == "asp" || shellType == "aspx" {
command = "dir " + h.escapePath(strings.TrimSpace(path))
} else {
command = "ls -la " + h.escapePath(strings.TrimSpace(path))
}
case "read":
path = strings.TrimSpace(path)
if path == "" {
return "", false, "path is required for read"
}
if shellType == "asp" || shellType == "aspx" {
command = "type " + h.escapePath(path)
} else {
command = "cat " + h.escapePath(path)
}
case "write":
path = strings.TrimSpace(path)
if path == "" {
return "", false, "path is required for write"
}
command = "echo " + h.escapeForEcho(content) + " > " + h.escapePath(path)
default:
return "", false, "unsupported action: " + action + " (supported: list, read, write)"
}
useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET"
cmdParam := strings.TrimSpace(conn.CmdParam)
if cmdParam == "" {
cmdParam = "cmd"
}
var httpReq *http.Request
var err error
if useGET {
targetURL := h.buildExecURL(conn.URL, conn.Type, conn.Password, cmdParam, command)
httpReq, err = http.NewRequest(http.MethodGet, targetURL, nil)
} else {
body := h.buildExecBody(conn.Type, conn.Password, cmdParam, command)
httpReq, err = http.NewRequest(http.MethodPost, conn.URL, bytes.NewReader(body))
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
if err != nil {
return "", false, err.Error()
}
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
resp, err := h.client.Do(httpReq)
if err != nil {
return "", false, err.Error()
}
defer resp.Body.Close()
out, _ := io.ReadAll(resp.Body)
return string(out), resp.StatusCode == http.StatusOK, ""
}
-67
View File
@@ -1,67 +0,0 @@
package knowledge
import (
"context"
"fmt"
"strings"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
"github.com/cloudwego/eino/components/document"
"github.com/pkoukk/tiktoken-go"
)
func tokenizerLenFunc(embeddingModel string) func(string) int {
fallback := func(s string) int {
r := []rune(s)
if len(r) == 0 {
return 0
}
return (len(r) + 3) / 4
}
m := strings.TrimSpace(embeddingModel)
if m == "" {
return fallback
}
tok, err := tiktoken.EncodingForModel(m)
if err != nil {
return fallback
}
return func(s string) int {
return len(tok.Encode(s, nil, nil))
}
}
// newKnowledgeSplitter builds an Eino recursive text splitter. LenFunc uses tiktoken for
// embeddingModel when available, else rune/4 approximation.
func newKnowledgeSplitter(chunkSize, overlap int, embeddingModel string) (document.Transformer, error) {
if chunkSize <= 0 {
return nil, fmt.Errorf("chunk size must be positive")
}
if overlap < 0 {
overlap = 0
}
return recursive.NewSplitter(context.Background(), &recursive.Config{
ChunkSize: chunkSize,
OverlapSize: overlap,
LenFunc: tokenizerLenFunc(embeddingModel),
Separators: []string{
"\n\n", "\n## ", "\n### ", "\n#### ", "\n",
"。", "", "", ". ", "? ", "! ",
" ",
},
})
}
// newMarkdownHeaderSplitter Eino-ext Markdown 按标题切分(#####),适合技术/Markdown 知识库。
func newMarkdownHeaderSplitter(ctx context.Context) (document.Transformer, error) {
return markdown.NewHeaderSplitter(ctx, &markdown.HeaderConfig{
Headers: map[string]string{
"#": "h1",
"##": "h2",
"###": "h3",
"####": "h4",
},
TrimHeaders: false,
})
}
-129
View File
@@ -1,129 +0,0 @@
package knowledge
import (
"fmt"
"strings"
)
// Document metadata keys for Eino schema.Document flowing through the RAG pipeline.
const (
metaKBCategory = "kb_category"
metaKBTitle = "kb_title"
metaKBItemID = "kb_item_id"
metaKBChunkIndex = "kb_chunk_index"
metaSimilarity = "similarity"
)
// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo].
const (
DSLRiskType = "risk_type"
DSLSimilarityThreshold = "similarity_threshold"
DSLSubIndexFilter = "sub_index_filter"
)
// FormatEmbeddingInput matches the historical indexing format so existing embeddings
// stay comparable if users skip reindex; new indexes use the same string shape.
func FormatEmbeddingInput(category, title, chunkText string) string {
return fmt.Sprintf("[风险类型:%s] [标题:%s]\n%s", category, title, chunkText)
}
// FormatQueryEmbeddingText builds the string embedded at query time so it matches
// [FormatEmbeddingInput] for the same risk category (title left empty for queries).
func FormatQueryEmbeddingText(riskType, query string) string {
q := strings.TrimSpace(query)
rt := strings.TrimSpace(riskType)
if rt != "" {
return FormatEmbeddingInput(rt, "", q)
}
return q
}
// MetaLookupString returns metadata string value or "" if absent.
func MetaLookupString(md map[string]any, key string) string {
if md == nil {
return ""
}
v, ok := md[key]
if !ok || v == nil {
return ""
}
switch t := v.(type) {
case string:
return t
default:
return strings.TrimSpace(fmt.Sprint(t))
}
}
// MetaStringOK returns trimmed non-empty string and true if present and non-empty.
func MetaStringOK(md map[string]any, key string) (string, bool) {
s := strings.TrimSpace(MetaLookupString(md, key))
if s == "" {
return "", false
}
return s, true
}
// RequireMetaString requires a non-empty string metadata field.
func RequireMetaString(md map[string]any, key string) (string, error) {
s, ok := MetaStringOK(md, key)
if !ok {
return "", fmt.Errorf("missing or empty metadata %q", key)
}
return s, nil
}
// RequireMetaInt requires an integer metadata field.
func RequireMetaInt(md map[string]any, key string) (int, error) {
if md == nil {
return 0, fmt.Errorf("missing metadata key %q", key)
}
v, ok := md[key]
if !ok {
return 0, fmt.Errorf("missing metadata key %q", key)
}
switch t := v.(type) {
case int:
return t, nil
case int32:
return int(t), nil
case int64:
return int(t), nil
case float64:
return int(t), nil
default:
return 0, fmt.Errorf("metadata %q: unsupported type %T", key, v)
}
}
// DSLNumeric coerces DSL map values (e.g. from JSON) to float64.
func DSLNumeric(v any) (float64, bool) {
switch t := v.(type) {
case float64:
return t, true
case float32:
return float64(t), true
case int:
return float64(t), true
case int64:
return float64(t), true
case uint32:
return float64(t), true
case uint64:
return float64(t), true
default:
return 0, false
}
}
// MetaFloat64OK reads a float metadata value.
func MetaFloat64OK(md map[string]any, key string) (float64, bool) {
if md == nil {
return 0, false
}
v, ok := md[key]
if !ok {
return 0, false
}
return DSLNumeric(v)
}
-14
View File
@@ -1,14 +0,0 @@
package knowledge
import "testing"
func TestFormatQueryEmbeddingText_AlignsWithIndexPrefix(t *testing.T) {
q := FormatQueryEmbeddingText("XSS", "payload")
want := FormatEmbeddingInput("XSS", "", "payload")
if q != want {
t.Fatalf("query embed text mismatch:\n got: %q\nwant: %q", q, want)
}
if FormatQueryEmbeddingText("", "hello") != "hello" {
t.Fatalf("expected bare query without risk type")
}
}
-25
View File
@@ -1,25 +0,0 @@
package knowledge
import (
"context"
"fmt"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
)
// BuildKnowledgeRetrieveChain 编译「查询字符串 → 文档列表」的 Eino Chain,底层为 SQLite 向量检索([VectorEinoRetriever])。
// 去重、上下文预算截断与最终 Top-K 均在 [VectorEinoRetriever.Retrieve] 内完成,与 HTTP/MCP 检索路径一致。
func BuildKnowledgeRetrieveChain(ctx context.Context, r *Retriever) (compose.Runnable[string, []*schema.Document], error) {
if r == nil {
return nil, fmt.Errorf("retriever is nil")
}
ch := compose.NewChain[string, []*schema.Document]()
ch.AppendRetriever(r.AsEinoRetriever())
return ch.Compile(ctx)
}
// CompileRetrieveChain 等价于 [BuildKnowledgeRetrieveChain](ctx, r)。
func (r *Retriever) CompileRetrieveChain(ctx context.Context) (compose.Runnable[string, []*schema.Document], error) {
return BuildKnowledgeRetrieveChain(ctx, r)
}
@@ -1,23 +0,0 @@
package knowledge
import (
"context"
"testing"
"go.uber.org/zap"
)
func TestBuildKnowledgeRetrieveChain_Compile(t *testing.T) {
r := NewRetriever(nil, nil, &RetrievalConfig{TopK: 3, SimilarityThreshold: 0.5}, zap.NewNop())
_, err := BuildKnowledgeRetrieveChain(context.Background(), r)
if err != nil {
t.Fatal(err)
}
}
func TestBuildKnowledgeRetrieveChain_NilRetriever(t *testing.T) {
_, err := BuildKnowledgeRetrieveChain(context.Background(), nil)
if err == nil {
t.Fatal("expected error for nil retriever")
}
}
@@ -1,202 +0,0 @@
package knowledge
import (
"context"
"fmt"
"strings"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// VectorEinoRetriever implements [retriever.Retriever] on top of SQLite-stored embeddings + cosine similarity.
//
// Options:
// - [retriever.WithTopK]
// - [retriever.WithDSLInfo] with [DSLRiskType] (string), [DSLSimilarityThreshold] (float, cosine 01), [DSLSubIndexFilter] (string)
//
// Document scores are cosine similarity; [retriever.WithScoreThreshold] is not mapped to a different metric.
//
// After vector search: optional [DocumentReranker] (see [Retriever.SetDocumentReranker]), then
// [ApplyPostRetrieve] (normalized-text dedupe, context budget, final Top-K) using [config.PostRetrieveConfig].
type VectorEinoRetriever struct {
inner *Retriever
}
// NewVectorEinoRetriever wraps r for Eino compose / tooling.
func NewVectorEinoRetriever(r *Retriever) *VectorEinoRetriever {
if r == nil {
return nil
}
return &VectorEinoRetriever{inner: r}
}
// GetType identifies this retriever for Eino callbacks.
func (h *VectorEinoRetriever) GetType() string {
return "SQLiteVectorKnowledgeRetriever"
}
// Retrieve runs vector search and returns [schema.Document] rows.
func (h *VectorEinoRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) (out []*schema.Document, err error) {
if h == nil || h.inner == nil {
return nil, fmt.Errorf("VectorEinoRetriever: nil retriever")
}
q := strings.TrimSpace(query)
if q == "" {
return nil, fmt.Errorf("查询不能为空")
}
ro := retriever.GetCommonOptions(nil, opts...)
cfg := h.inner.config
req := &SearchRequest{Query: q}
if ro.TopK != nil && *ro.TopK > 0 {
req.TopK = *ro.TopK
} else if cfg != nil && cfg.TopK > 0 {
req.TopK = cfg.TopK
} else {
req.TopK = 5
}
req.Threshold = 0
if ro.DSLInfo != nil {
if rt, ok := ro.DSLInfo[DSLRiskType].(string); ok {
req.RiskType = strings.TrimSpace(rt)
}
if v, ok := ro.DSLInfo[DSLSimilarityThreshold]; ok {
if f, ok2 := DSLNumeric(v); ok2 && f > 0 {
req.Threshold = f
}
}
if sf, ok := ro.DSLInfo[DSLSubIndexFilter].(string); ok {
req.SubIndexFilter = strings.TrimSpace(sf)
}
}
if req.SubIndexFilter == "" && cfg != nil && strings.TrimSpace(cfg.SubIndexFilter) != "" {
req.SubIndexFilter = strings.TrimSpace(cfg.SubIndexFilter)
}
if req.Threshold <= 0 && cfg != nil && cfg.SimilarityThreshold > 0 {
req.Threshold = cfg.SimilarityThreshold
}
if req.Threshold <= 0 {
req.Threshold = 0.7
}
finalTopK := req.TopK
var postPO *config.PostRetrieveConfig
if cfg != nil {
postPO = &cfg.PostRetrieve
}
fetchK := EffectivePrefetchTopK(finalTopK, postPO)
searchReq := *req
searchReq.TopK = fetchK
ctx = callbacks.EnsureRunInfo(ctx, h.GetType(), components.ComponentOfRetriever)
th := req.Threshold
st := &th
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
Query: q,
TopK: finalTopK,
ScoreThreshold: st,
Extra: ro.DSLInfo,
})
defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
return
}
_ = callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: out})
}()
results, err := h.inner.vectorSearch(ctx, &searchReq)
if err != nil {
return nil, err
}
out = retrievalResultsToDocuments(results)
if rr := h.inner.documentReranker(); rr != nil && len(out) > 1 {
reranked, rerr := rr.Rerank(ctx, q, out)
if rerr != nil {
if h.inner.logger != nil {
h.inner.logger.Warn("知识检索重排失败,已使用向量序", zap.Error(rerr))
}
} else if len(reranked) > 0 {
out = reranked
}
}
tokenModel := ""
if h.inner.embedder != nil {
tokenModel = h.inner.embedder.EmbeddingModelName()
}
out, err = ApplyPostRetrieve(out, postPO, tokenModel, finalTopK)
if err != nil {
return nil, err
}
return out, nil
}
func retrievalResultsToDocuments(results []*RetrievalResult) []*schema.Document {
out := make([]*schema.Document, 0, len(results))
for _, res := range results {
if res == nil || res.Chunk == nil || res.Item == nil {
continue
}
d := &schema.Document{
ID: res.Chunk.ID,
Content: res.Chunk.ChunkText,
MetaData: map[string]any{
metaKBItemID: res.Item.ID,
metaKBCategory: res.Item.Category,
metaKBTitle: res.Item.Title,
metaKBChunkIndex: res.Chunk.ChunkIndex,
metaSimilarity: res.Similarity,
},
}
d.WithScore(res.Score)
out = append(out, d)
}
return out
}
func documentsToRetrievalResults(docs []*schema.Document) ([]*RetrievalResult, error) {
out := make([]*RetrievalResult, 0, len(docs))
for i, d := range docs {
if d == nil {
continue
}
itemID, err := RequireMetaString(d.MetaData, metaKBItemID)
if err != nil {
return nil, fmt.Errorf("document %d: %w", i, err)
}
cat := MetaLookupString(d.MetaData, metaKBCategory)
title := MetaLookupString(d.MetaData, metaKBTitle)
chunkIdx, err := RequireMetaInt(d.MetaData, metaKBChunkIndex)
if err != nil {
return nil, fmt.Errorf("document %d: %w", i, err)
}
sim, _ := MetaFloat64OK(d.MetaData, metaSimilarity)
item := &KnowledgeItem{ID: itemID, Category: cat, Title: title}
chunk := &KnowledgeChunk{
ID: d.ID,
ItemID: itemID,
ChunkIndex: chunkIdx,
ChunkText: d.Content,
}
out = append(out, &RetrievalResult{
Chunk: chunk,
Item: item,
Similarity: sim,
Score: d.Score(),
})
}
return out, nil
}
var _ retriever.Retriever = (*VectorEinoRetriever)(nil)
-142
View File
@@ -1,142 +0,0 @@
package knowledge
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
"github.com/google/uuid"
)
// SQLiteIndexer implements [indexer.Indexer] against knowledge_embeddings + existing schema.
type SQLiteIndexer struct {
db *sql.DB
batchSize int
embeddingModel string
}
// NewSQLiteIndexer returns an indexer that writes chunk rows for one knowledge item per Store call.
// batchSize is the embedding batch size; if <= 0, default 64 is used.
// embeddingModel is persisted per row for retrieval-time consistency checks (may be empty).
func NewSQLiteIndexer(db *sql.DB, batchSize int, embeddingModel string) *SQLiteIndexer {
return &SQLiteIndexer{db: db, batchSize: batchSize, embeddingModel: strings.TrimSpace(embeddingModel)}
}
// GetType implements eino callback run info.
func (s *SQLiteIndexer) GetType() string {
return "SQLiteKnowledgeIndexer"
}
// Store embeds documents and inserts rows. Each doc must carry MetaData:
// kb_item_id, kb_category, kb_title, kb_chunk_index (int). Content is chunk text only.
func (s *SQLiteIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (ids []string, err error) {
options := indexer.GetCommonOptions(nil, opts...)
if options.Embedding == nil {
return nil, fmt.Errorf("sqlite indexer: embedding is required")
}
if len(docs) == 0 {
return nil, nil
}
ctx = callbacks.EnsureRunInfo(ctx, s.GetType(), components.ComponentOfIndexer)
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
defer func() {
if err != nil {
_ = callbacks.OnError(ctx, err)
return
}
_ = callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: ids})
}()
subIdxStr := strings.Join(options.SubIndexes, ",")
texts := make([]string, len(docs))
for i, d := range docs {
if d == nil {
return nil, fmt.Errorf("sqlite indexer: nil document at %d", i)
}
cat := MetaLookupString(d.MetaData, metaKBCategory)
title := MetaLookupString(d.MetaData, metaKBTitle)
texts[i] = FormatEmbeddingInput(cat, title, d.Content)
}
bs := s.batchSize
if bs <= 0 {
bs = 64
}
var allVecs [][]float64
for start := 0; start < len(texts); start += bs {
end := start + bs
if end > len(texts) {
end = len(texts)
}
batch := texts[start:end]
vecs, embedErr := options.Embedding.EmbedStrings(ctx, batch)
if embedErr != nil {
return nil, fmt.Errorf("sqlite indexer: embed batch %d-%d: %w", start, end, embedErr)
}
if len(vecs) != len(batch) {
return nil, fmt.Errorf("sqlite indexer: embed count mismatch: got %d want %d", len(vecs), len(batch))
}
allVecs = append(allVecs, vecs...)
}
embedDim := 0
if len(allVecs) > 0 {
embedDim = len(allVecs[0])
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("sqlite indexer: begin tx: %w", err)
}
defer tx.Rollback()
ids = make([]string, 0, len(docs))
for i, d := range docs {
chunkID := uuid.New().String()
itemID, metaErr := RequireMetaString(d.MetaData, metaKBItemID)
if metaErr != nil {
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
}
chunkIdx, metaErr := RequireMetaInt(d.MetaData, metaKBChunkIndex)
if metaErr != nil {
return nil, fmt.Errorf("sqlite indexer: doc %d: %w", i, metaErr)
}
vec := allVecs[i]
if embedDim > 0 && len(vec) != embedDim {
return nil, fmt.Errorf("sqlite indexer: inconsistent embedding dim at doc %d: got %d want %d", i, len(vec), embedDim)
}
vec32 := make([]float32, len(vec))
for j, v := range vec {
vec32[j] = float32(v)
}
embeddingJSON, jsonErr := json.Marshal(vec32)
if jsonErr != nil {
return nil, fmt.Errorf("sqlite indexer: marshal embedding: %w", jsonErr)
}
_, err = tx.ExecContext(ctx,
`INSERT INTO knowledge_embeddings (id, item_id, chunk_index, chunk_text, embedding, sub_indexes, embedding_model, embedding_dim, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))`,
chunkID, itemID, chunkIdx, d.Content, string(embeddingJSON), subIdxStr, s.embeddingModel, embedDim,
)
if err != nil {
return nil, fmt.Errorf("sqlite indexer: insert chunk %d: %w", i, err)
}
ids = append(ids, chunkID)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("sqlite indexer: commit: %w", err)
}
return ids, nil
}
var _ indexer.Indexer = (*SQLiteIndexer)(nil)
-251
View File
@@ -1,251 +0,0 @@
package knowledge
import (
"context"
"fmt"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
einoembedopenai "github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino/components/embedding"
"go.uber.org/zap"
"golang.org/x/time/rate"
)
// Embedder 使用 CloudWeGo Eino 的 OpenAI Embedding 组件,并保留速率限制与重试。
type Embedder struct {
eino embedding.Embedder
config *config.KnowledgeConfig
logger *zap.Logger
rateLimiter *rate.Limiter
rateLimitDelay time.Duration
maxRetries int
retryDelay time.Duration
mu sync.Mutex
}
// NewEmbedder 基于 Eino eino-ext OpenAI EmbedderopenAIConfig 用于在知识库未单独配置 key 时回退 API Key。
func NewEmbedder(ctx context.Context, cfg *config.KnowledgeConfig, openAIConfig *config.OpenAIConfig, logger *zap.Logger) (*Embedder, error) {
if cfg == nil {
return nil, fmt.Errorf("knowledge config is nil")
}
var rateLimiter *rate.Limiter
var rateLimitDelay time.Duration
if cfg.Indexing.MaxRPM > 0 {
rpm := cfg.Indexing.MaxRPM
rateLimiter = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), rpm)
if logger != nil {
logger.Info("知识库索引速率限制已启用", zap.Int("maxRPM", rpm))
}
} else if cfg.Indexing.RateLimitDelayMs > 0 {
rateLimitDelay = time.Duration(cfg.Indexing.RateLimitDelayMs) * time.Millisecond
if logger != nil {
logger.Info("知识库索引固定延迟已启用", zap.Duration("delay", rateLimitDelay))
}
}
maxRetries := 3
retryDelay := 1000 * time.Millisecond
if cfg.Indexing.MaxRetries > 0 {
maxRetries = cfg.Indexing.MaxRetries
}
if cfg.Indexing.RetryDelayMs > 0 {
retryDelay = time.Duration(cfg.Indexing.RetryDelayMs) * time.Millisecond
}
model := strings.TrimSpace(cfg.Embedding.Model)
if model == "" {
model = "text-embedding-3-small"
}
baseURL := strings.TrimSpace(cfg.Embedding.BaseURL)
baseURL = strings.TrimSuffix(baseURL, "/")
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
apiKey := strings.TrimSpace(cfg.Embedding.APIKey)
if apiKey == "" && openAIConfig != nil {
apiKey = strings.TrimSpace(openAIConfig.APIKey)
}
if apiKey == "" {
return nil, fmt.Errorf("embedding API key 未配置")
}
timeout := 120 * time.Second
if cfg.Indexing.RequestTimeoutSeconds > 0 {
timeout = time.Duration(cfg.Indexing.RequestTimeoutSeconds) * time.Second
}
httpClient := &http.Client{Timeout: timeout}
inner, err := einoembedopenai.NewEmbedder(ctx, &einoembedopenai.EmbeddingConfig{
APIKey: apiKey,
BaseURL: baseURL,
ByAzure: false,
Model: model,
HTTPClient: httpClient,
})
if err != nil {
return nil, fmt.Errorf("eino OpenAI embedder: %w", err)
}
return &Embedder{
eino: inner,
config: cfg,
logger: logger,
rateLimiter: rateLimiter,
rateLimitDelay: rateLimitDelay,
maxRetries: maxRetries,
retryDelay: retryDelay,
}, nil
}
// EmbeddingModelName 返回配置的嵌入模型名(用于 tiktoken 分块与向量行元数据)。
func (e *Embedder) EmbeddingModelName() string {
if e == nil || e.config == nil {
return ""
}
s := strings.TrimSpace(e.config.Embedding.Model)
if s != "" {
return s
}
return "text-embedding-3-small"
}
func (e *Embedder) waitRateLimiter() {
e.mu.Lock()
defer e.mu.Unlock()
if e.rateLimiter != nil {
ctx := context.Background()
if err := e.rateLimiter.Wait(ctx); err != nil && e.logger != nil {
e.logger.Warn("速率限制器等待失败", zap.Error(err))
}
}
if e.rateLimitDelay > 0 {
time.Sleep(e.rateLimitDelay)
}
}
// EmbedText 单条嵌入(float32,与历史存储格式一致)。
func (e *Embedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
vecs, err := e.EmbedStrings(ctx, []string{text})
if err != nil {
return nil, err
}
if len(vecs) != 1 {
return nil, fmt.Errorf("unexpected embedding count: %d", len(vecs))
}
return vecs[0], nil
}
// EmbedStrings 批量嵌入,带重试;实现 [embedding.Embedder],可供 Eino Indexer 使用。
func (e *Embedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float32, error) {
if e == nil || e.eino == nil {
return nil, fmt.Errorf("embedder not initialized")
}
if len(texts) == 0 {
return nil, nil
}
var lastErr error
for attempt := 0; attempt < e.maxRetries; attempt++ {
if attempt > 0 {
wait := e.retryDelay * time.Duration(attempt)
if e.logger != nil {
e.logger.Debug("嵌入重试前等待", zap.Int("attempt", attempt+1), zap.Duration("wait", wait))
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(wait):
}
} else {
e.waitRateLimiter()
}
raw, err := e.eino.EmbedStrings(ctx, texts, opts...)
if err == nil {
out := make([][]float32, len(raw))
for i, row := range raw {
out[i] = make([]float32, len(row))
for j, v := range row {
out[i][j] = float32(v)
}
}
return out, nil
}
lastErr = err
if !e.isRetryableError(err) {
return nil, err
}
if e.logger != nil {
e.logger.Debug("嵌入失败,将重试", zap.Int("attempt", attempt+1), zap.Error(err))
}
}
return nil, fmt.Errorf("达到最大重试次数 (%d): %v", e.maxRetries, lastErr)
}
// EmbedTexts 批量 float32 嵌入(兼容旧调用;单次请求批量以减小延迟)。
func (e *Embedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) {
return e.EmbedStrings(ctx, texts)
}
func (e *Embedder) isRetryableError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
if strings.Contains(errStr, "429") || strings.Contains(errStr, "rate limit") {
return true
}
if strings.Contains(errStr, "500") || strings.Contains(errStr, "502") ||
strings.Contains(errStr, "503") || strings.Contains(errStr, "504") {
return true
}
if strings.Contains(errStr, "timeout") || strings.Contains(errStr, "connection") ||
strings.Contains(errStr, "network") || strings.Contains(errStr, "EOF") {
return true
}
return false
}
// einoFloatEmbedder adapts [][]float32 embedder to Eino's [][]float64 [embedding.Embedder] for Indexer.Store.
type einoFloatEmbedder struct {
inner *Embedder
}
func (w *einoFloatEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
vec32, err := w.inner.EmbedStrings(ctx, texts, opts...)
if err != nil {
return nil, err
}
out := make([][]float64, len(vec32))
for i, row := range vec32 {
out[i] = make([]float64, len(row))
for j, v := range row {
out[i][j] = float64(v)
}
}
return out, nil
}
func (w *einoFloatEmbedder) GetType() string {
return "CyberStrikeKnowledgeEmbedder"
}
func (w *einoFloatEmbedder) IsCallbacksEnabled() bool {
return false
}
// EinoEmbeddingComponent returns an [embedding.Embedder] that uses the same retry/rate-limit path
// and produces float64 vectors expected by generic Eino indexer helpers.
func (e *Embedder) EinoEmbeddingComponent() embedding.Embedder {
return &einoFloatEmbedder{inner: e}
}
-91
View File
@@ -1,91 +0,0 @@
package knowledge
import (
"context"
"database/sql"
"fmt"
"strings"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/schema"
)
// normalizeChunkStrategy returns "recursive" or "markdown_then_recursive".
func normalizeChunkStrategy(s string) string {
v := strings.TrimSpace(strings.ToLower(s))
switch v {
case "recursive":
return "recursive"
case "markdown_then_recursive", "markdown_recursive", "markdown":
return "markdown_then_recursive"
case "":
return "markdown_then_recursive"
default:
return "markdown_then_recursive"
}
}
func buildKnowledgeIndexChain(
ctx context.Context,
indexingCfg *config.IndexingConfig,
db *sql.DB,
recursive document.Transformer,
embeddingModel string,
) (compose.Runnable[[]*schema.Document, []string], error) {
if recursive == nil {
return nil, fmt.Errorf("recursive transformer is nil")
}
if db == nil {
return nil, fmt.Errorf("db is nil")
}
strategy := normalizeChunkStrategy("markdown_then_recursive")
batch := 64
maxChunks := 0
if indexingCfg != nil {
strategy = normalizeChunkStrategy(indexingCfg.ChunkStrategy)
if indexingCfg.BatchSize > 0 {
batch = indexingCfg.BatchSize
}
maxChunks = indexingCfg.MaxChunksPerItem
}
si := NewSQLiteIndexer(db, batch, embeddingModel)
ch := compose.NewChain[[]*schema.Document, []string]()
if strategy != "recursive" {
md, err := newMarkdownHeaderSplitter(ctx)
if err != nil {
return nil, fmt.Errorf("markdown splitter: %w", err)
}
ch.AppendDocumentTransformer(md)
}
ch.AppendDocumentTransformer(recursive)
ch.AppendLambda(newChunkEnrichLambda(maxChunks))
ch.AppendIndexer(si)
return ch.Compile(ctx)
}
func newChunkEnrichLambda(maxChunks int) *compose.Lambda {
return compose.InvokableLambda(func(ctx context.Context, docs []*schema.Document) ([]*schema.Document, error) {
_ = ctx
out := make([]*schema.Document, 0, len(docs))
for _, d := range docs {
if d == nil || strings.TrimSpace(d.Content) == "" {
continue
}
out = append(out, d)
}
if maxChunks > 0 && len(out) > maxChunks {
out = out[:maxChunks]
}
for i, d := range out {
if d.MetaData == nil {
d.MetaData = make(map[string]any)
}
d.MetaData[metaKBChunkIndex] = i
}
return out, nil
})
}
-21
View File
@@ -1,21 +0,0 @@
package knowledge
import "testing"
func TestNormalizeChunkStrategy(t *testing.T) {
cases := []struct {
in, want string
}{
{"", "markdown_then_recursive"},
{"recursive", "recursive"},
{"RECURSIVE", "recursive"},
{"markdown_then_recursive", "markdown_then_recursive"},
{"markdown", "markdown_then_recursive"},
{"unknown", "markdown_then_recursive"},
}
for _, tc := range cases {
if got := normalizeChunkStrategy(tc.in); got != tc.want {
t.Errorf("normalizeChunkStrategy(%q) = %q, want %q", tc.in, got, tc.want)
}
}
}
-352
View File
@@ -1,352 +0,0 @@
package knowledge
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// Indexer 使用 Eino Compose 索引链(Markdown/递归分块、Lambda enrich、SQLite 索引)与嵌入写入。
type Indexer struct {
db *sql.DB
embedder *Embedder
logger *zap.Logger
chunkSize int
overlap int
indexingCfg *config.IndexingConfig
indexChain compose.Runnable[[]*schema.Document, []string]
fileLoader *fileloader.FileLoader
mu sync.RWMutex
lastError string
lastErrorTime time.Time
errorCount int
rebuildMu sync.RWMutex
isRebuilding bool
rebuildTotalItems int
rebuildCurrent int
rebuildFailed int
rebuildStartTime time.Time
rebuildLastItemID string
rebuildLastChunks int
}
// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。
func NewIndexer(ctx context.Context, db *sql.DB, embedder *Embedder, logger *zap.Logger, kcfg *config.KnowledgeConfig) (*Indexer, error) {
if db == nil {
return nil, fmt.Errorf("db is nil")
}
if embedder == nil {
return nil, fmt.Errorf("embedder is nil")
}
if err := EnsureKnowledgeEmbeddingsSchema(db); err != nil {
return nil, fmt.Errorf("knowledge_embeddings 结构迁移: %w", err)
}
if kcfg == nil {
kcfg = &config.KnowledgeConfig{}
}
indexingCfg := &kcfg.Indexing
chunkSize := 512
overlap := 50
if indexingCfg.ChunkSize > 0 {
chunkSize = indexingCfg.ChunkSize
}
if indexingCfg.ChunkOverlap >= 0 {
overlap = indexingCfg.ChunkOverlap
}
embedModel := embedder.EmbeddingModelName()
splitter, err := newKnowledgeSplitter(chunkSize, overlap, embedModel)
if err != nil {
return nil, fmt.Errorf("eino recursive splitter: %w", err)
}
chain, err := buildKnowledgeIndexChain(ctx, indexingCfg, db, splitter, embedModel)
if err != nil {
return nil, fmt.Errorf("knowledge index chain: %w", err)
}
var fl *fileloader.FileLoader
fl, err = fileloader.NewFileLoader(ctx, nil)
if err != nil {
if logger != nil {
logger.Warn("Eino FileLoader 初始化失败,prefer_source_file 将回退数据库正文", zap.Error(err))
}
fl = nil
err = nil
}
return &Indexer{
db: db,
embedder: embedder,
logger: logger,
chunkSize: chunkSize,
overlap: overlap,
indexingCfg: indexingCfg,
indexChain: chain,
fileLoader: fl,
}, nil
}
// RecompileIndexChain 在配置或嵌入模型变更后重建 Eino 索引链(无需重启进程)。
func (idx *Indexer) RecompileIndexChain(ctx context.Context) error {
if idx == nil || idx.db == nil || idx.embedder == nil {
return fmt.Errorf("indexer 未初始化")
}
if err := EnsureKnowledgeEmbeddingsSchema(idx.db); err != nil {
return err
}
embedModel := idx.embedder.EmbeddingModelName()
splitter, err := newKnowledgeSplitter(idx.chunkSize, idx.overlap, embedModel)
if err != nil {
return fmt.Errorf("eino recursive splitter: %w", err)
}
chain, err := buildKnowledgeIndexChain(ctx, idx.indexingCfg, idx.db, splitter, embedModel)
if err != nil {
return fmt.Errorf("knowledge index chain: %w", err)
}
idx.indexChain = chain
return nil
}
// IndexItem 索引单个知识项:先清空旧向量,再走 Compose 链(分块、嵌入、写入)。
func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
if idx.indexChain == nil {
return fmt.Errorf("索引链未初始化")
}
if idx.embedder == nil {
return fmt.Errorf("嵌入器未初始化")
}
var content, category, title, filePath string
err := idx.db.QueryRow("SELECT content, category, title, file_path FROM knowledge_base_items WHERE id = ?", itemID).Scan(&content, &category, &title, &filePath)
if err != nil {
return fmt.Errorf("获取知识项失败:%w", err)
}
if _, err := idx.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", itemID); err != nil {
return fmt.Errorf("删除旧向量失败:%w", err)
}
body := strings.TrimSpace(content)
if idx.indexingCfg != nil && idx.indexingCfg.PreferSourceFile && strings.TrimSpace(filePath) != "" && idx.fileLoader != nil {
docs, lerr := idx.fileLoader.Load(ctx, document.Source{URI: strings.TrimSpace(filePath)})
if lerr == nil && len(docs) > 0 {
var b strings.Builder
for i, d := range docs {
if d == nil {
continue
}
if i > 0 {
b.WriteString("\n\n")
}
b.WriteString(d.Content)
}
if s := strings.TrimSpace(b.String()); s != "" {
body = s
}
} else if idx.logger != nil {
idx.logger.Warn("优先源文件读取失败,使用数据库正文",
zap.String("itemId", itemID),
zap.String("path", filePath),
zap.Error(lerr))
}
}
root := &schema.Document{
ID: itemID,
Content: body,
MetaData: map[string]any{
metaKBCategory: category,
metaKBTitle: title,
metaKBItemID: itemID,
},
}
idxOpts := []indexer.Option{indexer.WithEmbedding(idx.embedder.EinoEmbeddingComponent())}
if idx.indexingCfg != nil && len(idx.indexingCfg.SubIndexes) > 0 {
idxOpts = append(idxOpts, indexer.WithSubIndexes(idx.indexingCfg.SubIndexes))
}
ids, err := idx.indexChain.Invoke(ctx, []*schema.Document{root}, compose.WithIndexerOption(idxOpts...))
if err != nil {
msg := fmt.Sprintf("索引写入失败 (知识项:%s): %v", itemID, err)
idx.mu.Lock()
idx.lastError = msg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
return err
}
if idx.logger != nil {
idx.logger.Info("知识项索引完成", zap.String("itemId", itemID), zap.Int("chunks", len(ids)))
}
idx.rebuildMu.Lock()
idx.rebuildLastItemID = itemID
idx.rebuildLastChunks = len(ids)
idx.rebuildMu.Unlock()
return nil
}
// HasIndex 检查是否存在索引
func (idx *Indexer) HasIndex() (bool, error) {
var count int
err := idx.db.QueryRow("SELECT COUNT(*) FROM knowledge_embeddings").Scan(&count)
if err != nil {
return false, fmt.Errorf("检查索引失败:%w", err)
}
return count > 0, nil
}
// RebuildIndex 重建所有索引
func (idx *Indexer) RebuildIndex(ctx context.Context) error {
idx.rebuildMu.Lock()
idx.isRebuilding = true
idx.rebuildTotalItems = 0
idx.rebuildCurrent = 0
idx.rebuildFailed = 0
idx.rebuildStartTime = time.Now()
idx.rebuildLastItemID = ""
idx.rebuildLastChunks = 0
idx.rebuildMu.Unlock()
idx.mu.Lock()
idx.lastError = ""
idx.lastErrorTime = time.Time{}
idx.errorCount = 0
idx.mu.Unlock()
rows, err := idx.db.Query("SELECT id FROM knowledge_base_items")
if err != nil {
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
return fmt.Errorf("查询知识项失败:%w", err)
}
defer rows.Close()
var itemIDs []string
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
return fmt.Errorf("扫描知识项 ID 失败:%w", err)
}
itemIDs = append(itemIDs, id)
}
idx.rebuildMu.Lock()
idx.rebuildTotalItems = len(itemIDs)
idx.rebuildMu.Unlock()
idx.logger.Info("开始重建索引", zap.Int("totalItems", len(itemIDs)))
failedCount := 0
consecutiveFailures := 0
maxConsecutiveFailures := 5
firstFailureItemID := ""
var firstFailureError error
for i, itemID := range itemIDs {
if err := idx.IndexItem(ctx, itemID); err != nil {
failedCount++
consecutiveFailures++
if consecutiveFailures == 1 {
firstFailureItemID = itemID
firstFailureError = err
idx.logger.Warn("索引知识项失败",
zap.String("itemId", itemID),
zap.Int("totalItems", len(itemIDs)),
zap.Error(err),
)
}
if consecutiveFailures >= maxConsecutiveFailures {
errorMsg := fmt.Sprintf("连续 %d 个知识项索引失败,可能存在配置问题(如嵌入模型配置错误、API 密钥无效、余额不足等)。第一个失败项:%s, 错误:%v", consecutiveFailures, firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("连续索引失败次数过多,立即停止索引",
zap.Int("consecutiveFailures", consecutiveFailures),
zap.Int("totalItems", len(itemIDs)),
zap.Int("processedItems", i+1),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
return fmt.Errorf("连续索引失败次数过多:%v", firstFailureError)
}
if failedCount > len(itemIDs)*3/10 && failedCount == len(itemIDs)*3/10+1 {
errorMsg := fmt.Sprintf("索引失败的知识项过多 (%d/%d),可能存在配置问题。第一个失败项:%s, 错误:%v", failedCount, len(itemIDs), firstFailureItemID, firstFailureError)
idx.mu.Lock()
idx.lastError = errorMsg
idx.lastErrorTime = time.Now()
idx.mu.Unlock()
idx.logger.Error("索引失败的知识项过多,可能存在配置问题",
zap.Int("failedCount", failedCount),
zap.Int("totalItems", len(itemIDs)),
zap.String("firstFailureItemId", firstFailureItemID),
zap.Error(firstFailureError),
)
}
continue
}
if consecutiveFailures > 0 {
consecutiveFailures = 0
firstFailureItemID = ""
firstFailureError = nil
}
idx.rebuildMu.Lock()
idx.rebuildCurrent = i + 1
idx.rebuildFailed = failedCount
idx.rebuildMu.Unlock()
if (i+1)%10 == 0 || (len(itemIDs) > 0 && (i+1)*100/len(itemIDs)%10 == 0 && (i+1)*100/len(itemIDs) > 0) {
idx.logger.Info("索引进度", zap.Int("current", i+1), zap.Int("total", len(itemIDs)), zap.Int("failed", failedCount))
}
}
idx.rebuildMu.Lock()
idx.isRebuilding = false
idx.rebuildMu.Unlock()
idx.logger.Info("索引重建完成", zap.Int("totalItems", len(itemIDs)), zap.Int("failedCount", failedCount))
return nil
}
// GetLastError 获取最近一次错误信息
func (idx *Indexer) GetLastError() (string, time.Time) {
idx.mu.RLock()
defer idx.mu.RUnlock()
return idx.lastError, idx.lastErrorTime
}
// GetRebuildStatus 获取重建索引状态
func (idx *Indexer) GetRebuildStatus() (isRebuilding bool, totalItems int, current int, failed int, lastItemID string, lastChunks int, startTime time.Time) {
idx.rebuildMu.RLock()
defer idx.rebuildMu.RUnlock()
return idx.isRebuilding, idx.rebuildTotalItems, idx.rebuildCurrent, idx.rebuildFailed, idx.rebuildLastItemID, idx.rebuildLastChunks, idx.rebuildStartTime
}
-885
View File
@@ -1,885 +0,0 @@
package knowledge
import (
"database/sql"
"encoding/json"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Manager 知识库管理器
type Manager struct {
db *sql.DB
basePath string
logger *zap.Logger
}
// NewManager 创建新的知识库管理器
func NewManager(db *sql.DB, basePath string, logger *zap.Logger) *Manager {
return &Manager{
db: db,
basePath: basePath,
logger: logger,
}
}
// ScanKnowledgeBase 扫描知识库目录,更新数据库
// 返回需要索引的知识项ID列表(新添加的或更新的)
func (m *Manager) ScanKnowledgeBase() ([]string, error) {
if m.basePath == "" {
return nil, fmt.Errorf("知识库路径未配置")
}
// 确保目录存在
if err := os.MkdirAll(m.basePath, 0755); err != nil {
return nil, fmt.Errorf("创建知识库目录失败: %w", err)
}
var itemsToIndex []string
// 遍历知识库目录
err := filepath.WalkDir(m.basePath, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
// 跳过目录和非markdown文件
if d.IsDir() || !strings.HasSuffix(strings.ToLower(path), ".md") {
return nil
}
// 计算相对路径和分类
relPath, err := filepath.Rel(m.basePath, path)
if err != nil {
return err
}
// 第一个目录名作为分类(风险类型)
parts := strings.Split(relPath, string(filepath.Separator))
category := "未分类"
if len(parts) > 1 {
category = parts[0]
}
// 文件名为标题
title := strings.TrimSuffix(filepath.Base(path), ".md")
// 读取文件内容
content, err := os.ReadFile(path)
if err != nil {
m.logger.Warn("读取知识库文件失败", zap.String("path", path), zap.Error(err))
return nil // 继续处理其他文件
}
// 检查是否已存在
var existingID string
var existingContent string
var existingUpdatedAt time.Time
err = m.db.QueryRow(
"SELECT id, content, updated_at FROM knowledge_base_items WHERE file_path = ?",
path,
).Scan(&existingID, &existingContent, &existingUpdatedAt)
if err == sql.ErrNoRows {
// 创建新项
id := uuid.New().String()
now := time.Now()
_, err = m.db.Exec(
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, category, title, path, string(content), now, now,
)
if err != nil {
return fmt.Errorf("插入知识项失败: %w", err)
}
m.logger.Info("添加知识项", zap.String("id", id), zap.String("title", title), zap.String("category", category))
// 新添加的项需要索引
itemsToIndex = append(itemsToIndex, id)
} else if err == nil {
// 检查内容是否有变化
contentChanged := existingContent != string(content)
if contentChanged {
// 更新现有项
_, err = m.db.Exec(
"UPDATE knowledge_base_items SET category = ?, title = ?, content = ?, updated_at = ? WHERE id = ?",
category, title, string(content), time.Now(), existingID,
)
if err != nil {
return fmt.Errorf("更新知识项失败: %w", err)
}
m.logger.Info("更新知识项", zap.String("id", existingID), zap.String("title", title))
// 内容已更新的项需要重新索引
itemsToIndex = append(itemsToIndex, existingID)
} else {
m.logger.Debug("知识项未变化,跳过", zap.String("id", existingID), zap.String("title", title))
}
} else {
return fmt.Errorf("查询知识项失败: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
return itemsToIndex, nil
}
// GetCategories 获取所有分类(风险类型)
func (m *Manager) GetCategories() ([]string, error) {
rows, err := m.db.Query("SELECT DISTINCT category FROM knowledge_base_items ORDER BY category")
if err != nil {
return nil, fmt.Errorf("查询分类失败: %w", err)
}
defer rows.Close()
var categories []string
for rows.Next() {
var category string
if err := rows.Scan(&category); err != nil {
return nil, fmt.Errorf("扫描分类失败: %w", err)
}
categories = append(categories, category)
}
return categories, nil
}
// GetStats 获取知识库统计信息
func (m *Manager) GetStats() (int, int, error) {
// 获取分类总数
categories, err := m.GetCategories()
if err != nil {
return 0, 0, fmt.Errorf("获取分类失败: %w", err)
}
totalCategories := len(categories)
// 获取知识项总数
var totalItems int
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
if err != nil {
return totalCategories, 0, fmt.Errorf("获取知识项总数失败: %w", err)
}
return totalCategories, totalItems, nil
}
// GetCategoriesWithItems 按分类分页获取知识项(每个分类包含其下的所有知识项)
// limit: 每页分类数量(0表示不限制)
// offset: 偏移量(按分类偏移)
func (m *Manager) GetCategoriesWithItems(limit, offset int) ([]*CategoryWithItems, int, error) {
// 首先获取所有分类(带数量统计)
rows, err := m.db.Query(`
SELECT category, COUNT(*) as item_count
FROM knowledge_base_items
GROUP BY category
ORDER BY category
`)
if err != nil {
return nil, 0, fmt.Errorf("查询分类失败: %w", err)
}
defer rows.Close()
// 收集所有分类信息
type categoryInfo struct {
name string
itemCount int
}
var allCategories []categoryInfo
for rows.Next() {
var info categoryInfo
if err := rows.Scan(&info.name, &info.itemCount); err != nil {
return nil, 0, fmt.Errorf("扫描分类失败: %w", err)
}
allCategories = append(allCategories, info)
}
totalCategories := len(allCategories)
// 应用分页(按分类分页)
var paginatedCategories []categoryInfo
if limit > 0 {
start := offset
end := offset + limit
if start >= totalCategories {
paginatedCategories = []categoryInfo{}
} else {
if end > totalCategories {
end = totalCategories
}
paginatedCategories = allCategories[start:end]
}
} else {
paginatedCategories = allCategories
}
// 为每个分类获取其下的知识项(只返回摘要,不包含完整内容)
result := make([]*CategoryWithItems, 0, len(paginatedCategories))
for _, catInfo := range paginatedCategories {
// 获取该分类下的所有知识项
items, _, err := m.GetItemsSummary(catInfo.name, 0, 0)
if err != nil {
return nil, 0, fmt.Errorf("获取分类 %s 的知识项失败: %w", catInfo.name, err)
}
result = append(result, &CategoryWithItems{
Category: catInfo.name,
ItemCount: catInfo.itemCount,
Items: items,
})
}
return result, totalCategories, nil
}
// GetItems 获取知识项列表(完整内容,用于向后兼容)
func (m *Manager) GetItems(category string) ([]*KnowledgeItem, error) {
return m.GetItemsWithOptions(category, 0, 0, true)
}
// GetItemsWithOptions 获取知识项列表(支持分页和可选内容)
// category: 分类筛选(空字符串表示所有分类)
// limit: 每页数量(0表示不限制)
// offset: 偏移量
// includeContent: 是否包含完整内容(false时只返回摘要)
func (m *Manager) GetItemsWithOptions(category string, limit, offset int, includeContent bool) ([]*KnowledgeItem, error) {
var rows *sql.Rows
var err error
// 构建SQL查询
var query string
var args []interface{}
if includeContent {
query = "SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items"
} else {
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
}
if category != "" {
query += " WHERE category = ?"
args = append(args, category)
}
query += " ORDER BY category, title"
if limit > 0 {
query += " LIMIT ?"
args = append(args, limit)
if offset > 0 {
query += " OFFSET ?"
args = append(args, offset)
}
}
rows, err = m.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("查询知识项失败: %w", err)
}
defer rows.Close()
var items []*KnowledgeItem
for rows.Next() {
item := &KnowledgeItem{}
var createdAt, updatedAt string
if includeContent {
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描知识项失败: %w", err)
}
} else {
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描知识项失败: %w", err)
}
// 不包含内容时,Content为空字符串
item.Content = ""
}
// 解析时间 - 支持多种格式
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
// 解析创建时间
if createdAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, createdAt)
if err == nil && !parsed.IsZero() {
item.CreatedAt = parsed
break
}
}
}
// 解析更新时间
if updatedAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, updatedAt)
if err == nil && !parsed.IsZero() {
item.UpdatedAt = parsed
break
}
}
}
// 如果更新时间为空,使用创建时间
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
}
items = append(items, item)
}
return items, nil
}
// GetItemsCount 获取知识项总数
func (m *Manager) GetItemsCount(category string) (int, error) {
var count int
var err error
if category != "" {
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items WHERE category = ?", category).Scan(&count)
} else {
err = m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&count)
}
if err != nil {
return 0, fmt.Errorf("查询知识项总数失败: %w", err)
}
return count, nil
}
// SearchItemsByKeyword 按关键字搜索知识项(在所有数据中搜索,支持标题、分类、路径、内容匹配)
func (m *Manager) SearchItemsByKeyword(keyword string, category string) ([]*KnowledgeItemSummary, error) {
if keyword == "" {
return nil, fmt.Errorf("搜索关键字不能为空")
}
// 构建SQL查询,使用LIKE进行关键字匹配(不区分大小写)
var query string
var args []interface{}
// SQLite的LIKE不区分大小写,使用COLLATE NOCASE或LOWER()函数
// 使用%keyword%进行模糊匹配
searchPattern := "%" + keyword + "%"
query = `
SELECT id, category, title, file_path, created_at, updated_at
FROM knowledge_base_items
WHERE (LOWER(title) LIKE LOWER(?) OR LOWER(category) LIKE LOWER(?) OR LOWER(file_path) LIKE LOWER(?) OR LOWER(content) LIKE LOWER(?))
`
args = append(args, searchPattern, searchPattern, searchPattern, searchPattern)
// 如果指定了分类,添加分类过滤
if category != "" {
query += " AND category = ?"
args = append(args, category)
}
query += " ORDER BY category, title"
rows, err := m.db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("搜索知识项失败: %w", err)
}
defer rows.Close()
var items []*KnowledgeItemSummary
for rows.Next() {
item := &KnowledgeItemSummary{}
var createdAt, updatedAt string
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描知识项失败: %w", err)
}
// 解析时间
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
if createdAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, createdAt)
if err == nil && !parsed.IsZero() {
item.CreatedAt = parsed
break
}
}
}
if updatedAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, updatedAt)
if err == nil && !parsed.IsZero() {
item.UpdatedAt = parsed
break
}
}
}
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
}
items = append(items, item)
}
return items, nil
}
// GetItemsSummary 获取知识项摘要列表(不包含完整内容,支持分页)
func (m *Manager) GetItemsSummary(category string, limit, offset int) ([]*KnowledgeItemSummary, int, error) {
// 获取总数
total, err := m.GetItemsCount(category)
if err != nil {
return nil, 0, err
}
// 获取列表数据(不包含内容)
var rows *sql.Rows
var query string
var args []interface{}
query = "SELECT id, category, title, file_path, created_at, updated_at FROM knowledge_base_items"
if category != "" {
query += " WHERE category = ?"
args = append(args, category)
}
query += " ORDER BY category, title"
if limit > 0 {
query += " LIMIT ?"
args = append(args, limit)
if offset > 0 {
query += " OFFSET ?"
args = append(args, offset)
}
}
rows, err = m.db.Query(query, args...)
if err != nil {
return nil, 0, fmt.Errorf("查询知识项失败: %w", err)
}
defer rows.Close()
var items []*KnowledgeItemSummary
for rows.Next() {
item := &KnowledgeItemSummary{}
var createdAt, updatedAt string
if err := rows.Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &createdAt, &updatedAt); err != nil {
return nil, 0, fmt.Errorf("扫描知识项失败: %w", err)
}
// 解析时间
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
if createdAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, createdAt)
if err == nil && !parsed.IsZero() {
item.CreatedAt = parsed
break
}
}
}
if updatedAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, updatedAt)
if err == nil && !parsed.IsZero() {
item.UpdatedAt = parsed
break
}
}
}
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
}
items = append(items, item)
}
return items, total, nil
}
// GetItem 获取单个知识项
func (m *Manager) GetItem(id string) (*KnowledgeItem, error) {
item := &KnowledgeItem{}
var createdAt, updatedAt string
err := m.db.QueryRow(
"SELECT id, category, title, file_path, content, created_at, updated_at FROM knowledge_base_items WHERE id = ?",
id,
).Scan(&item.ID, &item.Category, &item.Title, &item.FilePath, &item.Content, &createdAt, &updatedAt)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("知识项不存在")
}
if err != nil {
return nil, fmt.Errorf("查询知识项失败: %w", err)
}
// 解析时间 - 支持多种格式
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
// 解析创建时间
if createdAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, createdAt)
if err == nil && !parsed.IsZero() {
item.CreatedAt = parsed
break
}
}
}
// 解析更新时间
if updatedAt != "" {
for _, format := range timeFormats {
parsed, err := time.Parse(format, updatedAt)
if err == nil && !parsed.IsZero() {
item.UpdatedAt = parsed
break
}
}
}
// 如果更新时间为空,使用创建时间
if item.UpdatedAt.IsZero() && !item.CreatedAt.IsZero() {
item.UpdatedAt = item.CreatedAt
}
return item, nil
}
// CreateItem 创建知识项
func (m *Manager) CreateItem(category, title, content string) (*KnowledgeItem, error) {
id := uuid.New().String()
now := time.Now()
// 构建文件路径
filePath := filepath.Join(m.basePath, category, title+".md")
// 确保目录存在
if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil {
return nil, fmt.Errorf("创建目录失败: %w", err)
}
// 写入文件
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
return nil, fmt.Errorf("写入文件失败: %w", err)
}
// 插入数据库
_, err := m.db.Exec(
"INSERT INTO knowledge_base_items (id, category, title, file_path, content, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, category, title, filePath, content, now, now,
)
if err != nil {
return nil, fmt.Errorf("插入知识项失败: %w", err)
}
return &KnowledgeItem{
ID: id,
Category: category,
Title: title,
FilePath: filePath,
Content: content,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
// UpdateItem 更新知识项
func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeItem, error) {
// 获取现有项
item, err := m.GetItem(id)
if err != nil {
return nil, err
}
// 构建新文件路径
newFilePath := filepath.Join(m.basePath, category, title+".md")
// 如果路径改变,需要移动文件
if item.FilePath != newFilePath {
// 确保新目录存在
if err := os.MkdirAll(filepath.Dir(newFilePath), 0755); err != nil {
return nil, fmt.Errorf("创建目录失败: %w", err)
}
// 移动文件
if err := os.Rename(item.FilePath, newFilePath); err != nil {
return nil, fmt.Errorf("移动文件失败: %w", err)
}
// 删除旧目录(如果为空)
oldDir := filepath.Dir(item.FilePath)
if isEmpty, _ := isEmptyDir(oldDir); isEmpty {
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
if oldDir != m.basePath {
if err := os.Remove(oldDir); err != nil {
m.logger.Warn("删除空目录失败", zap.String("dir", oldDir), zap.Error(err))
}
}
}
}
// 写入文件
if err := os.WriteFile(newFilePath, []byte(content), 0644); err != nil {
return nil, fmt.Errorf("写入文件失败: %w", err)
}
// 更新数据库
_, err = m.db.Exec(
"UPDATE knowledge_base_items SET category = ?, title = ?, file_path = ?, content = ?, updated_at = ? WHERE id = ?",
category, title, newFilePath, content, time.Now(), id,
)
if err != nil {
return nil, fmt.Errorf("更新知识项失败: %w", err)
}
// 删除旧的向量嵌入(需要重新索引)
_, err = m.db.Exec("DELETE FROM knowledge_embeddings WHERE item_id = ?", id)
if err != nil {
m.logger.Warn("删除旧向量嵌入失败", zap.Error(err))
}
return m.GetItem(id)
}
// DeleteItem 删除知识项
func (m *Manager) DeleteItem(id string) error {
// 获取文件路径
var filePath string
err := m.db.QueryRow("SELECT file_path FROM knowledge_base_items WHERE id = ?", id).Scan(&filePath)
if err != nil {
return fmt.Errorf("查询知识项失败: %w", err)
}
// 删除文件
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
m.logger.Warn("删除文件失败", zap.String("path", filePath), zap.Error(err))
}
// 删除数据库记录(级联删除向量)
_, err = m.db.Exec("DELETE FROM knowledge_base_items WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除知识项失败: %w", err)
}
// 删除空目录(如果为空)
dir := filepath.Dir(filePath)
if isEmpty, _ := isEmptyDir(dir); isEmpty {
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
if dir != m.basePath {
if err := os.Remove(dir); err != nil {
m.logger.Warn("删除空目录失败", zap.String("dir", dir), zap.Error(err))
}
}
}
return nil
}
// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件)
func isEmptyDir(dir string) (bool, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return false, err
}
for _, entry := range entries {
// 忽略隐藏文件(以 . 开头)
if !strings.HasPrefix(entry.Name(), ".") {
return false, nil
}
}
return true, nil
}
// LogRetrieval 记录检索日志
func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error {
id := uuid.New().String()
itemsJSON, _ := json.Marshal(retrievedItems)
_, err := m.db.Exec(
"INSERT INTO knowledge_retrieval_logs (id, conversation_id, message_id, query, risk_type, retrieved_items, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
id, conversationID, messageID, query, riskType, string(itemsJSON), time.Now(),
)
return err
}
// GetIndexStatus 获取索引状态
func (m *Manager) GetIndexStatus() (map[string]interface{}, error) {
// 获取总知识项数
var totalItems int
err := m.db.QueryRow("SELECT COUNT(*) FROM knowledge_base_items").Scan(&totalItems)
if err != nil {
return nil, fmt.Errorf("查询总知识项数失败: %w", err)
}
// 获取已索引的知识项数(有向量嵌入的)
var indexedItems int
err = m.db.QueryRow(`
SELECT COUNT(DISTINCT item_id)
FROM knowledge_embeddings
`).Scan(&indexedItems)
if err != nil {
return nil, fmt.Errorf("查询已索引项数失败: %w", err)
}
// 计算进度百分比
var progressPercent float64
if totalItems > 0 {
progressPercent = float64(indexedItems) / float64(totalItems) * 100
} else {
progressPercent = 100.0
}
// 判断是否完成
isComplete := indexedItems >= totalItems && totalItems > 0
return map[string]interface{}{
"total_items": totalItems,
"indexed_items": indexedItems,
"progress_percent": progressPercent,
"is_complete": isComplete,
}, nil
}
// GetRetrievalLogs 获取检索日志
func (m *Manager) GetRetrievalLogs(conversationID, messageID string, limit int) ([]*RetrievalLog, error) {
var rows *sql.Rows
var err error
if messageID != "" {
rows, err = m.db.Query(
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE message_id = ? ORDER BY created_at DESC LIMIT ?",
messageID, limit,
)
} else if conversationID != "" {
rows, err = m.db.Query(
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs WHERE conversation_id = ? ORDER BY created_at DESC LIMIT ?",
conversationID, limit,
)
} else {
rows, err = m.db.Query(
"SELECT id, conversation_id, message_id, query, risk_type, retrieved_items, created_at FROM knowledge_retrieval_logs ORDER BY created_at DESC LIMIT ?",
limit,
)
}
if err != nil {
return nil, fmt.Errorf("查询检索日志失败: %w", err)
}
defer rows.Close()
var logs []*RetrievalLog
for rows.Next() {
log := &RetrievalLog{}
var createdAt string
var itemsJSON sql.NullString
if err := rows.Scan(&log.ID, &log.ConversationID, &log.MessageID, &log.Query, &log.RiskType, &itemsJSON, &createdAt); err != nil {
return nil, fmt.Errorf("扫描检索日志失败: %w", err)
}
// 解析时间 - 支持多种格式
var err error
timeFormats := []string{
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02 15:04:05",
time.RFC3339,
time.RFC3339Nano,
}
for _, format := range timeFormats {
log.CreatedAt, err = time.Parse(format, createdAt)
if err == nil && !log.CreatedAt.IsZero() {
break
}
}
// 如果所有格式都失败,记录警告但继续处理
if log.CreatedAt.IsZero() {
m.logger.Warn("解析检索日志时间失败",
zap.String("timeStr", createdAt),
zap.Error(err),
)
// 使用当前时间作为fallback
log.CreatedAt = time.Now()
}
// 解析检索项
if itemsJSON.Valid {
json.Unmarshal([]byte(itemsJSON.String), &log.RetrievedItems)
}
logs = append(logs, log)
}
return logs, nil
}
// DeleteRetrievalLog 删除检索日志
func (m *Manager) DeleteRetrievalLog(id string) error {
result, err := m.db.Exec("DELETE FROM knowledge_retrieval_logs WHERE id = ?", id)
if err != nil {
return fmt.Errorf("删除检索日志失败: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("获取删除行数失败: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("检索日志不存在")
}
return nil
}
-213
View File
@@ -1,213 +0,0 @@
package knowledge
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"sync"
"unicode"
"unicode/utf8"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/schema"
"github.com/pkoukk/tiktoken-go"
)
// postRetrieveMaxPrefetchCap 限制单次向量候选上限,避免误配置导致全表扫压力过大。
const postRetrieveMaxPrefetchCap = 200
// DocumentReranker 可选重排(如交叉编码器 / 第三方 Rerank API),由 [Retriever.SetDocumentReranker] 注入;失败时在适配层降级为向量序。
type DocumentReranker interface {
Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error)
}
// NopDocumentReranker 占位实现,便于测试或未启用重排时显式注入。
type NopDocumentReranker struct{}
// Rerank implements [DocumentReranker] as no-op.
func (NopDocumentReranker) Rerank(_ context.Context, _ string, docs []*schema.Document) ([]*schema.Document, error) {
return docs, nil
}
var tiktokenEncMu sync.Mutex
var tiktokenEncCache = map[string]*tiktoken.Tiktoken{}
func encodingForTokenizerModel(model string) (*tiktoken.Tiktoken, error) {
m := strings.TrimSpace(model)
if m == "" {
m = "gpt-4"
}
tiktokenEncMu.Lock()
defer tiktokenEncMu.Unlock()
if enc, ok := tiktokenEncCache[m]; ok {
return enc, nil
}
enc, err := tiktoken.EncodingForModel(m)
if err != nil {
enc, err = tiktoken.GetEncoding("cl100k_base")
if err != nil {
return nil, err
}
}
tiktokenEncCache[m] = enc
return enc, nil
}
func countDocTokens(text, model string) (int, error) {
enc, err := encodingForTokenizerModel(model)
if err != nil {
return 0, err
}
toks := enc.Encode(text, nil, nil)
return len(toks), nil
}
// normalizeContentFingerprintKey 去重键:trim + 空白折叠(不改动大小写,避免合并仅大小写不同的代码片段)。
func normalizeContentFingerprintKey(s string) string {
s = strings.TrimSpace(s)
var b strings.Builder
b.Grow(len(s))
prevSpace := false
for _, r := range s {
if unicode.IsSpace(r) {
if !prevSpace {
b.WriteByte(' ')
prevSpace = true
}
continue
}
prevSpace = false
b.WriteRune(r)
}
return b.String()
}
func contentNormKey(d *schema.Document) string {
if d == nil {
return ""
}
n := normalizeContentFingerprintKey(d.Content)
if n == "" {
return ""
}
sum := sha256.Sum256([]byte(n))
return hex.EncodeToString(sum[:])
}
// dedupeByNormalizedContent 按规范化正文去重,保留向量检索顺序中首次出现的文档(同正文仅保留一条)。
func dedupeByNormalizedContent(docs []*schema.Document) []*schema.Document {
if len(docs) < 2 {
return docs
}
seen := make(map[string]struct{}, len(docs))
out := make([]*schema.Document, 0, len(docs))
for _, d := range docs {
if d == nil {
continue
}
k := contentNormKey(d)
if k == "" {
out = append(out, d)
continue
}
if _, ok := seen[k]; ok {
continue
}
seen[k] = struct{}{}
out = append(out, d)
}
return out
}
// truncateDocumentsByBudget 按检索顺序整段保留文档,直至字符数或 token 数(任一启用)超限则停止。
func truncateDocumentsByBudget(docs []*schema.Document, maxRunes, maxTokens int, tokenModel string) ([]*schema.Document, error) {
if len(docs) == 0 {
return docs, nil
}
unlimitedChars := maxRunes <= 0
unlimitedTok := maxTokens <= 0
if unlimitedChars && unlimitedTok {
return docs, nil
}
remRunes := maxRunes
remTok := maxTokens
out := make([]*schema.Document, 0, len(docs))
for _, d := range docs {
if d == nil || strings.TrimSpace(d.Content) == "" {
continue
}
runes := utf8.RuneCountInString(d.Content)
if !unlimitedChars && runes > remRunes {
break
}
var tok int
var err error
if !unlimitedTok {
tok, err = countDocTokens(d.Content, tokenModel)
if err != nil {
return nil, fmt.Errorf("token count: %w", err)
}
if tok > remTok {
break
}
}
out = append(out, d)
if !unlimitedChars {
remRunes -= runes
}
if !unlimitedTok {
remTok -= tok
}
}
return out, nil
}
// EffectivePrefetchTopK 计算向量检索应拉取的候选条数(供粗排 / 去重 / 重排)。
func EffectivePrefetchTopK(topK int, po *config.PostRetrieveConfig) int {
if topK < 1 {
topK = 5
}
fetch := topK
if po != nil && po.PrefetchTopK > fetch {
fetch = po.PrefetchTopK
}
if fetch > postRetrieveMaxPrefetchCap {
fetch = postRetrieveMaxPrefetchCap
}
return fetch
}
// ApplyPostRetrieve 检索后处理:规范化正文去重 → 预算截断 → 最终 TopK。重排在 [VectorEinoRetriever] 中单独调用以便失败时降级。
func ApplyPostRetrieve(docs []*schema.Document, po *config.PostRetrieveConfig, tokenModel string, finalTopK int) ([]*schema.Document, error) {
if finalTopK < 1 {
finalTopK = 5
}
if len(docs) == 0 {
return docs, nil
}
maxChars := 0
maxTok := 0
if po != nil {
maxChars = po.MaxContextChars
maxTok = po.MaxContextTokens
}
out := dedupeByNormalizedContent(docs)
var err error
out, err = truncateDocumentsByBudget(out, maxChars, maxTok, tokenModel)
if err != nil {
return nil, err
}
if len(out) > finalTopK {
out = out[:finalTopK]
}
return out, nil
}
@@ -1,62 +0,0 @@
package knowledge
import (
"testing"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/schema"
)
func doc(id, content string, score float64) *schema.Document {
d := &schema.Document{ID: id, Content: content, MetaData: map[string]any{metaKBItemID: "it1"}}
d.WithScore(score)
return d
}
func TestDedupeByNormalizedContent(t *testing.T) {
a := doc("1", "hello world", 0.9)
b := doc("2", "hello world", 0.8)
c := doc("3", "other", 0.7)
out := dedupeByNormalizedContent([]*schema.Document{a, b, c})
if len(out) != 2 {
t.Fatalf("len=%d want 2", len(out))
}
if out[0].ID != "1" || out[1].ID != "3" {
t.Fatalf("order/ids wrong: %#v", out)
}
}
func TestEffectivePrefetchTopK(t *testing.T) {
if g := EffectivePrefetchTopK(5, nil); g != 5 {
t.Fatalf("got %d", g)
}
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 50}); g != 50 {
t.Fatalf("got %d", g)
}
if g := EffectivePrefetchTopK(5, &config.PostRetrieveConfig{PrefetchTopK: 9999}); g != postRetrieveMaxPrefetchCap {
t.Fatalf("cap: got %d", g)
}
}
func TestApplyPostRetrieveTruncateAndTopK(t *testing.T) {
d1 := doc("1", "ab", 0.9)
d2 := doc("2", "cd", 0.8)
d3 := doc("3", "ef", 0.7)
po := &config.PostRetrieveConfig{MaxContextChars: 3}
out, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, po, "gpt-4", 5)
if err != nil {
t.Fatal(err)
}
if len(out) != 1 || out[0].ID != "1" {
t.Fatalf("got %#v", out)
}
out2, err := ApplyPostRetrieve([]*schema.Document{d1, d2, d3}, nil, "gpt-4", 2)
if err != nil {
t.Fatal(err)
}
if len(out2) != 2 {
t.Fatalf("topk: len=%d", len(out2))
}
}
-305
View File
@@ -1,305 +0,0 @@
package knowledge
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"math"
"sort"
"strings"
"sync"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// Retriever 检索器:SQLite 存向量 + Eino 嵌入,**纯向量检索**(余弦相似度、TopK、阈值),
// 实现语义与 [retriever.Retriever] 适配层 [VectorEinoRetriever] 一致。
type Retriever struct {
db *sql.DB
embedder *Embedder
config *RetrievalConfig
logger *zap.Logger
rerankMu sync.RWMutex
reranker DocumentReranker
}
// RetrievalConfig 检索配置
type RetrievalConfig struct {
TopK int
SimilarityThreshold float64
// SubIndexFilter 非空时仅检索 sub_indexes 包含该标签(逗号分隔之一)的行;空 sub_indexes 的旧行仍保留以兼容。
SubIndexFilter string
PostRetrieve config.PostRetrieveConfig
}
// NewRetriever 创建新的检索器
func NewRetriever(db *sql.DB, embedder *Embedder, config *RetrievalConfig, logger *zap.Logger) *Retriever {
return &Retriever{
db: db,
embedder: embedder,
config: config,
logger: logger,
}
}
// UpdateConfig 更新检索配置
func (r *Retriever) UpdateConfig(cfg *RetrievalConfig) {
if cfg != nil {
r.config = cfg
if r.logger != nil {
r.logger.Info("检索器配置已更新",
zap.Int("top_k", cfg.TopK),
zap.Float64("similarity_threshold", cfg.SimilarityThreshold),
zap.String("sub_index_filter", cfg.SubIndexFilter),
zap.Int("post_retrieve_prefetch_top_k", cfg.PostRetrieve.PrefetchTopK),
zap.Int("post_retrieve_max_context_chars", cfg.PostRetrieve.MaxContextChars),
zap.Int("post_retrieve_max_context_tokens", cfg.PostRetrieve.MaxContextTokens),
)
}
}
}
// SetDocumentReranker 注入可选重排器(并发安全);nil 表示禁用。
func (r *Retriever) SetDocumentReranker(rr DocumentReranker) {
if r == nil {
return
}
r.rerankMu.Lock()
defer r.rerankMu.Unlock()
r.reranker = rr
}
func (r *Retriever) documentReranker() DocumentReranker {
if r == nil {
return nil
}
r.rerankMu.RLock()
defer r.rerankMu.RUnlock()
return r.reranker
}
func cosineSimilarity(a, b []float32) float64 {
if len(a) != len(b) {
return 0.0
}
var dotProduct, normA, normB float64
for i := range a {
dotProduct += float64(a[i] * b[i])
normA += float64(a[i] * a[i])
normB += float64(b[i] * b[i])
}
if normA == 0 || normB == 0 {
return 0.0
}
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
}
// Search 搜索知识库。统一经 [VectorEinoRetriever]Eino retriever.Retriever 边界)。
func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
if req == nil {
return nil, fmt.Errorf("请求不能为空")
}
q := strings.TrimSpace(req.Query)
if q == "" {
return nil, fmt.Errorf("查询不能为空")
}
opts := r.einoRetrieverOptions(req)
docs, err := NewVectorEinoRetriever(r).Retrieve(ctx, q, opts...)
if err != nil {
return nil, err
}
return documentsToRetrievalResults(docs)
}
func (r *Retriever) einoRetrieverOptions(req *SearchRequest) []retriever.Option {
var opts []retriever.Option
if req.TopK > 0 {
opts = append(opts, retriever.WithTopK(req.TopK))
}
dsl := map[string]any{}
if strings.TrimSpace(req.RiskType) != "" {
dsl[DSLRiskType] = strings.TrimSpace(req.RiskType)
}
if req.Threshold > 0 {
dsl[DSLSimilarityThreshold] = req.Threshold
}
if strings.TrimSpace(req.SubIndexFilter) != "" {
dsl[DSLSubIndexFilter] = strings.TrimSpace(req.SubIndexFilter)
}
if len(dsl) > 0 {
opts = append(opts, retriever.WithDSLInfo(dsl))
}
return opts
}
// EinoRetrieve 直接返回 [schema.Document],供 Eino Graph / Chain 使用。
func (r *Retriever) EinoRetrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
return NewVectorEinoRetriever(r).Retrieve(ctx, query, opts...)
}
func (r *Retriever) knowledgeEmbeddingSelectSQL(riskType, subIndexFilter string) (string, []interface{}) {
q := `SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, e.embedding_model, e.embedding_dim, i.category, i.title
FROM knowledge_embeddings e
JOIN knowledge_base_items i ON e.item_id = i.id
WHERE 1=1`
var args []interface{}
if strings.TrimSpace(riskType) != "" {
q += ` AND TRIM(i.category) = TRIM(?) COLLATE NOCASE`
args = append(args, riskType)
}
if tag := strings.TrimSpace(subIndexFilter); tag != "" {
tag = strings.ToLower(strings.ReplaceAll(tag, " ", ""))
q += ` AND (TRIM(COALESCE(e.sub_indexes,'')) = '' OR INSTR(',' || LOWER(REPLACE(e.sub_indexes,' ','')) || ',', ',' || ? || ',') > 0)`
args = append(args, tag)
}
return q, args
}
// vectorSearch 纯向量检索:余弦相似度排序,按相似度阈值与 TopK 截断(无 BM25、无混合分、无邻块扩展)。
func (r *Retriever) vectorSearch(ctx context.Context, req *SearchRequest) ([]*RetrievalResult, error) {
if req.Query == "" {
return nil, fmt.Errorf("查询不能为空")
}
topK := req.TopK
if topK <= 0 && r.config != nil {
topK = r.config.TopK
}
if topK <= 0 {
topK = 5
}
threshold := req.Threshold
if threshold <= 0 && r.config != nil {
threshold = r.config.SimilarityThreshold
}
if threshold <= 0 {
threshold = 0.7
}
subIdxFilter := strings.TrimSpace(req.SubIndexFilter)
if subIdxFilter == "" && r.config != nil {
subIdxFilter = strings.TrimSpace(r.config.SubIndexFilter)
}
queryText := FormatQueryEmbeddingText(req.RiskType, req.Query)
queryEmbedding, err := r.embedder.EmbedText(ctx, queryText)
if err != nil {
return nil, fmt.Errorf("向量化查询失败: %w", err)
}
queryDim := len(queryEmbedding)
expectedModel := ""
if r.embedder != nil {
expectedModel = r.embedder.EmbeddingModelName()
}
sqlStr, sqlArgs := r.knowledgeEmbeddingSelectSQL(strings.TrimSpace(req.RiskType), subIdxFilter)
rows, err := r.db.QueryContext(ctx, sqlStr, sqlArgs...)
if err != nil {
return nil, fmt.Errorf("查询向量失败: %w", err)
}
defer rows.Close()
type candidate struct {
chunk *KnowledgeChunk
item *KnowledgeItem
similarity float64
}
candidates := make([]candidate, 0)
rowNum := 0
for rows.Next() {
rowNum++
if rowNum%48 == 0 {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
var chunkID, itemID, chunkText, embeddingJSON, category, title, rowModel string
var chunkIndex, rowDim int
if err := rows.Scan(&chunkID, &itemID, &chunkIndex, &chunkText, &embeddingJSON, &rowModel, &rowDim, &category, &title); err != nil {
r.logger.Warn("扫描向量失败", zap.Error(err))
continue
}
var embedding []float32
if err := json.Unmarshal([]byte(embeddingJSON), &embedding); err != nil {
r.logger.Warn("解析向量失败", zap.Error(err))
continue
}
if rowDim > 0 && len(embedding) != rowDim {
r.logger.Debug("跳过维度不一致的向量行", zap.String("chunkId", chunkID), zap.Int("rowDim", rowDim), zap.Int("got", len(embedding)))
continue
}
if queryDim > 0 && len(embedding) != queryDim {
r.logger.Debug("跳过与查询维度不一致的向量", zap.String("chunkId", chunkID), zap.Int("queryDim", queryDim), zap.Int("got", len(embedding)))
continue
}
if expectedModel != "" && strings.TrimSpace(rowModel) != "" && strings.TrimSpace(rowModel) != expectedModel {
r.logger.Debug("跳过嵌入模型不一致的行", zap.String("chunkId", chunkID), zap.String("rowModel", rowModel), zap.String("expected", expectedModel))
continue
}
similarity := cosineSimilarity(queryEmbedding, embedding)
candidates = append(candidates, candidate{
chunk: &KnowledgeChunk{
ID: chunkID,
ItemID: itemID,
ChunkIndex: chunkIndex,
ChunkText: chunkText,
Embedding: embedding,
},
item: &KnowledgeItem{
ID: itemID,
Category: category,
Title: title,
},
similarity: similarity,
})
}
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].similarity > candidates[j].similarity
})
filtered := make([]candidate, 0, len(candidates))
for _, c := range candidates {
if c.similarity >= threshold {
filtered = append(filtered, c)
}
}
if len(filtered) > topK {
filtered = filtered[:topK]
}
results := make([]*RetrievalResult, len(filtered))
for i, c := range filtered {
results[i] = &RetrievalResult{
Chunk: c.chunk,
Item: c.item,
Similarity: c.similarity,
Score: c.similarity,
}
}
return results, nil
}
// AsEinoRetriever 将纯向量检索暴露为 Eino [retriever.Retriever]。
func (r *Retriever) AsEinoRetriever() retriever.Retriever {
return NewVectorEinoRetriever(r)
}
-51
View File
@@ -1,51 +0,0 @@
package knowledge
import (
"database/sql"
"fmt"
)
// EnsureKnowledgeEmbeddingsSchema migrates knowledge_embeddings for sub_indexes + embedding metadata.
func EnsureKnowledgeEmbeddingsSchema(db *sql.DB) error {
if db == nil {
return fmt.Errorf("db is nil")
}
var n int
if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil {
return err
}
if n == 0 {
return nil
}
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "sub_indexes",
`ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`); err != nil {
return err
}
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_model",
`ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`); err != nil {
return err
}
if err := addKnowledgeEmbeddingsColumnIfMissing(db, "embedding_dim",
`ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`); err != nil {
return err
}
return nil
}
func addKnowledgeEmbeddingsColumnIfMissing(db *sql.DB, column, alterSQL string) error {
var colCount int
q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?`
if err := db.QueryRow(q, column).Scan(&colCount); err != nil {
return err
}
if colCount > 0 {
return nil
}
_, err := db.Exec(alterSQL)
return err
}
// ensureKnowledgeEmbeddingsSubIndexesColumn 向后兼容;请使用 [EnsureKnowledgeEmbeddingsSchema]。
func ensureKnowledgeEmbeddingsSubIndexesColumn(db *sql.DB) error {
return EnsureKnowledgeEmbeddingsSchema(db)
}
-323
View File
@@ -1,323 +0,0 @@
package knowledge
import (
"context"
"encoding/json"
"fmt"
"sort"
"strings"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
// RegisterKnowledgeTool 注册知识检索工具到MCP服务器
func RegisterKnowledgeTool(
mcpServer *mcp.Server,
retriever *Retriever,
manager *Manager,
logger *zap.Logger,
) {
// 注册第一个工具:获取所有可用的风险类型列表
listRiskTypesTool := mcp.Tool{
Name: builtin.ToolListKnowledgeRiskTypes,
Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。",
ShortDescription: "获取知识库中所有可用的风险类型列表",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
"required": []string{},
},
}
listRiskTypesHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
categories, err := manager.GetCategories()
if err != nil {
logger.Error("获取风险类型列表失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("获取风险类型列表失败: %v", err),
},
},
IsError: true,
}, nil
}
if len(categories) == 0 {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "知识库中暂无风险类型。",
},
},
}, nil
}
var resultText strings.Builder
resultText.WriteString(fmt.Sprintf("知识库中共有 %d 个风险类型:\n\n", len(categories)))
for i, category := range categories {
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
}
resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: resultText.String(),
},
},
}, nil
}
mcpServer.RegisterTool(listRiskTypesTool, listRiskTypesHandler)
logger.Info("风险类型列表工具已注册", zap.String("toolName", listRiskTypesTool.Name))
// 注册第二个工具:搜索知识库(保持原有功能)
searchTool := mcp.Tool{
Name: builtin.ToolSearchKnowledgeBase,
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具基于向量嵌入与余弦相似度检索(与 Eino retriever 语义一致)。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
ShortDescription: "搜索知识库中的安全知识(向量语义检索)",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "搜索查询内容,描述你想要了解的安全知识主题",
},
"risk_type": map[string]interface{}{
"type": "string",
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
},
},
"required": []string{"query"},
},
}
searchHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
query, ok := args["query"].(string)
if !ok || query == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: 查询参数不能为空",
},
},
IsError: true,
}, nil
}
riskType := ""
if rt, ok := args["risk_type"].(string); ok && rt != "" {
riskType = rt
}
logger.Info("执行知识库检索",
zap.String("query", query),
zap.String("riskType", riskType),
)
// 检索统一走 Retriever.Search → VectorEinoRetrieverEino retriever 语义)。
searchReq := &SearchRequest{
Query: query,
RiskType: riskType,
TopK: 5,
}
results, err := retriever.Search(ctx, searchReq)
if err != nil {
logger.Error("知识库检索失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("检索失败: %v", err),
},
},
IsError: true,
}, nil
}
if len(results) == 0 {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("未找到与查询 '%s' 相关的知识。建议:\n1. 尝试使用不同的关键词\n2. 检查风险类型是否正确\n3. 确认知识库中是否包含相关内容", query),
},
},
}, nil
}
// 格式化结果
var resultText strings.Builder
// 按余弦相似度(Score)降序
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
// 按文档分组结果,以便更好地展示上下文
type itemGroup struct {
itemID string
results []*RetrievalResult
maxScore float64 // 该文档块的最高相似度
}
itemGroups := make([]*itemGroup, 0)
itemMap := make(map[string]*itemGroup)
for _, result := range results {
itemID := result.Item.ID
group, exists := itemMap[itemID]
if !exists {
group = &itemGroup{
itemID: itemID,
results: make([]*RetrievalResult, 0),
maxScore: result.Score,
}
itemMap[itemID] = group
itemGroups = append(itemGroups, group)
}
group.results = append(group.results, result)
if result.Score > group.maxScore {
group.maxScore = result.Score
}
}
// 按文档内最高相似度排序
sort.Slice(itemGroups, func(i, j int) bool {
return itemGroups[i].maxScore > itemGroups[j].maxScore
})
// 收集检索到的知识项ID(用于日志)
retrievedItemIDs := make([]string, 0, len(itemGroups))
resultText.WriteString(fmt.Sprintf("找到 %d 条相关知识片段:\n\n", len(results)))
resultIndex := 1
for _, group := range itemGroups {
itemResults := group.results
mainResult := itemResults[0]
maxScore := mainResult.Score
for _, result := range itemResults {
if result.Score > maxScore {
maxScore = result.Score
mainResult = result
}
}
// 按chunk_index排序,保证阅读的逻辑顺序(文档的原始顺序)
sort.Slice(itemResults, func(i, j int) bool {
return itemResults[i].Chunk.ChunkIndex < itemResults[j].Chunk.ChunkIndex
})
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%) ---\n",
resultIndex, mainResult.Similarity*100))
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
// 按逻辑顺序显示所有chunk(包括主结果和扩展的chunk)
if len(itemResults) == 1 {
// 只有一个chunk,直接显示
resultText.WriteString(fmt.Sprintf("内容片段:\n%s\n", mainResult.Chunk.ChunkText))
} else {
// 多个chunk,按逻辑顺序显示
resultText.WriteString("内容片段(按文档顺序):\n")
for i, result := range itemResults {
// 标记主结果
marker := ""
if result.Chunk.ID == mainResult.Chunk.ID {
marker = " [主匹配]"
}
resultText.WriteString(fmt.Sprintf(" [片段 %d%s]\n%s\n", i+1, marker, result.Chunk.ChunkText))
}
}
resultText.WriteString("\n")
if !contains(retrievedItemIDs, group.itemID) {
retrievedItemIDs = append(retrievedItemIDs, group.itemID)
}
resultIndex++
}
// 在结果末尾添加元数据(JSON格式,用于提取知识项ID)
// 使用特殊标记,避免影响AI阅读结果
if len(retrievedItemIDs) > 0 {
metadataJSON, _ := json.Marshal(map[string]interface{}{
"_metadata": map[string]interface{}{
"retrievedItemIDs": retrievedItemIDs,
},
})
resultText.WriteString(fmt.Sprintf("\n<!-- METADATA: %s -->", string(metadataJSON)))
}
// 记录检索日志(异步,不阻塞)
// 注意:这里没有conversationID和messageID,需要在Agent层面记录
// 实际的日志记录应该在Agent的progressCallback中完成
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: resultText.String(),
},
},
}, nil
}
mcpServer.RegisterTool(searchTool, searchHandler)
logger.Info("知识检索工具已注册", zap.String("toolName", searchTool.Name))
}
// contains 检查切片是否包含元素
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
// GetRetrievalMetadata 从工具调用中提取检索元数据(用于日志记录)
func GetRetrievalMetadata(args map[string]interface{}) (query string, riskType string) {
if q, ok := args["query"].(string); ok {
query = q
}
if rt, ok := args["risk_type"].(string); ok {
riskType = rt
}
return
}
// FormatRetrievalResults 格式化检索结果为字符串(用于日志)
func FormatRetrievalResults(results []*RetrievalResult) string {
if len(results) == 0 {
return "未找到相关结果"
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("检索到 %d 条结果:\n", len(results)))
itemIDs := make(map[string]bool)
for i, result := range results {
builder.WriteString(fmt.Sprintf("%d. [%s] %s (相似度: %.2f%%)\n",
i+1, result.Item.Category, result.Item.Title, result.Similarity*100))
itemIDs[result.Item.ID] = true
}
// 返回知识项ID列表(JSON格式)
ids := make([]string, 0, len(itemIDs))
for id := range itemIDs {
ids = append(ids, id)
}
idsJSON, _ := json.Marshal(ids)
builder.WriteString(fmt.Sprintf("\n检索到的知识项ID: %s", string(idsJSON)))
return builder.String()
}
-123
View File
@@ -1,123 +0,0 @@
package knowledge
import (
"encoding/json"
"time"
)
// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串
func formatTime(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format(time.RFC3339)
}
// KnowledgeItem 知识库项
type KnowledgeItem struct {
ID string `json:"id"`
Category string `json:"category"` // 风险类型(文件夹名)
Title string `json:"title"` // 标题(文件名)
FilePath string `json:"filePath"` // 文件路径
Content string `json:"content"` // 文件内容
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// KnowledgeItemSummary 知识库项摘要(用于列表,不包含完整内容)
type KnowledgeItemSummary struct {
ID string `json:"id"`
Category string `json:"category"`
Title string `json:"title"`
FilePath string `json:"filePath"`
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符)
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
type Alias KnowledgeItemSummary
aux := &struct {
*Alias
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}{
Alias: (*Alias)(k),
}
aux.CreatedAt = formatTime(k.CreatedAt)
aux.UpdatedAt = formatTime(k.UpdatedAt)
return json.Marshal(aux)
}
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
type Alias KnowledgeItem
aux := &struct {
*Alias
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
}{
Alias: (*Alias)(k),
}
aux.CreatedAt = formatTime(k.CreatedAt)
aux.UpdatedAt = formatTime(k.UpdatedAt)
return json.Marshal(aux)
}
// KnowledgeChunk 知识块(用于向量化)
type KnowledgeChunk struct {
ID string `json:"id"`
ItemID string `json:"itemId"`
ChunkIndex int `json:"chunkIndex"`
ChunkText string `json:"chunkText"`
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON
CreatedAt time.Time `json:"createdAt"`
}
// RetrievalResult 检索结果
type RetrievalResult struct {
Chunk *KnowledgeChunk `json:"chunk"`
Item *KnowledgeItem `json:"item"`
Similarity float64 `json:"similarity"` // 相似度分数
Score float64 `json:"score"` // 与 Similarity 相同:余弦相似度
}
// RetrievalLog 检索日志
type RetrievalLog struct {
ID string `json:"id"`
ConversationID string `json:"conversationId,omitempty"`
MessageID string `json:"messageId,omitempty"`
Query string `json:"query"`
RiskType string `json:"riskType,omitempty"`
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表
CreatedAt time.Time `json:"createdAt"`
}
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
type Alias RetrievalLog
return json.Marshal(&struct {
*Alias
CreatedAt string `json:"createdAt"`
}{
Alias: (*Alias)(r),
CreatedAt: formatTime(r.CreatedAt),
})
}
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
type CategoryWithItems struct {
Category string `json:"category"` // 分类名称
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
}
// SearchRequest 搜索请求
type SearchRequest struct {
Query string `json:"query"`
RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型
SubIndexFilter string `json:"subIndexFilter,omitempty"` // 可选:仅保留 sub_indexes 含该标签的行(含未打标旧数据)
TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7
}
-68
View File
@@ -1,68 +0,0 @@
package logger
import (
"os"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type Logger struct {
*zap.Logger
}
func New(level, output string) *Logger {
var zapLevel zapcore.Level
switch level {
case "debug":
zapLevel = zapcore.DebugLevel
case "info":
zapLevel = zapcore.InfoLevel
case "warn":
zapLevel = zapcore.WarnLevel
case "error":
zapLevel = zapcore.ErrorLevel
default:
zapLevel = zapcore.InfoLevel
}
config := zap.NewProductionConfig()
config.Level = zap.NewAtomicLevelAt(zapLevel)
config.EncoderConfig.TimeKey = "timestamp"
config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
var writeSyncer zapcore.WriteSyncer
if output == "stdout" {
writeSyncer = zapcore.AddSync(os.Stdout)
} else {
file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
writeSyncer = zapcore.AddSync(os.Stdout)
} else {
writeSyncer = zapcore.AddSync(file)
}
}
core := zapcore.NewCore(
zapcore.NewJSONEncoder(config.EncoderConfig),
writeSyncer,
zapLevel,
)
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
return &Logger{Logger: logger}
}
func (l *Logger) Fatal(msg string, fields ...interface{}) {
zapFields := make([]zap.Field, 0, len(fields))
for _, f := range fields {
switch v := f.(type) {
case error:
zapFields = append(zapFields, zap.Error(v))
default:
zapFields = append(zapFields, zap.Any("field", v))
}
}
l.Logger.Fatal(msg, zapFields...)
}
-113
View File
@@ -1,113 +0,0 @@
package builtin
// 内置工具名称常量
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
const (
// 漏洞管理工具
ToolRecordVulnerability = "record_vulnerability"
// 知识库工具
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
ToolSearchKnowledgeBase = "search_knowledge_base"
// Skills工具
ToolListSkills = "list_skills"
ToolReadSkill = "read_skill"
// WebShell 助手工具(AI 在 WebShell 管理 - AI 助手 中使用)
ToolWebshellExec = "webshell_exec"
ToolWebshellFileList = "webshell_file_list"
ToolWebshellFileRead = "webshell_file_read"
ToolWebshellFileWrite = "webshell_file_write"
// WebShell 连接管理工具(用于通过 MCP 管理 webshell 连接)
ToolManageWebshellList = "manage_webshell_list"
ToolManageWebshellAdd = "manage_webshell_add"
ToolManageWebshellUpdate = "manage_webshell_update"
ToolManageWebshellDelete = "manage_webshell_delete"
ToolManageWebshellTest = "manage_webshell_test"
// 批量任务队列(与 Web 端批量任务一致,供模型创建/启停/查询队列)
ToolBatchTaskList = "batch_task_list"
ToolBatchTaskGet = "batch_task_get"
ToolBatchTaskCreate = "batch_task_create"
ToolBatchTaskStart = "batch_task_start"
ToolBatchTaskRerun = "batch_task_rerun"
ToolBatchTaskPause = "batch_task_pause"
ToolBatchTaskDelete = "batch_task_delete"
ToolBatchTaskUpdateMetadata = "batch_task_update_metadata"
ToolBatchTaskUpdateSchedule = "batch_task_update_schedule"
ToolBatchTaskScheduleEnabled = "batch_task_schedule_enabled"
ToolBatchTaskAdd = "batch_task_add_task"
ToolBatchTaskUpdate = "batch_task_update_task"
ToolBatchTaskRemove = "batch_task_remove_task"
)
// IsBuiltinTool 检查工具名称是否是内置工具
func IsBuiltinTool(toolName string) bool {
switch toolName {
case ToolRecordVulnerability,
ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase,
ToolListSkills,
ToolReadSkill,
ToolWebshellExec,
ToolWebshellFileList,
ToolWebshellFileRead,
ToolWebshellFileWrite,
ToolManageWebshellList,
ToolManageWebshellAdd,
ToolManageWebshellUpdate,
ToolManageWebshellDelete,
ToolManageWebshellTest,
ToolBatchTaskList,
ToolBatchTaskGet,
ToolBatchTaskCreate,
ToolBatchTaskStart,
ToolBatchTaskRerun,
ToolBatchTaskPause,
ToolBatchTaskDelete,
ToolBatchTaskUpdateMetadata,
ToolBatchTaskUpdateSchedule,
ToolBatchTaskScheduleEnabled,
ToolBatchTaskAdd,
ToolBatchTaskUpdate,
ToolBatchTaskRemove:
return true
default:
return false
}
}
// GetAllBuiltinTools 返回所有内置工具名称列表
func GetAllBuiltinTools() []string {
return []string{
ToolRecordVulnerability,
ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase,
ToolListSkills,
ToolReadSkill,
ToolWebshellExec,
ToolWebshellFileList,
ToolWebshellFileRead,
ToolWebshellFileWrite,
ToolManageWebshellList,
ToolManageWebshellAdd,
ToolManageWebshellUpdate,
ToolManageWebshellDelete,
ToolManageWebshellTest,
ToolBatchTaskList,
ToolBatchTaskGet,
ToolBatchTaskCreate,
ToolBatchTaskStart,
ToolBatchTaskRerun,
ToolBatchTaskPause,
ToolBatchTaskDelete,
ToolBatchTaskUpdateMetadata,
ToolBatchTaskUpdateSchedule,
ToolBatchTaskScheduleEnabled,
ToolBatchTaskAdd,
ToolBatchTaskUpdate,
ToolBatchTaskRemove,
}
}
-551
View File
@@ -1,551 +0,0 @@
// Package mcp 外部 MCP 客户端 - 基于官方 go-sdk 实现,保证协议兼容性
package mcp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"github.com/google/uuid"
"github.com/modelcontextprotocol/go-sdk/mcp"
"go.uber.org/zap"
)
const (
clientName = "CyberStrikeAI"
clientVersion = "1.0.0"
)
// sdkClient 基于官方 MCP Go SDK 的外部 MCP 客户端,实现 ExternalMCPClient 接口
type sdkClient struct {
session *mcp.ClientSession
client *mcp.Client
logger *zap.Logger
mu sync.RWMutex
status string // "disconnected", "connecting", "connected", "error"
}
// newSDKClientFromSession 用已连接成功的 session 构造(供 createSDKClient 内部使用)
func newSDKClientFromSession(session *mcp.ClientSession, client *mcp.Client, logger *zap.Logger) *sdkClient {
return &sdkClient{
session: session,
client: client,
logger: logger,
status: "connected",
}
}
// lazySDKClient 延迟连接:Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient
type lazySDKClient struct {
serverCfg config.ExternalMCPServerConfig
logger *zap.Logger
inner ExternalMCPClient // 连接成功后为 *sdkClient
mu sync.RWMutex
status string
}
func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient {
return &lazySDKClient{
serverCfg: serverCfg,
logger: logger,
status: "connecting",
}
}
func (c *lazySDKClient) setStatus(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = s
}
func (c *lazySDKClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
if c.inner != nil {
return c.inner.GetStatus()
}
return c.status
}
func (c *lazySDKClient) IsConnected() bool {
c.mu.RLock()
inner := c.inner
c.mu.RUnlock()
if inner != nil {
return inner.IsConnected()
}
return false
}
func (c *lazySDKClient) Initialize(ctx context.Context) error {
c.mu.Lock()
if c.inner != nil {
c.mu.Unlock()
return nil
}
c.mu.Unlock()
inner, err := createSDKClient(ctx, c.serverCfg, c.logger)
if err != nil {
c.setStatus("error")
return err
}
c.mu.Lock()
c.inner = inner
c.mu.Unlock()
c.setStatus("connected")
return nil
}
func (c *lazySDKClient) ListTools(ctx context.Context) ([]Tool, error) {
c.mu.RLock()
inner := c.inner
c.mu.RUnlock()
if inner == nil {
return nil, fmt.Errorf("未连接")
}
return inner.ListTools(ctx)
}
func (c *lazySDKClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
c.mu.RLock()
inner := c.inner
c.mu.RUnlock()
if inner == nil {
return nil, fmt.Errorf("未连接")
}
return inner.CallTool(ctx, name, args)
}
func (c *lazySDKClient) Close() error {
c.mu.Lock()
inner := c.inner
c.inner = nil
c.mu.Unlock()
c.setStatus("disconnected")
if inner != nil {
return inner.Close()
}
return nil
}
func (c *sdkClient) setStatus(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = s
}
func (c *sdkClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
func (c *sdkClient) IsConnected() bool {
return c.GetStatus() == "connected"
}
func (c *sdkClient) Initialize(ctx context.Context) error {
// sdkClient 由 createSDKClient 在 Connect 成功后才创建,因此 Initialize 时已经连接
// 此方法仅用于满足 ExternalMCPClient 接口,实际连接在 createSDKClient 中完成
return nil
}
func (c *sdkClient) ListTools(ctx context.Context) ([]Tool, error) {
if c.session == nil {
return nil, fmt.Errorf("未连接")
}
res, err := c.session.ListTools(ctx, nil)
if err != nil {
return nil, err
}
if res == nil {
return nil, nil
}
return sdkToolsToOur(res.Tools), nil
}
func (c *sdkClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
if c.session == nil {
return nil, fmt.Errorf("未连接")
}
params := &mcp.CallToolParams{
Name: name,
Arguments: args,
}
res, err := c.session.CallTool(ctx, params)
if err != nil {
return nil, err
}
return sdkCallToolResultToOurs(res), nil
}
func (c *sdkClient) Close() error {
c.setStatus("disconnected")
if c.session != nil {
err := c.session.Close()
c.session = nil
return err
}
return nil
}
// sdkToolsToOur 将 SDK 的 []*mcp.Tool 转为我们的 []Tool
func sdkToolsToOur(tools []*mcp.Tool) []Tool {
if len(tools) == 0 {
return nil
}
out := make([]Tool, 0, len(tools))
for _, t := range tools {
if t == nil {
continue
}
schema := make(map[string]interface{})
if t.InputSchema != nil {
// SDK InputSchema 可能为 *jsonschema.Schema 或 map,统一转为 map
if m, ok := t.InputSchema.(map[string]interface{}); ok {
schema = m
} else {
_ = json.Unmarshal(mustJSON(t.InputSchema), &schema)
}
}
desc := t.Description
shortDesc := desc
if t.Annotations != nil && t.Annotations.Title != "" {
shortDesc = t.Annotations.Title
}
out = append(out, Tool{
Name: t.Name,
Description: desc,
ShortDescription: shortDesc,
InputSchema: schema,
})
}
return out
}
// sdkCallToolResultToOurs 将 SDK 的 *mcp.CallToolResult 转为我们的 *ToolResult
func sdkCallToolResultToOurs(res *mcp.CallToolResult) *ToolResult {
if res == nil {
return &ToolResult{Content: []Content{}}
}
content := sdkContentToOurs(res.Content)
return &ToolResult{
Content: content,
IsError: res.IsError,
}
}
func sdkContentToOurs(list []mcp.Content) []Content {
if len(list) == 0 {
return nil
}
out := make([]Content, 0, len(list))
for _, c := range list {
switch v := c.(type) {
case *mcp.TextContent:
out = append(out, Content{Type: "text", Text: v.Text})
default:
out = append(out, Content{Type: "text", Text: fmt.Sprintf("%v", c)})
}
}
return out
}
func mustJSON(v interface{}) []byte {
b, _ := json.Marshal(v)
return b
}
// simpleHTTPClient 简单 JSON-RPC over HTTP:每次请求一次 POST、响应在 body。实现 ExternalMCPClient。
// 用于自建 MCP(如 http://127.0.0.1:8081/mcp)或其它仅支持简单 POST 的端点。
type simpleHTTPClient struct {
url string
client *http.Client
logger *zap.Logger
mu sync.RWMutex
status string
}
func newSimpleHTTPClient(ctx context.Context, url string, timeout time.Duration, headers map[string]string, logger *zap.Logger) (ExternalMCPClient, error) {
c := &simpleHTTPClient{
url: url,
client: httpClientWithTimeoutAndHeaders(timeout, headers),
logger: logger,
status: "connecting",
}
if err := c.initialize(ctx); err != nil {
return nil, err
}
c.mu.Lock()
c.status = "connected"
c.mu.Unlock()
return c, nil
}
func (c *simpleHTTPClient) setStatus(s string) {
c.mu.Lock()
defer c.mu.Unlock()
c.status = s
}
func (c *simpleHTTPClient) GetStatus() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.status
}
func (c *simpleHTTPClient) IsConnected() bool {
return c.GetStatus() == "connected"
}
func (c *simpleHTTPClient) Initialize(context.Context) error {
return nil // 已在 newSimpleHTTPClient 中完成
}
func (c *simpleHTTPClient) initialize(ctx context.Context) error {
params := InitializeRequest{
ProtocolVersion: ProtocolVersion,
Capabilities: make(map[string]interface{}),
ClientInfo: ClientInfo{Name: clientName, Version: clientVersion},
}
paramsJSON, _ := json.Marshal(params)
req := &Message{
ID: MessageID{value: "1"},
Method: "initialize",
Version: "2.0",
Params: paramsJSON,
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return fmt.Errorf("initialize: %w", err)
}
if resp.Error != nil {
return fmt.Errorf("initialize: %s (code %d)", resp.Error.Message, resp.Error.Code)
}
// 发送 notifications/initialized(协议要求)
notify := &Message{
ID: MessageID{value: nil},
Method: "notifications/initialized",
Version: "2.0",
Params: json.RawMessage("{}"),
}
_ = c.sendNotification(notify)
return nil
}
func (c *simpleHTTPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
body, err := json.Marshal(msg)
if err != nil {
return nil, err
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(b))
}
var out Message
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, err
}
return &out, nil
}
func (c *simpleHTTPClient) sendNotification(msg *Message) error {
body, _ := json.Marshal(msg)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
return err
}
resp.Body.Close()
return nil
}
func (c *simpleHTTPClient) ListTools(ctx context.Context) ([]Tool, error) {
req := &Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/list",
Version: "2.0",
Params: json.RawMessage("{}"),
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, fmt.Errorf("tools/list: %s (code %d)", resp.Error.Message, resp.Error.Code)
}
var listResp ListToolsResponse
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
return nil, err
}
return listResp.Tools, nil
}
func (c *simpleHTTPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
params := CallToolRequest{Name: name, Arguments: args}
paramsJSON, _ := json.Marshal(params)
req := &Message{
ID: MessageID{value: uuid.New().String()},
Method: "tools/call",
Version: "2.0",
Params: paramsJSON,
}
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, fmt.Errorf("tools/call: %s (code %d)", resp.Error.Message, resp.Error.Code)
}
var callResp CallToolResponse
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
return nil, err
}
return &ToolResult{Content: callResp.Content, IsError: callResp.IsError}, nil
}
func (c *simpleHTTPClient) Close() error {
c.setStatus("disconnected")
return nil
}
// createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK),返回实现 ExternalMCPClient 的 *sdkClient
// 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。
func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) {
timeout := time.Duration(serverCfg.Timeout) * time.Second
if timeout <= 0 {
timeout = 30 * time.Second
}
transport := serverCfg.Transport
if transport == "" {
if serverCfg.Command != "" {
transport = "stdio"
} else if serverCfg.URL != "" {
transport = "http"
} else {
return nil, fmt.Errorf("配置缺少 command 或 url")
}
}
client := mcp.NewClient(&mcp.Implementation{
Name: clientName,
Version: clientVersion,
}, nil)
var t mcp.Transport
switch transport {
case "stdio":
if serverCfg.Command == "" {
return nil, fmt.Errorf("stdio 模式需要配置 command")
}
// 必须用 exec.Command 而非 CommandContextdoConnect 返回后 ctx 会被 cancel
// 若用 CommandContext(ctx) 会立刻杀掉子进程,导致 ListTools 等后续请求失败、显示 0 工具
cmd := exec.Command(serverCfg.Command, serverCfg.Args...)
if len(serverCfg.Env) > 0 {
cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...)
}
t = &mcp.CommandTransport{Command: cmd}
case "sse":
if serverCfg.URL == "" {
return nil, fmt.Errorf("sse 模式需要配置 url")
}
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
t = &mcp.SSEClientTransport{
Endpoint: serverCfg.URL,
HTTPClient: httpClient,
}
case "http":
if serverCfg.URL == "" {
return nil, fmt.Errorf("http 模式需要配置 url")
}
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
t = &mcp.StreamableClientTransport{
Endpoint: serverCfg.URL,
HTTPClient: httpClient,
}
case "simple_http":
// 简单 JSON-RPC HTTP:每次请求一次 POST、响应在 body。用于自建 MCP 或兼容旧端点(如 http://127.0.0.1:8081/mcp
if serverCfg.URL == "" {
return nil, fmt.Errorf("simple_http 模式需要配置 url")
}
return newSimpleHTTPClient(ctx, serverCfg.URL, timeout, serverCfg.Headers, logger)
default:
return nil, fmt.Errorf("不支持的传输模式: %s", transport)
}
session, err := client.Connect(ctx, t, nil)
if err != nil {
return nil, fmt.Errorf("连接失败: %w", err)
}
return newSDKClientFromSession(session, client, logger), nil
}
func envMapToSlice(env map[string]string) []string {
m := make(map[string]string)
for _, s := range os.Environ() {
if i := strings.IndexByte(s, '='); i > 0 {
m[s[:i]] = s[i+1:]
}
}
for k, v := range env {
m[k] = v
}
out := make([]string, 0, len(m))
for k, v := range m {
out = append(out, k+"="+v)
}
return out
}
func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]string) *http.Client {
transport := http.DefaultTransport
if len(headers) > 0 {
transport = &headerRoundTripper{
headers: headers,
base: http.DefaultTransport,
}
}
return &http.Client{
Timeout: timeout,
Transport: transport,
}
}
type headerRoundTripper struct {
headers map[string]string
base http.RoundTripper
}
func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
for k, v := range h.headers {
req.Header.Set(k, v)
}
return h.base.RoundTrip(req)
}
File diff suppressed because it is too large Load Diff
-239
View File
@@ -1,239 +0,0 @@
package mcp
import (
"context"
"testing"
"time"
"cyberstrike-ai/internal/config"
"go.uber.org/zap"
)
func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
// 测试添加stdio配置
stdioCfg := config.ExternalMCPServerConfig{
Command: "python3",
Args: []string{"/path/to/script.py"},
Transport: "stdio",
Description: "Test stdio MCP",
Timeout: 30,
Enabled: true,
}
err := manager.AddOrUpdateConfig("test-stdio", stdioCfg)
if err != nil {
t.Fatalf("添加stdio配置失败: %v", err)
}
// 测试添加HTTP配置
httpCfg := config.ExternalMCPServerConfig{
Transport: "http",
URL: "http://127.0.0.1:8081/mcp",
Description: "Test HTTP MCP",
Timeout: 30,
Enabled: false,
}
err = manager.AddOrUpdateConfig("test-http", httpCfg)
if err != nil {
t.Fatalf("添加HTTP配置失败: %v", err)
}
// 验证配置已保存
configs := manager.GetConfigs()
if len(configs) != 2 {
t.Fatalf("期望2个配置,实际%d个", len(configs))
}
if configs["test-stdio"].Command != stdioCfg.Command {
t.Errorf("stdio配置命令不匹配")
}
if configs["test-http"].URL != httpCfg.URL {
t.Errorf("HTTP配置URL不匹配")
}
}
func TestExternalMCPManager_RemoveConfig(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
cfg := config.ExternalMCPServerConfig{
Command: "python3",
Transport: "stdio",
Enabled: false,
}
manager.AddOrUpdateConfig("test-remove", cfg)
// 移除配置
err := manager.RemoveConfig("test-remove")
if err != nil {
t.Fatalf("移除配置失败: %v", err)
}
configs := manager.GetConfigs()
if _, exists := configs["test-remove"]; exists {
t.Error("配置应该已被移除")
}
}
func TestExternalMCPManager_GetStats(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
// 添加多个配置
manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: true,
})
manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
Enabled: true,
})
manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
Command: "python3",
Enabled: false,
Disabled: true, // 明确设置为禁用
})
stats := manager.GetStats()
if stats["total"].(int) != 3 {
t.Errorf("期望总数3,实际%d", stats["total"])
}
if stats["enabled"].(int) != 2 {
t.Errorf("期望启用数2,实际%d", stats["enabled"])
}
if stats["disabled"].(int) != 1 {
t.Errorf("期望停用数1,实际%d", stats["disabled"])
}
}
func TestExternalMCPManager_LoadConfigs(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
externalMCPConfig := config.ExternalMCPConfig{
Servers: map[string]config.ExternalMCPServerConfig{
"loaded1": {
Command: "python3",
Enabled: true,
},
"loaded2": {
URL: "http://127.0.0.1:8081/mcp",
Enabled: false,
},
},
}
manager.LoadConfigs(&externalMCPConfig)
configs := manager.GetConfigs()
if len(configs) != 2 {
t.Fatalf("期望2个配置,实际%d个", len(configs))
}
if configs["loaded1"].Command != "python3" {
t.Error("配置1加载失败")
}
if configs["loaded2"].URL != "http://127.0.0.1:8081/mcp" {
t.Error("配置2加载失败")
}
}
// TestLazySDKClient_InitializeFails 验证无效配置时 SDK 客户端 Initialize 失败并设置 error 状态
func TestLazySDKClient_InitializeFails(t *testing.T) {
logger := zap.NewNop()
// 使用不存在的 HTTP 地址,Initialize 应失败
cfg := config.ExternalMCPServerConfig{
Transport: "http",
URL: "http://127.0.0.1:19999/nonexistent",
Timeout: 2,
}
c := newLazySDKClient(cfg, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := c.Initialize(ctx)
if err == nil {
t.Fatal("expected error when connecting to invalid server")
}
if c.GetStatus() != "error" {
t.Errorf("expected status error, got %s", c.GetStatus())
}
c.Close()
}
func TestExternalMCPManager_StartStopClient(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
// 添加一个禁用的配置
cfg := config.ExternalMCPServerConfig{
Command: "python3",
Transport: "stdio",
Enabled: false,
}
manager.AddOrUpdateConfig("test-start-stop", cfg)
// 尝试启动(可能会失败,因为没有真实的服务器)
err := manager.StartClient("test-start-stop")
if err != nil {
t.Logf("启动失败(可能是没有服务器): %v", err)
}
// 停止
err = manager.StopClient("test-start-stop")
if err != nil {
t.Fatalf("停止失败: %v", err)
}
// 验证配置已更新为禁用
configs := manager.GetConfigs()
if configs["test-start-stop"].Enabled {
t.Error("配置应该已被禁用")
}
}
func TestExternalMCPManager_CallTool(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
// 测试调用不存在的工具
_, _, err := manager.CallTool(context.Background(), "nonexistent::tool", map[string]interface{}{})
if err == nil {
t.Error("应该返回错误")
}
// 测试无效的工具名称格式
_, _, err = manager.CallTool(context.Background(), "invalid-tool-name", map[string]interface{}{})
if err == nil {
t.Error("应该返回错误(无效格式)")
}
}
func TestExternalMCPManager_GetAllTools(t *testing.T) {
logger := zap.NewNop()
manager := NewExternalMCPManager(logger)
ctx := context.Background()
tools, err := manager.GetAllTools(ctx)
if err != nil {
t.Fatalf("获取工具列表失败: %v", err)
}
// 如果没有连接的客户端,应该返回空列表
if len(tools) != 0 {
t.Logf("获取到%d个工具", len(tools))
}
}
File diff suppressed because it is too large Load Diff
-295
View File
@@ -1,295 +0,0 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"time"
)
// ExternalMCPClient 外部 MCP 客户端接口(由 client_sdk.go 基于官方 SDK 实现)
type ExternalMCPClient interface {
Initialize(ctx context.Context) error
ListTools(ctx context.Context) ([]Tool, error)
CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error)
Close() error
IsConnected() bool
GetStatus() string
}
// MCP消息类型
const (
MessageTypeRequest = "request"
MessageTypeResponse = "response"
MessageTypeError = "error"
MessageTypeNotify = "notify"
)
// MCP协议版本
const ProtocolVersion = "2024-11-05"
// MessageID 表示JSON-RPC 2.0的id字段,可以是字符串、数字或null
type MessageID struct {
value interface{}
}
// UnmarshalJSON 自定义反序列化,支持字符串、数字和null
func (m *MessageID) UnmarshalJSON(data []byte) error {
// 尝试解析为null
if string(data) == "null" {
m.value = nil
return nil
}
// 尝试解析为字符串
var str string
if err := json.Unmarshal(data, &str); err == nil {
m.value = str
return nil
}
// 尝试解析为数字
var num json.Number
if err := json.Unmarshal(data, &num); err == nil {
m.value = num
return nil
}
return fmt.Errorf("invalid id type")
}
// MarshalJSON 自定义序列化
func (m MessageID) MarshalJSON() ([]byte, error) {
if m.value == nil {
return []byte("null"), nil
}
return json.Marshal(m.value)
}
// String 返回字符串表示
func (m MessageID) String() string {
if m.value == nil {
return ""
}
return fmt.Sprintf("%v", m.value)
}
// Value 返回原始值
func (m MessageID) Value() interface{} {
return m.value
}
// Message 表示MCP消息(符合JSON-RPC 2.0规范)
type Message struct {
ID MessageID `json:"id,omitempty"`
Type string `json:"-"` // 内部使用,不序列化到JSON
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *Error `json:"error,omitempty"`
Version string `json:"jsonrpc,omitempty"` // JSON-RPC 2.0 版本标识
}
// Error 表示MCP错误
type Error struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// Tool 表示MCP工具定义
type Tool struct {
Name string `json:"name"`
Description string `json:"description"` // 详细描述
ShortDescription string `json:"shortDescription,omitempty"` // 简短描述(用于工具列表,减少token消耗)
InputSchema map[string]interface{} `json:"inputSchema"`
}
// ToolCall 表示工具调用
type ToolCall struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
// ToolResult 表示工具执行结果
type ToolResult struct {
Content []Content `json:"content"`
IsError bool `json:"isError,omitempty"`
}
// Content 表示内容
type Content struct {
Type string `json:"type"`
Text string `json:"text"`
}
// InitializeRequest 初始化请求
type InitializeRequest struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities map[string]interface{} `json:"capabilities"`
ClientInfo ClientInfo `json:"clientInfo"`
}
// ClientInfo 客户端信息
type ClientInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
// InitializeResponse 初始化响应
type InitializeResponse struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities ServerCapabilities `json:"capabilities"`
ServerInfo ServerInfo `json:"serverInfo"`
}
// ServerCapabilities 服务器能力
type ServerCapabilities struct {
Tools map[string]interface{} `json:"tools,omitempty"`
Prompts map[string]interface{} `json:"prompts,omitempty"`
Resources map[string]interface{} `json:"resources,omitempty"`
Sampling map[string]interface{} `json:"sampling,omitempty"`
}
// ServerInfo 服务器信息
type ServerInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
// ListToolsRequest 列出工具请求
type ListToolsRequest struct{}
// ListToolsResponse 列出工具响应
type ListToolsResponse struct {
Tools []Tool `json:"tools"`
}
// ListPromptsResponse 列出提示词响应
type ListPromptsResponse struct {
Prompts []Prompt `json:"prompts"`
}
// ListResourcesResponse 列出资源响应
type ListResourcesResponse struct {
Resources []Resource `json:"resources"`
}
// CallToolRequest 调用工具请求
type CallToolRequest struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
// CallToolResponse 调用工具响应
type CallToolResponse struct {
Content []Content `json:"content"`
IsError bool `json:"isError,omitempty"`
}
// ToolExecution 工具执行记录
type ToolExecution struct {
ID string `json:"id"`
ToolName string `json:"toolName"`
Arguments map[string]interface{} `json:"arguments"`
Status string `json:"status"` // pending, running, completed, failed
Result *ToolResult `json:"result,omitempty"`
Error string `json:"error,omitempty"`
StartTime time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
}
// ToolStats 工具统计信息
type ToolStats struct {
ToolName string `json:"toolName"`
TotalCalls int `json:"totalCalls"`
SuccessCalls int `json:"successCalls"`
FailedCalls int `json:"failedCalls"`
LastCallTime *time.Time `json:"lastCallTime,omitempty"`
}
// Prompt 提示词模板
type Prompt struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Arguments []PromptArgument `json:"arguments,omitempty"`
}
// PromptArgument 提示词参数
type PromptArgument struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Required bool `json:"required,omitempty"`
}
// GetPromptRequest 获取提示词请求
type GetPromptRequest struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
}
// GetPromptResponse 获取提示词响应
type GetPromptResponse struct {
Messages []PromptMessage `json:"messages"`
}
// PromptMessage 提示词消息
type PromptMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// Resource 资源
type Resource struct {
URI string `json:"uri"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
MimeType string `json:"mimeType,omitempty"`
}
// ReadResourceRequest 读取资源请求
type ReadResourceRequest struct {
URI string `json:"uri"`
}
// ReadResourceResponse 读取资源响应
type ReadResourceResponse struct {
Contents []ResourceContent `json:"contents"`
}
// ResourceContent 资源内容
type ResourceContent struct {
URI string `json:"uri"`
MimeType string `json:"mimeType,omitempty"`
Text string `json:"text,omitempty"`
Blob string `json:"blob,omitempty"`
}
// SamplingRequest 采样请求
type SamplingRequest struct {
Messages []SamplingMessage `json:"messages"`
Model string `json:"model,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
}
// SamplingMessage 采样消息
type SamplingMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// SamplingResponse 采样响应
type SamplingResponse struct {
Content []SamplingContent `json:"content"`
Model string `json:"model,omitempty"`
StopReason string `json:"stopReason,omitempty"`
}
// SamplingContent 采样内容
type SamplingContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
}
-140
View File
@@ -1,140 +0,0 @@
package multiagent
import (
"context"
"fmt"
"strings"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/adk/middlewares/summarization"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// einoSummarizeUserInstruction 与单 Agent MemoryCompressor 目标一致:压缩时保留渗透关键信息。
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史
必须保留已确认漏洞与攻击路径工具输出中的核心发现凭证与认证细节架构与薄弱点当前进度失败尝试与死路策略决策
保留精确技术细节URL路径参数Payload版本号报错原文可摘要但要点不丢
将冗长扫描输出概括为结论重复发现合并表述
输出须使后续代理能无缝继续同一授权测试任务`
// newEinoSummarizationMiddleware 使用 Eino ADK Summarization 中间件(见 https://www.cloudwego.io/zh/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_summarization/)。
// 触发阈值与单 Agent MemoryCompressor 一致:当估算 token 超过 openai.max_total_tokens 的 90% 时摘要。
func newEinoSummarizationMiddleware(
ctx context.Context,
summaryModel model.BaseChatModel,
appCfg *config.Config,
logger *zap.Logger,
) (adk.ChatModelAgentMiddleware, error) {
if summaryModel == nil || appCfg == nil {
return nil, fmt.Errorf("multiagent: summarization 需要 model 与配置")
}
maxTotal := appCfg.OpenAI.MaxTotalTokens
if maxTotal <= 0 {
maxTotal = 120000
}
trigger := int(float64(maxTotal) * 0.9)
if trigger < 4096 {
trigger = maxTotal
if trigger < 4096 {
trigger = 4096
}
}
preserveMax := trigger / 3
if preserveMax < 2048 {
preserveMax = 2048
}
modelName := strings.TrimSpace(appCfg.OpenAI.Model)
if modelName == "" {
modelName = "gpt-4o"
}
mw, err := summarization.New(ctx, &summarization.Config{
Model: summaryModel,
Trigger: &summarization.TriggerCondition{
ContextTokens: trigger,
},
TokenCounter: einoSummarizationTokenCounter(modelName),
UserInstruction: einoSummarizeUserInstruction,
EmitInternalEvents: false,
PreserveUserMessages: &summarization.PreserveUserMessages{
Enabled: true,
MaxTokens: preserveMax,
},
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
if logger == nil {
return nil
}
logger.Info("eino summarization 已压缩上下文",
zap.Int("messages_before", len(before.Messages)),
zap.Int("messages_after", len(after.Messages)),
zap.Int("max_total_tokens", maxTotal),
zap.Int("trigger_context_tokens", trigger),
)
return nil
},
})
if err != nil {
return nil, fmt.Errorf("summarization.New: %w", err)
}
return mw, nil
}
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
tc := agent.NewTikTokenCounter()
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
var sb strings.Builder
for _, msg := range input.Messages {
if msg == nil {
continue
}
sb.WriteString(string(msg.Role))
sb.WriteByte('\n')
if msg.Content != "" {
sb.WriteString(msg.Content)
sb.WriteByte('\n')
}
if msg.ReasoningContent != "" {
sb.WriteString(msg.ReasoningContent)
sb.WriteByte('\n')
}
if len(msg.ToolCalls) > 0 {
if b, err := sonic.Marshal(msg.ToolCalls); err == nil {
sb.Write(b)
sb.WriteByte('\n')
}
}
for _, part := range msg.UserInputMultiContent {
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
sb.WriteString(part.Text)
sb.WriteByte('\n')
}
}
}
for _, tl := range input.Tools {
if tl == nil {
continue
}
cp := *tl
cp.Extra = nil
if text, err := sonic.MarshalString(cp); err == nil {
sb.WriteString(text)
sb.WriteByte('\n')
}
}
text := sb.String()
n, err := tc.Count(openAIModel, text)
if err != nil {
return (len(text) + 3) / 4, nil
}
return n, nil
}
}
-62
View File
@@ -1,62 +0,0 @@
package multiagent
import (
"context"
"strings"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/components/tool"
)
// noNestedTaskMiddleware 禁止在已经处于 task(sub-agent) 执行链中再次调用 task
// 避免子代理再次委派子代理造成的无限委派/递归。
//
// 通过在 ctx 中设置临时标记来实现嵌套检测:外层 task 调用会先标记 ctx,
// 子代理内再调用 task 时会命中该标记并拒绝。
type noNestedTaskMiddleware struct {
adk.BaseChatModelAgentMiddleware
}
type nestedTaskCtxKey struct{}
func newNoNestedTaskMiddleware() adk.ChatModelAgentMiddleware {
return &noNestedTaskMiddleware{}
}
func (m *noNestedTaskMiddleware) WrapInvokableToolCall(
ctx context.Context,
endpoint adk.InvokableToolCallEndpoint,
tCtx *adk.ToolContext,
) (adk.InvokableToolCallEndpoint, error) {
if tCtx == nil || strings.TrimSpace(tCtx.Name) == "" {
return endpoint, nil
}
// Deep 内置 task 工具名固定为 "task";为兼容可能的大小写/空白,仅做不区分大小写匹配。
if !strings.EqualFold(strings.TrimSpace(tCtx.Name), "task") {
return endpoint, nil
}
// 已在 task 执行链中:拒绝继续委派,直接报错让上层快速终止。
if ctx != nil {
if v, ok := ctx.Value(nestedTaskCtxKey{}).(bool); ok && v {
return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
// Important: return a tool result text (not an error) to avoid hard-stopping the whole multi-agent run.
// The nested task is still prevented from spawning another sub-agent, so recursion is avoided.
_ = argumentsInJSON
_ = opts
return "Nested task delegation is forbidden (already inside a sub-agent delegation chain) to avoid infinite delegation. Please continue the work using the current agent's tools.", nil
}, nil
}
}
// 标记当前 task 调用链,确保子代理内的再次 task 调用能检测到嵌套。
return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
ctx2 := ctx
if ctx2 == nil {
ctx2 = context.Background()
}
ctx2 = context.WithValue(ctx2, nestedTaskCtxKey{}, true)
return endpoint(ctx2, argumentsInJSON, opts...)
}, nil
}
File diff suppressed because it is too large Load Diff
@@ -1,51 +0,0 @@
package multiagent
import (
"fmt"
"strings"
"github.com/cloudwego/eino/schema"
)
// maxToolCallRecoveryAttempts 含首次运行:首次 + 自动重试次数。
// 例如为 3 表示最多共 3 次完整 DeepAgent 运行(2 次失败后各追加一条纠错提示)。
// 该常量同时用于 JSON 参数错误和工具执行错误(如子代理名称不存在)的恢复重试。
const maxToolCallRecoveryAttempts = 5
// toolCallArgumentsJSONRetryHint 追加在用户消息后,提示模型输出合法 JSON 工具参数(部分云厂商会在流式阶段校验 arguments)。
func toolCallArgumentsJSONRetryHint() *schema.Message {
return schema.UserMessage(`[系统提示] 上一次输出中工具调用的 function.arguments 不是合法 JSON接口已拒绝请重新生成每个 tool call arguments 必须是完整可解析的 JSON 对象字符串键名用双引号无多余逗号括号配对不要输出截断或不完整的 JSON
[System] Your previous tool call used invalid JSON in function.arguments and was rejected by the API. Regenerate with strictly valid JSON objects only (double-quoted keys, matched braces, no trailing commas).`)
}
// toolCallArgumentsJSONRecoveryTimelineMessage 供 eino_recovery 事件落库与前端时间线展示。
func toolCallArgumentsJSONRecoveryTimelineMessage(attempt int) string {
return fmt.Sprintf(
"接口拒绝了无效的工具参数 JSON。已向对话追加系统提示并要求模型重新生成合法的 function.arguments。"+
"当前为第 %d/%d 轮完整运行。\n\n"+
"The API rejected invalid JSON in tool arguments. A system hint was appended. This is full run %d of %d.",
attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts,
)
}
// isRecoverableToolCallArgumentsJSONError 判断是否为「工具参数非合法 JSON」类流式错误,可通过追加提示后重跑一轮。
func isRecoverableToolCallArgumentsJSONError(err error) bool {
if err == nil {
return false
}
s := strings.ToLower(err.Error())
if !strings.Contains(s, "json") {
return false
}
if strings.Contains(s, "function.arguments") || strings.Contains(s, "function arguments") {
return true
}
if strings.Contains(s, "invalidparameter") && strings.Contains(s, "json") {
return true
}
if strings.Contains(s, "must be in json format") {
return true
}
return false
}
@@ -1,17 +0,0 @@
package multiagent
import (
"errors"
"testing"
)
func TestIsRecoverableToolCallArgumentsJSONError(t *testing.T) {
yes := errors.New(`failed to receive stream chunk: error, <400> InternalError.Algo.InvalidParameter: The "function.arguments" parameter of the code model must be in JSON format.`)
if !isRecoverableToolCallArgumentsJSONError(yes) {
t.Fatal("expected recoverable for function.arguments + JSON")
}
no := errors.New("unrelated network failure")
if isRecoverableToolCallArgumentsJSONError(no) {
t.Fatal("expected not recoverable")
}
}
@@ -1,131 +0,0 @@
package multiagent
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/cloudwego/eino/compose"
)
// softRecoveryToolCallMiddleware returns an InvokableToolMiddleware that catches
// specific recoverable errors from tool execution (JSON parse errors, tool-not-found,
// etc.) and converts them into soft errors: nil error + descriptive error content
// returned to the LLM. This allows the model to self-correct within the same
// iteration rather than crashing the entire graph and requiring a full replay.
//
// Without this middleware, a JSON parse failure in any tool's InvokableRun propagates
// as a hard error through the Eino ToolsNode → [NodeRunError] → ev.Err, which
// either triggers the full-replay retry loop (expensive) or terminates the run
// entirely once retries are exhausted. With it, the LLM simply sees an error message
// in the tool result and can adjust its next tool call accordingly.
func softRecoveryToolCallMiddleware() compose.InvokableToolMiddleware {
return func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint {
return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
output, err := next(ctx, input)
if err == nil {
return output, nil
}
if !isSoftRecoverableToolError(err) {
return output, err
}
// Convert the hard error into a soft error: the LLM will see this
// message as the tool's output and can self-correct.
msg := buildSoftRecoveryMessage(input.Name, input.Arguments, err)
return &compose.ToolOutput{Result: msg}, nil
}
}
}
// isSoftRecoverableToolError determines whether a tool execution error should be
// silently converted to a tool-result message rather than crashing the graph.
func isSoftRecoverableToolError(err error) bool {
if err == nil {
return false
}
s := strings.ToLower(err.Error())
// JSON unmarshal/parse failures — the model generated truncated or malformed arguments.
if isJSONRelatedError(s) {
return true
}
// Sub-agent type not found (from deep/task_tool.go)
if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") {
return true
}
// Tool not found in ToolsNode indexes
if strings.Contains(s, "tool") && strings.Contains(s, "not found") {
return true
}
return false
}
// isJSONRelatedError checks whether an error string indicates a JSON parsing problem.
func isJSONRelatedError(lower string) bool {
if !strings.Contains(lower, "json") {
return false
}
jsonIndicators := []string{
"unexpected end of json",
"unmarshal",
"invalid character",
"cannot unmarshal",
"invalid tool arguments",
"failed to unmarshal",
"must be in json format",
"unexpected eof",
}
for _, ind := range jsonIndicators {
if strings.Contains(lower, ind) {
return true
}
}
return false
}
// buildSoftRecoveryMessage creates a bilingual error message that the LLM can act on.
func buildSoftRecoveryMessage(toolName, arguments string, err error) string {
// Truncate arguments preview to avoid flooding the context.
argPreview := arguments
if len(argPreview) > 300 {
argPreview = argPreview[:300] + "... (truncated)"
}
// Try to determine if it's specifically a JSON parse error for a friendlier message.
errStr := err.Error()
var jsonErr *json.SyntaxError
isJSONErr := strings.Contains(strings.ToLower(errStr), "json") ||
strings.Contains(strings.ToLower(errStr), "unmarshal")
_ = jsonErr // suppress unused
if isJSONErr {
return fmt.Sprintf(
"[Tool Error] The arguments for tool '%s' are not valid JSON and could not be parsed.\n"+
"Error: %s\n"+
"Arguments received: %s\n\n"+
"Please fix the JSON (ensure double-quoted keys, matched braces/brackets, no trailing commas, "+
"no truncation) and call the tool again.\n\n"+
"[工具错误] 工具 '%s' 的参数不是合法 JSON,无法解析。\n"+
"错误:%s\n"+
"收到的参数:%s\n\n"+
"请修正 JSON(确保双引号键名、括号配对、无尾部逗号、无截断),然后重新调用工具。",
toolName, errStr, argPreview,
toolName, errStr, argPreview,
)
}
return fmt.Sprintf(
"[Tool Error] Tool '%s' execution failed: %s\n"+
"Arguments: %s\n\n"+
"Please review the available tools and their expected arguments, then retry.\n\n"+
"[工具错误] 工具 '%s' 执行失败:%s\n"+
"参数:%s\n\n"+
"请检查可用工具及其参数要求,然后重试。",
toolName, errStr, argPreview,
toolName, errStr, argPreview,
)
}
@@ -1,166 +0,0 @@
package multiagent
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/cloudwego/eino/compose"
)
func TestIsSoftRecoverableToolError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "unexpected end of JSON input",
err: errors.New("unexpected end of JSON input"),
expected: true,
},
{
name: "failed to unmarshal task tool input json",
err: errors.New("failed to unmarshal task tool input json: unexpected end of JSON input"),
expected: true,
},
{
name: "invalid tool arguments JSON",
err: errors.New("invalid tool arguments JSON: unexpected end of JSON input"),
expected: true,
},
{
name: "json invalid character",
err: errors.New(`invalid character '}' looking for beginning of value in JSON`),
expected: true,
},
{
name: "subagent type not found",
err: errors.New("subagent type recon_agent not found"),
expected: true,
},
{
name: "tool not found",
err: errors.New("tool nmap_scan not found in toolsNode indexes"),
expected: true,
},
{
name: "unrelated network error",
err: errors.New("connection refused"),
expected: false,
},
{
name: "context cancelled",
err: context.Canceled,
expected: false,
},
{
name: "real json unmarshal error",
err: func() error {
var v map[string]interface{}
return json.Unmarshal([]byte(`{"key": `), &v)
}(),
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isSoftRecoverableToolError(tt.err)
if got != tt.expected {
t.Errorf("isSoftRecoverableToolError(%v) = %v, want %v", tt.err, got, tt.expected)
}
})
}
}
func TestSoftRecoveryToolCallMiddleware_PassesThrough(t *testing.T) {
mw := softRecoveryToolCallMiddleware()
called := false
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
called = true
return &compose.ToolOutput{Result: "success"}, nil
}
wrapped := mw(next)
out, err := wrapped(context.Background(), &compose.ToolInput{
Name: "test_tool",
Arguments: `{"key": "value"}`,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !called {
t.Fatal("next endpoint was not called")
}
if out.Result != "success" {
t.Fatalf("expected 'success', got %q", out.Result)
}
}
func TestSoftRecoveryToolCallMiddleware_ConvertsJSONError(t *testing.T) {
mw := softRecoveryToolCallMiddleware()
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
return nil, errors.New("failed to unmarshal task tool input json: unexpected end of JSON input")
}
wrapped := mw(next)
out, err := wrapped(context.Background(), &compose.ToolInput{
Name: "task",
Arguments: `{"subagent_type": "recon`,
})
if err != nil {
t.Fatalf("expected nil error (soft recovery), got: %v", err)
}
if out == nil || out.Result == "" {
t.Fatal("expected non-empty recovery message")
}
if !containsAll(out.Result, "[Tool Error]", "task", "JSON") {
t.Fatalf("recovery message missing expected content: %s", out.Result)
}
}
func TestSoftRecoveryToolCallMiddleware_PropagatesNonRecoverable(t *testing.T) {
mw := softRecoveryToolCallMiddleware()
origErr := errors.New("connection timeout to remote server")
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
return nil, origErr
}
wrapped := mw(next)
_, err := wrapped(context.Background(), &compose.ToolInput{
Name: "test_tool",
Arguments: `{}`,
})
if err == nil {
t.Fatal("expected error to propagate for non-recoverable errors")
}
if err != origErr {
t.Fatalf("expected original error, got: %v", err)
}
}
func containsAll(s string, subs ...string) bool {
for _, sub := range subs {
if !contains(s, sub) {
return false
}
}
return true
}
func contains(s, sub string) bool {
return len(s) >= len(sub) && searchString(s, sub)
}
func searchString(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}
@@ -1,76 +0,0 @@
package multiagent
import (
"fmt"
"strings"
"github.com/cloudwego/eino/schema"
)
// isRecoverableToolExecutionError detects tool-level execution errors that can be
// recovered by retrying with a corrective hint. These errors originate from eino
// framework internals (e.g. task_tool.go, tool_node.go) when the LLM produces
// invalid tool calls such as non-existent sub-agent types, malformed JSON arguments,
// or unregistered tool names.
func isRecoverableToolExecutionError(err error) bool {
if err == nil {
return false
}
s := strings.ToLower(err.Error())
// Sub-agent type not found (from deep/task_tool.go)
if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") {
return true
}
// Tool not found in toolsNode indexes (from compose/tool_node.go, when UnknownToolsHandler is nil)
if strings.Contains(s, "tool") && strings.Contains(s, "not found") {
return true
}
// Invalid tool arguments JSON (from einomcp/mcp_tools.go or eino internals)
if strings.Contains(s, "invalid tool arguments json") {
return true
}
// Failed to unmarshal task tool input json (from deep/task_tool.go)
if strings.Contains(s, "failed to unmarshal") && strings.Contains(s, "json") {
return true
}
// Generic tool call stream/invoke failure wrapping the above
if (strings.Contains(s, "failed to stream tool call") || strings.Contains(s, "failed to invoke tool")) &&
(strings.Contains(s, "not found") || strings.Contains(s, "json") || strings.Contains(s, "unmarshal")) {
return true
}
return false
}
// toolExecutionRetryHint returns a user message appended to the conversation to prompt
// the LLM to correct its tool call after a tool execution error.
func toolExecutionRetryHint() *schema.Message {
return schema.UserMessage(`[System] Your previous tool call failed because:
- The tool or sub-agent name you used does not exist, OR
- The tool call arguments were not valid JSON.
Please carefully review the available tools and sub-agents listed in your context, use only exact registered names (case-sensitive), and ensure all arguments are well-formed JSON objects. Then retry your action.
[系统提示] 上一次工具调用失败可能原因
- 你使用的工具名或子代理名称不存在
- 工具调用参数不是合法 JSON
请仔细检查上下文中列出的可用工具和子代理名称须完全匹配区分大小写确保所有参数均为合法的 JSON 对象然后重新执行`)
}
// toolExecutionRecoveryTimelineMessage returns a message for the eino_recovery event
// displayed in the UI timeline when a tool execution error triggers a retry.
func toolExecutionRecoveryTimelineMessage(attempt int) string {
return fmt.Sprintf(
"工具调用执行失败(工具/子代理名称不存在或参数 JSON 无效)。已向对话追加纠错提示并要求模型重新生成。"+
"当前为第 %d/%d 轮完整运行。\n\n"+
"Tool call execution failed (unknown tool/sub-agent name or invalid JSON arguments). "+
"A corrective hint was appended. This is full run %d of %d.",
attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts,
)
}
File diff suppressed because it is too large Load Diff
-493
View File
@@ -1,493 +0,0 @@
package openai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"go.uber.org/zap"
)
// Client 统一封装与OpenAI兼容模型交互的HTTP客户端。
type Client struct {
httpClient *http.Client
config *config.OpenAIConfig
logger *zap.Logger
}
// APIError 表示OpenAI接口返回的非200错误。
type APIError struct {
StatusCode int
Body string
}
func (e *APIError) Error() string {
return fmt.Sprintf("openai api error: status=%d body=%s", e.StatusCode, e.Body)
}
// NewClient 创建一个新的OpenAI客户端。
func NewClient(cfg *config.OpenAIConfig, httpClient *http.Client, logger *zap.Logger) *Client {
if httpClient == nil {
httpClient = http.DefaultClient
}
if logger == nil {
logger = zap.NewNop()
}
return &Client{
httpClient: httpClient,
config: cfg,
logger: logger,
}
}
// UpdateConfig 动态更新OpenAI配置。
func (c *Client) UpdateConfig(cfg *config.OpenAIConfig) {
c.config = cfg
}
// ChatCompletion 调用 /chat/completions 接口。
func (c *Client) ChatCompletion(ctx context.Context, payload interface{}, out interface{}) error {
if c == nil {
return fmt.Errorf("openai client is not initialized")
}
if c.config == nil {
return fmt.Errorf("openai config is nil")
}
if strings.TrimSpace(c.config.APIKey) == "" {
return fmt.Errorf("openai api key is empty")
}
if c.isClaude() {
return c.claudeChatCompletion(ctx, payload, out)
}
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal openai payload: %w", err)
}
c.logger.Debug("sending OpenAI chat completion request",
zap.Int("payloadSizeKB", len(body)/1024))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("build openai request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
requestStart := time.Now()
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("call openai api: %w", err)
}
defer resp.Body.Close()
bodyChan := make(chan []byte, 1)
errChan := make(chan error, 1)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
errChan <- err
return
}
bodyChan <- responseBody
}()
var respBody []byte
select {
case respBody = <-bodyChan:
case err := <-errChan:
return fmt.Errorf("read openai response: %w", err)
case <-ctx.Done():
return fmt.Errorf("read openai response timeout: %w", ctx.Err())
case <-time.After(25 * time.Minute):
return fmt.Errorf("read openai response timeout (25m)")
}
c.logger.Debug("received OpenAI response",
zap.Int("status", resp.StatusCode),
zap.Duration("duration", time.Since(requestStart)),
zap.Int("responseSizeKB", len(respBody)/1024),
)
if resp.StatusCode != http.StatusOK {
c.logger.Warn("OpenAI chat completion returned non-200",
zap.Int("status", resp.StatusCode),
zap.String("body", string(respBody)),
)
return &APIError{
StatusCode: resp.StatusCode,
Body: string(respBody),
}
}
if out != nil {
if err := json.Unmarshal(respBody, out); err != nil {
c.logger.Error("failed to unmarshal OpenAI response",
zap.Error(err),
zap.String("body", string(respBody)),
)
return fmt.Errorf("unmarshal openai response: %w", err)
}
}
return nil
}
// ChatCompletionStream 调用 /chat/completions 的流式模式(stream=true),并在每个 delta 到达时回调 onDelta。
// 返回最终拼接的 content(只拼 content delta;工具调用 delta 未做处理)。
func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{}, onDelta func(delta string) error) (string, error) {
if c == nil {
return "", fmt.Errorf("openai client is not initialized")
}
if c.config == nil {
return "", fmt.Errorf("openai config is nil")
}
if strings.TrimSpace(c.config.APIKey) == "" {
return "", fmt.Errorf("openai api key is empty")
}
if c.isClaude() {
return c.claudeChatCompletionStream(ctx, payload, onDelta)
}
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
body, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("marshal openai payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("build openai request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
requestStart := time.Now()
resp, err := c.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("call openai api: %w", err)
}
defer resp.Body.Close()
// 非200:读完 body 返回
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return "", &APIError{
StatusCode: resp.StatusCode,
Body: string(respBody),
}
}
type streamDelta struct {
// OpenAI 兼容流式通常使用 content;但部分兼容实现可能用 text。
Content string `json:"content,omitempty"`
Text string `json:"text,omitempty"`
}
type streamChoice struct {
Delta streamDelta `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type streamResponse struct {
ID string `json:"id,omitempty"`
Choices []streamChoice `json:"choices"`
Error *struct {
Message string `json:"message"`
Type string `json:"type"`
} `json:"error,omitempty"`
}
reader := bufio.NewReader(resp.Body)
var full strings.Builder
// 典型 SSE 结构:
// data: {...}\n\n
// data: [DONE]\n\n
for {
line, readErr := reader.ReadString('\n')
if readErr != nil {
if readErr == io.EOF {
break
}
return full.String(), fmt.Errorf("read openai stream: %w", readErr)
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
if !strings.HasPrefix(trimmed, "data:") {
continue
}
dataStr := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if dataStr == "[DONE]" {
break
}
var chunk streamResponse
if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil {
// 解析失败跳过(兼容各种兼容层的差异)
continue
}
if chunk.Error != nil && strings.TrimSpace(chunk.Error.Message) != "" {
return full.String(), fmt.Errorf("openai stream error: %s", chunk.Error.Message)
}
if len(chunk.Choices) == 0 {
continue
}
delta := chunk.Choices[0].Delta.Content
if delta == "" {
delta = chunk.Choices[0].Delta.Text
}
if delta == "" {
continue
}
full.WriteString(delta)
if onDelta != nil {
if err := onDelta(delta); err != nil {
return full.String(), err
}
}
}
c.logger.Debug("received OpenAI stream completion",
zap.Duration("duration", time.Since(requestStart)),
zap.Int("contentLen", full.Len()),
)
return full.String(), nil
}
// StreamToolCall 流式工具调用的累积结果(arguments 以字符串形式拼接,留给上层再解析为 JSON)。
type StreamToolCall struct {
Index int
ID string
Type string
FunctionName string
FunctionArgsStr string
}
// ChatCompletionStreamWithToolCalls 流式模式:同时把 content delta 实时回调,并在结束后返回 tool_calls 和 finish_reason。
func (c *Client) ChatCompletionStreamWithToolCalls(
ctx context.Context,
payload interface{},
onContentDelta func(delta string) error,
) (string, []StreamToolCall, string, error) {
if c == nil {
return "", nil, "", fmt.Errorf("openai client is not initialized")
}
if c.config == nil {
return "", nil, "", fmt.Errorf("openai config is nil")
}
if strings.TrimSpace(c.config.APIKey) == "" {
return "", nil, "", fmt.Errorf("openai api key is empty")
}
if c.isClaude() {
return c.claudeChatCompletionStreamWithToolCalls(ctx, payload, onContentDelta)
}
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
if baseURL == "" {
baseURL = "https://api.openai.com/v1"
}
body, err := json.Marshal(payload)
if err != nil {
return "", nil, "", fmt.Errorf("marshal openai payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
return "", nil, "", fmt.Errorf("build openai request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
requestStart := time.Now()
resp, err := c.httpClient.Do(req)
if err != nil {
return "", nil, "", fmt.Errorf("call openai api: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return "", nil, "", &APIError{
StatusCode: resp.StatusCode,
Body: string(respBody),
}
}
// delta tool_calls 的增量结构
type toolCallFunctionDelta struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
type toolCallDelta struct {
Index int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function toolCallFunctionDelta `json:"function,omitempty"`
}
type streamDelta2 struct {
Content string `json:"content,omitempty"`
Text string `json:"text,omitempty"`
ToolCalls []toolCallDelta `json:"tool_calls,omitempty"`
}
type streamChoice2 struct {
Delta streamDelta2 `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type streamResponse2 struct {
Choices []streamChoice2 `json:"choices"`
Error *struct {
Message string `json:"message"`
Type string `json:"type"`
} `json:"error,omitempty"`
}
type toolCallAccum struct {
id string
typ string
name string
args strings.Builder
}
toolCallAccums := make(map[int]*toolCallAccum)
reader := bufio.NewReader(resp.Body)
var full strings.Builder
finishReason := ""
for {
line, readErr := reader.ReadString('\n')
if readErr != nil {
if readErr == io.EOF {
break
}
return full.String(), nil, finishReason, fmt.Errorf("read openai stream: %w", readErr)
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
if !strings.HasPrefix(trimmed, "data:") {
continue
}
dataStr := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
if dataStr == "[DONE]" {
break
}
var chunk streamResponse2
if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil {
// 兼容:解析失败跳过
continue
}
if chunk.Error != nil && strings.TrimSpace(chunk.Error.Message) != "" {
return full.String(), nil, finishReason, fmt.Errorf("openai stream error: %s", chunk.Error.Message)
}
if len(chunk.Choices) == 0 {
continue
}
choice := chunk.Choices[0]
if choice.FinishReason != nil && strings.TrimSpace(*choice.FinishReason) != "" {
finishReason = strings.TrimSpace(*choice.FinishReason)
}
delta := choice.Delta
content := delta.Content
if content == "" {
content = delta.Text
}
if content != "" {
full.WriteString(content)
if onContentDelta != nil {
if err := onContentDelta(content); err != nil {
return full.String(), nil, finishReason, err
}
}
}
if len(delta.ToolCalls) > 0 {
for _, tc := range delta.ToolCalls {
acc, ok := toolCallAccums[tc.Index]
if !ok {
acc = &toolCallAccum{}
toolCallAccums[tc.Index] = acc
}
if tc.ID != "" {
acc.id = tc.ID
}
if tc.Type != "" {
acc.typ = tc.Type
}
if tc.Function.Name != "" {
acc.name = tc.Function.Name
}
if tc.Function.Arguments != "" {
acc.args.WriteString(tc.Function.Arguments)
}
}
}
}
// 组装 tool calls
indices := make([]int, 0, len(toolCallAccums))
for idx := range toolCallAccums {
indices = append(indices, idx)
}
// 手写简单排序(避免额外 import)
for i := 0; i < len(indices); i++ {
for j := i + 1; j < len(indices); j++ {
if indices[j] < indices[i] {
indices[i], indices[j] = indices[j], indices[i]
}
}
}
toolCalls := make([]StreamToolCall, 0, len(indices))
for _, idx := range indices {
acc := toolCallAccums[idx]
tc := StreamToolCall{
Index: idx,
ID: acc.id,
Type: acc.typ,
FunctionName: acc.name,
FunctionArgsStr: acc.args.String(),
}
toolCalls = append(toolCalls, tc)
}
c.logger.Debug("received OpenAI stream completion (tool_calls)",
zap.Duration("duration", time.Since(requestStart)),
zap.Int("contentLen", full.Len()),
zap.Int("toolCalls", len(toolCalls)),
zap.String("finishReason", finishReason),
)
if strings.TrimSpace(finishReason) == "" {
finishReason = "stop"
}
return full.String(), toolCalls, finishReason, nil
}
-6
View File
@@ -1,6 +0,0 @@
package robot
// MessageHandler 供飞书/钉钉长连接调用的消息处理接口(由 handler.RobotHandler 实现)
type MessageHandler interface {
HandleMessage(platform, userID, text string) string
}
-137
View File
@@ -1,137 +0,0 @@
package robot
import (
"bytes"
"context"
"encoding/json"
"net/http"
"strings"
"time"
"cyberstrike-ai/internal/config"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
dingutils "github.com/open-dingtalk/dingtalk-stream-sdk-go/utils"
"go.uber.org/zap"
)
const (
dingReconnectInitial = 5 * time.Second // 首次重连间隔
dingReconnectMax = 60 * time.Second // 最大重连间隔
)
// StartDing 启动钉钉 Stream 长连接(无需公网),收到消息后调用 handler 并通过 SessionWebhook 回复。
// 断线(如笔记本睡眠、网络中断)后会自动重连;ctx 被取消时退出,便于配置变更时重启。
func StartDing(ctx context.Context, cfg config.RobotDingtalkConfig, h MessageHandler, logger *zap.Logger) {
if !cfg.Enabled || cfg.ClientID == "" || cfg.ClientSecret == "" {
return
}
go runDingLoop(ctx, cfg, h, logger)
}
// runDingLoop 循环维持钉钉长连接:断开且 ctx 未取消时按退避间隔重连。
func runDingLoop(ctx context.Context, cfg config.RobotDingtalkConfig, h MessageHandler, logger *zap.Logger) {
backoff := dingReconnectInitial
for {
streamClient := client.NewStreamClient(
client.WithAppCredential(client.NewAppCredentialConfig(cfg.ClientID, cfg.ClientSecret)),
client.WithSubscription(dingutils.SubscriptionTypeKCallback, "/v1.0/im/bot/messages/get",
chatbot.NewDefaultChatBotFrameHandler(func(ctx context.Context, msg *chatbot.BotCallbackDataModel) ([]byte, error) {
go handleDingMessage(ctx, msg, h, logger)
return nil, nil
}).OnEventReceived),
)
logger.Info("钉钉 Stream 正在连接…", zap.String("client_id", cfg.ClientID))
err := streamClient.Start(ctx)
if ctx.Err() != nil {
logger.Info("钉钉 Stream 已按配置重启关闭")
return
}
if err != nil {
logger.Warn("钉钉 Stream 长连接断开(如睡眠/断网),将自动重连", zap.Error(err), zap.Duration("retry_after", backoff))
}
select {
case <-ctx.Done():
return
case <-time.After(backoff):
// 下次重连间隔递增,上限 60 秒,避免频繁重试
if backoff < dingReconnectMax {
backoff *= 2
if backoff > dingReconnectMax {
backoff = dingReconnectMax
}
}
}
}
}
func handleDingMessage(ctx context.Context, msg *chatbot.BotCallbackDataModel, h MessageHandler, logger *zap.Logger) {
if msg == nil || msg.SessionWebhook == "" {
return
}
content := ""
if msg.Text.Content != "" {
content = strings.TrimSpace(msg.Text.Content)
}
if content == "" && msg.Msgtype == "richText" {
if cMap, ok := msg.Content.(map[string]interface{}); ok {
if rich, ok := cMap["richText"].([]interface{}); ok {
for _, c := range rich {
if m, ok := c.(map[string]interface{}); ok {
if txt, ok := m["text"].(string); ok {
content = strings.TrimSpace(txt)
break
}
}
}
}
}
}
if content == "" {
logger.Debug("钉钉消息内容为空,已忽略", zap.String("msgtype", msg.Msgtype))
return
}
logger.Info("钉钉收到消息", zap.String("sender", msg.SenderId), zap.String("content", content))
userID := msg.SenderId
if userID == "" {
userID = msg.ConversationId
}
reply := h.HandleMessage("dingtalk", userID, content)
// 使用 markdown 类型以便正确展示标题、列表、代码块等格式
title := reply
if idx := strings.IndexAny(reply, "\n"); idx > 0 {
title = strings.TrimSpace(reply[:idx])
}
if len(title) > 50 {
title = title[:50] + "…"
}
if title == "" {
title = "回复"
}
body := map[string]interface{}{
"msgtype": "markdown",
"markdown": map[string]string{
"title": title,
"text": reply,
},
}
bodyBytes, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, msg.SessionWebhook, bytes.NewReader(bodyBytes))
if err != nil {
logger.Warn("钉钉构造回复请求失败", zap.Error(err))
return
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.Warn("钉钉回复请求失败", zap.Error(err))
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.Warn("钉钉回复非 200", zap.Int("status", resp.StatusCode))
return
}
logger.Debug("钉钉回复成功", zap.String("content_preview", reply))
}
-111
View File
@@ -1,111 +0,0 @@
package robot
import (
"context"
"encoding/json"
"strings"
"time"
"cyberstrike-ai/internal/config"
lark "github.com/larksuite/oapi-sdk-go/v3"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
"github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
"go.uber.org/zap"
)
const (
larkReconnectInitial = 5 * time.Second // 首次重连间隔
larkReconnectMax = 60 * time.Second // 最大重连间隔
)
type larkTextContent struct {
Text string `json:"text"`
}
// StartLark 启动飞书长连接(无需公网),收到消息后调用 handler 并回复。
// 断线(如笔记本睡眠、网络中断)后会自动重连;ctx 被取消时退出,便于配置变更时重启。
func StartLark(ctx context.Context, cfg config.RobotLarkConfig, h MessageHandler, logger *zap.Logger) {
if !cfg.Enabled || cfg.AppID == "" || cfg.AppSecret == "" {
return
}
go runLarkLoop(ctx, cfg, h, logger)
}
// runLarkLoop 循环维持飞书长连接:断开且 ctx 未取消时按退避间隔重连。
func runLarkLoop(ctx context.Context, cfg config.RobotLarkConfig, h MessageHandler, logger *zap.Logger) {
backoff := larkReconnectInitial
for {
larkClient := lark.NewClient(cfg.AppID, cfg.AppSecret)
eventHandler := dispatcher.NewEventDispatcher("", "").OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
go handleLarkMessage(ctx, event, h, larkClient, logger)
return nil
})
wsClient := larkws.NewClient(cfg.AppID, cfg.AppSecret,
larkws.WithEventHandler(eventHandler),
larkws.WithLogLevel(larkcore.LogLevelInfo),
)
logger.Info("飞书长连接正在连接…", zap.String("app_id", cfg.AppID))
err := wsClient.Start(ctx)
if ctx.Err() != nil {
logger.Info("飞书长连接已按配置重启关闭")
return
}
if err != nil {
logger.Warn("飞书长连接断开(如睡眠/断网),将自动重连", zap.Error(err), zap.Duration("retry_after", backoff))
}
select {
case <-ctx.Done():
return
case <-time.After(backoff):
if backoff < larkReconnectMax {
backoff *= 2
if backoff > larkReconnectMax {
backoff = larkReconnectMax
}
}
}
}
}
func handleLarkMessage(ctx context.Context, event *larkim.P2MessageReceiveV1, h MessageHandler, client *lark.Client, logger *zap.Logger) {
if event == nil || event.Event == nil || event.Event.Message == nil || event.Event.Sender == nil || event.Event.Sender.SenderId == nil {
return
}
msg := event.Event.Message
msgType := larkcore.StringValue(msg.MessageType)
if msgType != larkim.MsgTypeText {
logger.Debug("飞书暂仅处理文本消息", zap.String("msg_type", msgType))
return
}
var textBody larkTextContent
if err := json.Unmarshal([]byte(larkcore.StringValue(msg.Content)), &textBody); err != nil {
logger.Warn("飞书消息 Content 解析失败", zap.Error(err))
return
}
text := strings.TrimSpace(textBody.Text)
if text == "" {
return
}
userID := ""
if event.Event.Sender.SenderId.UserId != nil {
userID = *event.Event.Sender.SenderId.UserId
}
messageID := larkcore.StringValue(msg.MessageId)
reply := h.HandleMessage("lark", userID, text)
contentBytes, _ := json.Marshal(larkTextContent{Text: reply})
_, err := client.Im.Message.Reply(ctx, larkim.NewReplyMessageReqBuilder().
MessageId(messageID).
Body(larkim.NewReplyMessageReqBodyBuilder().
MsgType(larkim.MsgTypeText).
Content(string(contentBytes)).
Build()).
Build())
if err != nil {
logger.Warn("飞书回复失败", zap.String("message_id", messageID), zap.Error(err))
return
}
logger.Debug("飞书已回复", zap.String("message_id", messageID))
}
-132
View File
@@ -1,132 +0,0 @@
package security
import (
"errors"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
// Predefined errors for authentication operations.
var (
ErrInvalidPassword = errors.New("invalid password")
)
// Session represents an authenticated user session.
type Session struct {
Token string
ExpiresAt time.Time
}
// AuthManager manages password-based authentication and session lifecycle.
type AuthManager struct {
password string
sessionDuration time.Duration
mu sync.RWMutex
sessions map[string]Session
}
// NewAuthManager creates a new AuthManager instance.
func NewAuthManager(password string, sessionDurationHours int) (*AuthManager, error) {
if strings.TrimSpace(password) == "" {
return nil, errors.New("auth password must be configured")
}
if sessionDurationHours <= 0 {
sessionDurationHours = 12
}
return &AuthManager{
password: password,
sessionDuration: time.Duration(sessionDurationHours) * time.Hour,
sessions: make(map[string]Session),
}, nil
}
// Authenticate validates the password and creates a new session.
func (a *AuthManager) Authenticate(password string) (string, time.Time, error) {
if password != a.password {
return "", time.Time{}, ErrInvalidPassword
}
token := uuid.NewString()
expiresAt := time.Now().Add(a.sessionDuration)
a.mu.Lock()
a.sessions[token] = Session{
Token: token,
ExpiresAt: expiresAt,
}
a.mu.Unlock()
return token, expiresAt, nil
}
// ValidateToken checks whether the provided token is still valid.
func (a *AuthManager) ValidateToken(token string) (Session, bool) {
if strings.TrimSpace(token) == "" {
return Session{}, false
}
a.mu.RLock()
session, ok := a.sessions[token]
a.mu.RUnlock()
if !ok {
return Session{}, false
}
if time.Now().After(session.ExpiresAt) {
a.mu.Lock()
delete(a.sessions, token)
a.mu.Unlock()
return Session{}, false
}
return session, true
}
// CheckPassword verifies whether the provided password matches the current password.
func (a *AuthManager) CheckPassword(password string) bool {
a.mu.RLock()
defer a.mu.RUnlock()
return password == a.password
}
// RevokeToken invalidates the specified token.
func (a *AuthManager) RevokeToken(token string) {
if strings.TrimSpace(token) == "" {
return
}
a.mu.Lock()
delete(a.sessions, token)
a.mu.Unlock()
}
// SessionDurationHours returns the configured session duration in hours.
func (a *AuthManager) SessionDurationHours() int {
return int(a.sessionDuration / time.Hour)
}
// UpdateConfig updates the password and session duration, revoking existing sessions.
func (a *AuthManager) UpdateConfig(password string, sessionDurationHours int) error {
password = strings.TrimSpace(password)
if password == "" {
return errors.New("auth password must be configured")
}
if sessionDurationHours <= 0 {
sessionDurationHours = 12
}
a.mu.Lock()
defer a.mu.Unlock()
a.password = password
a.sessionDuration = time.Duration(sessionDurationHours) * time.Hour
a.sessions = make(map[string]Session)
return nil
}
-51
View File
@@ -1,51 +0,0 @@
package security
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
const (
ContextAuthTokenKey = "authToken"
ContextSessionExpiry = "authSessionExpiry"
)
// AuthMiddleware enforces authentication on protected routes.
func AuthMiddleware(manager *AuthManager) gin.HandlerFunc {
return func(c *gin.Context) {
token := extractTokenFromRequest(c)
session, ok := manager.ValidateToken(token)
if !ok {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "未授权访问,请先登录",
})
return
}
c.Set(ContextAuthTokenKey, session.Token)
c.Set(ContextSessionExpiry, session.ExpiresAt)
c.Next()
}
}
func extractTokenFromRequest(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
if len(authHeader) > 7 && strings.EqualFold(authHeader[0:7], "Bearer ") {
return strings.TrimSpace(authHeader[7:])
}
return strings.TrimSpace(authHeader)
}
if token := c.Query("token"); token != "" {
return strings.TrimSpace(token)
}
if cookie, err := c.Cookie("auth_token"); err == nil {
return strings.TrimSpace(cookie)
}
return ""
}
File diff suppressed because it is too large Load Diff
-268
View File
@@ -1,268 +0,0 @@
package security
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/storage"
"go.uber.org/zap"
)
// setupTestExecutor 创建测试用的执行器
func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
cfg := &config.SecurityConfig{
Tools: []config.ToolConfig{},
}
executor := NewExecutor(cfg, mcpServer, logger)
return executor, mcpServer
}
// setupTestStorage 创建测试用的存储
func setupTestStorage(t *testing.T) *storage.FileResultStorage {
tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405"))
logger := zap.NewNop()
storage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
return storage
}
func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
executor, _ := setupTestExecutor(t)
testStorage := setupTestStorage(t)
executor.SetResultStorage(testStorage)
// 准备测试数据
executionID := "test_exec_001"
toolName := "nmap_scan"
result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred"
// 保存测试结果
err := testStorage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存测试结果失败: %v", err)
}
ctx := context.Background()
// 测试1: 基本查询(第一页)
args := map[string]interface{}{
"execution_id": executionID,
"page": float64(1),
"limit": float64(2),
}
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if toolResult.IsError {
t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text)
}
// 验证结果包含预期内容
resultText := toolResult.Content[0].Text
if !strings.Contains(resultText, executionID) {
t.Errorf("结果中应该包含执行ID: %s", executionID)
}
if !strings.Contains(resultText, "第 1/") {
t.Errorf("结果中应该包含分页信息")
}
// 测试2: 搜索功能
args2 := map[string]interface{}{
"execution_id": executionID,
"search": "error",
"page": float64(1),
"limit": float64(10),
}
toolResult2, err := executor.executeQueryExecutionResult(ctx, args2)
if err != nil {
t.Fatalf("执行搜索失败: %v", err)
}
if toolResult2.IsError {
t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text)
}
resultText2 := toolResult2.Content[0].Text
if !strings.Contains(resultText2, "error") {
t.Errorf("搜索结果中应该包含关键词: error")
}
// 测试3: 过滤功能
args3 := map[string]interface{}{
"execution_id": executionID,
"filter": "Port",
"page": float64(1),
"limit": float64(10),
}
toolResult3, err := executor.executeQueryExecutionResult(ctx, args3)
if err != nil {
t.Fatalf("执行过滤失败: %v", err)
}
if toolResult3.IsError {
t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text)
}
resultText3 := toolResult3.Content[0].Text
if !strings.Contains(resultText3, "Port") {
t.Errorf("过滤结果中应该包含关键词: Port")
}
// 测试4: 缺少必需参数
args4 := map[string]interface{}{
"page": float64(1),
}
toolResult4, err := executor.executeQueryExecutionResult(ctx, args4)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult4.IsError {
t.Fatal("缺少execution_id应该返回错误")
}
// 测试5: 不存在的执行ID
args5 := map[string]interface{}{
"execution_id": "nonexistent_id",
"page": float64(1),
}
toolResult5, err := executor.executeQueryExecutionResult(ctx, args5)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult5.IsError {
t.Fatal("不存在的执行ID应该返回错误")
}
}
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
executor, _ := setupTestExecutor(t)
ctx := context.Background()
args := map[string]interface{}{
"test": "value",
}
// 测试未知的内部工具类型
toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args)
if err != nil {
t.Fatalf("执行内部工具失败: %v", err)
}
if !toolResult.IsError {
t.Fatal("未知的工具类型应该返回错误")
}
if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") {
t.Errorf("错误消息应该包含'未知的内部工具类型'")
}
}
func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
executor, _ := setupTestExecutor(t)
// 不设置存储,测试未初始化的情况
ctx := context.Background()
args := map[string]interface{}{
"execution_id": "test_id",
}
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult.IsError {
t.Fatal("未初始化的存储应该返回错误")
}
if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") {
t.Errorf("错误消息应该包含'结果存储未初始化'")
}
}
func TestPaginateLines(t *testing.T) {
lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"}
// 测试第一页
page := paginateLines(lines, 1, 2)
if page.Page != 1 {
t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page)
}
if page.Limit != 2 {
t.Errorf("每页行数不匹配。期望: 2, 实际: %d", page.Limit)
}
if page.TotalLines != 5 {
t.Errorf("总行数不匹配。期望: 5, 实际: %d", page.TotalLines)
}
if page.TotalPages != 3 {
t.Errorf("总页数不匹配。期望: 3, 实际: %d", page.TotalPages)
}
if len(page.Lines) != 2 {
t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines))
}
// 测试第二页
page2 := paginateLines(lines, 2, 2)
if len(page2.Lines) != 2 {
t.Errorf("第二页行数不匹配。期望: 2, 实际: %d", len(page2.Lines))
}
if page2.Lines[0] != "Line 3" {
t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0])
}
// 测试最后一页
page3 := paginateLines(lines, 3, 2)
if len(page3.Lines) != 1 {
t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines))
}
// 测试超出范围的页码(应该返回最后一页)
page4 := paginateLines(lines, 4, 2)
if page4.Page != 3 {
t.Errorf("超出范围的页码应该被修正为最后一页。期望: 3, 实际: %d", page4.Page)
}
if len(page4.Lines) != 1 {
t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines))
}
// 测试无效页码(小于1
page0 := paginateLines(lines, 0, 2)
if page0.Page != 1 {
t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page)
}
// 测试空列表
emptyPage := paginateLines([]string{}, 1, 10)
if emptyPage.TotalLines != 0 {
t.Errorf("空列表的总行数应该为0。实际: %d", emptyPage.TotalLines)
}
if len(emptyPage.Lines) != 0 {
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
}
}
-274
View File
@@ -1,274 +0,0 @@
package skills
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"go.uber.org/zap"
)
// Manager Skills管理器
type Manager struct {
skillsDir string
logger *zap.Logger
skills map[string]*cachedSkill // 缓存已加载的skills(含文件状态)
mu sync.RWMutex // 保护skills map的并发访问
}
type cachedSkill struct {
skill *Skill
filePath string
modTime int64
}
// Skill Skill定义
type Skill struct {
Name string // Skill名称
Description string // Skill描述
Content string // Skill内容(从SKILL.md中提取)
Path string // Skill路径
}
// NewManager 创建新的Skills管理器
func NewManager(skillsDir string, logger *zap.Logger) *Manager {
return &Manager{
skillsDir: skillsDir,
logger: logger,
skills: make(map[string]*cachedSkill),
}
}
// LoadSkill 加载单个skill
func (m *Manager) LoadSkill(skillName string) (*Skill, error) {
// 构建skill路径
skillPath := filepath.Join(m.skillsDir, skillName)
// 检查目录是否存在
if _, err := os.Stat(skillPath); os.IsNotExist(err) {
m.InvalidateSkill(skillName)
return nil, fmt.Errorf("skill %s not found", skillName)
}
// 查找skill文件并读取文件状态
skillFile, err := m.resolveSkillFile(skillPath)
if err != nil {
m.InvalidateSkill(skillName)
return nil, err
}
fileInfo, err := os.Stat(skillFile)
if err != nil {
m.InvalidateSkill(skillName)
return nil, fmt.Errorf("failed to stat skill file: %w", err)
}
modTime := fileInfo.ModTime().UnixNano()
// 先尝试读锁命中缓存(文件路径和修改时间都未变化)
m.mu.RLock()
if cached, exists := m.skills[skillName]; exists &&
cached.filePath == skillFile &&
cached.modTime == modTime {
m.mu.RUnlock()
return cached.skill, nil
}
m.mu.RUnlock()
// 读取skill文件
content, err := os.ReadFile(skillFile)
if err != nil {
return nil, fmt.Errorf("failed to read skill file: %w", err)
}
// 解析skill内容
skill := m.parseSkillContent(string(content), skillName, skillPath)
// 使用写锁更新缓存
m.mu.Lock()
m.skills[skillName] = &cachedSkill{
skill: skill,
filePath: skillFile,
modTime: modTime,
}
m.mu.Unlock()
return skill, nil
}
// LoadSkills 批量加载skills
func (m *Manager) LoadSkills(skillNames []string) ([]*Skill, error) {
var skills []*Skill
var errors []string
for _, name := range skillNames {
skill, err := m.LoadSkill(name)
if err != nil {
errors = append(errors, fmt.Sprintf("failed to load skill %s: %v", name, err))
m.logger.Warn("加载skill失败", zap.String("skill", name), zap.Error(err))
continue
}
skills = append(skills, skill)
}
if len(errors) > 0 && len(skills) == 0 {
return nil, fmt.Errorf("failed to load any skills: %s", strings.Join(errors, "; "))
}
return skills, nil
}
// ListSkills 列出所有可用的skills
func (m *Manager) ListSkills() ([]string, error) {
if _, err := os.Stat(m.skillsDir); os.IsNotExist(err) {
return []string{}, nil
}
entries, err := os.ReadDir(m.skillsDir)
if err != nil {
return nil, fmt.Errorf("failed to read skills directory: %w", err)
}
var skills []string
for _, entry := range entries {
if !entry.IsDir() {
continue
}
skillName := entry.Name()
// 检查是否有SKILL.md文件
skillFile := filepath.Join(m.skillsDir, skillName, "SKILL.md")
if _, err := os.Stat(skillFile); err == nil {
skills = append(skills, skillName)
continue
}
// 尝试其他可能的文件名
alternatives := []string{
filepath.Join(m.skillsDir, skillName, "skill.md"),
filepath.Join(m.skillsDir, skillName, "README.md"),
filepath.Join(m.skillsDir, skillName, "readme.md"),
}
for _, alt := range alternatives {
if _, err := os.Stat(alt); err == nil {
skills = append(skills, skillName)
break
}
}
}
return skills, nil
}
func (m *Manager) resolveSkillFile(skillPath string) (string, error) {
// 优先标准文件名
skillFile := filepath.Join(skillPath, "SKILL.md")
if _, err := os.Stat(skillFile); err == nil {
return skillFile, nil
}
// 兼容历史文件名
alternatives := []string{
filepath.Join(skillPath, "skill.md"),
filepath.Join(skillPath, "README.md"),
filepath.Join(skillPath, "readme.md"),
}
for _, alt := range alternatives {
if _, err := os.Stat(alt); err == nil {
return alt, nil
}
}
return "", fmt.Errorf("skill file not found for %s", filepath.Base(skillPath))
}
// InvalidateSkill 使指定skill缓存失效
func (m *Manager) InvalidateSkill(skillName string) {
m.mu.Lock()
delete(m.skills, skillName)
m.mu.Unlock()
}
// InvalidateAll 清空全部skill缓存
func (m *Manager) InvalidateAll() {
m.mu.Lock()
m.skills = make(map[string]*cachedSkill)
m.mu.Unlock()
}
// parseSkillContent 解析skill内容
// 支持YAML front matter格式,类似goskills
func (m *Manager) parseSkillContent(content, skillName, skillPath string) *Skill {
skill := &Skill{
Name: skillName,
Path: skillPath,
}
// 检查是否有YAML front matter
if strings.HasPrefix(content, "---") {
parts := strings.SplitN(content, "---", 3)
if len(parts) >= 3 {
// 解析front matter(简单实现,只提取name和description
frontMatter := parts[1]
lines := strings.Split(frontMatter, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "name:") {
name := strings.TrimSpace(strings.TrimPrefix(line, "name:"))
name = strings.Trim(name, `"'"`)
if name != "" {
skill.Name = name
}
} else if strings.HasPrefix(line, "description:") {
desc := strings.TrimSpace(strings.TrimPrefix(line, "description:"))
desc = strings.Trim(desc, `"'"`)
skill.Description = desc
}
}
// 剩余部分是内容
if len(parts) == 3 {
skill.Content = strings.TrimSpace(parts[2])
}
} else {
// 没有front matter,整个内容就是skill内容
skill.Content = content
}
} else {
// 没有front matter,整个内容就是skill内容
skill.Content = content
}
// 如果内容为空,使用描述作为内容
if skill.Content == "" {
skill.Content = skill.Description
}
return skill
}
// GetSkillContent 获取skill的完整内容(用于注入到系统提示词)
func (m *Manager) GetSkillContent(skillNames []string) (string, error) {
skills, err := m.LoadSkills(skillNames)
if err != nil {
return "", err
}
if len(skills) == 0 {
return "", nil
}
var builder strings.Builder
builder.WriteString("## 可用Skills\n\n")
builder.WriteString("在执行任务前,请仔细阅读以下skills内容,这些内容包含了相关的专业知识和方法:\n\n")
for _, skill := range skills {
builder.WriteString(fmt.Sprintf("### Skill: %s\n", skill.Name))
if skill.Description != "" {
builder.WriteString(fmt.Sprintf("**描述**: %s\n\n", skill.Description))
}
builder.WriteString(skill.Content)
builder.WriteString("\n\n---\n\n")
}
return builder.String(), nil
}
-201
View File
@@ -1,201 +0,0 @@
package skills
import (
"context"
"fmt"
"strings"
"time"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
// RegisterSkillsTool 注册Skills工具到MCP服务器
func RegisterSkillsTool(
mcpServer *mcp.Server,
manager *Manager,
logger *zap.Logger,
) {
RegisterSkillsToolWithStorage(mcpServer, manager, nil, logger)
}
// RegisterSkillsToolWithStorage 注册Skills工具到MCP服务器(带存储支持)
func RegisterSkillsToolWithStorage(
mcpServer *mcp.Server,
manager *Manager,
storage SkillStatsStorage,
logger *zap.Logger,
) {
// 注册第一个工具:获取所有可用的skills列表
listSkillsTool := mcp.Tool{
Name: builtin.ToolListSkills,
Description: "获取所有可用的skills列表。Skills是专业知识文档,可以在执行任务前阅读以获取相关专业知识。使用此工具可以查看系统中所有可用的skills,然后使用read_skill工具读取特定skill的内容。",
ShortDescription: "获取所有可用的skills列表",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
"required": []string{},
},
}
listSkillsHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
skills, err := manager.ListSkills()
if err != nil {
logger.Error("获取skills列表失败", zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("获取skills列表失败: %v", err),
},
},
IsError: true,
}, nil
}
if len(skills) == 0 {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "当前没有可用的skills。\n\nSkills是专业知识文档,可以在执行任务前阅读以获取相关专业知识。你可以在skills目录下创建新的skill。",
},
},
IsError: false,
}, nil
}
var result strings.Builder
result.WriteString(fmt.Sprintf("共有 %d 个可用的skills\n\n", len(skills)))
for i, skill := range skills {
result.WriteString(fmt.Sprintf("%d. %s\n", i+1, skill))
}
result.WriteString("\n使用 read_skill 工具可以读取特定skill的详细内容。\n")
result.WriteString("例如:read_skill(skill_name=\"sql-injection-testing\")")
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: result.String(),
},
},
IsError: false,
}, nil
}
mcpServer.RegisterTool(listSkillsTool, listSkillsHandler)
logger.Info("注册skills列表工具成功")
// 注册第二个工具:读取特定skill的内容
readSkillTool := mcp.Tool{
Name: builtin.ToolReadSkill,
Description: "读取指定skill的详细内容。Skills是专业知识文档,包含测试方法、工具使用、最佳实践等。在执行相关任务前,可以调用此工具读取相关skill的内容,以获取专业知识和指导。",
ShortDescription: "读取指定skill的详细内容",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"skill_name": map[string]interface{}{
"type": "string",
"description": "要读取的skill名称(必需)。可以使用list_skills工具获取所有可用的skill名称。",
},
},
"required": []string{"skill_name"},
},
}
readSkillHandler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
skillName, ok := args["skill_name"].(string)
if !ok || skillName == "" {
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: "错误: skill_name 参数必需且不能为空。请使用list_skills工具获取所有可用的skill名称。",
},
},
IsError: true,
}, nil
}
skill, err := manager.LoadSkill(skillName)
failed := err != nil
now := time.Now()
// 记录调用统计
if storage != nil {
totalCalls := 1
successCalls := 0
failedCalls := 0
if failed {
failedCalls = 1
} else {
successCalls = 1
}
if err := storage.UpdateSkillStats(skillName, totalCalls, successCalls, failedCalls, &now); err != nil {
logger.Warn("保存Skills统计信息失败", zap.String("skill", skillName), zap.Error(err))
} else {
logger.Info("Skills统计信息已更新",
zap.String("skill", skillName),
zap.Int("totalCalls", totalCalls),
zap.Int("successCalls", successCalls),
zap.Int("failedCalls", failedCalls))
}
} else {
logger.Warn("Skills统计存储未配置,无法记录调用统计", zap.String("skill", skillName))
}
if err != nil {
logger.Warn("读取skill失败", zap.String("skill", skillName), zap.Error(err))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: fmt.Sprintf("读取skill失败: %v\n\n请使用list_skills工具确认skill名称是否正确。", err),
},
},
IsError: true,
}, nil
}
var result strings.Builder
result.WriteString(fmt.Sprintf("## Skill: %s\n\n", skill.Name))
if skill.Description != "" {
result.WriteString(fmt.Sprintf("**描述**: %s\n\n", skill.Description))
}
result.WriteString("---\n\n")
result.WriteString(skill.Content)
result.WriteString("\n\n---\n\n")
result.WriteString(fmt.Sprintf("*Skill路径: %s*", skill.Path))
return &mcp.ToolResult{
Content: []mcp.Content{
{
Type: "text",
Text: result.String(),
},
},
IsError: false,
}, nil
}
mcpServer.RegisterTool(readSkillTool, readSkillHandler)
logger.Info("注册skill读取工具成功")
}
// SkillStatsStorage Skills统计存储接口
type SkillStatsStorage interface {
UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error
LoadSkillStats() (map[string]*SkillStats, error)
}
// SkillStats Skills统计信息
type SkillStats struct {
SkillName string
TotalCalls int
SuccessCalls int
FailedCalls int
LastCallTime *time.Time
}
-297
View File
@@ -1,297 +0,0 @@
package storage
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"go.uber.org/zap"
)
// ResultStorage 结果存储接口
type ResultStorage interface {
// SaveResult 保存工具执行结果
SaveResult(executionID string, toolName string, result string) error
// GetResult 获取完整结果
GetResult(executionID string) (string, error)
// GetResultPage 分页获取结果
GetResultPage(executionID string, page int, limit int) (*ResultPage, error)
// SearchResult 搜索结果
// useRegex: 如果为 true,将 keyword 作为正则表达式使用;如果为 false,使用简单的字符串包含匹配
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
// FilterResult 过滤结果
// useRegex: 如果为 true,将 filter 作为正则表达式使用;如果为 false,使用简单的字符串包含匹配
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
// GetResultMetadata 获取结果元信息
GetResultMetadata(executionID string) (*ResultMetadata, error)
// GetResultPath 获取结果文件路径
GetResultPath(executionID string) string
// DeleteResult 删除结果
DeleteResult(executionID string) error
}
// ResultPage 分页结果
type ResultPage struct {
Lines []string `json:"lines"`
Page int `json:"page"`
Limit int `json:"limit"`
TotalLines int `json:"total_lines"`
TotalPages int `json:"total_pages"`
}
// ResultMetadata 结果元信息
type ResultMetadata struct {
ExecutionID string `json:"execution_id"`
ToolName string `json:"tool_name"`
TotalSize int `json:"total_size"`
TotalLines int `json:"total_lines"`
CreatedAt time.Time `json:"created_at"`
}
// FileResultStorage 基于文件的结果存储实现
type FileResultStorage struct {
baseDir string
logger *zap.Logger
mu sync.RWMutex
}
// NewFileResultStorage 创建新的文件结果存储
func NewFileResultStorage(baseDir string, logger *zap.Logger) (*FileResultStorage, error) {
// 确保目录存在
if err := os.MkdirAll(baseDir, 0755); err != nil {
return nil, fmt.Errorf("创建存储目录失败: %w", err)
}
return &FileResultStorage{
baseDir: baseDir,
logger: logger,
}, nil
}
// getResultPath 获取结果文件路径
func (s *FileResultStorage) getResultPath(executionID string) string {
return filepath.Join(s.baseDir, executionID+".txt")
}
// getMetadataPath 获取元数据文件路径
func (s *FileResultStorage) getMetadataPath(executionID string) string {
return filepath.Join(s.baseDir, executionID+".meta.json")
}
// SaveResult 保存工具执行结果
func (s *FileResultStorage) SaveResult(executionID string, toolName string, result string) error {
s.mu.Lock()
defer s.mu.Unlock()
// 保存结果文件
resultPath := s.getResultPath(executionID)
if err := os.WriteFile(resultPath, []byte(result), 0644); err != nil {
return fmt.Errorf("保存结果文件失败: %w", err)
}
// 计算统计信息
lines := strings.Split(result, "\n")
metadata := &ResultMetadata{
ExecutionID: executionID,
ToolName: toolName,
TotalSize: len(result),
TotalLines: len(lines),
CreatedAt: time.Now(),
}
// 保存元数据
metadataPath := s.getMetadataPath(executionID)
metadataJSON, err := json.Marshal(metadata)
if err != nil {
return fmt.Errorf("序列化元数据失败: %w", err)
}
if err := os.WriteFile(metadataPath, metadataJSON, 0644); err != nil {
return fmt.Errorf("保存元数据文件失败: %w", err)
}
s.logger.Info("保存工具执行结果",
zap.String("executionID", executionID),
zap.String("toolName", toolName),
zap.Int("size", len(result)),
zap.Int("lines", len(lines)),
)
return nil
}
// GetResult 获取完整结果
func (s *FileResultStorage) GetResult(executionID string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
resultPath := s.getResultPath(executionID)
data, err := os.ReadFile(resultPath)
if err != nil {
if os.IsNotExist(err) {
return "", fmt.Errorf("结果不存在: %s", executionID)
}
return "", fmt.Errorf("读取结果文件失败: %w", err)
}
return string(data), nil
}
// GetResultMetadata 获取结果元信息
func (s *FileResultStorage) GetResultMetadata(executionID string) (*ResultMetadata, error) {
s.mu.RLock()
defer s.mu.RUnlock()
metadataPath := s.getMetadataPath(executionID)
data, err := os.ReadFile(metadataPath)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("结果不存在: %s", executionID)
}
return nil, fmt.Errorf("读取元数据文件失败: %w", err)
}
var metadata ResultMetadata
if err := json.Unmarshal(data, &metadata); err != nil {
return nil, fmt.Errorf("解析元数据失败: %w", err)
}
return &metadata, nil
}
// GetResultPage 分页获取结果
func (s *FileResultStorage) GetResultPage(executionID string, page int, limit int) (*ResultPage, error) {
s.mu.RLock()
defer s.mu.RUnlock()
// 获取完整结果
result, err := s.GetResult(executionID)
if err != nil {
return nil, err
}
// 分割为行
lines := strings.Split(result, "\n")
totalLines := len(lines)
// 计算分页
totalPages := (totalLines + limit - 1) / limit
if page < 1 {
page = 1
}
if page > totalPages && totalPages > 0 {
page = totalPages
}
// 计算起始和结束索引
start := (page - 1) * limit
end := start + limit
if end > totalLines {
end = totalLines
}
// 提取指定页的行
var pageLines []string
if start < totalLines {
pageLines = lines[start:end]
} else {
pageLines = []string{}
}
return &ResultPage{
Lines: pageLines,
Page: page,
Limit: limit,
TotalLines: totalLines,
TotalPages: totalPages,
}, nil
}
// SearchResult 搜索结果
func (s *FileResultStorage) SearchResult(executionID string, keyword string, useRegex bool) ([]string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
// 获取完整结果
result, err := s.GetResult(executionID)
if err != nil {
return nil, err
}
// 如果使用正则表达式,先编译正则
var regex *regexp.Regexp
if useRegex {
compiledRegex, err := regexp.Compile(keyword)
if err != nil {
return nil, fmt.Errorf("无效的正则表达式: %w", err)
}
regex = compiledRegex
}
// 分割为行并搜索
lines := strings.Split(result, "\n")
var matchedLines []string
for _, line := range lines {
var matched bool
if useRegex {
matched = regex.MatchString(line)
} else {
matched = strings.Contains(line, keyword)
}
if matched {
matchedLines = append(matchedLines, line)
}
}
return matchedLines, nil
}
// FilterResult 过滤结果
func (s *FileResultStorage) FilterResult(executionID string, filter string, useRegex bool) ([]string, error) {
// 过滤和搜索逻辑相同,都是查找包含关键词的行
return s.SearchResult(executionID, filter, useRegex)
}
// GetResultPath 获取结果文件路径
func (s *FileResultStorage) GetResultPath(executionID string) string {
return s.getResultPath(executionID)
}
// DeleteResult 删除结果
func (s *FileResultStorage) DeleteResult(executionID string) error {
s.mu.Lock()
defer s.mu.Unlock()
resultPath := s.getResultPath(executionID)
metadataPath := s.getMetadataPath(executionID)
// 删除结果文件
if err := os.Remove(resultPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("删除结果文件失败: %w", err)
}
// 删除元数据文件
if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("删除元数据文件失败: %w", err)
}
s.logger.Info("删除工具执行结果",
zap.String("executionID", executionID),
)
return nil
}
-453
View File
@@ -1,453 +0,0 @@
package storage
import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
"go.uber.org/zap"
)
// setupTestStorage 创建测试用的存储实例
func setupTestStorage(t *testing.T) (*FileResultStorage, string) {
tmpDir := filepath.Join(os.TempDir(), "test_result_storage_"+time.Now().Format("20060102_150405"))
logger := zap.NewNop()
storage, err := NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
return storage, tmpDir
}
// cleanupTestStorage 清理测试数据
func cleanupTestStorage(t *testing.T, tmpDir string) {
if err := os.RemoveAll(tmpDir); err != nil {
t.Logf("清理测试目录失败: %v", err)
}
}
func TestNewFileResultStorage(t *testing.T) {
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
defer cleanupTestStorage(t, tmpDir)
logger := zap.NewNop()
storage, err := NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建存储失败: %v", err)
}
if storage == nil {
t.Fatal("存储实例为nil")
}
// 验证目录已创建
if _, err := os.Stat(tmpDir); os.IsNotExist(err) {
t.Fatal("存储目录未创建")
}
}
func TestFileResultStorage_SaveResult(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_001"
toolName := "nmap_scan"
result := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 验证结果文件存在
resultPath := filepath.Join(tmpDir, executionID+".txt")
if _, err := os.Stat(resultPath); os.IsNotExist(err) {
t.Fatal("结果文件未创建")
}
// 验证元数据文件存在
metadataPath := filepath.Join(tmpDir, executionID+".meta.json")
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
t.Fatal("元数据文件未创建")
}
}
func TestFileResultStorage_GetResult(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_002"
toolName := "test_tool"
expectedResult := "Test result content\nLine 2\nLine 3"
// 先保存结果
err := storage.SaveResult(executionID, toolName, expectedResult)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 获取结果
result, err := storage.GetResult(executionID)
if err != nil {
t.Fatalf("获取结果失败: %v", err)
}
if result != expectedResult {
t.Errorf("结果不匹配。期望: %q, 实际: %q", expectedResult, result)
}
// 测试不存在的执行ID
_, err = storage.GetResult("nonexistent_id")
if err == nil {
t.Fatal("应该返回错误")
}
}
func TestFileResultStorage_GetResultMetadata(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_003"
toolName := "test_tool"
result := "Line 1\nLine 2\nLine 3"
// 保存结果
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 获取元数据
metadata, err := storage.GetResultMetadata(executionID)
if err != nil {
t.Fatalf("获取元数据失败: %v", err)
}
if metadata.ExecutionID != executionID {
t.Errorf("执行ID不匹配。期望: %s, 实际: %s", executionID, metadata.ExecutionID)
}
if metadata.ToolName != toolName {
t.Errorf("工具名称不匹配。期望: %s, 实际: %s", toolName, metadata.ToolName)
}
if metadata.TotalSize != len(result) {
t.Errorf("总大小不匹配。期望: %d, 实际: %d", len(result), metadata.TotalSize)
}
expectedLines := len(strings.Split(result, "\n"))
if metadata.TotalLines != expectedLines {
t.Errorf("总行数不匹配。期望: %d, 实际: %d", expectedLines, metadata.TotalLines)
}
// 验证创建时间在合理范围内
now := time.Now()
if metadata.CreatedAt.After(now) || metadata.CreatedAt.Before(now.Add(-time.Second)) {
t.Errorf("创建时间不在合理范围内: %v", metadata.CreatedAt)
}
}
func TestFileResultStorage_GetResultPage(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_004"
toolName := "test_tool"
// 创建包含10行的结果
lines := make([]string, 10)
for i := 0; i < 10; i++ {
lines[i] = fmt.Sprintf("Line %d", i+1)
}
result := strings.Join(lines, "\n")
// 保存结果
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 测试第一页(每页3行)
page, err := storage.GetResultPage(executionID, 1, 3)
if err != nil {
t.Fatalf("获取第一页失败: %v", err)
}
if page.Page != 1 {
t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page)
}
if page.Limit != 3 {
t.Errorf("每页行数不匹配。期望: 3, 实际: %d", page.Limit)
}
if page.TotalLines != 10 {
t.Errorf("总行数不匹配。期望: 10, 实际: %d", page.TotalLines)
}
if page.TotalPages != 4 {
t.Errorf("总页数不匹配。期望: 4, 实际: %d", page.TotalPages)
}
if len(page.Lines) != 3 {
t.Errorf("第一页行数不匹配。期望: 3, 实际: %d", len(page.Lines))
}
if page.Lines[0] != "Line 1" {
t.Errorf("第一行内容不匹配。期望: Line 1, 实际: %s", page.Lines[0])
}
// 测试第二页
page2, err := storage.GetResultPage(executionID, 2, 3)
if err != nil {
t.Fatalf("获取第二页失败: %v", err)
}
if len(page2.Lines) != 3 {
t.Errorf("第二页行数不匹配。期望: 3, 实际: %d", len(page2.Lines))
}
if page2.Lines[0] != "Line 4" {
t.Errorf("第二页第一行内容不匹配。期望: Line 4, 实际: %s", page2.Lines[0])
}
// 测试最后一页(可能不满一页)
page4, err := storage.GetResultPage(executionID, 4, 3)
if err != nil {
t.Fatalf("获取第四页失败: %v", err)
}
if len(page4.Lines) != 1 {
t.Errorf("第四页行数不匹配。期望: 1, 实际: %d", len(page4.Lines))
}
// 测试超出范围的页码(应该返回最后一页)
page5, err := storage.GetResultPage(executionID, 5, 3)
if err != nil {
t.Fatalf("获取第五页失败: %v", err)
}
// 超出范围的页码会被修正为最后一页,所以应该返回最后一页的内容
if page5.Page != 4 {
t.Errorf("超出范围的页码应该被修正为最后一页。期望: 4, 实际: %d", page5.Page)
}
// 最后一页应该只有1行
if len(page5.Lines) != 1 {
t.Errorf("最后一页应该只有1行。实际: %d行", len(page5.Lines))
}
}
func TestFileResultStorage_SearchResult(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_005"
toolName := "test_tool"
result := "Line 1: error occurred\nLine 2: success\nLine 3: error again\nLine 4: ok"
// 保存结果
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 搜索包含"error"的行(简单字符串匹配)
matchedLines, err := storage.SearchResult(executionID, "error", false)
if err != nil {
t.Fatalf("搜索失败: %v", err)
}
if len(matchedLines) != 2 {
t.Errorf("搜索结果数量不匹配。期望: 2, 实际: %d", len(matchedLines))
}
// 验证搜索结果内容
for i, line := range matchedLines {
if !strings.Contains(line, "error") {
t.Errorf("搜索结果第%d行不包含关键词: %s", i+1, line)
}
}
// 测试搜索不存在的关键词
noMatch, err := storage.SearchResult(executionID, "nonexistent", false)
if err != nil {
t.Fatalf("搜索失败: %v", err)
}
if len(noMatch) != 0 {
t.Errorf("搜索不存在的关键词应该返回空结果。实际: %d行", len(noMatch))
}
// 测试正则表达式搜索
regexMatched, err := storage.SearchResult(executionID, "error.*again", true)
if err != nil {
t.Fatalf("正则搜索失败: %v", err)
}
if len(regexMatched) != 1 {
t.Errorf("正则搜索结果数量不匹配。期望: 1, 实际: %d", len(regexMatched))
}
}
func TestFileResultStorage_FilterResult(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_006"
toolName := "test_tool"
result := "Line 1: warning message\nLine 2: info message\nLine 3: warning again\nLine 4: debug message"
// 保存结果
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 过滤包含"warning"的行(简单字符串匹配)
filteredLines, err := storage.FilterResult(executionID, "warning", false)
if err != nil {
t.Fatalf("过滤失败: %v", err)
}
if len(filteredLines) != 2 {
t.Errorf("过滤结果数量不匹配。期望: 2, 实际: %d", len(filteredLines))
}
// 验证过滤结果内容
for i, line := range filteredLines {
if !strings.Contains(line, "warning") {
t.Errorf("过滤结果第%d行不包含关键词: %s", i+1, line)
}
}
}
func TestFileResultStorage_DeleteResult(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_007"
toolName := "test_tool"
result := "Test result"
// 保存结果
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存结果失败: %v", err)
}
// 验证文件存在
resultPath := filepath.Join(tmpDir, executionID+".txt")
metadataPath := filepath.Join(tmpDir, executionID+".meta.json")
if _, err := os.Stat(resultPath); os.IsNotExist(err) {
t.Fatal("结果文件不存在")
}
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
t.Fatal("元数据文件不存在")
}
// 删除结果
err = storage.DeleteResult(executionID)
if err != nil {
t.Fatalf("删除结果失败: %v", err)
}
// 验证文件已删除
if _, err := os.Stat(resultPath); !os.IsNotExist(err) {
t.Fatal("结果文件未被删除")
}
if _, err := os.Stat(metadataPath); !os.IsNotExist(err) {
t.Fatal("元数据文件未被删除")
}
// 测试删除不存在的执行ID(应该不报错)
err = storage.DeleteResult("nonexistent_id")
if err != nil {
t.Errorf("删除不存在的执行ID不应该报错: %v", err)
}
}
func TestFileResultStorage_ConcurrentAccess(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
// 并发保存多个结果
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(id int) {
executionID := fmt.Sprintf("test_exec_%d", id)
toolName := "test_tool"
result := fmt.Sprintf("Result %d\nLine 2\nLine 3", id)
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Errorf("并发保存失败 (ID: %s): %v", executionID, err)
}
// 并发读取
_, err = storage.GetResult(executionID)
if err != nil {
t.Errorf("并发读取失败 (ID: %s): %v", executionID, err)
}
done <- true
}(i)
}
// 等待所有goroutine完成
for i := 0; i < 10; i++ {
<-done
}
}
func TestFileResultStorage_LargeResult(t *testing.T) {
storage, tmpDir := setupTestStorage(t)
defer cleanupTestStorage(t, tmpDir)
executionID := "test_exec_large"
toolName := "test_tool"
// 创建大结果(1000行)
lines := make([]string, 1000)
for i := 0; i < 1000; i++ {
lines[i] = fmt.Sprintf("Line %d: This is a test line with some content", i+1)
}
result := strings.Join(lines, "\n")
// 保存大结果
err := storage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存大结果失败: %v", err)
}
// 验证元数据
metadata, err := storage.GetResultMetadata(executionID)
if err != nil {
t.Fatalf("获取元数据失败: %v", err)
}
if metadata.TotalLines != 1000 {
t.Errorf("总行数不匹配。期望: 1000, 实际: %d", metadata.TotalLines)
}
// 测试分页查询大结果
page, err := storage.GetResultPage(executionID, 1, 100)
if err != nil {
t.Fatalf("获取第一页失败: %v", err)
}
if page.TotalPages != 10 {
t.Errorf("总页数不匹配。期望: 10, 实际: %d", page.TotalPages)
}
if len(page.Lines) != 100 {
t.Errorf("第一页行数不匹配。期望: 100, 实际: %d", len(page.Lines))
}
}