mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-01 23:35:18 +02:00
Add files via upload
This commit is contained in:
@@ -0,0 +1,933 @@
|
||||
package attackchain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Builder 攻击链构建器
|
||||
type Builder struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
openAIClient *openai.Client
|
||||
openAIConfig *config.OpenAIConfig
|
||||
tokenCounter agent.TokenCounter
|
||||
maxTokens int // 最大tokens限制,默认100000
|
||||
}
|
||||
|
||||
// Node 攻击链节点(使用database包的类型)
|
||||
type Node = database.AttackChainNode
|
||||
|
||||
// Edge 攻击链边(使用database包的类型)
|
||||
type Edge = database.AttackChainEdge
|
||||
|
||||
// Chain 完整的攻击链
|
||||
type Chain struct {
|
||||
Nodes []Node `json:"nodes"`
|
||||
Edges []Edge `json:"edges"`
|
||||
}
|
||||
|
||||
// NewBuilder 创建新的攻击链构建器
|
||||
func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.Logger) *Builder {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
httpClient := &http.Client{Timeout: 5 * time.Minute, Transport: transport}
|
||||
|
||||
// 优先使用配置文件中的统一 Token 上限(config.yaml -> openai.max_total_tokens)
|
||||
maxTokens := 0
|
||||
if openAIConfig != nil && openAIConfig.MaxTotalTokens > 0 {
|
||||
maxTokens = openAIConfig.MaxTotalTokens
|
||||
} else if openAIConfig != nil {
|
||||
// 如果未显式配置 max_total_tokens,则根据模型设置一个合理的默认值
|
||||
model := strings.ToLower(openAIConfig.Model)
|
||||
if strings.Contains(model, "gpt-4") {
|
||||
maxTokens = 128000 // gpt-4通常支持128k
|
||||
} else if strings.Contains(model, "gpt-3.5") {
|
||||
maxTokens = 16000 // gpt-3.5-turbo通常支持16k
|
||||
} else if strings.Contains(model, "deepseek") {
|
||||
maxTokens = 131072 // deepseek-chat通常支持131k
|
||||
} else {
|
||||
maxTokens = 100000 // 兜底默认值
|
||||
}
|
||||
} else {
|
||||
// 没有 OpenAI 配置时使用兜底值,避免为 0
|
||||
maxTokens = 100000
|
||||
}
|
||||
|
||||
return &Builder{
|
||||
db: db,
|
||||
logger: logger,
|
||||
openAIClient: openai.NewClient(openAIConfig, httpClient, logger),
|
||||
openAIConfig: openAIConfig,
|
||||
tokenCounter: agent.NewTikTokenCounter(),
|
||||
maxTokens: maxTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildChainFromConversation 从对话构建攻击链(简化版本:用户输入+最后一轮ReAct输入+大模型输出)
|
||||
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
|
||||
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
|
||||
|
||||
// 0. 首先检查是否有实际的工具执行记录
|
||||
messages, err := b.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取对话消息失败: %w", err)
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
b.logger.Info("对话中没有数据", zap.String("conversationId", conversationID))
|
||||
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
|
||||
}
|
||||
|
||||
// 检查是否有实际的工具执行:assistant 的 mcp_execution_ids,或过程详情中的 tool_call/tool_result
|
||||
//(多代理下若 MCP 未返回 execution_id,IDs 可能为空,但工具已通过 Eino 执行并写入 process_details)
|
||||
hasToolExecutions := false
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "assistant") {
|
||||
if len(messages[i].MCPExecutionIDs) > 0 {
|
||||
hasToolExecutions = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasToolExecutions {
|
||||
if pdOK, err := b.db.ConversationHasToolProcessDetails(conversationID); err != nil {
|
||||
b.logger.Warn("查询过程详情判定工具执行失败", zap.Error(err))
|
||||
} else if pdOK {
|
||||
hasToolExecutions = true
|
||||
}
|
||||
}
|
||||
|
||||
// 检查任务是否被取消(通过检查最后一条assistant消息内容或process_details)
|
||||
taskCancelled := false
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "assistant") {
|
||||
content := strings.ToLower(messages[i].Content)
|
||||
if strings.Contains(content, "取消") || strings.Contains(content, "cancelled") {
|
||||
taskCancelled = true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果任务被取消且没有实际工具执行,返回空攻击链
|
||||
if taskCancelled && !hasToolExecutions {
|
||||
b.logger.Info("任务已取消且没有实际工具执行,返回空攻击链",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Bool("taskCancelled", taskCancelled),
|
||||
zap.Bool("hasToolExecutions", hasToolExecutions))
|
||||
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
|
||||
}
|
||||
|
||||
// 如果没有实际工具执行,也返回空攻击链(避免AI编造)
|
||||
if !hasToolExecutions {
|
||||
b.logger.Info("没有实际工具执行记录,返回空攻击链",
|
||||
zap.String("conversationId", conversationID))
|
||||
return &Chain{Nodes: []Node{}, Edges: []Edge{}}, nil
|
||||
}
|
||||
|
||||
// 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出
|
||||
reactInputJSON, modelOutput, err := b.db.GetReActData(conversationID)
|
||||
if err != nil {
|
||||
b.logger.Warn("获取保存的ReAct数据失败,将使用消息历史构建", zap.Error(err))
|
||||
// 继续使用原来的逻辑
|
||||
reactInputJSON = ""
|
||||
modelOutput = ""
|
||||
}
|
||||
|
||||
// var userInput string
|
||||
var reactInputFinal string
|
||||
var dataSource string // 记录数据来源
|
||||
|
||||
// 如果成功获取到保存的ReAct数据,直接使用
|
||||
if reactInputJSON != "" && modelOutput != "" {
|
||||
// 计算 ReAct 输入的哈希值,用于追踪
|
||||
hash := sha256.Sum256([]byte(reactInputJSON))
|
||||
reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识
|
||||
|
||||
// 统计消息数量
|
||||
var messageCount int
|
||||
var tempMessages []interface{}
|
||||
if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil {
|
||||
messageCount = len(tempMessages)
|
||||
}
|
||||
|
||||
dataSource = "database_last_react_input"
|
||||
b.logger.Info("使用保存的ReAct数据构建攻击链",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("dataSource", dataSource),
|
||||
zap.Int("reactInputSize", len(reactInputJSON)),
|
||||
zap.Int("messageCount", messageCount),
|
||||
zap.String("reactInputHash", reactInputHash),
|
||||
zap.Int("modelOutputSize", len(modelOutput)))
|
||||
|
||||
// 从保存的ReAct输入(JSON格式)中提取用户输入
|
||||
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
|
||||
|
||||
// 将JSON格式的messages转换为可读格式
|
||||
reactInputFinal = b.formatReActInputFromJSON(reactInputJSON)
|
||||
} else {
|
||||
// 2. 如果没有保存的ReAct数据,从对话消息构建
|
||||
dataSource = "messages_table"
|
||||
b.logger.Info("从消息历史构建ReAct数据",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("dataSource", dataSource),
|
||||
zap.Int("messageCount", len(messages)))
|
||||
|
||||
// 提取用户输入(最后一条user消息)
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "user") {
|
||||
// userInput = messages[i].Content
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 提取最后一轮ReAct的输入(历史消息+当前用户输入)
|
||||
reactInputFinal = b.buildReActInput(messages)
|
||||
|
||||
// 提取大模型最后的输出(最后一条assistant消息)
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "assistant") {
|
||||
modelOutput = messages[i].Content
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 多代理:保存的 last_react_input 可能仅为首轮用户消息,不含工具轨迹;补充最后一轮助手的过程详情(与单代理「最后一轮 ReAct」对齐)
|
||||
hasMCPOnAssistant := false
|
||||
var lastAssistantID string
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "assistant") {
|
||||
lastAssistantID = messages[i].ID
|
||||
if len(messages[i].MCPExecutionIDs) > 0 {
|
||||
hasMCPOnAssistant = true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if lastAssistantID != "" {
|
||||
pdHasTools, _ := b.db.ConversationHasToolProcessDetails(conversationID)
|
||||
if pdHasTools && !(hasMCPOnAssistant && reactInputContainsToolTrace(reactInputJSON)) {
|
||||
detailsMap, err := b.db.GetProcessDetailsByConversation(conversationID)
|
||||
if err != nil {
|
||||
b.logger.Warn("加载过程详情用于攻击链失败", zap.Error(err))
|
||||
} else if dets := detailsMap[lastAssistantID]; len(dets) > 0 {
|
||||
extra := b.formatProcessDetailsForAttackChain(dets)
|
||||
if strings.TrimSpace(extra) != "" {
|
||||
reactInputFinal = reactInputFinal + "\n\n## 执行过程与工具记录(含多代理编排与子任务)\n\n" + extra
|
||||
b.logger.Info("攻击链输入已补充过程详情",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("messageId", lastAssistantID),
|
||||
zap.Int("detailEvents", len(dets)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 构建简化的prompt,一次性传递给大模型
|
||||
prompt := b.buildSimplePrompt(reactInputFinal, modelOutput)
|
||||
// fmt.Println(prompt)
|
||||
// 6. 调用AI生成攻击链(一次性,不做任何处理)
|
||||
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("AI生成失败: %w", err)
|
||||
}
|
||||
|
||||
// 7. 解析JSON并生成节点/边ID(前端需要有效的ID)
|
||||
chainData, err := b.parseChainJSON(chainJSON)
|
||||
if err != nil {
|
||||
// 如果解析失败,返回空链,让前端处理错误
|
||||
b.logger.Warn("解析攻击链JSON失败", zap.Error(err), zap.String("raw_json", chainJSON))
|
||||
return &Chain{
|
||||
Nodes: []Node{},
|
||||
Edges: []Edge{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
b.logger.Info("攻击链构建完成",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("dataSource", dataSource),
|
||||
zap.Int("nodes", len(chainData.Nodes)),
|
||||
zap.Int("edges", len(chainData.Edges)))
|
||||
|
||||
// 保存到数据库(供后续加载使用)
|
||||
if err := b.saveChain(conversationID, chainData.Nodes, chainData.Edges); err != nil {
|
||||
b.logger.Warn("保存攻击链到数据库失败", zap.Error(err))
|
||||
// 即使保存失败,也返回数据给前端
|
||||
}
|
||||
|
||||
// 直接返回,不做任何处理和校验
|
||||
return chainData, nil
|
||||
}
|
||||
|
||||
// reactInputContainsToolTrace 判断保存的 ReAct JSON 是否包含可解析的工具调用轨迹(单代理完整保存时为 true)。
|
||||
func reactInputContainsToolTrace(reactInputJSON string) bool {
|
||||
s := strings.TrimSpace(reactInputJSON)
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(s, "tool_calls") ||
|
||||
strings.Contains(s, "tool_call_id") ||
|
||||
strings.Contains(s, `"role":"tool"`) ||
|
||||
strings.Contains(s, `"role": "tool"`)
|
||||
}
|
||||
|
||||
// formatProcessDetailsForAttackChain 将最后一轮助手的过程详情格式化为攻击链分析的输入(覆盖多代理下 last_react_input 不完整的情况)。
|
||||
func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessDetail) string {
|
||||
if len(details) == 0 {
|
||||
return ""
|
||||
}
|
||||
var sb strings.Builder
|
||||
for _, d := range details {
|
||||
// 目标:以主 agent(编排器)视角输出整轮迭代
|
||||
// - 保留:编排器工具调用/结果、对子代理的 task 调度、子代理最终回复(不含推理)
|
||||
// - 丢弃:thinking/planning/progress 等噪声、子代理的工具细节与推理过程
|
||||
if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "planning" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析 data(JSON string),用于识别 einoRole / toolName 等
|
||||
var dataMap map[string]interface{}
|
||||
if strings.TrimSpace(d.Data) != "" {
|
||||
_ = json.Unmarshal([]byte(d.Data), &dataMap)
|
||||
}
|
||||
einoRole := ""
|
||||
if v, ok := dataMap["einoRole"]; ok {
|
||||
einoRole = strings.ToLower(strings.TrimSpace(fmt.Sprint(v)))
|
||||
}
|
||||
toolName := ""
|
||||
if v, ok := dataMap["toolName"]; ok {
|
||||
toolName = strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
|
||||
// 1) 编排器的工具调用/结果:保留(这是“主 agent 调了什么工具”)
|
||||
if (d.EventType == "tool_call" || d.EventType == "tool_result" || d.EventType == "tool_calls_detected" || d.EventType == "iteration" || d.EventType == "eino_recovery") && einoRole == "orchestrator" {
|
||||
sb.WriteString("[")
|
||||
sb.WriteString(d.EventType)
|
||||
sb.WriteString("] ")
|
||||
sb.WriteString(strings.TrimSpace(d.Message))
|
||||
sb.WriteString("\n")
|
||||
if strings.TrimSpace(d.Data) != "" {
|
||||
sb.WriteString(d.Data)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
|
||||
// 2) 子代理调度:tool_call(toolName=="task") 代表编排器把子任务派发出去;保留(只需任务,不要子代理推理)
|
||||
if d.EventType == "tool_call" && strings.EqualFold(toolName, "task") {
|
||||
sb.WriteString("[dispatch_subagent_task] ")
|
||||
sb.WriteString(strings.TrimSpace(d.Message))
|
||||
sb.WriteString("\n")
|
||||
if strings.TrimSpace(d.Data) != "" {
|
||||
sb.WriteString(d.Data)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
|
||||
// 3) 子代理最终回复:保留(只保留最终输出,不保留分析过程)
|
||||
if d.EventType == "eino_agent_reply" && einoRole == "sub" {
|
||||
sb.WriteString("[subagent_final_reply] ")
|
||||
sb.WriteString(strings.TrimSpace(d.Message))
|
||||
sb.WriteString("\n")
|
||||
// data 里含 einoAgent 等元信息,保留有助于追踪“哪个子代理说的”
|
||||
if strings.TrimSpace(d.Data) != "" {
|
||||
sb.WriteString(d.Data)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
|
||||
// 其他事件默认丢弃,避免把子代理工具细节/推理塞进 prompt,偏离“主 agent 一轮迭代”的视角。
|
||||
}
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
// buildReActInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
|
||||
func (b *Builder) buildReActInput(messages []database.Message) string {
|
||||
var builder strings.Builder
|
||||
for _, msg := range messages {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// extractUserInputFromReActInput 从保存的ReAct输入(JSON格式的messages数组)中提取最后一条用户输入
|
||||
// func (b *Builder) extractUserInputFromReActInput(reactInputJSON string) string {
|
||||
// // reactInputJSON是JSON格式的ChatMessage数组,需要解析
|
||||
// var messages []map[string]interface{}
|
||||
// if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
||||
// b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||||
// return ""
|
||||
// }
|
||||
|
||||
// // 从后往前查找最后一条user消息
|
||||
// for i := len(messages) - 1; i >= 0; i-- {
|
||||
// if role, ok := messages[i]["role"].(string); ok && strings.EqualFold(role, "user") {
|
||||
// if content, ok := messages[i]["content"].(string); ok {
|
||||
// return content
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return ""
|
||||
// }
|
||||
|
||||
// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
|
||||
func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
|
||||
var messages []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
||||
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||||
return reactInputJSON // 如果解析失败,返回原始JSON
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
for _, msg := range messages {
|
||||
role, _ := msg["role"].(string)
|
||||
content, _ := msg["content"].(string)
|
||||
|
||||
// 处理assistant消息:提取tool_calls信息
|
||||
if role == "assistant" {
|
||||
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 {
|
||||
// 如果有文本内容,先显示
|
||||
if content != "" {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
|
||||
}
|
||||
// 详细显示每个工具调用
|
||||
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls)))
|
||||
for i, toolCall := range toolCalls {
|
||||
if tc, ok := toolCall.(map[string]interface{}); ok {
|
||||
toolCallID, _ := tc["id"].(string)
|
||||
if funcData, ok := tc["function"].(map[string]interface{}); ok {
|
||||
toolName, _ := funcData["name"].(string)
|
||||
arguments, _ := funcData["arguments"].(string)
|
||||
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
|
||||
builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID))
|
||||
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName))
|
||||
builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments))
|
||||
}
|
||||
}
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 处理tool消息:显示tool_call_id和完整内容
|
||||
if role == "tool" {
|
||||
toolCallID, _ := msg["tool_call_id"].(string)
|
||||
if toolCallID != "" {
|
||||
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content))
|
||||
} else {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 其他消息类型(system, user等)正常显示
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// buildSimplePrompt 构建简化的prompt
|
||||
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||||
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据对话记录和工具执行结果,构建一个逻辑清晰、有教育意义的攻击链图,完整展现渗透测试的思维过程和执行路径。
|
||||
|
||||
## 核心目标
|
||||
|
||||
构建一个能够讲述完整攻击故事的攻击链让学习者能够:
|
||||
1. 理解渗透测试的完整流程和思维逻辑(从目标识别到漏洞发现的每一步)
|
||||
2. 学习如何从失败中获取线索并调整策略
|
||||
3. 掌握工具使用的实际效果和局限性
|
||||
4. 理解漏洞发现和利用的因果关系
|
||||
|
||||
**关键原则**:完整性优先。必须包含所有有意义的工具执行和关键步骤,不要为了控制节点数量而遗漏重要信息。
|
||||
|
||||
## 构建流程(按此顺序思考)
|
||||
|
||||
### 第一步:理解上下文
|
||||
仔细分析ReAct输入中的工具调用序列和大模型输出,识别:
|
||||
- 测试目标(IP、域名、URL等)
|
||||
- 实际执行的工具和参数
|
||||
- 工具返回的关键信息(成功结果、错误信息、超时等)
|
||||
- AI的分析和决策过程
|
||||
|
||||
### 第二步:提取关键节点
|
||||
从工具执行记录中提取有意义的节点,**确保不遗漏任何关键步骤**:
|
||||
- **target节点**:每个独立的测试目标创建一个target节点
|
||||
- **action节点**:每个有意义的工具执行创建一个action节点(包括提供线索的失败、成功的信息收集、漏洞验证等)
|
||||
- **vulnerability节点**:每个真实确认的漏洞创建一个vulnerability节点
|
||||
- **完整性检查**:对照ReAct输入中的工具调用序列,确保每个有意义的工具执行都被包含在攻击链中
|
||||
|
||||
### 第三步:构建逻辑关系(树状结构)
|
||||
**重要:必须构建树状结构,而不是简单的线性链。**
|
||||
按照因果关系连接节点,形成树状图(因为是单agent执行,所以可以不按照时间顺序):
|
||||
- **分支结构**:一个节点可以有多个后续节点(例如:端口扫描发现多个端口后,可以同时进行多个不同的测试)
|
||||
- **汇聚结构**:多个节点可以指向同一个节点(例如:多个不同的测试都发现了同一个漏洞)
|
||||
- 识别哪些action是基于前面action的结果而执行的
|
||||
- 识别哪些vulnerability是由哪些action发现的
|
||||
- 识别失败节点如何为后续成功提供线索
|
||||
- **避免线性链**:不要将所有节点连成一条线,应该根据实际的并行测试和分支探索构建树状结构
|
||||
|
||||
### 第四步:优化和精简
|
||||
- **完整性检查**:确保所有有意义的工具执行都被包含,不要遗漏关键步骤
|
||||
- **合并规则**:只合并真正相似或重复的action节点(如多次相同工具的相似调用)
|
||||
- **删除规则**:只删除完全无价值的失败节点(完全无输出、纯系统错误、重复的相同失败)
|
||||
- **重要提醒**:宁可保留更多节点,也不要遗漏关键步骤。攻击链必须完整展现渗透测试过程
|
||||
- 确保攻击链逻辑连贯,能够讲述完整故事
|
||||
|
||||
## 节点类型详解
|
||||
|
||||
### target(目标节点)
|
||||
- **用途**:标识测试目标
|
||||
- **创建规则**:每个独立目标(不同IP/域名)创建一个target节点
|
||||
- **多目标处理**:不同目标的节点不相互连接,各自形成独立的子图
|
||||
- **metadata.target**:精确记录目标标识(IP地址、域名、URL等)
|
||||
|
||||
### action(行动节点)
|
||||
- **用途**:记录工具执行和AI分析结果
|
||||
- **标签规则**:
|
||||
* 15-25个汉字,动宾结构
|
||||
* 成功节点:描述执行结果(如"扫描端口发现80/443/8080"、"目录扫描发现/admin路径")
|
||||
* 失败节点:描述失败原因(如"尝试SQL注入(被WAF拦截)"、"端口扫描超时(目标不可达)")
|
||||
- **ai_analysis要求**:
|
||||
* 成功节点:总结工具执行的关键发现,说明这些发现的意义
|
||||
* 失败节点:必须说明失败原因、获得的线索、这些线索如何指引后续行动
|
||||
* 不超过150字,要具体、有信息量
|
||||
- **findings要求**:
|
||||
* 提取工具返回结果中的关键信息点
|
||||
* 每个finding应该是独立的、有价值的信息片段
|
||||
* 成功节点:列出关键发现(如["80端口开放", "443端口开放", "HTTP服务为Apache 2.4"])
|
||||
* 失败节点:列出失败线索(如["WAF拦截", "返回403", "检测到Cloudflare"])
|
||||
- **status标记**:
|
||||
* 成功节点:不设置或设为"success"
|
||||
* 提供线索的失败节点:必须设为"failed_insight"
|
||||
- **risk_score**:始终为0(action节点不评估风险)
|
||||
|
||||
### vulnerability(漏洞节点)
|
||||
- **用途**:记录真实确认的安全漏洞
|
||||
- **创建规则**:
|
||||
* 必须是真实确认的漏洞,不是所有发现都是漏洞
|
||||
* 需要明确的漏洞证据(如SQL注入返回数据库错误、XSS成功执行等)
|
||||
- **risk_score规则**:
|
||||
* critical(90-100):可导致系统完全沦陷(RCE、SQL注入导致数据泄露等)
|
||||
* high(80-89):可导致敏感信息泄露或权限提升
|
||||
* medium(60-79):存在安全风险但影响有限
|
||||
* low(40-59):轻微安全问题
|
||||
- **metadata要求**:
|
||||
* vulnerability_type:漏洞类型(SQL注入、XSS、RCE等)
|
||||
* description:详细描述漏洞位置、原理、影响
|
||||
* severity:critical/high/medium/low
|
||||
* location:精确的漏洞位置(URL、参数、文件路径等)
|
||||
|
||||
## 节点过滤和合并规则
|
||||
|
||||
### 必须保留的失败节点
|
||||
以下失败情况必须创建节点,因为它们提供了有价值的线索:
|
||||
- 工具返回明确的错误信息(权限错误、连接拒绝、认证失败等)
|
||||
- 超时或连接失败(可能表明防火墙、网络隔离等)
|
||||
- WAF/防火墙拦截(返回403、406等,表明存在防护机制)
|
||||
- 工具未安装或配置错误(但执行了调用)
|
||||
- 目标不可达(DNS解析失败、网络不通等)
|
||||
|
||||
### 应该删除的失败节点
|
||||
以下情况不应创建节点:
|
||||
- 完全无输出的工具调用
|
||||
- 纯系统错误(与目标无关,如本地环境问题)
|
||||
- 重复的相同失败(多次相同错误只保留第一次)
|
||||
|
||||
### 节点合并规则
|
||||
以下情况应合并节点:
|
||||
- 同一工具的多次相似调用(如多次nmap扫描不同端口范围,合并为一个"端口扫描"节点)
|
||||
- 同一目标的多个相似探测(如多个目录扫描工具,合并为一个"目录扫描"节点)
|
||||
|
||||
### 节点数量控制
|
||||
- **完整性优先**:必须包含所有有意义的工具执行和关键步骤,不要为了控制数量而删除重要节点
|
||||
- **建议范围**:单目标通常8-15个节点,但如果实际执行步骤较多,可以适当增加(最多20个节点)
|
||||
- **优先保留**:关键成功步骤、提供线索的失败、发现的漏洞、重要的信息收集步骤
|
||||
- **可以合并**:同一工具的多次相似调用(如多次nmap扫描不同端口范围,合并为一个"端口扫描"节点)
|
||||
- **可以删除**:完全无输出的工具调用、纯系统错误、重复的相同失败(多次相同错误只保留第一次)
|
||||
- **重要原则**:宁可节点稍多,也不要遗漏关键步骤。攻击链必须能够完整展现渗透测试的完整过程
|
||||
|
||||
## 边的类型和权重
|
||||
|
||||
### 边的类型
|
||||
- **leads_to**:表示"导致"或"引导到",用于action→action、target→action
|
||||
* 例如:端口扫描 → 目录扫描(因为发现了80端口,所以进行目录扫描)
|
||||
- **discovers**:表示"发现",**专门用于action→vulnerability**
|
||||
* 例如:SQL注入测试 → SQL注入漏洞
|
||||
* **重要**:所有action→vulnerability的边都必须使用discovers类型,即使多个action都指向同一个vulnerability,也应该统一使用discovers
|
||||
- **enables**:表示"使能"或"促成",**仅用于vulnerability→vulnerability、action→action(当后续行动依赖前面结果时)**
|
||||
* 例如:信息泄露漏洞 → 权限提升漏洞(通过信息泄露获得的信息促成了权限提升)
|
||||
* **重要**:enables不能用于action→vulnerability,action→vulnerability必须使用discovers
|
||||
|
||||
### 边的权重
|
||||
- **权重1-2**:弱关联(如初步探测到进一步探测)
|
||||
- **权重3-4**:中等关联(如发现端口到服务识别)
|
||||
- **权重5-7**:强关联(如发现漏洞、关键信息泄露)
|
||||
- **权重8-10**:极强关联(如漏洞利用成功、权限提升)
|
||||
|
||||
### DAG结构要求(有向无环图)
|
||||
**关键:必须确保生成的是真正的DAG(有向无环图),不能有任何循环。**
|
||||
|
||||
- **节点编号规则**:节点id从"node_1"开始递增(node_1, node_2, node_3...)
|
||||
- **边的方向规则**:所有边的source节点id必须严格小于target节点id(source < target),这是确保无环的关键
|
||||
* 例如:node_1 → node_2 ✓(正确)
|
||||
* 例如:node_2 → node_1 ✗(错误,会形成环)
|
||||
* 例如:node_3 → node_5 ✓(正确)
|
||||
- **无环验证**:在输出JSON前,必须检查所有边,确保没有任何一条边的source >= target
|
||||
- **无孤立节点**:确保每个节点至少有一条边连接(除了可能的根节点)
|
||||
- **DAG结构特点**:
|
||||
* 一个节点可以有多个后续节点(分支),例如:node_2(端口扫描)可以同时连接到node_3、node_4、node_5等多个节点
|
||||
* 多个节点可以汇聚到一个节点(汇聚),例如:node_3、node_4、node_5都指向node_6(漏洞节点)
|
||||
* 避免将所有节点连成一条线,应该根据实际的并行测试和分支探索构建DAG结构
|
||||
- **拓扑排序验证**:如果按照节点id从小到大排序,所有边都应该从左指向右(从上指向下),这样就能保证无环
|
||||
|
||||
## 攻击链逻辑连贯性要求
|
||||
|
||||
构建的攻击链应该能够回答以下问题:
|
||||
1. **起点**:测试从哪里开始?(target节点)
|
||||
2. **探索过程**:如何逐步收集信息?(action节点序列)
|
||||
3. **失败与调整**:遇到障碍时如何调整策略?(failed_insight节点)
|
||||
4. **关键发现**:发现了哪些重要信息?(action的findings)
|
||||
5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability)
|
||||
6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径)
|
||||
|
||||
## 最后一轮ReAct输入
|
||||
|
||||
%s
|
||||
|
||||
## 大模型输出
|
||||
|
||||
%s
|
||||
|
||||
## 输出格式
|
||||
|
||||
严格按照以下JSON格式输出,不要添加任何其他文字:
|
||||
|
||||
**重要:示例展示的是树状结构,注意node_2(端口扫描)同时连接到多个后续节点(node_3、node_4),形成分支结构。**
|
||||
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node_1",
|
||||
"type": "target",
|
||||
"label": "测试目标: example.com",
|
||||
"risk_score": 40,
|
||||
"metadata": {
|
||||
"target": "example.com"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "node_2",
|
||||
"type": "action",
|
||||
"label": "扫描端口发现80/443/8080",
|
||||
"risk_score": 0,
|
||||
"metadata": {
|
||||
"tool_name": "nmap",
|
||||
"tool_intent": "端口扫描",
|
||||
"ai_analysis": "使用nmap对目标进行端口扫描,发现80、443、8080端口开放。80端口运行HTTP服务,443端口运行HTTPS服务,8080端口可能为管理后台。这些开放端口为后续Web应用测试提供了入口。",
|
||||
"findings": ["80端口开放", "443端口开放", "8080端口开放", "HTTP服务为Apache 2.4"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "node_3",
|
||||
"type": "action",
|
||||
"label": "目录扫描发现/admin后台",
|
||||
"risk_score": 0,
|
||||
"metadata": {
|
||||
"tool_name": "dirsearch",
|
||||
"tool_intent": "目录扫描",
|
||||
"ai_analysis": "使用dirsearch对目标进行目录扫描,发现/admin目录存在且可访问。该目录可能为管理后台,是重要的测试目标。",
|
||||
"findings": ["/admin目录存在", "返回200状态码", "疑似管理后台"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "node_4",
|
||||
"type": "action",
|
||||
"label": "识别Web服务为Apache 2.4",
|
||||
"risk_score": 0,
|
||||
"metadata": {
|
||||
"tool_name": "whatweb",
|
||||
"tool_intent": "Web服务识别",
|
||||
"ai_analysis": "识别出目标运行Apache 2.4服务器,这为后续的漏洞测试提供了重要信息。",
|
||||
"findings": ["Apache 2.4", "PHP版本信息"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "node_5",
|
||||
"type": "action",
|
||||
"label": "尝试SQL注入(被WAF拦截)",
|
||||
"risk_score": 0,
|
||||
"metadata": {
|
||||
"tool_name": "sqlmap",
|
||||
"tool_intent": "SQL注入检测",
|
||||
"ai_analysis": "对/login.php进行SQL注入测试时被WAF拦截,返回403错误。错误信息显示检测到Cloudflare防护。这表明目标部署了WAF,需要调整测试策略。",
|
||||
"findings": ["WAF拦截", "返回403", "检测到Cloudflare", "目标部署WAF"],
|
||||
"status": "failed_insight"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "node_6",
|
||||
"type": "vulnerability",
|
||||
"label": "SQL注入漏洞",
|
||||
"risk_score": 85,
|
||||
"metadata": {
|
||||
"vulnerability_type": "SQL注入",
|
||||
"description": "在/admin/login.php的username参数发现SQL注入漏洞,可通过注入payload绕过登录验证,直接获取管理员权限。漏洞返回数据库错误信息,确认存在注入点。",
|
||||
"severity": "high",
|
||||
"location": "/admin/login.php?username="
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": "node_1",
|
||||
"target": "node_2",
|
||||
"type": "leads_to",
|
||||
"weight": 3
|
||||
},
|
||||
{
|
||||
"source": "node_2",
|
||||
"target": "node_3",
|
||||
"type": "leads_to",
|
||||
"weight": 4
|
||||
},
|
||||
{
|
||||
"source": "node_2",
|
||||
"target": "node_4",
|
||||
"type": "leads_to",
|
||||
"weight": 3
|
||||
},
|
||||
{
|
||||
"source": "node_3",
|
||||
"target": "node_5",
|
||||
"type": "leads_to",
|
||||
"weight": 4
|
||||
},
|
||||
{
|
||||
"source": "node_5",
|
||||
"target": "node_6",
|
||||
"type": "discovers",
|
||||
"weight": 7
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## 重要提醒
|
||||
|
||||
1. **严禁杜撰**:只使用ReAct输入中实际执行的工具和实际返回的结果。如无实际数据,返回空的nodes和edges数组。
|
||||
2. **DAG结构必须**:必须构建真正的DAG(有向无环图),不能有任何循环。所有边的source节点id必须严格小于target节点id(source < target)。
|
||||
3. **拓扑顺序**:节点应该按照逻辑顺序编号,target节点通常是node_1,后续的action节点按执行顺序递增,vulnerability节点在最后。
|
||||
4. **完整性优先**:必须包含所有有意义的工具执行和关键步骤,不要为了控制节点数量而删除重要节点。攻击链必须能够完整展现从目标识别到漏洞发现的完整过程。
|
||||
5. **逻辑连贯**:确保攻击链能够讲述一个完整、连贯的渗透测试故事,包括所有关键步骤和决策点。
|
||||
6. **教育价值**:优先保留有教育意义的节点,帮助学习者理解渗透测试思维和完整流程。
|
||||
7. **准确性**:所有节点信息必须基于实际数据,不要推测或假设。
|
||||
8. **完整性检查**:确保每个节点都有必要的metadata字段,每条边都有正确的source和target,没有孤立节点,没有循环。
|
||||
9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
|
||||
10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。
|
||||
|
||||
现在开始分析并构建攻击链:`, reactInput, modelOutput)
|
||||
}
|
||||
|
||||
// saveChain 保存攻击链到数据库
|
||||
func (b *Builder) saveChain(conversationID string, nodes []Node, edges []Edge) error {
|
||||
// 先删除旧的攻击链数据
|
||||
if err := b.db.DeleteAttackChain(conversationID); err != nil {
|
||||
b.logger.Warn("删除旧攻击链失败", zap.Error(err))
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
metadataJSON, _ := json.Marshal(node.Metadata)
|
||||
if err := b.db.SaveAttackChainNode(conversationID, node.ID, node.Type, node.Label, "", string(metadataJSON), node.RiskScore); err != nil {
|
||||
b.logger.Warn("保存攻击链节点失败", zap.String("nodeId", node.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 保存边
|
||||
for _, edge := range edges {
|
||||
if err := b.db.SaveAttackChainEdge(conversationID, edge.ID, edge.Source, edge.Target, edge.Type, edge.Weight); err != nil {
|
||||
b.logger.Warn("保存攻击链边失败", zap.String("edgeId", edge.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadChainFromDatabase 从数据库加载攻击链
|
||||
func (b *Builder) LoadChainFromDatabase(conversationID string) (*Chain, error) {
|
||||
nodes, err := b.db.LoadAttackChainNodes(conversationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载攻击链节点失败: %w", err)
|
||||
}
|
||||
|
||||
edges, err := b.db.LoadAttackChainEdges(conversationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载攻击链边失败: %w", err)
|
||||
}
|
||||
|
||||
return &Chain{
|
||||
Nodes: nodes,
|
||||
Edges: edges,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// callAIForChainGeneration 调用AI生成攻击链
|
||||
func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (string, error) {
|
||||
requestBody := map[string]interface{}{
|
||||
"model": b.openAIConfig.Model,
|
||||
"messages": []map[string]interface{}{
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业的安全测试分析师,擅长构建攻击链图。请严格按照JSON格式返回攻击链数据。",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
},
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 8000,
|
||||
}
|
||||
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if b.openAIClient == nil {
|
||||
return "", fmt.Errorf("OpenAI客户端未初始化")
|
||||
}
|
||||
if err := b.openAIClient.ChatCompletion(ctx, requestBody, &apiResponse); err != nil {
|
||||
var apiErr *openai.APIError
|
||||
if errors.As(err, &apiErr) {
|
||||
bodyStr := strings.ToLower(apiErr.Body)
|
||||
if strings.Contains(bodyStr, "context") || strings.Contains(bodyStr, "length") || strings.Contains(bodyStr, "too long") {
|
||||
return "", fmt.Errorf("context length exceeded")
|
||||
}
|
||||
} else if strings.Contains(strings.ToLower(err.Error()), "context") || strings.Contains(strings.ToLower(err.Error()), "length") {
|
||||
return "", fmt.Errorf("context length exceeded")
|
||||
}
|
||||
return "", fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return "", fmt.Errorf("API未返回有效响应")
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(apiResponse.Choices[0].Message.Content)
|
||||
// 尝试提取JSON(可能包含markdown代码块)
|
||||
content = strings.TrimPrefix(content, "```json")
|
||||
content = strings.TrimPrefix(content, "```")
|
||||
content = strings.TrimSuffix(content, "```")
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// ChainJSON 攻击链JSON结构
|
||||
type ChainJSON struct {
|
||||
Nodes []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Label string `json:"label"`
|
||||
RiskScore int `json:"risk_score"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
} `json:"nodes"`
|
||||
Edges []struct {
|
||||
Source string `json:"source"`
|
||||
Target string `json:"target"`
|
||||
Type string `json:"type"`
|
||||
Weight int `json:"weight"`
|
||||
} `json:"edges"`
|
||||
}
|
||||
|
||||
// parseChainJSON 解析攻击链JSON
|
||||
func (b *Builder) parseChainJSON(chainJSON string) (*Chain, error) {
|
||||
var chainData ChainJSON
|
||||
if err := json.Unmarshal([]byte(chainJSON), &chainData); err != nil {
|
||||
return nil, fmt.Errorf("解析JSON失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建节点ID映射(AI返回的ID -> 新的UUID)
|
||||
nodeIDMap := make(map[string]string)
|
||||
|
||||
// 转换为Chain结构
|
||||
nodes := make([]Node, 0, len(chainData.Nodes))
|
||||
for _, n := range chainData.Nodes {
|
||||
// 生成新的UUID节点ID
|
||||
newNodeID := fmt.Sprintf("node_%s", uuid.New().String())
|
||||
nodeIDMap[n.ID] = newNodeID
|
||||
|
||||
node := Node{
|
||||
ID: newNodeID,
|
||||
Type: n.Type,
|
||||
Label: n.Label,
|
||||
RiskScore: n.RiskScore,
|
||||
Metadata: n.Metadata,
|
||||
}
|
||||
if node.Metadata == nil {
|
||||
node.Metadata = make(map[string]interface{})
|
||||
}
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
// 转换边
|
||||
edges := make([]Edge, 0, len(chainData.Edges))
|
||||
for _, e := range chainData.Edges {
|
||||
sourceID, ok := nodeIDMap[e.Source]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
targetID, ok := nodeIDMap[e.Target]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// 生成边的ID(前端需要)
|
||||
edgeID := fmt.Sprintf("edge_%s", uuid.New().String())
|
||||
|
||||
edges = append(edges, Edge{
|
||||
ID: edgeID,
|
||||
Source: sourceID,
|
||||
Target: targetID,
|
||||
Type: e.Type,
|
||||
Weight: e.Weight,
|
||||
})
|
||||
}
|
||||
|
||||
return &Chain{
|
||||
Nodes: nodes,
|
||||
Edges: edges,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 以下所有方法已不再使用,已删除以简化代码
|
||||
@@ -0,0 +1,877 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Version string `yaml:"version,omitempty" json:"version,omitempty"` // 前端显示的版本号,如 v1.3.3
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Log LogConfig `yaml:"log"`
|
||||
MCP MCPConfig `yaml:"mcp"`
|
||||
OpenAI OpenAIConfig `yaml:"openai"`
|
||||
FOFA FofaConfig `yaml:"fofa,omitempty" json:"fofa,omitempty"`
|
||||
Agent AgentConfig `yaml:"agent"`
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||
Robots RobotsConfig `yaml:"robots,omitempty" json:"robots,omitempty"` // 企业微信/钉钉/飞书等机器人配置
|
||||
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
|
||||
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
|
||||
SkillsDir string `yaml:"skills_dir,omitempty" json:"skills_dir,omitempty"` // Skills配置文件目录
|
||||
AgentsDir string `yaml:"agents_dir,omitempty" json:"agents_dir,omitempty"` // 多代理子 Agent Markdown 定义目录(*.md,YAML front matter)
|
||||
MultiAgent MultiAgentConfig `yaml:"multi_agent,omitempty" json:"multi_agent,omitempty"`
|
||||
}
|
||||
|
||||
// MultiAgentConfig 基于 CloudWeGo Eino DeepAgent 的多代理编排(与单 Agent /agent-loop 并存)。
|
||||
type MultiAgentConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
DefaultMode string `yaml:"default_mode" json:"default_mode"` // single | multi,供前端默认展示
|
||||
RobotUseMultiAgent bool `yaml:"robot_use_multi_agent" json:"robot_use_multi_agent"` // 为 true 时钉钉/飞书/企微机器人走 Eino 多代理
|
||||
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
||||
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // Deep 主代理最大推理轮次
|
||||
SubAgentMaxIterations int `yaml:"sub_agent_max_iterations" json:"sub_agent_max_iterations"`
|
||||
WithoutGeneralSubAgent bool `yaml:"without_general_sub_agent" json:"without_general_sub_agent"`
|
||||
WithoutWriteTodos bool `yaml:"without_write_todos" json:"without_write_todos"`
|
||||
OrchestratorInstruction string `yaml:"orchestrator_instruction" json:"orchestrator_instruction"`
|
||||
SubAgents []MultiAgentSubConfig `yaml:"sub_agents" json:"sub_agents"`
|
||||
// EinoSkills configures CloudWeGo Eino ADK skill middleware + optional local filesystem/execute on DeepAgent.
|
||||
EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"`
|
||||
}
|
||||
|
||||
// MultiAgentEinoSkillsConfig toggles Eino official skill progressive disclosure and host filesystem tools.
|
||||
type MultiAgentEinoSkillsConfig struct {
|
||||
// Disable skips skill middleware (and does not attach local FS tools for Deep).
|
||||
Disable bool `yaml:"disable" json:"disable"`
|
||||
// FilesystemTools registers read_file/glob/grep/write/edit/execute (eino-ext local backend). Nil/omitted = true.
|
||||
FilesystemTools *bool `yaml:"filesystem_tools,omitempty" json:"filesystem_tools,omitempty"`
|
||||
// SkillToolName overrides the default Eino tool name "skill".
|
||||
SkillToolName string `yaml:"skill_tool_name,omitempty" json:"skill_tool_name,omitempty"`
|
||||
}
|
||||
|
||||
// EinoSkillFilesystemToolsEffective returns whether Deep/sub-agents should attach local filesystem + streaming shell.
|
||||
func (c MultiAgentEinoSkillsConfig) EinoSkillFilesystemToolsEffective() bool {
|
||||
if c.FilesystemTools != nil {
|
||||
return *c.FilesystemTools
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// MultiAgentSubConfig 子代理(Eino ChatModelAgent),由 DeepAgent 通过 task 工具调度。
|
||||
type MultiAgentSubConfig struct {
|
||||
ID string `yaml:"id" json:"id"`
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Description string `yaml:"description" json:"description"`
|
||||
Instruction string `yaml:"instruction" json:"instruction"`
|
||||
BindRole string `yaml:"bind_role,omitempty" json:"bind_role,omitempty"` // 可选:关联主配置 roles 中的角色名;未配 role_tools 时沿用该角色的 tools,并把 skills 写入指令提示
|
||||
RoleTools []string `yaml:"role_tools" json:"role_tools"` // 与单 Agent 角色工具相同 key;空表示全部工具(bind_role 可补全 tools)
|
||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||
Kind string `yaml:"kind,omitempty" json:"kind,omitempty"` // 仅 Markdown:kind=orchestrator 表示 Deep 主代理(与 orchestrator.md 二选一约定)
|
||||
}
|
||||
|
||||
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
|
||||
type MultiAgentPublic struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
DefaultMode string `json:"default_mode"`
|
||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
SubAgentCount int `json:"sub_agent_count"`
|
||||
}
|
||||
|
||||
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
|
||||
type MultiAgentAPIUpdate struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
DefaultMode string `json:"default_mode"`
|
||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
}
|
||||
|
||||
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等)
|
||||
type RobotsConfig struct {
|
||||
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
|
||||
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
|
||||
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
|
||||
}
|
||||
|
||||
// RobotWecomConfig 企业微信机器人配置
|
||||
type RobotWecomConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
Token string `yaml:"token" json:"token"` // 回调 URL 校验 Token
|
||||
EncodingAESKey string `yaml:"encoding_aes_key" json:"encoding_aes_key"` // EncodingAESKey
|
||||
CorpID string `yaml:"corp_id" json:"corp_id"` // 企业 ID
|
||||
Secret string `yaml:"secret" json:"secret"` // 应用 Secret
|
||||
AgentID int64 `yaml:"agent_id" json:"agent_id"` // 应用 AgentId
|
||||
}
|
||||
|
||||
// RobotDingtalkConfig 钉钉机器人配置
|
||||
type RobotDingtalkConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
ClientID string `yaml:"client_id" json:"client_id"` // 应用 Key (AppKey)
|
||||
ClientSecret string `yaml:"client_secret" json:"client_secret"` // 应用 Secret
|
||||
}
|
||||
|
||||
// RobotLarkConfig 飞书机器人配置
|
||||
type RobotLarkConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
AppID string `yaml:"app_id" json:"app_id"` // 应用 App ID
|
||||
AppSecret string `yaml:"app_secret" json:"app_secret"` // 应用 App Secret
|
||||
VerifyToken string `yaml:"verify_token" json:"verify_token"` // 事件订阅 Verification Token(可选)
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
Output string `yaml:"output"`
|
||||
}
|
||||
|
||||
type MCPConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
AuthHeader string `yaml:"auth_header,omitempty"` // 鉴权 header 名,留空表示不鉴权
|
||||
AuthHeaderValue string `yaml:"auth_header_value,omitempty"` // 鉴权 header 值,需与请求中该 header 一致
|
||||
}
|
||||
|
||||
type OpenAIConfig struct {
|
||||
Provider string `yaml:"provider,omitempty" json:"provider,omitempty"` // API 提供商: "openai"(默认) 或 "claude",claude 时自动桥接为 Anthropic Messages API
|
||||
APIKey string `yaml:"api_key" json:"api_key"`
|
||||
BaseURL string `yaml:"base_url" json:"base_url"`
|
||||
Model string `yaml:"model" json:"model"`
|
||||
MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type FofaConfig struct {
|
||||
// Email 为 FOFA 账号邮箱;APIKey 为 FOFA API Key(建议使用只读权限的 Key)
|
||||
Email string `yaml:"email,omitempty" json:"email,omitempty"`
|
||||
APIKey string `yaml:"api_key,omitempty" json:"api_key,omitempty"`
|
||||
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://fofa.info/api/v1/search/all
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
Tools []ToolConfig `yaml:"tools,omitempty"` // 向后兼容:支持在主配置文件中定义工具
|
||||
ToolsDir string `yaml:"tools_dir,omitempty"` // 工具配置文件目录(新方式)
|
||||
ToolDescriptionMode string `yaml:"tool_description_mode,omitempty"` // 工具描述模式: "short" | "full",默认 short
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Path string `yaml:"path"` // 会话数据库路径
|
||||
KnowledgeDBPath string `yaml:"knowledge_db_path,omitempty"` // 知识库数据库路径(可选,为空则使用会话数据库)
|
||||
}
|
||||
|
||||
type AgentConfig struct {
|
||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||
LargeResultThreshold int `yaml:"large_result_threshold" json:"large_result_threshold"` // 大结果阈值(字节),默认50KB
|
||||
ResultStorageDir string `yaml:"result_storage_dir" json:"result_storage_dir"` // 结果存储目录,默认tmp
|
||||
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
Password string `yaml:"password" json:"password"`
|
||||
SessionDurationHours int `yaml:"session_duration_hours" json:"session_duration_hours"`
|
||||
GeneratedPassword string `yaml:"-" json:"-"`
|
||||
GeneratedPasswordPersisted bool `yaml:"-" json:"-"`
|
||||
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// ExternalMCPConfig 外部MCP配置
|
||||
type ExternalMCPConfig struct {
|
||||
Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"`
|
||||
}
|
||||
|
||||
// ExternalMCPServerConfig 外部MCP服务器配置
|
||||
type ExternalMCPServerConfig struct {
|
||||
// stdio模式配置
|
||||
Command string `yaml:"command,omitempty" json:"command,omitempty"`
|
||||
Args []string `yaml:"args,omitempty" json:"args,omitempty"`
|
||||
Env map[string]string `yaml:"env,omitempty" json:"env,omitempty"` // 环境变量(用于stdio模式)
|
||||
|
||||
// HTTP模式配置
|
||||
Transport string `yaml:"transport,omitempty" json:"transport,omitempty"` // "stdio" | "sse" | "http"(Streamable) | "simple_http"(自建/简单POST端点,如本机 http://127.0.0.1:8081/mcp)
|
||||
URL string `yaml:"url,omitempty" json:"url,omitempty"`
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` // HTTP/SSE 请求头(如 x-api-key)
|
||||
|
||||
// 通用配置
|
||||
Description string `yaml:"description,omitempty" json:"description,omitempty"`
|
||||
Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` // 超时时间(秒)
|
||||
ExternalMCPEnable bool `yaml:"external_mcp_enable,omitempty" json:"external_mcp_enable,omitempty"` // 是否启用外部MCP
|
||||
ToolEnabled map[string]bool `yaml:"tool_enabled,omitempty" json:"tool_enabled,omitempty"` // 每个工具的启用状态(工具名称 -> 是否启用)
|
||||
|
||||
// 向后兼容字段(已废弃,保留用于读取旧配置)
|
||||
Enabled bool `yaml:"enabled,omitempty" json:"enabled,omitempty"` // 已废弃,使用 external_mcp_enable
|
||||
Disabled bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` // 已废弃,使用 external_mcp_enable
|
||||
}
|
||||
type ToolConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Command string `yaml:"command"`
|
||||
Args []string `yaml:"args,omitempty"` // 固定参数(可选)
|
||||
ShortDescription string `yaml:"short_description,omitempty"` // 简短描述(用于工具列表,减少token消耗)
|
||||
Description string `yaml:"description"` // 详细描述(用于工具文档)
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
|
||||
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
|
||||
AllowedExitCodes []int `yaml:"allowed_exit_codes,omitempty"` // 允许的退出码列表(某些工具在成功时也返回非零退出码)
|
||||
}
|
||||
|
||||
// ParameterConfig 参数配置
|
||||
type ParameterConfig struct {
|
||||
Name string `yaml:"name"` // 参数名称
|
||||
Type string `yaml:"type"` // 参数类型: string, int, bool, array
|
||||
Description string `yaml:"description"` // 参数描述
|
||||
Required bool `yaml:"required,omitempty"` // 是否必需
|
||||
Default interface{} `yaml:"default,omitempty"` // 默认值
|
||||
ItemType string `yaml:"item_type,omitempty"` // 当 type 为 array 时,数组元素类型,如 string, number, object
|
||||
Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p"
|
||||
Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始)
|
||||
Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template"
|
||||
Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}"
|
||||
Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举)
|
||||
}
|
||||
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Auth.SessionDurationHours <= 0 {
|
||||
cfg.Auth.SessionDurationHours = 12
|
||||
}
|
||||
|
||||
if strings.TrimSpace(cfg.Auth.Password) == "" {
|
||||
password, err := generateStrongPassword(24)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成默认密码失败: %w", err)
|
||||
}
|
||||
|
||||
cfg.Auth.Password = password
|
||||
cfg.Auth.GeneratedPassword = password
|
||||
|
||||
if err := PersistAuthPassword(path, password); err != nil {
|
||||
cfg.Auth.GeneratedPasswordPersisted = false
|
||||
cfg.Auth.GeneratedPasswordPersistErr = err.Error()
|
||||
} else {
|
||||
cfg.Auth.GeneratedPasswordPersisted = true
|
||||
}
|
||||
}
|
||||
|
||||
// 如果配置了工具目录,从目录加载工具配置
|
||||
if cfg.Security.ToolsDir != "" {
|
||||
configDir := filepath.Dir(path)
|
||||
toolsDir := cfg.Security.ToolsDir
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(toolsDir) {
|
||||
toolsDir = filepath.Join(configDir, toolsDir)
|
||||
}
|
||||
|
||||
tools, err := LoadToolsFromDir(toolsDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从工具目录加载工具配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 合并工具配置:目录中的工具优先,主配置中的工具作为补充
|
||||
existingTools := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
existingTools[tool.Name] = true
|
||||
}
|
||||
|
||||
// 添加主配置中不存在于目录中的工具(向后兼容)
|
||||
for _, tool := range cfg.Security.Tools {
|
||||
if !existingTools[tool.Name] {
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
}
|
||||
|
||||
cfg.Security.Tools = tools
|
||||
}
|
||||
|
||||
// 迁移外部MCP配置:将旧的 enabled/disabled 字段迁移到 external_mcp_enable
|
||||
if cfg.ExternalMCP.Servers != nil {
|
||||
for name, serverCfg := range cfg.ExternalMCP.Servers {
|
||||
// 如果已经设置了 external_mcp_enable,跳过迁移
|
||||
// 否则从 enabled/disabled 字段迁移
|
||||
// 注意:由于 ExternalMCPEnable 是 bool 类型,零值为 false,所以需要检查是否真的设置了
|
||||
// 这里我们通过检查旧的 enabled/disabled 字段来判断是否需要迁移
|
||||
if serverCfg.Disabled {
|
||||
// 旧配置使用 disabled,迁移到 external_mcp_enable
|
||||
serverCfg.ExternalMCPEnable = false
|
||||
} else if serverCfg.Enabled {
|
||||
// 旧配置使用 enabled,迁移到 external_mcp_enable
|
||||
serverCfg.ExternalMCPEnable = true
|
||||
} else {
|
||||
// 都没有设置,默认为启用
|
||||
serverCfg.ExternalMCPEnable = true
|
||||
}
|
||||
cfg.ExternalMCP.Servers[name] = serverCfg
|
||||
}
|
||||
}
|
||||
|
||||
// 从角色目录加载角色配置
|
||||
if cfg.RolesDir != "" {
|
||||
configDir := filepath.Dir(path)
|
||||
rolesDir := cfg.RolesDir
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
roles, err := LoadRolesFromDir(rolesDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从角色目录加载角色配置失败: %w", err)
|
||||
}
|
||||
|
||||
cfg.Roles = roles
|
||||
} else {
|
||||
// 如果未配置 roles_dir,初始化为空 map
|
||||
if cfg.Roles == nil {
|
||||
cfg.Roles = make(map[string]RoleConfig)
|
||||
}
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func generateStrongPassword(length int) (string, error) {
|
||||
if length <= 0 {
|
||||
length = 24
|
||||
}
|
||||
|
||||
bytesLen := length
|
||||
randomBytes := make([]byte, bytesLen)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
password := base64.RawURLEncoding.EncodeToString(randomBytes)
|
||||
if len(password) > length {
|
||||
password = password[:length]
|
||||
}
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func PersistAuthPassword(path, password string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
inAuthBlock := false
|
||||
authIndent := -1
|
||||
|
||||
for i, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if !inAuthBlock {
|
||||
if strings.HasPrefix(trimmed, "auth:") {
|
||||
inAuthBlock = true
|
||||
authIndent = len(line) - len(strings.TrimLeft(line, " "))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
leadingSpaces := len(line) - len(strings.TrimLeft(line, " "))
|
||||
if leadingSpaces <= authIndent {
|
||||
// 离开 auth 块
|
||||
inAuthBlock = false
|
||||
authIndent = -1
|
||||
// 继续寻找其它 auth 块(理论上没有)
|
||||
if strings.HasPrefix(trimmed, "auth:") {
|
||||
inAuthBlock = true
|
||||
authIndent = leadingSpaces
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(strings.TrimSpace(line), "password:") {
|
||||
prefix := line[:len(line)-len(strings.TrimLeft(line, " "))]
|
||||
comment := ""
|
||||
if idx := strings.Index(line, "#"); idx >= 0 {
|
||||
comment = strings.TrimRight(line[idx:], " ")
|
||||
}
|
||||
|
||||
newLine := fmt.Sprintf("%spassword: %s", prefix, password)
|
||||
if comment != "" {
|
||||
if !strings.HasPrefix(comment, " ") {
|
||||
newLine += " "
|
||||
}
|
||||
newLine += comment
|
||||
}
|
||||
lines[i] = newLine
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644)
|
||||
}
|
||||
|
||||
func PrintGeneratedPasswordWarning(password string, persisted bool, persistErr string) {
|
||||
if strings.TrimSpace(password) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if persisted {
|
||||
fmt.Println("[CyberStrikeAI] ✅ 已为您自动生成并写入 Web 登录密码。")
|
||||
} else {
|
||||
if persistErr != "" {
|
||||
fmt.Printf("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码: %s\n", persistErr)
|
||||
} else {
|
||||
fmt.Println("[CyberStrikeAI] ⚠️ 无法自动写入配置文件中的密码。")
|
||||
}
|
||||
fmt.Println("请手动将以下随机密码写入 config.yaml 的 auth.password:")
|
||||
}
|
||||
|
||||
fmt.Println("----------------------------------------------------------------")
|
||||
fmt.Println("CyberStrikeAI Auto-Generated Web Password")
|
||||
fmt.Printf("Password: %s\n", password)
|
||||
fmt.Println("WARNING: Anyone with this password can fully control CyberStrikeAI.")
|
||||
fmt.Println("Please store it securely and change it in config.yaml as soon as possible.")
|
||||
fmt.Println("警告:持有此密码的人将拥有对 CyberStrikeAI 的完全控制权限。")
|
||||
fmt.Println("请妥善保管,并尽快在 config.yaml 中修改 auth.password!")
|
||||
fmt.Println("----------------------------------------------------------------")
|
||||
}
|
||||
|
||||
// generateRandomToken 生成用于 MCP 鉴权的随机字符串(64 位十六进制)
|
||||
func generateRandomToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// persistMCPAuth 将 MCP 的 auth_header / auth_header_value 写回配置文件
|
||||
func persistMCPAuth(path string, mcp *MCPConfig) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lines := strings.Split(string(data), "\n")
|
||||
inMcpBlock := false
|
||||
mcpIndent := -1
|
||||
|
||||
for i, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if !inMcpBlock {
|
||||
if strings.HasPrefix(trimmed, "mcp:") {
|
||||
inMcpBlock = true
|
||||
mcpIndent = len(line) - len(strings.TrimLeft(line, " "))
|
||||
}
|
||||
continue
|
||||
}
|
||||
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||||
continue
|
||||
}
|
||||
leadingSpaces := len(line) - len(strings.TrimLeft(line, " "))
|
||||
if leadingSpaces <= mcpIndent {
|
||||
inMcpBlock = false
|
||||
mcpIndent = -1
|
||||
if strings.HasPrefix(trimmed, "mcp:") {
|
||||
inMcpBlock = true
|
||||
mcpIndent = leadingSpaces
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := line[:leadingSpaces]
|
||||
rest := strings.TrimSpace(line[leadingSpaces:])
|
||||
comment := ""
|
||||
if idx := strings.Index(line, "#"); idx >= 0 {
|
||||
comment = strings.TrimRight(line[idx:], " ")
|
||||
}
|
||||
withComment := ""
|
||||
if comment != "" {
|
||||
if !strings.HasPrefix(comment, " ") {
|
||||
withComment = " "
|
||||
}
|
||||
withComment += comment
|
||||
}
|
||||
|
||||
if strings.HasPrefix(rest, "auth_header_value:") {
|
||||
lines[i] = fmt.Sprintf("%sauth_header_value: %q%s", prefix, mcp.AuthHeaderValue, withComment)
|
||||
} else if strings.HasPrefix(rest, "auth_header:") {
|
||||
lines[i] = fmt.Sprintf("%sauth_header: %q%s", prefix, mcp.AuthHeader, withComment)
|
||||
}
|
||||
}
|
||||
|
||||
return os.WriteFile(path, []byte(strings.Join(lines, "\n")), 0644)
|
||||
}
|
||||
|
||||
// EnsureMCPAuth 在 MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置
|
||||
func EnsureMCPAuth(path string, cfg *Config) error {
|
||||
if !cfg.MCP.Enabled || strings.TrimSpace(cfg.MCP.AuthHeaderValue) != "" {
|
||||
return nil
|
||||
}
|
||||
token, err := generateRandomToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("生成 MCP 鉴权密钥失败: %w", err)
|
||||
}
|
||||
cfg.MCP.AuthHeaderValue = token
|
||||
if strings.TrimSpace(cfg.MCP.AuthHeader) == "" {
|
||||
cfg.MCP.AuthHeader = "X-MCP-Token"
|
||||
}
|
||||
return persistMCPAuth(path, &cfg.MCP)
|
||||
}
|
||||
|
||||
// PrintMCPConfigJSON 向终端输出 MCP 配置的 JSON,可直接复制到 Cursor / Claude Code 的 mcp 配置中使用
|
||||
func PrintMCPConfigJSON(mcp MCPConfig) {
|
||||
if !mcp.Enabled {
|
||||
return
|
||||
}
|
||||
hostForURL := strings.TrimSpace(mcp.Host)
|
||||
if hostForURL == "" || hostForURL == "0.0.0.0" {
|
||||
hostForURL = "localhost"
|
||||
}
|
||||
url := fmt.Sprintf("http://%s:%d/mcp", hostForURL, mcp.Port)
|
||||
headers := map[string]string{}
|
||||
if mcp.AuthHeader != "" {
|
||||
headers[mcp.AuthHeader] = mcp.AuthHeaderValue
|
||||
}
|
||||
serverEntry := map[string]interface{}{
|
||||
"url": url,
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
serverEntry["headers"] = headers
|
||||
}
|
||||
// Claude Code 需要 type: "http"
|
||||
serverEntry["type"] = "http"
|
||||
out := map[string]interface{}{
|
||||
"mcpServers": map[string]interface{}{
|
||||
"cyberstrike-ai": serverEntry,
|
||||
},
|
||||
}
|
||||
b, _ := json.MarshalIndent(out, "", " ")
|
||||
fmt.Println("[CyberStrikeAI] MCP 配置(可复制到 Cursor / Claude Code 使用):")
|
||||
fmt.Println(" Cursor: 放入 ~/.cursor/mcp.json 的 mcpServers,或项目 .cursor/mcp.json")
|
||||
fmt.Println(" Claude Code: 放入 .mcp.json 或 ~/.claude.json 的 mcpServers")
|
||||
fmt.Println("----------------------------------------------------------------")
|
||||
fmt.Println(string(b))
|
||||
fmt.Println("----------------------------------------------------------------")
|
||||
}
|
||||
|
||||
// LoadToolsFromDir 从目录加载所有工具配置文件
|
||||
func LoadToolsFromDir(dir string) ([]ToolConfig, error) {
|
||||
var tools []ToolConfig
|
||||
|
||||
// 检查目录是否存在
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
return tools, nil // 目录不存在时返回空列表,不报错
|
||||
}
|
||||
|
||||
// 读取目录中的所有 .yaml 和 .yml 文件
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取工具目录失败: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(dir, name)
|
||||
tool, err := LoadToolFromFile(filePath)
|
||||
if err != nil {
|
||||
// 记录错误但继续加载其他文件
|
||||
fmt.Printf("警告: 加载工具配置文件 %s 失败: %v\n", filePath, err)
|
||||
continue
|
||||
}
|
||||
|
||||
tools = append(tools, *tool)
|
||||
}
|
||||
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
// LoadToolFromFile 从单个文件加载工具配置
|
||||
func LoadToolFromFile(path string) (*ToolConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取文件失败: %w", err)
|
||||
}
|
||||
|
||||
var tool ToolConfig
|
||||
if err := yaml.Unmarshal(data, &tool); err != nil {
|
||||
return nil, fmt.Errorf("解析工具配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证必需字段
|
||||
if tool.Name == "" {
|
||||
return nil, fmt.Errorf("工具名称不能为空")
|
||||
}
|
||||
if tool.Command == "" {
|
||||
return nil, fmt.Errorf("工具命令不能为空")
|
||||
}
|
||||
|
||||
return &tool, nil
|
||||
}
|
||||
|
||||
// LoadRolesFromDir 从目录加载所有角色配置文件
|
||||
func LoadRolesFromDir(dir string) (map[string]RoleConfig, error) {
|
||||
roles := make(map[string]RoleConfig)
|
||||
|
||||
// 检查目录是否存在
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
return roles, nil // 目录不存在时返回空map,不报错
|
||||
}
|
||||
|
||||
// 读取目录中的所有 .yaml 和 .yml 文件
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取角色目录失败: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(dir, name)
|
||||
role, err := LoadRoleFromFile(filePath)
|
||||
if err != nil {
|
||||
// 记录错误但继续加载其他文件
|
||||
fmt.Printf("警告: 加载角色配置文件 %s 失败: %v\n", filePath, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 使用角色名称作为key
|
||||
roleName := role.Name
|
||||
if roleName == "" {
|
||||
// 如果角色名称为空,使用文件名(去掉扩展名)作为名称
|
||||
roleName = strings.TrimSuffix(strings.TrimSuffix(name, ".yaml"), ".yml")
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
roles[roleName] = *role
|
||||
}
|
||||
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// LoadRoleFromFile 从单个文件加载角色配置
|
||||
func LoadRoleFromFile(path string) (*RoleConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取文件失败: %w", err)
|
||||
}
|
||||
|
||||
var role RoleConfig
|
||||
if err := yaml.Unmarshal(data, &role); err != nil {
|
||||
return nil, fmt.Errorf("解析角色配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 处理 icon 字段:如果包含 Unicode 转义格式(\U0001F3C6),转换为实际的 Unicode 字符
|
||||
// Go 的 yaml 库可能不会自动解析 \U 转义序列,需要手动转换
|
||||
if role.Icon != "" {
|
||||
icon := role.Icon
|
||||
// 去除可能的引号
|
||||
icon = strings.Trim(icon, `"`)
|
||||
|
||||
// 检查是否是 Unicode 转义格式 \U0001F3C6(8位十六进制)或 \uXXXX(4位十六进制)
|
||||
if len(icon) >= 3 && icon[0] == '\\' {
|
||||
if icon[1] == 'U' && len(icon) >= 10 {
|
||||
// \U0001F3C6 格式(8位十六进制)
|
||||
if codePoint, err := strconv.ParseInt(icon[2:10], 16, 32); err == nil {
|
||||
role.Icon = string(rune(codePoint))
|
||||
}
|
||||
} else if icon[1] == 'u' && len(icon) >= 6 {
|
||||
// \uXXXX 格式(4位十六进制)
|
||||
if codePoint, err := strconv.ParseInt(icon[2:6], 16, 32); err == nil {
|
||||
role.Icon = string(rune(codePoint))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证必需字段
|
||||
if role.Name == "" {
|
||||
// 如果名称为空,尝试从文件名获取
|
||||
baseName := filepath.Base(path)
|
||||
role.Name = strings.TrimSuffix(strings.TrimSuffix(baseName, ".yaml"), ".yml")
|
||||
}
|
||||
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
func Default() *Config {
|
||||
return &Config{
|
||||
Server: ServerConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
},
|
||||
Log: LogConfig{
|
||||
Level: "info",
|
||||
Output: "stdout",
|
||||
},
|
||||
MCP: MCPConfig{
|
||||
Enabled: true,
|
||||
Host: "0.0.0.0",
|
||||
Port: 8081,
|
||||
},
|
||||
OpenAI: OpenAIConfig{
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Model: "gpt-4",
|
||||
MaxTotalTokens: 120000,
|
||||
},
|
||||
Agent: AgentConfig{
|
||||
MaxIterations: 30, // 默认最大迭代次数
|
||||
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
|
||||
},
|
||||
Security: SecurityConfig{
|
||||
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载
|
||||
ToolsDir: "tools", // 默认工具目录
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
Path: "data/conversations.db",
|
||||
KnowledgeDBPath: "data/knowledge.db", // 默认知识库数据库路径
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
SessionDurationHours: 12,
|
||||
},
|
||||
Knowledge: KnowledgeConfig{
|
||||
Enabled: true,
|
||||
BasePath: "knowledge_base",
|
||||
Embedding: EmbeddingConfig{
|
||||
Provider: "openai",
|
||||
Model: "text-embedding-3-small",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
},
|
||||
Retrieval: RetrievalConfig{
|
||||
TopK: 5,
|
||||
SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检
|
||||
},
|
||||
Indexing: IndexingConfig{
|
||||
ChunkStrategy: "markdown_then_recursive",
|
||||
RequestTimeoutSeconds: 120,
|
||||
ChunkSize: 768, // 增加到 768,更好的上下文保持
|
||||
ChunkOverlap: 50,
|
||||
MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额
|
||||
BatchSize: 64,
|
||||
PreferSourceFile: false,
|
||||
MaxRPM: 100, // 默认 100 RPM,避免 429 错误
|
||||
RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM
|
||||
MaxRetries: 3,
|
||||
RetryDelayMs: 1000,
|
||||
SubIndexes: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// KnowledgeConfig 知识库配置
|
||||
type KnowledgeConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用知识检索
|
||||
BasePath string `yaml:"base_path" json:"base_path"` // 知识库路径
|
||||
Embedding EmbeddingConfig `yaml:"embedding" json:"embedding"`
|
||||
Retrieval RetrievalConfig `yaml:"retrieval" json:"retrieval"`
|
||||
Indexing IndexingConfig `yaml:"indexing,omitempty" json:"indexing,omitempty"` // 索引构建配置
|
||||
}
|
||||
|
||||
// IndexingConfig 索引构建配置(用于控制知识库索引构建时的行为)
|
||||
type IndexingConfig struct {
|
||||
// ChunkStrategy: "markdown_then_recursive"(默认,Eino Markdown 标题切分后再递归切)或 "recursive"(仅递归切分)
|
||||
ChunkStrategy string `yaml:"chunk_strategy,omitempty" json:"chunk_strategy,omitempty"`
|
||||
// RequestTimeoutSeconds 嵌入 HTTP 客户端超时(秒),0 表示使用默认 120
|
||||
RequestTimeoutSeconds int `yaml:"request_timeout_seconds,omitempty" json:"request_timeout_seconds,omitempty"`
|
||||
// 分块配置
|
||||
ChunkSize int `yaml:"chunk_size,omitempty" json:"chunk_size,omitempty"` // 每个块的最大 token 数(估算),默认 512
|
||||
ChunkOverlap int `yaml:"chunk_overlap,omitempty" json:"chunk_overlap,omitempty"` // 块之间的重叠 token 数,默认 50
|
||||
MaxChunksPerItem int `yaml:"max_chunks_per_item,omitempty" json:"max_chunks_per_item,omitempty"` // 单个知识项的最大块数量,0 表示不限制
|
||||
|
||||
// PreferSourceFile 为 true 时优先用 Eino FileLoader 从 file_path 读原文再索引(与库内 content 不一致时以磁盘为准)
|
||||
PreferSourceFile bool `yaml:"prefer_source_file,omitempty" json:"prefer_source_file,omitempty"`
|
||||
|
||||
// 速率限制配置(用于避免 API 速率限制)
|
||||
RateLimitDelayMs int `yaml:"rate_limit_delay_ms,omitempty" json:"rate_limit_delay_ms,omitempty"` // 请求间隔时间(毫秒),0 表示不使用固定延迟
|
||||
MaxRPM int `yaml:"max_rpm,omitempty" json:"max_rpm,omitempty"` // 每分钟最大请求数,0 表示不限制
|
||||
|
||||
// 重试配置(用于处理临时错误)
|
||||
MaxRetries int `yaml:"max_retries,omitempty" json:"max_retries,omitempty"` // 最大重试次数,默认 3
|
||||
RetryDelayMs int `yaml:"retry_delay_ms,omitempty" json:"retry_delay_ms,omitempty"` // 重试间隔(毫秒),默认 1000
|
||||
|
||||
// BatchSize 嵌入批大小(SQLite 索引写入),0 表示默认 64
|
||||
BatchSize int `yaml:"batch_size,omitempty" json:"batch_size,omitempty"`
|
||||
// SubIndexes 传入 Eino indexer.WithSubIndexes(逻辑分区标记,随 Document 元数据传递)
|
||||
SubIndexes []string `yaml:"sub_indexes,omitempty" json:"sub_indexes,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingConfig 嵌入配置
|
||||
type EmbeddingConfig struct {
|
||||
Provider string `yaml:"provider" json:"provider"` // 嵌入模型提供商
|
||||
Model string `yaml:"model" json:"model"` // 模型名称
|
||||
BaseURL string `yaml:"base_url" json:"base_url"` // API Base URL
|
||||
APIKey string `yaml:"api_key" json:"api_key"` // API Key(从OpenAI配置继承)
|
||||
}
|
||||
|
||||
// PostRetrieveConfig 检索后处理:固定对正文做规范化去重(最佳实践)、上下文预算截断;PrefetchTopK 用于多取候选再收敛到 top_k。
|
||||
type PostRetrieveConfig struct {
|
||||
// PrefetchTopK 向量检索阶段最多保留的候选数(余弦序),应 ≥ top_k,0 表示与 top_k 相同;上限见知识库包内常量。
|
||||
PrefetchTopK int `yaml:"prefetch_top_k,omitempty" json:"prefetch_top_k,omitempty"`
|
||||
// MaxContextChars 返回文档内容总 Unicode 字符数上限(整段 chunk,不截断半段);0 表示不限制。
|
||||
MaxContextChars int `yaml:"max_context_chars,omitempty" json:"max_context_chars,omitempty"`
|
||||
// MaxContextTokens 返回文档内容总 token 上限(tiktoken,按嵌入模型名映射,失败则 cl100k_base);0 表示不限制。
|
||||
MaxContextTokens int `yaml:"max_context_tokens,omitempty" json:"max_context_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// RetrievalConfig 检索配置
|
||||
type RetrievalConfig struct {
|
||||
TopK int `yaml:"top_k" json:"top_k"` // 检索Top-K
|
||||
SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 余弦相似度阈值
|
||||
// SubIndexFilter 非空时仅保留 sub_indexes 含该标签(逗号分隔之一)的行;sub_indexes 为空的旧行仍返回。
|
||||
SubIndexFilter string `yaml:"sub_index_filter,omitempty" json:"sub_index_filter,omitempty"`
|
||||
// PostRetrieve 检索后处理(去重、预算截断);重排通过代码注入 [knowledge.DocumentReranker]。
|
||||
PostRetrieve PostRetrieveConfig `yaml:"post_retrieve,omitempty" json:"post_retrieve,omitempty"`
|
||||
}
|
||||
|
||||
// RolesConfig 角色配置(已废弃,使用 map[string]RoleConfig 替代)
|
||||
// 保留此类型以兼容旧代码,但建议直接使用 map[string]RoleConfig
|
||||
type RolesConfig struct {
|
||||
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"`
|
||||
}
|
||||
|
||||
// RoleConfig 单个角色配置
|
||||
type RoleConfig struct {
|
||||
Name string `yaml:"name" json:"name"` // 角色名称
|
||||
Description string `yaml:"description" json:"description"` // 角色描述
|
||||
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
|
||||
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
|
||||
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName")
|
||||
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代)
|
||||
Skills []string `yaml:"skills,omitempty" json:"skills,omitempty"` // 关联的skills列表(skill名称列表,在执行任务前会读取这些skills的内容)
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AttackChainNode 攻击链节点
|
||||
type AttackChainNode struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // tool, vulnerability, target, exploit
|
||||
Label string `json:"label"`
|
||||
ToolExecutionID string `json:"tool_execution_id,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
RiskScore int `json:"risk_score"`
|
||||
}
|
||||
|
||||
// AttackChainEdge 攻击链边
|
||||
type AttackChainEdge struct {
|
||||
ID string `json:"id"`
|
||||
Source string `json:"source"`
|
||||
Target string `json:"target"`
|
||||
Type string `json:"type"` // leads_to, exploits, enables, depends_on
|
||||
Weight int `json:"weight"`
|
||||
}
|
||||
|
||||
// SaveAttackChainNode 保存攻击链节点
|
||||
func (db *DB) SaveAttackChainNode(conversationID, nodeID, nodeType, nodeName, toolExecutionID, metadata string, riskScore int) error {
|
||||
var toolExecID sql.NullString
|
||||
if toolExecutionID != "" {
|
||||
toolExecID = sql.NullString{String: toolExecutionID, Valid: true}
|
||||
}
|
||||
|
||||
var metadataJSON sql.NullString
|
||||
if metadata != "" {
|
||||
metadataJSON = sql.NullString{String: metadata, Valid: true}
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT OR REPLACE INTO attack_chain_nodes
|
||||
(id, conversation_id, node_type, node_name, tool_execution_id, metadata, risk_score, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
|
||||
`
|
||||
|
||||
_, err := db.Exec(query, nodeID, conversationID, nodeType, nodeName, toolExecID, metadataJSON, riskScore)
|
||||
if err != nil {
|
||||
db.logger.Error("保存攻击链节点失败", zap.Error(err), zap.String("nodeId", nodeID))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveAttackChainEdge 保存攻击链边
|
||||
func (db *DB) SaveAttackChainEdge(conversationID, edgeID, sourceNodeID, targetNodeID, edgeType string, weight int) error {
|
||||
query := `
|
||||
INSERT OR REPLACE INTO attack_chain_edges
|
||||
(id, conversation_id, source_node_id, target_node_id, edge_type, weight, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
|
||||
`
|
||||
|
||||
_, err := db.Exec(query, edgeID, conversationID, sourceNodeID, targetNodeID, edgeType, weight)
|
||||
if err != nil {
|
||||
db.logger.Error("保存攻击链边失败", zap.Error(err), zap.String("edgeId", edgeID))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAttackChainNodes 加载攻击链节点
|
||||
func (db *DB) LoadAttackChainNodes(conversationID string) ([]AttackChainNode, error) {
|
||||
query := `
|
||||
SELECT id, node_type, node_name, tool_execution_id, metadata, risk_score
|
||||
FROM attack_chain_nodes
|
||||
WHERE conversation_id = ?
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, conversationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询攻击链节点失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var nodes []AttackChainNode
|
||||
for rows.Next() {
|
||||
var node AttackChainNode
|
||||
var toolExecID sql.NullString
|
||||
var metadataJSON sql.NullString
|
||||
|
||||
err := rows.Scan(&node.ID, &node.Type, &node.Label, &toolExecID, &metadataJSON, &node.RiskScore)
|
||||
if err != nil {
|
||||
db.logger.Warn("扫描攻击链节点失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
if toolExecID.Valid {
|
||||
node.ToolExecutionID = toolExecID.String
|
||||
}
|
||||
|
||||
if metadataJSON.Valid && metadataJSON.String != "" {
|
||||
if err := json.Unmarshal([]byte(metadataJSON.String), &node.Metadata); err != nil {
|
||||
db.logger.Warn("解析节点元数据失败", zap.Error(err))
|
||||
node.Metadata = make(map[string]interface{})
|
||||
}
|
||||
} else {
|
||||
node.Metadata = make(map[string]interface{})
|
||||
}
|
||||
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// LoadAttackChainEdges 加载攻击链边
|
||||
func (db *DB) LoadAttackChainEdges(conversationID string) ([]AttackChainEdge, error) {
|
||||
query := `
|
||||
SELECT id, source_node_id, target_node_id, edge_type, weight
|
||||
FROM attack_chain_edges
|
||||
WHERE conversation_id = ?
|
||||
ORDER BY created_at ASC
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, conversationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询攻击链边失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var edges []AttackChainEdge
|
||||
for rows.Next() {
|
||||
var edge AttackChainEdge
|
||||
|
||||
err := rows.Scan(&edge.ID, &edge.Source, &edge.Target, &edge.Type, &edge.Weight)
|
||||
if err != nil {
|
||||
db.logger.Warn("扫描攻击链边失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
edges = append(edges, edge)
|
||||
}
|
||||
|
||||
return edges, nil
|
||||
}
|
||||
|
||||
// DeleteAttackChain 删除对话的攻击链数据
|
||||
func (db *DB) DeleteAttackChain(conversationID string) error {
|
||||
// 先删除边(因为有外键约束)
|
||||
_, err := db.Exec("DELETE FROM attack_chain_edges WHERE conversation_id = ?", conversationID)
|
||||
if err != nil {
|
||||
db.logger.Warn("删除攻击链边失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 再删除节点
|
||||
_, err = db.Exec("DELETE FROM attack_chain_nodes WHERE conversation_id = ?", conversationID)
|
||||
if err != nil {
|
||||
db.logger.Error("删除攻击链节点失败", zap.Error(err), zap.String("conversationId", conversationID))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,537 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// BatchTaskQueueRow 批量任务队列数据库行
|
||||
type BatchTaskQueueRow struct {
|
||||
ID string
|
||||
Title sql.NullString
|
||||
Role sql.NullString
|
||||
AgentMode sql.NullString
|
||||
ScheduleMode sql.NullString
|
||||
CronExpr sql.NullString
|
||||
NextRunAt sql.NullTime
|
||||
ScheduleEnabled sql.NullInt64
|
||||
LastScheduleTriggerAt sql.NullTime
|
||||
LastScheduleError sql.NullString
|
||||
LastRunError sql.NullString
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
StartedAt sql.NullTime
|
||||
CompletedAt sql.NullTime
|
||||
CurrentIndex int
|
||||
}
|
||||
|
||||
// BatchTaskRow 批量任务数据库行
|
||||
type BatchTaskRow struct {
|
||||
ID string
|
||||
QueueID string
|
||||
Message string
|
||||
ConversationID sql.NullString
|
||||
Status string
|
||||
StartedAt sql.NullTime
|
||||
CompletedAt sql.NullTime
|
||||
Error sql.NullString
|
||||
Result sql.NullString
|
||||
}
|
||||
|
||||
// CreateBatchQueue 创建批量任务队列
|
||||
func (db *DB) CreateBatchQueue(
|
||||
queueID string,
|
||||
title string,
|
||||
role string,
|
||||
agentMode string,
|
||||
scheduleMode string,
|
||||
cronExpr string,
|
||||
nextRunAt *time.Time,
|
||||
tasks []map[string]interface{},
|
||||
) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("开始事务失败: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
now := time.Now()
|
||||
var nextRunAtValue interface{}
|
||||
if nextRunAt != nil {
|
||||
nextRunAtValue = *nextRunAt
|
||||
}
|
||||
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, "pending", now, 0,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
||||
}
|
||||
|
||||
// 插入任务
|
||||
for _, task := range tasks {
|
||||
taskID, ok := task["id"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
message, ok := task["message"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)",
|
||||
taskID, queueID, message, "pending",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建批量任务失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetBatchQueue 获取批量任务队列
|
||||
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
||||
queueID,
|
||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务队列失败: %w", err)
|
||||
}
|
||||
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
if parseErr != nil {
|
||||
// 尝试其他时间格式
|
||||
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
|
||||
if parseErr != nil {
|
||||
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
|
||||
parsedTime = time.Now()
|
||||
}
|
||||
}
|
||||
row.CreatedAt = parsedTime
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// GetAllBatchQueues 获取所有批量任务队列
|
||||
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var queues []*BatchTaskQueueRow
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
if parseErr != nil {
|
||||
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
|
||||
if parseErr != nil {
|
||||
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
|
||||
parsedTime = time.Now()
|
||||
}
|
||||
}
|
||||
row.CreatedAt = parsedTime
|
||||
queues = append(queues, &row)
|
||||
}
|
||||
|
||||
return queues, nil
|
||||
}
|
||||
|
||||
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
||||
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
||||
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
// 状态筛选
|
||||
if status != "" && status != "all" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
|
||||
// 关键字搜索(搜索队列ID和标题)
|
||||
if keyword != "" {
|
||||
query += " AND (id LIKE ? OR title LIKE ?)"
|
||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var queues []*BatchTaskQueueRow
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
if parseErr != nil {
|
||||
parsedTime, parseErr = time.Parse(time.RFC3339, createdAt)
|
||||
if parseErr != nil {
|
||||
db.logger.Warn("解析创建时间失败", zap.String("createdAt", createdAt), zap.Error(parseErr))
|
||||
parsedTime = time.Now()
|
||||
}
|
||||
}
|
||||
row.CreatedAt = parsedTime
|
||||
queues = append(queues, &row)
|
||||
}
|
||||
|
||||
return queues, nil
|
||||
}
|
||||
|
||||
// CountBatchQueues 统计批量任务队列总数(支持筛选条件)
|
||||
func (db *DB) CountBatchQueues(status, keyword string) (int, error) {
|
||||
query := "SELECT COUNT(*) FROM batch_task_queues WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
// 状态筛选
|
||||
if status != "" && status != "all" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
|
||||
// 关键字搜索(搜索队列ID和标题)
|
||||
if keyword != "" {
|
||||
query += " AND (id LIKE ? OR title LIKE ?)"
|
||||
args = append(args, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
var count int
|
||||
err := db.QueryRow(query, args...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("统计批量任务队列总数失败: %w", err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetBatchTasks 获取批量任务队列的所有任务
|
||||
func (db *DB) GetBatchTasks(queueID string) ([]*BatchTaskRow, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, queue_id, message, conversation_id, status, started_at, completed_at, error, result FROM batch_tasks WHERE queue_id = ? ORDER BY id",
|
||||
queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tasks []*BatchTaskRow
|
||||
for rows.Next() {
|
||||
var task BatchTaskRow
|
||||
if err := rows.Scan(
|
||||
&task.ID, &task.QueueID, &task.Message, &task.ConversationID,
|
||||
&task.Status, &task.StartedAt, &task.CompletedAt, &task.Error, &task.Result,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务失败: %w", err)
|
||||
}
|
||||
tasks = append(tasks, &task)
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// UpdateBatchQueueStatus 更新批量任务队列状态
|
||||
func (db *DB) UpdateBatchQueueStatus(queueID, status string) error {
|
||||
var err error
|
||||
now := time.Now()
|
||||
|
||||
if status == "running" {
|
||||
_, err = db.Exec(
|
||||
"UPDATE batch_task_queues SET status = ?, started_at = COALESCE(started_at, ?) WHERE id = ?",
|
||||
status, now, queueID,
|
||||
)
|
||||
} else if status == "completed" || status == "cancelled" {
|
||||
_, err = db.Exec(
|
||||
"UPDATE batch_task_queues SET status = ?, completed_at = COALESCE(completed_at, ?) WHERE id = ?",
|
||||
status, now, queueID,
|
||||
)
|
||||
} else {
|
||||
_, err = db.Exec(
|
||||
"UPDATE batch_task_queues SET status = ? WHERE id = ?",
|
||||
status, queueID,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务队列状态失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateBatchTaskStatus 更新批量任务状态
|
||||
func (db *DB) UpdateBatchTaskStatus(queueID, taskID, status string, conversationID, result, errorMsg string) error {
|
||||
var err error
|
||||
now := time.Now()
|
||||
|
||||
// 构建更新语句
|
||||
var updates []string
|
||||
var args []interface{}
|
||||
|
||||
updates = append(updates, "status = ?")
|
||||
args = append(args, status)
|
||||
|
||||
if conversationID != "" {
|
||||
updates = append(updates, "conversation_id = ?")
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
|
||||
if result != "" {
|
||||
updates = append(updates, "result = ?")
|
||||
args = append(args, result)
|
||||
}
|
||||
|
||||
if errorMsg != "" {
|
||||
updates = append(updates, "error = ?")
|
||||
args = append(args, errorMsg)
|
||||
}
|
||||
|
||||
if status == "running" {
|
||||
updates = append(updates, "started_at = COALESCE(started_at, ?)")
|
||||
args = append(args, now)
|
||||
}
|
||||
|
||||
if status == "completed" || status == "failed" || status == "cancelled" {
|
||||
updates = append(updates, "completed_at = COALESCE(completed_at, ?)")
|
||||
args = append(args, now)
|
||||
}
|
||||
|
||||
args = append(args, queueID, taskID)
|
||||
|
||||
// 构建SQL语句
|
||||
sql := "UPDATE batch_tasks SET "
|
||||
for i, update := range updates {
|
||||
if i > 0 {
|
||||
sql += ", "
|
||||
}
|
||||
sql += update
|
||||
}
|
||||
sql += " WHERE queue_id = ? AND id = ?"
|
||||
|
||||
_, err = db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务状态失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateBatchQueueCurrentIndex 更新批量任务队列的当前索引
|
||||
func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_task_queues SET current_index = ? WHERE id = ?",
|
||||
currentIndex, queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务队列当前索引失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色和代理模式
|
||||
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?",
|
||||
title, role, agentMode, queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateBatchQueueSchedule 更新批量任务队列调度相关信息
|
||||
func (db *DB) UpdateBatchQueueSchedule(queueID, scheduleMode, cronExpr string, nextRunAt *time.Time) error {
|
||||
var nextRunAtValue interface{}
|
||||
if nextRunAt != nil {
|
||||
nextRunAtValue = *nextRunAt
|
||||
}
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_task_queues SET schedule_mode = ?, cron_expr = ?, next_run_at = ? WHERE id = ?",
|
||||
scheduleMode, cronExpr, nextRunAtValue, queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务调度配置失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateBatchQueueScheduleEnabled 是否允许 Cron 自动触发(手工「开始执行」不受影响)
|
||||
func (db *DB) UpdateBatchQueueScheduleEnabled(queueID string, enabled bool) error {
|
||||
v := 0
|
||||
if enabled {
|
||||
v = 1
|
||||
}
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_task_queues SET schedule_enabled = ? WHERE id = ?",
|
||||
v, queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务调度开关失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordBatchQueueScheduledTriggerStart 记录一次由调度触发的开始时间并清空调度层错误
|
||||
func (db *DB) RecordBatchQueueScheduledTriggerStart(queueID string, at time.Time) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_task_queues SET last_schedule_trigger_at = ?, last_schedule_error = NULL WHERE id = ?",
|
||||
at, queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("记录调度触发时间失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetBatchQueueLastScheduleError 调度启动失败等原因(如状态不允许、重置失败)
|
||||
func (db *DB) SetBatchQueueLastScheduleError(queueID, msg string) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_task_queues SET last_schedule_error = ? WHERE id = ?",
|
||||
msg, queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("写入调度错误信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetBatchQueueLastRunError 最近一轮执行中出现的子任务失败摘要(空串表示清空)
|
||||
func (db *DB) SetBatchQueueLastRunError(queueID, msg string) error {
|
||||
var v interface{}
|
||||
if strings.TrimSpace(msg) == "" {
|
||||
v = nil
|
||||
} else {
|
||||
v = msg
|
||||
}
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_task_queues SET last_run_error = ? WHERE id = ?",
|
||||
v, queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("写入最近运行错误失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetBatchQueueForRerun 重置队列和任务状态用于下一轮调度执行
|
||||
func (db *DB) ResetBatchQueueForRerun(queueID string) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("开始事务失败: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec(
|
||||
"UPDATE batch_task_queues SET status = ?, current_index = 0, started_at = NULL, completed_at = NULL, last_run_error = NULL, last_schedule_error = NULL WHERE id = ?",
|
||||
"pending", queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("重置批量任务队列状态失败: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(
|
||||
"UPDATE batch_tasks SET status = ?, conversation_id = NULL, started_at = NULL, completed_at = NULL, error = NULL, result = NULL WHERE queue_id = ?",
|
||||
"pending", queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("重置批量任务状态失败: %w", err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// UpdateBatchTaskMessage 更新批量任务消息
|
||||
func (db *DB) UpdateBatchTaskMessage(queueID, taskID, message string) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_tasks SET message = ? WHERE queue_id = ? AND id = ?",
|
||||
message, queueID, taskID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务消息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddBatchTask 添加任务到批量任务队列
|
||||
func (db *DB) AddBatchTask(queueID, taskID, message string) error {
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO batch_tasks (id, queue_id, message, status) VALUES (?, ?, ?, ?)",
|
||||
taskID, queueID, message, "pending",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加批量任务失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CancelPendingBatchTasks 批量取消队列中所有 pending 状态的任务(单条 SQL)
|
||||
func (db *DB) CancelPendingBatchTasks(queueID string, completedAt time.Time) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_tasks SET status = ?, completed_at = ? WHERE queue_id = ? AND status = ?",
|
||||
"cancelled", completedAt, queueID, "pending",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("批量取消 pending 任务失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteBatchTask 删除批量任务
|
||||
func (db *DB) DeleteBatchTask(queueID, taskID string) error {
|
||||
_, err := db.Exec(
|
||||
"DELETE FROM batch_tasks WHERE queue_id = ? AND id = ?",
|
||||
queueID, taskID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除批量任务失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteBatchQueue 删除批量任务队列
|
||||
func (db *DB) DeleteBatchQueue(queueID string) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("开始事务失败: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// 删除任务(外键会自动级联删除)
|
||||
_, err = tx.Exec("DELETE FROM batch_tasks WHERE queue_id = ?", queueID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除批量任务失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除队列
|
||||
_, err = tx.Exec("DELETE FROM batch_task_queues WHERE id = ?", queueID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除批量任务队列失败: %w", err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
@@ -0,0 +1,758 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Conversation 对话
|
||||
type Conversation struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Pinned bool `json:"pinned"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
// Message 消息
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||||
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
func (db *DB) CreateConversation(title string) (*Conversation, error) {
|
||||
return db.CreateConversationWithWebshell("", title)
|
||||
}
|
||||
|
||||
// CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话)
|
||||
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string) (*Conversation, error) {
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
|
||||
var err error
|
||||
if webshellConnectionID != "" {
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversations (id, title, created_at, updated_at, webshell_connection_id) VALUES (?, ?, ?, ?, ?)",
|
||||
id, title, now, now, webshellConnectionID,
|
||||
)
|
||||
} else {
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
|
||||
id, title, now, now,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||||
}
|
||||
|
||||
return &Conversation{
|
||||
ID: id,
|
||||
Title: title,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化)
|
||||
func (db *DB) GetConversationByWebshellConnectionID(connectionID string) (*Conversation, error) {
|
||||
if connectionID == "" {
|
||||
return nil, fmt.Errorf("connectionID is empty")
|
||||
}
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE webshell_connection_id = ? ORDER BY updated_at DESC LIMIT 1",
|
||||
connectionID,
|
||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("查询对话失败: %w", err)
|
||||
}
|
||||
conv.Pinned = pinned != 0
|
||||
if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt); e == nil {
|
||||
conv.CreatedAt = t
|
||||
} else if t, e := time.Parse("2006-01-02 15:04:05", createdAt); e == nil {
|
||||
conv.CreatedAt = t
|
||||
} else {
|
||||
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil {
|
||||
conv.UpdatedAt = t
|
||||
} else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil {
|
||||
conv.UpdatedAt = t
|
||||
} else {
|
||||
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
messages, err := db.GetMessages(conv.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载消息失败: %w", err)
|
||||
}
|
||||
conv.Messages = messages
|
||||
|
||||
// 加载过程详情并附加到对应消息(与 GetConversation 一致,便于刷新后仍可查看执行过程)
|
||||
processDetailsMap, err := db.GetProcessDetailsByConversation(conv.ID)
|
||||
if err != nil {
|
||||
db.logger.Warn("加载过程详情失败", zap.Error(err))
|
||||
processDetailsMap = make(map[string][]ProcessDetail)
|
||||
}
|
||||
for i := range conv.Messages {
|
||||
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
||||
detailsJSON := make([]map[string]interface{}, len(details))
|
||||
for j, detail := range details {
|
||||
var data interface{}
|
||||
if detail.Data != "" {
|
||||
if err := json.Unmarshal([]byte(detail.Data), &data); err != nil {
|
||||
db.logger.Warn("解析过程详情数据失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
detailsJSON[j] = map[string]interface{}{
|
||||
"id": detail.ID,
|
||||
"messageId": detail.MessageID,
|
||||
"conversationId": detail.ConversationID,
|
||||
"eventType": detail.EventType,
|
||||
"message": detail.Message,
|
||||
"data": data,
|
||||
"createdAt": detail.CreatedAt,
|
||||
}
|
||||
}
|
||||
conv.Messages[i].ProcessDetails = detailsJSON
|
||||
}
|
||||
}
|
||||
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
// WebShellConversationItem 用于侧边栏列表,不含消息
|
||||
type WebShellConversationItem struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// ListConversationsByWebshellConnectionID 列出该 WebShell 连接下的所有对话(按更新时间倒序),供侧边栏展示
|
||||
func (db *DB) ListConversationsByWebshellConnectionID(connectionID string) ([]WebShellConversationItem, error) {
|
||||
if connectionID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
rows, err := db.Query(
|
||||
"SELECT id, title, updated_at FROM conversations WHERE webshell_connection_id = ? ORDER BY updated_at DESC",
|
||||
connectionID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询对话列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var list []WebShellConversationItem
|
||||
for rows.Next() {
|
||||
var item WebShellConversationItem
|
||||
var updatedAt string
|
||||
if err := rows.Scan(&item.ID, &item.Title, &updatedAt); err != nil {
|
||||
continue
|
||||
}
|
||||
if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil {
|
||||
item.UpdatedAt = t
|
||||
} else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil {
|
||||
item.UpdatedAt = t
|
||||
} else {
|
||||
item.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
list = append(list, item)
|
||||
}
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
// GetConversation 获取对话
|
||||
func (db *DB) GetConversation(id string) (*Conversation, error) {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?",
|
||||
id,
|
||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("对话不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询对话失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
}
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
|
||||
conv.Pinned = pinned != 0
|
||||
|
||||
// 加载消息
|
||||
messages, err := db.GetMessages(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载消息失败: %w", err)
|
||||
}
|
||||
conv.Messages = messages
|
||||
|
||||
// 加载过程详情(按消息ID分组)
|
||||
processDetailsMap, err := db.GetProcessDetailsByConversation(id)
|
||||
if err != nil {
|
||||
db.logger.Warn("加载过程详情失败", zap.Error(err))
|
||||
processDetailsMap = make(map[string][]ProcessDetail)
|
||||
}
|
||||
|
||||
// 将过程详情附加到对应的消息上
|
||||
for i := range conv.Messages {
|
||||
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
||||
// 将ProcessDetail转换为JSON格式,以便前端使用
|
||||
detailsJSON := make([]map[string]interface{}, len(details))
|
||||
for j, detail := range details {
|
||||
var data interface{}
|
||||
if detail.Data != "" {
|
||||
if err := json.Unmarshal([]byte(detail.Data), &data); err != nil {
|
||||
db.logger.Warn("解析过程详情数据失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
detailsJSON[j] = map[string]interface{}{
|
||||
"id": detail.ID,
|
||||
"messageId": detail.MessageID,
|
||||
"conversationId": detail.ConversationID,
|
||||
"eventType": detail.EventType,
|
||||
"message": detail.Message,
|
||||
"data": data,
|
||||
"createdAt": detail.CreatedAt,
|
||||
}
|
||||
}
|
||||
conv.Messages[i].ProcessDetails = detailsJSON
|
||||
}
|
||||
}
|
||||
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
// GetConversationLite 获取对话(轻量版):包含 messages,但不加载 process_details。
|
||||
// 用于历史会话快速切换,避免一次性把大体量过程详情灌到前端导致卡顿。
|
||||
func (db *DB) GetConversationLite(id string) (*Conversation, error) {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, pinned, created_at, updated_at FROM conversations WHERE id = ?",
|
||||
id,
|
||||
).Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("对话不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询对话失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
}
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
|
||||
conv.Pinned = pinned != 0
|
||||
|
||||
// 加载消息(不加载 process_details)
|
||||
messages, err := db.GetMessages(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载消息失败: %w", err)
|
||||
}
|
||||
conv.Messages = messages
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
// ListConversations 列出所有对话
|
||||
func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
if search != "" {
|
||||
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
||||
searchPattern := "%" + search + "%"
|
||||
rows, err = db.Query(
|
||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at
|
||||
FROM conversations c
|
||||
WHERE c.title LIKE ?
|
||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
|
||||
ORDER BY c.updated_at DESC
|
||||
LIMIT ? OFFSET ?`,
|
||||
searchPattern, searchPattern, limit, offset,
|
||||
)
|
||||
} else {
|
||||
rows, err = db.Query(
|
||||
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
||||
limit, offset,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询对话列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var conversations []*Conversation
|
||||
for rows.Next() {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
|
||||
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描对话失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
}
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
|
||||
conv.Pinned = pinned != 0
|
||||
|
||||
conversations = append(conversations, &conv)
|
||||
}
|
||||
|
||||
return conversations, nil
|
||||
}
|
||||
|
||||
// UpdateConversationTitle 更新对话标题
|
||||
func (db *DB) UpdateConversationTitle(id, title string) error {
|
||||
// 注意:不更新 updated_at,因为重命名操作不应该改变对话的更新时间
|
||||
_, err := db.Exec(
|
||||
"UPDATE conversations SET title = ? WHERE id = ?",
|
||||
title, id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新对话标题失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateConversationTime 更新对话时间
|
||||
func (db *DB) UpdateConversationTime(id string) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE conversations SET updated_at = ? WHERE id = ?",
|
||||
time.Now(), id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新对话时间失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteConversation 删除对话及其所有相关数据
|
||||
// 由于数据库外键约束设置了 ON DELETE CASCADE,删除对话时会自动删除:
|
||||
// - messages(消息)
|
||||
// - process_details(过程详情)
|
||||
// - attack_chain_nodes(攻击链节点)
|
||||
// - attack_chain_edges(攻击链边)
|
||||
// - vulnerabilities(漏洞)
|
||||
// - conversation_group_mappings(分组映射)
|
||||
// 注意:knowledge_retrieval_logs 使用 ON DELETE SET NULL,记录会保留但 conversation_id 会被设为 NULL
|
||||
func (db *DB) DeleteConversation(id string) error {
|
||||
// 显式删除知识检索日志(虽然外键是SET NULL,但为了彻底清理,我们手动删除)
|
||||
_, err := db.Exec("DELETE FROM knowledge_retrieval_logs WHERE conversation_id = ?", id)
|
||||
if err != nil {
|
||||
db.logger.Warn("删除知识检索日志失败", zap.String("conversationId", id), zap.Error(err))
|
||||
// 不返回错误,继续删除对话
|
||||
}
|
||||
|
||||
// 删除对话(外键CASCADE会自动删除其他相关数据)
|
||||
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除对话失败: %w", err)
|
||||
}
|
||||
|
||||
db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveReActData 保存最后一轮ReAct的输入和输出
|
||||
func (db *DB) SaveReActData(conversationID, reactInput, reactOutput string) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?",
|
||||
reactInput, reactOutput, time.Now(), conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存ReAct数据失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetReActData 获取最后一轮ReAct的输入和输出
|
||||
func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput string, err error) {
|
||||
var input, output sql.NullString
|
||||
err = db.QueryRow(
|
||||
"SELECT last_react_input, last_react_output FROM conversations WHERE id = ?",
|
||||
conversationID,
|
||||
).Scan(&input, &output)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "", "", fmt.Errorf("对话不存在")
|
||||
}
|
||||
return "", "", fmt.Errorf("获取ReAct数据失败: %w", err)
|
||||
}
|
||||
|
||||
if input.Valid {
|
||||
reactInput = input.String
|
||||
}
|
||||
if output.Valid {
|
||||
reactOutput = output.String
|
||||
}
|
||||
|
||||
return reactInput, reactOutput, nil
|
||||
}
|
||||
|
||||
// ConversationHasToolProcessDetails 对话是否存在已落库的工具调用/结果(用于多代理等场景下 MCP execution id 未汇总时的攻击链判定)。
|
||||
func (db *DB) ConversationHasToolProcessDetails(conversationID string) (bool, error) {
|
||||
var n int
|
||||
err := db.QueryRow(
|
||||
`SELECT COUNT(*) FROM process_details WHERE conversation_id = ? AND event_type IN ('tool_call', 'tool_result')`,
|
||||
conversationID,
|
||||
).Scan(&n)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("查询过程详情失败: %w", err)
|
||||
}
|
||||
return n > 0, nil
|
||||
}
|
||||
|
||||
// AddMessage 添加消息
|
||||
func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) {
|
||||
id := uuid.New().String()
|
||||
|
||||
var mcpIDsJSON string
|
||||
if len(mcpExecutionIDs) > 0 {
|
||||
jsonData, err := json.Marshal(mcpExecutionIDs)
|
||||
if err != nil {
|
||||
db.logger.Warn("序列化MCP执行ID失败", zap.Error(err))
|
||||
} else {
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, role, content, mcpIDsJSON, time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("添加消息失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新对话时间
|
||||
if err := db.UpdateConversationTime(conversationID); err != nil {
|
||||
db.logger.Warn("更新对话时间失败", zap.Error(err))
|
||||
}
|
||||
|
||||
message := &Message{
|
||||
ID: id,
|
||||
ConversationID: conversationID,
|
||||
Role: role,
|
||||
Content: content,
|
||||
MCPExecutionIDs: mcpExecutionIDs,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return message, nil
|
||||
}
|
||||
|
||||
// GetMessages 获取对话的所有消息
|
||||
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||
conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询消息失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var msg Message
|
||||
var mcpIDsJSON sql.NullString
|
||||
var createdAt string
|
||||
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err error
|
||||
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err != nil {
|
||||
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err != nil {
|
||||
msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
// 解析MCP执行ID
|
||||
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
|
||||
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
|
||||
db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。
|
||||
// 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。
|
||||
func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) {
|
||||
idx := -1
|
||||
for i := range msgs {
|
||||
if msgs[i].ID == anchorID {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
return 0, 0, fmt.Errorf("message not found")
|
||||
}
|
||||
start = idx
|
||||
for start > 0 && msgs[start].Role != "user" {
|
||||
start--
|
||||
}
|
||||
if start < len(msgs) && msgs[start].Role != "user" {
|
||||
start = 0
|
||||
}
|
||||
end = len(msgs)
|
||||
for i := start + 1; i < len(msgs); i++ {
|
||||
if msgs[i].Role == "user" {
|
||||
end = i
|
||||
break
|
||||
}
|
||||
}
|
||||
return start, end, nil
|
||||
}
|
||||
|
||||
// DeleteConversationTurn 删除锚点所在轮次的全部消息(用户提问 + 该轮助手回复等),并清空 last_react_*,避免与消息表不一致。
|
||||
func (db *DB) DeleteConversationTurn(conversationID, anchorMessageID string) (deletedIDs []string, err error) {
|
||||
msgs, err := db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
start, end, err := turnSliceRange(msgs, anchorMessageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if start >= end {
|
||||
return nil, fmt.Errorf("empty turn range")
|
||||
}
|
||||
deletedIDs = make([]string, 0, end-start)
|
||||
for i := start; i < end; i++ {
|
||||
deletedIDs = append(deletedIDs, msgs[i].ID)
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
ph := strings.Repeat("?,", len(deletedIDs))
|
||||
ph = ph[:len(ph)-1]
|
||||
args := make([]interface{}, 0, 1+len(deletedIDs))
|
||||
args = append(args, conversationID)
|
||||
for _, id := range deletedIDs {
|
||||
args = append(args, id)
|
||||
}
|
||||
res, err := tx.Exec(
|
||||
"DELETE FROM messages WHERE conversation_id = ? AND id IN ("+ph+")",
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delete messages: %w", err)
|
||||
}
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int(n) != len(deletedIDs) {
|
||||
return nil, fmt.Errorf("deleted count mismatch")
|
||||
}
|
||||
|
||||
_, err = tx.Exec(
|
||||
`UPDATE conversations SET last_react_input = NULL, last_react_output = NULL, updated_at = ? WHERE id = ?`,
|
||||
time.Now(), conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("clear react data: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("commit: %w", err)
|
||||
}
|
||||
|
||||
db.logger.Info("conversation turn deleted",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Strings("deletedMessageIds", deletedIDs),
|
||||
zap.Int("count", len(deletedIDs)),
|
||||
)
|
||||
return deletedIDs, nil
|
||||
}
|
||||
|
||||
// ProcessDetail 过程详情事件
|
||||
type ProcessDetail struct {
|
||||
ID string `json:"id"`
|
||||
MessageID string `json:"messageId"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error
|
||||
Message string `json:"message"`
|
||||
Data string `json:"data"` // JSON格式的数据
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// AddProcessDetail 添加过程详情事件
|
||||
func (db *DB) AddProcessDetail(messageID, conversationID, eventType, message string, data interface{}) error {
|
||||
id := uuid.New().String()
|
||||
|
||||
var dataJSON string
|
||||
if data != nil {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
db.logger.Warn("序列化过程详情数据失败", zap.Error(err))
|
||||
} else {
|
||||
dataJSON = string(jsonData)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO process_details (id, message_id, conversation_id, event_type, message, data, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
id, messageID, conversationID, eventType, message, dataJSON, time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加过程详情失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProcessDetails 获取消息的过程详情
|
||||
func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC",
|
||||
messageID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询过程详情失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var details []ProcessDetail
|
||||
for rows.Next() {
|
||||
var detail ProcessDetail
|
||||
var createdAt string
|
||||
|
||||
if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描过程详情失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err error
|
||||
detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err != nil {
|
||||
detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err != nil {
|
||||
detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
details = append(details, detail)
|
||||
}
|
||||
|
||||
return details, nil
|
||||
}
|
||||
|
||||
// GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组)
|
||||
func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||
conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询过程详情失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
detailsMap := make(map[string][]ProcessDetail)
|
||||
for rows.Next() {
|
||||
var detail ProcessDetail
|
||||
var createdAt string
|
||||
|
||||
if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描过程详情失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err error
|
||||
detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err != nil {
|
||||
detail.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err != nil {
|
||||
detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
detailsMap[detail.MessageID] = append(detailsMap[detail.MessageID], detail)
|
||||
}
|
||||
|
||||
return detailsMap, nil
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTurnSliceRange(t *testing.T) {
|
||||
mk := func(id, role string) Message {
|
||||
return Message{ID: id, Role: role}
|
||||
}
|
||||
msgs := []Message{
|
||||
mk("u1", "user"),
|
||||
mk("a1", "assistant"),
|
||||
mk("u2", "user"),
|
||||
mk("a2", "assistant"),
|
||||
}
|
||||
cases := []struct {
|
||||
anchor string
|
||||
start int
|
||||
end int
|
||||
}{
|
||||
{"u1", 0, 2},
|
||||
{"a1", 0, 2},
|
||||
{"u2", 2, 4},
|
||||
{"a2", 2, 4},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
s, e, err := turnSliceRange(msgs, tc.anchor)
|
||||
if err != nil {
|
||||
t.Fatalf("anchor %s: %v", tc.anchor, err)
|
||||
}
|
||||
if s != tc.start || e != tc.end {
|
||||
t.Fatalf("anchor %s: got [%d,%d) want [%d,%d)", tc.anchor, s, e, tc.start, tc.end)
|
||||
}
|
||||
}
|
||||
if _, _, err := turnSliceRange(msgs, "nope"); err == nil {
|
||||
t.Fatal("expected error for missing id")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,809 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DB 数据库连接
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewDB 创建数据库连接
|
||||
func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据库失败: %w", err)
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{
|
||||
DB: db,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// 初始化表
|
||||
if err := database.initTables(); err != nil {
|
||||
return nil, fmt.Errorf("初始化表失败: %w", err)
|
||||
}
|
||||
|
||||
return database, nil
|
||||
}
|
||||
|
||||
// initTables 初始化数据库表
|
||||
func (db *DB) initTables() error {
|
||||
// 创建对话表
|
||||
createConversationsTable := `
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL,
|
||||
last_react_input TEXT,
|
||||
last_react_output TEXT
|
||||
);`
|
||||
|
||||
// 创建消息表
|
||||
createMessagesTable := `
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
mcp_execution_ids TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建过程详情表
|
||||
createProcessDetailsTable := `
|
||||
CREATE TABLE IF NOT EXISTS process_details (
|
||||
id TEXT PRIMARY KEY,
|
||||
message_id TEXT NOT NULL,
|
||||
conversation_id TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
message TEXT,
|
||||
data TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建工具执行记录表
|
||||
createToolExecutionsTable := `
|
||||
CREATE TABLE IF NOT EXISTS tool_executions (
|
||||
id TEXT PRIMARY KEY,
|
||||
tool_name TEXT NOT NULL,
|
||||
arguments TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
result TEXT,
|
||||
error TEXT,
|
||||
start_time DATETIME NOT NULL,
|
||||
end_time DATETIME,
|
||||
duration_ms INTEGER,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
// 创建工具统计表
|
||||
createToolStatsTable := `
|
||||
CREATE TABLE IF NOT EXISTS tool_stats (
|
||||
tool_name TEXT PRIMARY KEY,
|
||||
total_calls INTEGER NOT NULL DEFAULT 0,
|
||||
success_calls INTEGER NOT NULL DEFAULT 0,
|
||||
failed_calls INTEGER NOT NULL DEFAULT 0,
|
||||
last_call_time DATETIME,
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
// 创建Skills统计表
|
||||
createSkillStatsTable := `
|
||||
CREATE TABLE IF NOT EXISTS skill_stats (
|
||||
skill_name TEXT PRIMARY KEY,
|
||||
total_calls INTEGER NOT NULL DEFAULT 0,
|
||||
success_calls INTEGER NOT NULL DEFAULT 0,
|
||||
failed_calls INTEGER NOT NULL DEFAULT 0,
|
||||
last_call_time DATETIME,
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
// 创建攻击链节点表
|
||||
createAttackChainNodesTable := `
|
||||
CREATE TABLE IF NOT EXISTS attack_chain_nodes (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
node_type TEXT NOT NULL,
|
||||
node_name TEXT NOT NULL,
|
||||
tool_execution_id TEXT,
|
||||
metadata TEXT,
|
||||
risk_score INTEGER DEFAULT 0,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (tool_execution_id) REFERENCES tool_executions(id) ON DELETE SET NULL
|
||||
);`
|
||||
|
||||
// 创建攻击链边表
|
||||
createAttackChainEdgesTable := `
|
||||
CREATE TABLE IF NOT EXISTS attack_chain_edges (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
source_node_id TEXT NOT NULL,
|
||||
target_node_id TEXT NOT NULL,
|
||||
edge_type TEXT NOT NULL,
|
||||
weight INTEGER DEFAULT 1,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (source_node_id) REFERENCES attack_chain_nodes(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (target_node_id) REFERENCES attack_chain_nodes(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建知识检索日志表(保留在会话数据库中,因为有外键关联)
|
||||
createKnowledgeRetrievalLogsTable := `
|
||||
CREATE TABLE IF NOT EXISTS knowledge_retrieval_logs (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT,
|
||||
message_id TEXT,
|
||||
query TEXT NOT NULL,
|
||||
risk_type TEXT,
|
||||
retrieved_items TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE SET NULL,
|
||||
FOREIGN KEY (message_id) REFERENCES messages(id) ON DELETE SET NULL
|
||||
);`
|
||||
|
||||
// 创建对话分组表
|
||||
createConversationGroupsTable := `
|
||||
CREATE TABLE IF NOT EXISTS conversation_groups (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
icon TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL
|
||||
);`
|
||||
|
||||
// 创建对话分组映射表
|
||||
createConversationGroupMappingsTable := `
|
||||
CREATE TABLE IF NOT EXISTS conversation_group_mappings (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
group_id TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (group_id) REFERENCES conversation_groups(id) ON DELETE CASCADE,
|
||||
UNIQUE(conversation_id, group_id)
|
||||
);`
|
||||
|
||||
// 创建漏洞表
|
||||
createVulnerabilitiesTable := `
|
||||
CREATE TABLE IF NOT EXISTS vulnerabilities (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
description TEXT,
|
||||
severity TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'open',
|
||||
vulnerability_type TEXT,
|
||||
target TEXT,
|
||||
proof TEXT,
|
||||
impact TEXT,
|
||||
recommendation TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建批量任务队列表
|
||||
createBatchTaskQueuesTable := `
|
||||
CREATE TABLE IF NOT EXISTS batch_task_queues (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT,
|
||||
role TEXT,
|
||||
agent_mode TEXT NOT NULL DEFAULT 'single',
|
||||
schedule_mode TEXT NOT NULL DEFAULT 'manual',
|
||||
cron_expr TEXT,
|
||||
next_run_at DATETIME,
|
||||
schedule_enabled INTEGER NOT NULL DEFAULT 1,
|
||||
last_schedule_trigger_at DATETIME,
|
||||
last_schedule_error TEXT,
|
||||
last_run_error TEXT,
|
||||
status TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL,
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
current_index INTEGER NOT NULL DEFAULT 0
|
||||
);`
|
||||
|
||||
// 创建批量任务表
|
||||
createBatchTasksTable := `
|
||||
CREATE TABLE IF NOT EXISTS batch_tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
queue_id TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
conversation_id TEXT,
|
||||
status TEXT NOT NULL,
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
error TEXT,
|
||||
result TEXT,
|
||||
FOREIGN KEY (queue_id) REFERENCES batch_task_queues(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建 WebShell 连接表
|
||||
createWebshellConnectionsTable := `
|
||||
CREATE TABLE IF NOT EXISTS webshell_connections (
|
||||
id TEXT PRIMARY KEY,
|
||||
url TEXT NOT NULL,
|
||||
password TEXT NOT NULL DEFAULT '',
|
||||
type TEXT NOT NULL DEFAULT 'php',
|
||||
method TEXT NOT NULL DEFAULT 'post',
|
||||
cmd_param TEXT NOT NULL DEFAULT '',
|
||||
remark TEXT NOT NULL DEFAULT '',
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
// 创建 WebShell 连接扩展状态表(前端工作区/终端状态持久化)
|
||||
createWebshellConnectionStatesTable := `
|
||||
CREATE TABLE IF NOT EXISTS webshell_connection_states (
|
||||
connection_id TEXT PRIMARY KEY,
|
||||
state_json TEXT NOT NULL DEFAULT '{}',
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (connection_id) REFERENCES webshell_connections(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建索引
|
||||
createIndexes := `
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_process_details_message_id ON process_details(message_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_process_details_conversation_id ON process_details(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_tool_executions_tool_name ON tool_executions(tool_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_tool_executions_start_time ON tool_executions(start_time);
|
||||
CREATE INDEX IF NOT EXISTS idx_tool_executions_status ON tool_executions(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_chain_nodes_conversation ON attack_chain_nodes(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_chain_edges_conversation ON attack_chain_edges(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_chain_edges_source ON attack_chain_edges(source_node_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_chain_edges_target ON attack_chain_edges(target_node_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_conversation ON knowledge_retrieval_logs(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_message ON knowledge_retrieval_logs(message_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_conversation ON conversation_group_mappings(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_created_at ON batch_task_queues(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title);
|
||||
CREATE INDEX IF NOT EXISTS idx_webshell_connections_created_at ON webshell_connections(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_webshell_connection_states_updated_at ON webshell_connection_states(updated_at);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(createConversationsTable); err != nil {
|
||||
return fmt.Errorf("创建conversations表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createMessagesTable); err != nil {
|
||||
return fmt.Errorf("创建messages表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createProcessDetailsTable); err != nil {
|
||||
return fmt.Errorf("创建process_details表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createToolExecutionsTable); err != nil {
|
||||
return fmt.Errorf("创建tool_executions表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createToolStatsTable); err != nil {
|
||||
return fmt.Errorf("创建tool_stats表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createSkillStatsTable); err != nil {
|
||||
return fmt.Errorf("创建skill_stats表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createAttackChainNodesTable); err != nil {
|
||||
return fmt.Errorf("创建attack_chain_nodes表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createAttackChainEdgesTable); err != nil {
|
||||
return fmt.Errorf("创建attack_chain_edges表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createKnowledgeRetrievalLogsTable); err != nil {
|
||||
return fmt.Errorf("创建knowledge_retrieval_logs表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createConversationGroupsTable); err != nil {
|
||||
return fmt.Errorf("创建conversation_groups表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createConversationGroupMappingsTable); err != nil {
|
||||
return fmt.Errorf("创建conversation_group_mappings表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
|
||||
return fmt.Errorf("创建vulnerabilities表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createBatchTaskQueuesTable); err != nil {
|
||||
return fmt.Errorf("创建batch_task_queues表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createBatchTasksTable); err != nil {
|
||||
return fmt.Errorf("创建batch_tasks表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createWebshellConnectionsTable); err != nil {
|
||||
return fmt.Errorf("创建webshell_connections表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createWebshellConnectionStatesTable); err != nil {
|
||||
return fmt.Errorf("创建webshell_connection_states表失败: %w", err)
|
||||
}
|
||||
|
||||
// 为已有表添加新字段(如果不存在)- 必须在创建索引之前
|
||||
if err := db.migrateConversationsTable(); err != nil {
|
||||
db.logger.Warn("迁移conversations表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
if err := db.migrateConversationGroupsTable(); err != nil {
|
||||
db.logger.Warn("迁移conversation_groups表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
if err := db.migrateConversationGroupMappingsTable(); err != nil {
|
||||
db.logger.Warn("迁移conversation_group_mappings表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
if err := db.migrateBatchTaskQueuesTable(); err != nil {
|
||||
db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createIndexes); err != nil {
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
}
|
||||
|
||||
db.logger.Info("数据库表初始化完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateConversationsTable 迁移conversations表,添加新字段
|
||||
func (db *DB) migrateConversationsTable() error {
|
||||
// 检查last_react_input字段是否存在
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_input'").Scan(&count)
|
||||
if err != nil {
|
||||
// 如果查询失败,尝试添加字段
|
||||
if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); addErr != nil {
|
||||
// 如果字段已存在,忽略错误(SQLite错误信息可能不同)
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加last_react_input字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if count == 0 {
|
||||
// 字段不存在,添加它
|
||||
if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_input TEXT"); err != nil {
|
||||
db.logger.Warn("添加last_react_input字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查last_react_output字段是否存在
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='last_react_output'").Scan(&count)
|
||||
if err != nil {
|
||||
// 如果查询失败,尝试添加字段
|
||||
if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); addErr != nil {
|
||||
// 如果字段已存在,忽略错误
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加last_react_output字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if count == 0 {
|
||||
// 字段不存在,添加它
|
||||
if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN last_react_output TEXT"); err != nil {
|
||||
db.logger.Warn("添加last_react_output字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查pinned字段是否存在
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='pinned'").Scan(&count)
|
||||
if err != nil {
|
||||
// 如果查询失败,尝试添加字段
|
||||
if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN pinned INTEGER DEFAULT 0"); addErr != nil {
|
||||
// 如果字段已存在,忽略错误
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加pinned字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if count == 0 {
|
||||
// 字段不存在,添加它
|
||||
if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN pinned INTEGER DEFAULT 0"); err != nil {
|
||||
db.logger.Warn("添加pinned字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 webshell_connection_id 字段是否存在(WebShell AI 助手对话关联)
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversations') WHERE name='webshell_connection_id'").Scan(&count)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE conversations ADD COLUMN webshell_connection_id TEXT"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加webshell_connection_id字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if count == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE conversations ADD COLUMN webshell_connection_id TEXT"); err != nil {
|
||||
db.logger.Warn("添加webshell_connection_id字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateConversationGroupsTable 迁移conversation_groups表,添加新字段
|
||||
func (db *DB) migrateConversationGroupsTable() error {
|
||||
// 检查pinned字段是否存在
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversation_groups') WHERE name='pinned'").Scan(&count)
|
||||
if err != nil {
|
||||
// 如果查询失败,尝试添加字段
|
||||
if _, addErr := db.Exec("ALTER TABLE conversation_groups ADD COLUMN pinned INTEGER DEFAULT 0"); addErr != nil {
|
||||
// 如果字段已存在,忽略错误
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加pinned字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if count == 0 {
|
||||
// 字段不存在,添加它
|
||||
if _, err := db.Exec("ALTER TABLE conversation_groups ADD COLUMN pinned INTEGER DEFAULT 0"); err != nil {
|
||||
db.logger.Warn("添加pinned字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateConversationGroupMappingsTable 迁移conversation_group_mappings表,添加新字段
|
||||
func (db *DB) migrateConversationGroupMappingsTable() error {
|
||||
// 检查pinned字段是否存在
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('conversation_group_mappings') WHERE name='pinned'").Scan(&count)
|
||||
if err != nil {
|
||||
// 如果查询失败,尝试添加字段
|
||||
if _, addErr := db.Exec("ALTER TABLE conversation_group_mappings ADD COLUMN pinned INTEGER DEFAULT 0"); addErr != nil {
|
||||
// 如果字段已存在,忽略错误
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加pinned字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if count == 0 {
|
||||
// 字段不存在,添加它
|
||||
if _, err := db.Exec("ALTER TABLE conversation_group_mappings ADD COLUMN pinned INTEGER DEFAULT 0"); err != nil {
|
||||
db.logger.Warn("添加pinned字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateBatchTaskQueuesTable 迁移batch_task_queues表,补充新字段
|
||||
func (db *DB) migrateBatchTaskQueuesTable() error {
|
||||
// 检查title字段是否存在
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='title'").Scan(&count)
|
||||
if err != nil {
|
||||
// 如果查询失败,尝试添加字段
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); addErr != nil {
|
||||
// 如果字段已存在,忽略错误
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加title字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if count == 0 {
|
||||
// 字段不存在,添加它
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN title TEXT"); err != nil {
|
||||
db.logger.Warn("添加title字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查role字段是否存在
|
||||
var roleCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='role'").Scan(&roleCount)
|
||||
if err != nil {
|
||||
// 如果查询失败,尝试添加字段
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN role TEXT"); addErr != nil {
|
||||
// 如果字段已存在,忽略错误
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加role字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if roleCount == 0 {
|
||||
// 字段不存在,添加它
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN role TEXT"); err != nil {
|
||||
db.logger.Warn("添加role字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查agent_mode字段是否存在
|
||||
var agentModeCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='agent_mode'").Scan(&agentModeCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加agent_mode字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if agentModeCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN agent_mode TEXT NOT NULL DEFAULT 'single'"); err != nil {
|
||||
db.logger.Warn("添加agent_mode字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查schedule_mode字段是否存在
|
||||
var scheduleModeCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='schedule_mode'").Scan(&scheduleModeCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_mode TEXT NOT NULL DEFAULT 'manual'"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加schedule_mode字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if scheduleModeCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_mode TEXT NOT NULL DEFAULT 'manual'"); err != nil {
|
||||
db.logger.Warn("添加schedule_mode字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查cron_expr字段是否存在
|
||||
var cronExprCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='cron_expr'").Scan(&cronExprCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN cron_expr TEXT"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加cron_expr字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if cronExprCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN cron_expr TEXT"); err != nil {
|
||||
db.logger.Warn("添加cron_expr字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查next_run_at字段是否存在
|
||||
var nextRunAtCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='next_run_at'").Scan(&nextRunAtCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN next_run_at DATETIME"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加next_run_at字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if nextRunAtCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN next_run_at DATETIME"); err != nil {
|
||||
db.logger.Warn("添加next_run_at字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// schedule_enabled:0=暂停 Cron 自动调度,1=允许(手工执行不受影响)
|
||||
var scheduleEnCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='schedule_enabled'").Scan(&scheduleEnCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_enabled INTEGER NOT NULL DEFAULT 1"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加schedule_enabled字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if scheduleEnCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN schedule_enabled INTEGER NOT NULL DEFAULT 1"); err != nil {
|
||||
db.logger.Warn("添加schedule_enabled字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
var lastTrigCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_schedule_trigger_at'").Scan(&lastTrigCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_trigger_at DATETIME"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加last_schedule_trigger_at字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if lastTrigCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_trigger_at DATETIME"); err != nil {
|
||||
db.logger.Warn("添加last_schedule_trigger_at字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
var lastSchedErrCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_schedule_error'").Scan(&lastSchedErrCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_error TEXT"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加last_schedule_error字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if lastSchedErrCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_schedule_error TEXT"); err != nil {
|
||||
db.logger.Warn("添加last_schedule_error字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
var lastRunErrCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='last_run_error'").Scan(&lastRunErrCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_run_error TEXT"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加last_run_error字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if lastRunErrCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN last_run_error TEXT"); err != nil {
|
||||
db.logger.Warn("添加last_run_error字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
|
||||
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开知识库数据库失败: %w", err)
|
||||
}
|
||||
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("连接知识库数据库失败: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{
|
||||
DB: sqlDB,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// 初始化知识库表
|
||||
if err := database.initKnowledgeTables(); err != nil {
|
||||
return nil, fmt.Errorf("初始化知识库表失败: %w", err)
|
||||
}
|
||||
|
||||
return database, nil
|
||||
}
|
||||
|
||||
// initKnowledgeTables 初始化知识库数据库表(只包含知识库相关的表)
|
||||
func (db *DB) initKnowledgeTables() error {
|
||||
// 创建知识库项表
|
||||
createKnowledgeBaseItemsTable := `
|
||||
CREATE TABLE IF NOT EXISTS knowledge_base_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
category TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
content TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL
|
||||
);`
|
||||
|
||||
// 创建知识库向量表
|
||||
createKnowledgeEmbeddingsTable := `
|
||||
CREATE TABLE IF NOT EXISTS knowledge_embeddings (
|
||||
id TEXT PRIMARY KEY,
|
||||
item_id TEXT NOT NULL,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
chunk_text TEXT NOT NULL,
|
||||
embedding TEXT NOT NULL,
|
||||
sub_indexes TEXT NOT NULL DEFAULT '',
|
||||
embedding_model TEXT NOT NULL DEFAULT '',
|
||||
embedding_dim INTEGER NOT NULL DEFAULT 0,
|
||||
created_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (item_id) REFERENCES knowledge_base_items(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建知识检索日志表(在独立知识库数据库中,不使用外键约束,因为conversations和messages表可能不在这个数据库中)
|
||||
createKnowledgeRetrievalLogsTable := `
|
||||
CREATE TABLE IF NOT EXISTS knowledge_retrieval_logs (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT,
|
||||
message_id TEXT,
|
||||
query TEXT NOT NULL,
|
||||
risk_type TEXT,
|
||||
retrieved_items TEXT,
|
||||
created_at DATETIME NOT NULL
|
||||
);`
|
||||
|
||||
// 创建索引
|
||||
createIndexes := `
|
||||
CREATE INDEX IF NOT EXISTS idx_knowledge_items_category ON knowledge_base_items(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_knowledge_embeddings_item_id ON knowledge_embeddings(item_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_conversation ON knowledge_retrieval_logs(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_message ON knowledge_retrieval_logs(message_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(createKnowledgeBaseItemsTable); err != nil {
|
||||
return fmt.Errorf("创建knowledge_base_items表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createKnowledgeEmbeddingsTable); err != nil {
|
||||
return fmt.Errorf("创建knowledge_embeddings表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createKnowledgeRetrievalLogsTable); err != nil {
|
||||
return fmt.Errorf("创建knowledge_retrieval_logs表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createIndexes); err != nil {
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
}
|
||||
|
||||
if err := db.migrateKnowledgeEmbeddingsColumns(); err != nil {
|
||||
return fmt.Errorf("迁移 knowledge_embeddings 列失败: %w", err)
|
||||
}
|
||||
|
||||
db.logger.Info("知识库数据库表初始化完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateKnowledgeEmbeddingsColumns 为已有库补充 sub_indexes、embedding_model、embedding_dim。
|
||||
func (db *DB) migrateKnowledgeEmbeddingsColumns() error {
|
||||
var n int
|
||||
if err := db.QueryRow(`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='knowledge_embeddings'`).Scan(&n); err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
migrations := []struct {
|
||||
col string
|
||||
stmt string
|
||||
}{
|
||||
{"sub_indexes", `ALTER TABLE knowledge_embeddings ADD COLUMN sub_indexes TEXT NOT NULL DEFAULT ''`},
|
||||
{"embedding_model", `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_model TEXT NOT NULL DEFAULT ''`},
|
||||
{"embedding_dim", `ALTER TABLE knowledge_embeddings ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0`},
|
||||
}
|
||||
for _, m := range migrations {
|
||||
var colCount int
|
||||
q := `SELECT COUNT(*) FROM pragma_table_info('knowledge_embeddings') WHERE name = ?`
|
||||
if err := db.QueryRow(q, m.col).Scan(&colCount); err != nil {
|
||||
return err
|
||||
}
|
||||
if colCount > 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := db.Exec(m.stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (db *DB) Close() error {
|
||||
return db.DB.Close()
|
||||
}
|
||||
@@ -0,0 +1,449 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ConversationGroup 对话分组
|
||||
type ConversationGroup struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Icon string `json:"icon"`
|
||||
Pinned bool `json:"pinned"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// GroupExistsByName 检查分组名称是否已存在
|
||||
func (db *DB) GroupExistsByName(name string, excludeID string) (bool, error) {
|
||||
var count int
|
||||
var err error
|
||||
|
||||
if excludeID != "" {
|
||||
err = db.QueryRow(
|
||||
"SELECT COUNT(*) FROM conversation_groups WHERE name = ? AND id != ?",
|
||||
name, excludeID,
|
||||
).Scan(&count)
|
||||
} else {
|
||||
err = db.QueryRow(
|
||||
"SELECT COUNT(*) FROM conversation_groups WHERE name = ?",
|
||||
name,
|
||||
).Scan(&count)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("检查分组名称失败: %w", err)
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// CreateGroup 创建分组
|
||||
func (db *DB) CreateGroup(name, icon string) (*ConversationGroup, error) {
|
||||
// 检查名称是否已存在
|
||||
exists, err := db.GroupExistsByName(name, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, fmt.Errorf("分组名称已存在")
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
|
||||
if icon == "" {
|
||||
icon = "📁"
|
||||
}
|
||||
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversation_groups (id, name, icon, pinned, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
id, name, icon, 0, now, now,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建分组失败: %w", err)
|
||||
}
|
||||
|
||||
return &ConversationGroup{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Icon: icon,
|
||||
Pinned: false,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListGroups 列出所有分组
|
||||
func (db *DB) ListGroups() ([]*ConversationGroup, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, name, icon, COALESCE(pinned, 0), created_at, updated_at FROM conversation_groups ORDER BY COALESCE(pinned, 0) DESC, created_at ASC",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询分组列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var groups []*ConversationGroup
|
||||
for rows.Next() {
|
||||
var group ConversationGroup
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
|
||||
if err := rows.Scan(&group.ID, &group.Name, &group.Icon, &pinned, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描分组失败: %w", err)
|
||||
}
|
||||
|
||||
group.Pinned = pinned != 0
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err1 != nil {
|
||||
group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err1 != nil {
|
||||
group.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||
if err2 != nil {
|
||||
group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
}
|
||||
if err2 != nil {
|
||||
group.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
|
||||
groups = append(groups, &group)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// GetGroup 获取分组
|
||||
func (db *DB) GetGroup(id string) (*ConversationGroup, error) {
|
||||
var group ConversationGroup
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
|
||||
err := db.QueryRow(
|
||||
"SELECT id, name, icon, COALESCE(pinned, 0), created_at, updated_at FROM conversation_groups WHERE id = ?",
|
||||
id,
|
||||
).Scan(&group.ID, &group.Name, &group.Icon, &pinned, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("分组不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询分组失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err1 != nil {
|
||||
group.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err1 != nil {
|
||||
group.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||
if err2 != nil {
|
||||
group.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
}
|
||||
if err2 != nil {
|
||||
group.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
|
||||
group.Pinned = pinned != 0
|
||||
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
// UpdateGroup 更新分组
|
||||
func (db *DB) UpdateGroup(id, name, icon string) error {
|
||||
// 检查名称是否已存在(排除当前分组)
|
||||
exists, err := db.GroupExistsByName(name, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return fmt.Errorf("分组名称已存在")
|
||||
}
|
||||
|
||||
_, err = db.Exec(
|
||||
"UPDATE conversation_groups SET name = ?, icon = ?, updated_at = ? WHERE id = ?",
|
||||
name, icon, time.Now(), id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新分组失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroup 删除分组
|
||||
func (db *DB) DeleteGroup(id string) error {
|
||||
_, err := db.Exec("DELETE FROM conversation_groups WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除分组失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddConversationToGroup 将对话添加到分组
|
||||
// 注意:一个对话只能属于一个分组,所以在添加新分组之前,会先删除该对话的所有旧分组关联
|
||||
func (db *DB) AddConversationToGroup(conversationID, groupID string) error {
|
||||
// 先删除该对话的所有旧分组关联,确保一个对话只属于一个分组
|
||||
_, err := db.Exec(
|
||||
"DELETE FROM conversation_group_mappings WHERE conversation_id = ?",
|
||||
conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除对话旧分组关联失败: %w", err)
|
||||
}
|
||||
|
||||
// 然后插入新的分组关联
|
||||
id := uuid.New().String()
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversation_group_mappings (id, conversation_id, group_id, created_at) VALUES (?, ?, ?, ?)",
|
||||
id, conversationID, groupID, time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加对话到分组失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveConversationFromGroup 从分组中移除对话
|
||||
func (db *DB) RemoveConversationFromGroup(conversationID, groupID string) error {
|
||||
_, err := db.Exec(
|
||||
"DELETE FROM conversation_group_mappings WHERE conversation_id = ? AND group_id = ?",
|
||||
conversationID, groupID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("从分组中移除对话失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConversationsByGroup 获取分组中的所有对话
|
||||
func (db *DB) GetConversationsByGroup(groupID string) ([]*Conversation, error) {
|
||||
rows, err := db.Query(
|
||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned
|
||||
FROM conversations c
|
||||
INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id
|
||||
WHERE cgm.group_id = ?
|
||||
ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC`,
|
||||
groupID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询分组对话失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var conversations []*Conversation
|
||||
for rows.Next() {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
var groupPinned int
|
||||
|
||||
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil {
|
||||
return nil, fmt.Errorf("扫描对话失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
}
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
|
||||
conv.Pinned = pinned != 0
|
||||
|
||||
conversations = append(conversations, &conv)
|
||||
}
|
||||
|
||||
return conversations, nil
|
||||
}
|
||||
|
||||
// SearchConversationsByGroup 搜索分组中的对话(按标题和消息内容模糊匹配)
|
||||
func (db *DB) SearchConversationsByGroup(groupID string, searchQuery string) ([]*Conversation, error) {
|
||||
// 构建SQL查询,支持按标题和消息内容搜索
|
||||
// 使用 DISTINCT 避免因为一个对话有多条匹配消息而重复
|
||||
query := `SELECT DISTINCT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, COALESCE(cgm.pinned, 0) as group_pinned
|
||||
FROM conversations c
|
||||
INNER JOIN conversation_group_mappings cgm ON c.id = cgm.conversation_id
|
||||
WHERE cgm.group_id = ?`
|
||||
|
||||
args := []interface{}{groupID}
|
||||
|
||||
// 如果有搜索关键词,添加标题和消息内容搜索条件
|
||||
if searchQuery != "" {
|
||||
searchPattern := "%" + searchQuery + "%"
|
||||
// 搜索标题或消息内容
|
||||
// 使用 LEFT JOIN 连接消息表,这样即使没有消息的对话也能被搜索到(通过标题)
|
||||
query += ` AND (
|
||||
LOWER(c.title) LIKE LOWER(?)
|
||||
OR EXISTS (
|
||||
SELECT 1 FROM messages m
|
||||
WHERE m.conversation_id = c.id
|
||||
AND LOWER(m.content) LIKE LOWER(?)
|
||||
)
|
||||
)`
|
||||
args = append(args, searchPattern, searchPattern)
|
||||
}
|
||||
|
||||
query += " ORDER BY COALESCE(cgm.pinned, 0) DESC, c.updated_at DESC"
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("搜索分组对话失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var conversations []*Conversation
|
||||
for rows.Next() {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
var pinned int
|
||||
var groupPinned int
|
||||
|
||||
if err := rows.Scan(&conv.ID, &conv.Title, &pinned, &createdAt, &updatedAt, &groupPinned); err != nil {
|
||||
return nil, fmt.Errorf("扫描对话失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
}
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
|
||||
conv.Pinned = pinned != 0
|
||||
|
||||
conversations = append(conversations, &conv)
|
||||
}
|
||||
|
||||
return conversations, nil
|
||||
}
|
||||
|
||||
// GetGroupByConversation 获取对话所属的分组
|
||||
func (db *DB) GetGroupByConversation(conversationID string) (string, error) {
|
||||
var groupID string
|
||||
err := db.QueryRow(
|
||||
"SELECT group_id FROM conversation_group_mappings WHERE conversation_id = ? LIMIT 1",
|
||||
conversationID,
|
||||
).Scan(&groupID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil // 没有分组
|
||||
}
|
||||
return "", fmt.Errorf("查询对话分组失败: %w", err)
|
||||
}
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// UpdateConversationPinned 更新对话置顶状态
|
||||
func (db *DB) UpdateConversationPinned(id string, pinned bool) error {
|
||||
pinnedValue := 0
|
||||
if pinned {
|
||||
pinnedValue = 1
|
||||
}
|
||||
// 注意:不更新 updated_at,因为置顶操作不应该改变对话的更新时间
|
||||
_, err := db.Exec(
|
||||
"UPDATE conversations SET pinned = ? WHERE id = ?",
|
||||
pinnedValue, id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新对话置顶状态失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateGroupPinned 更新分组置顶状态
|
||||
func (db *DB) UpdateGroupPinned(id string, pinned bool) error {
|
||||
pinnedValue := 0
|
||||
if pinned {
|
||||
pinnedValue = 1
|
||||
}
|
||||
_, err := db.Exec(
|
||||
"UPDATE conversation_groups SET pinned = ?, updated_at = ? WHERE id = ?",
|
||||
pinnedValue, time.Now(), id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新分组置顶状态失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GroupMapping 分组映射关系
|
||||
type GroupMapping struct {
|
||||
ConversationID string `json:"conversationId"`
|
||||
GroupID string `json:"groupId"`
|
||||
}
|
||||
|
||||
// GetAllGroupMappings 批量获取所有分组映射(消除 N+1 查询)
|
||||
func (db *DB) GetAllGroupMappings() ([]GroupMapping, error) {
|
||||
rows, err := db.Query("SELECT conversation_id, group_id FROM conversation_group_mappings")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询分组映射失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var mappings []GroupMapping
|
||||
for rows.Next() {
|
||||
var m GroupMapping
|
||||
if err := rows.Scan(&m.ConversationID, &m.GroupID); err != nil {
|
||||
return nil, fmt.Errorf("扫描分组映射失败: %w", err)
|
||||
}
|
||||
mappings = append(mappings, m)
|
||||
}
|
||||
|
||||
if mappings == nil {
|
||||
mappings = []GroupMapping{}
|
||||
}
|
||||
return mappings, nil
|
||||
}
|
||||
|
||||
// UpdateConversationPinnedInGroup 更新对话在分组中的置顶状态
|
||||
func (db *DB) UpdateConversationPinnedInGroup(conversationID, groupID string, pinned bool) error {
|
||||
pinnedValue := 0
|
||||
if pinned {
|
||||
pinnedValue = 1
|
||||
}
|
||||
_, err := db.Exec(
|
||||
"UPDATE conversation_group_mappings SET pinned = ? WHERE conversation_id = ? AND group_id = ?",
|
||||
pinnedValue, conversationID, groupID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新分组对话置顶状态失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,537 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SaveToolExecution 保存工具执行记录
|
||||
func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
|
||||
argsJSON, err := json.Marshal(exec.Arguments)
|
||||
if err != nil {
|
||||
db.logger.Warn("序列化执行参数失败", zap.Error(err))
|
||||
argsJSON = []byte("{}")
|
||||
}
|
||||
|
||||
var resultJSON sql.NullString
|
||||
if exec.Result != nil {
|
||||
resultBytes, err := json.Marshal(exec.Result)
|
||||
if err != nil {
|
||||
db.logger.Warn("序列化执行结果失败", zap.Error(err))
|
||||
} else {
|
||||
resultJSON = sql.NullString{String: string(resultBytes), Valid: true}
|
||||
}
|
||||
}
|
||||
|
||||
var errorText sql.NullString
|
||||
if exec.Error != "" {
|
||||
errorText = sql.NullString{String: exec.Error, Valid: true}
|
||||
}
|
||||
|
||||
var endTime sql.NullTime
|
||||
if exec.EndTime != nil {
|
||||
endTime = sql.NullTime{Time: *exec.EndTime, Valid: true}
|
||||
}
|
||||
|
||||
var durationMs sql.NullInt64
|
||||
if exec.Duration > 0 {
|
||||
durationMs = sql.NullInt64{Int64: exec.Duration.Milliseconds(), Valid: true}
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT OR REPLACE INTO tool_executions
|
||||
(id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err = db.Exec(query,
|
||||
exec.ID,
|
||||
exec.ToolName,
|
||||
string(argsJSON),
|
||||
exec.Status,
|
||||
resultJSON,
|
||||
errorText,
|
||||
exec.StartTime,
|
||||
endTime,
|
||||
durationMs,
|
||||
time.Now(),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
db.logger.Error("保存工具执行记录失败", zap.Error(err), zap.String("executionId", exec.ID))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CountToolExecutions 统计工具执行记录总数
|
||||
func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
|
||||
query := `SELECT COUNT(*) FROM tool_executions`
|
||||
args := []interface{}{}
|
||||
conditions := []string{}
|
||||
if status != "" {
|
||||
conditions = append(conditions, "status = ?")
|
||||
args = append(args, status)
|
||||
}
|
||||
if toolName != "" {
|
||||
// 支持部分匹配(模糊搜索),不区分大小写
|
||||
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
|
||||
args = append(args, "%"+strings.ToLower(toolName)+"%")
|
||||
}
|
||||
if len(conditions) > 0 {
|
||||
query += ` WHERE ` + conditions[0]
|
||||
for i := 1; i < len(conditions); i++ {
|
||||
query += ` AND ` + conditions[i]
|
||||
}
|
||||
}
|
||||
var count int
|
||||
err := db.QueryRow(query, args...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// LoadToolExecutions 加载所有工具执行记录(支持分页)
|
||||
func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) {
|
||||
return db.LoadToolExecutionsWithPagination(0, 1000, "", "")
|
||||
}
|
||||
|
||||
// LoadToolExecutionsWithPagination 分页加载工具执行记录
|
||||
// limit: 最大返回记录数,0 表示使用默认值 1000
|
||||
// offset: 跳过的记录数,用于分页
|
||||
// status: 状态筛选,空字符串表示不过滤
|
||||
// toolName: 工具名称筛选,空字符串表示不过滤
|
||||
func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status, toolName string) ([]*mcp.ToolExecution, error) {
|
||||
if limit <= 0 {
|
||||
limit = 1000 // 默认限制
|
||||
}
|
||||
if limit > 10000 {
|
||||
limit = 10000 // 最大限制,防止一次性加载过多数据
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
|
||||
FROM tool_executions
|
||||
`
|
||||
args := []interface{}{}
|
||||
conditions := []string{}
|
||||
if status != "" {
|
||||
conditions = append(conditions, "status = ?")
|
||||
args = append(args, status)
|
||||
}
|
||||
if toolName != "" {
|
||||
// 支持部分匹配(模糊搜索),不区分大小写
|
||||
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
|
||||
args = append(args, "%"+strings.ToLower(toolName)+"%")
|
||||
}
|
||||
if len(conditions) > 0 {
|
||||
query += ` WHERE ` + conditions[0]
|
||||
for i := 1; i < len(conditions); i++ {
|
||||
query += ` AND ` + conditions[i]
|
||||
}
|
||||
}
|
||||
query += ` ORDER BY start_time DESC LIMIT ? OFFSET ?`
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var executions []*mcp.ToolExecution
|
||||
for rows.Next() {
|
||||
var exec mcp.ToolExecution
|
||||
var argsJSON string
|
||||
var resultJSON sql.NullString
|
||||
var errorText sql.NullString
|
||||
var endTime sql.NullTime
|
||||
var durationMs sql.NullInt64
|
||||
|
||||
err := rows.Scan(
|
||||
&exec.ID,
|
||||
&exec.ToolName,
|
||||
&argsJSON,
|
||||
&exec.Status,
|
||||
&resultJSON,
|
||||
&errorText,
|
||||
&exec.StartTime,
|
||||
&endTime,
|
||||
&durationMs,
|
||||
)
|
||||
if err != nil {
|
||||
db.logger.Warn("加载执行记录失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析参数
|
||||
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
|
||||
db.logger.Warn("解析执行参数失败", zap.Error(err))
|
||||
exec.Arguments = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// 解析结果
|
||||
if resultJSON.Valid && resultJSON.String != "" {
|
||||
var result mcp.ToolResult
|
||||
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
|
||||
db.logger.Warn("解析执行结果失败", zap.Error(err))
|
||||
} else {
|
||||
exec.Result = &result
|
||||
}
|
||||
}
|
||||
|
||||
// 设置错误
|
||||
if errorText.Valid {
|
||||
exec.Error = errorText.String
|
||||
}
|
||||
|
||||
// 设置结束时间
|
||||
if endTime.Valid {
|
||||
exec.EndTime = &endTime.Time
|
||||
}
|
||||
|
||||
// 设置持续时间
|
||||
if durationMs.Valid {
|
||||
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
||||
}
|
||||
|
||||
executions = append(executions, &exec)
|
||||
}
|
||||
|
||||
return executions, nil
|
||||
}
|
||||
|
||||
// GetToolExecution 根据ID获取单条工具执行记录
|
||||
func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
|
||||
query := `
|
||||
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
|
||||
FROM tool_executions
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
row := db.QueryRow(query, id)
|
||||
|
||||
var exec mcp.ToolExecution
|
||||
var argsJSON string
|
||||
var resultJSON sql.NullString
|
||||
var errorText sql.NullString
|
||||
var endTime sql.NullTime
|
||||
var durationMs sql.NullInt64
|
||||
|
||||
err := row.Scan(
|
||||
&exec.ID,
|
||||
&exec.ToolName,
|
||||
&argsJSON,
|
||||
&exec.Status,
|
||||
&resultJSON,
|
||||
&errorText,
|
||||
&exec.StartTime,
|
||||
&endTime,
|
||||
&durationMs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
|
||||
db.logger.Warn("解析执行参数失败", zap.Error(err))
|
||||
exec.Arguments = make(map[string]interface{})
|
||||
}
|
||||
|
||||
if resultJSON.Valid && resultJSON.String != "" {
|
||||
var result mcp.ToolResult
|
||||
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
|
||||
db.logger.Warn("解析执行结果失败", zap.Error(err))
|
||||
} else {
|
||||
exec.Result = &result
|
||||
}
|
||||
}
|
||||
|
||||
if errorText.Valid {
|
||||
exec.Error = errorText.String
|
||||
}
|
||||
|
||||
if endTime.Valid {
|
||||
exec.EndTime = &endTime.Time
|
||||
}
|
||||
|
||||
if durationMs.Valid {
|
||||
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
||||
}
|
||||
|
||||
return &exec, nil
|
||||
}
|
||||
|
||||
// DeleteToolExecution 删除工具执行记录
|
||||
func (db *DB) DeleteToolExecution(id string) error {
|
||||
query := `DELETE FROM tool_executions WHERE id = ?`
|
||||
_, err := db.Exec(query, id)
|
||||
if err != nil {
|
||||
db.logger.Error("删除工具执行记录失败", zap.Error(err), zap.String("executionId", id))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteToolExecutions 批量删除工具执行记录
|
||||
func (db *DB) DeleteToolExecutions(ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 构建 IN 查询的占位符
|
||||
placeholders := make([]string, len(ids))
|
||||
args := make([]interface{}, len(ids))
|
||||
for i, id := range ids {
|
||||
placeholders[i] = "?"
|
||||
args[i] = id
|
||||
}
|
||||
|
||||
query := `DELETE FROM tool_executions WHERE id IN (` + strings.Join(placeholders, ",") + `)`
|
||||
_, err := db.Exec(query, args...)
|
||||
if err != nil {
|
||||
db.logger.Error("批量删除工具执行记录失败", zap.Error(err), zap.Int("count", len(ids)))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetToolExecutionsByIds 根据ID列表获取工具执行记录(用于批量删除前获取统计信息)
|
||||
func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error) {
|
||||
if len(ids) == 0 {
|
||||
return []*mcp.ToolExecution{}, nil
|
||||
}
|
||||
|
||||
// 构建 IN 查询的占位符
|
||||
placeholders := make([]string, len(ids))
|
||||
args := make([]interface{}, len(ids))
|
||||
for i, id := range ids {
|
||||
placeholders[i] = "?"
|
||||
args[i] = id
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
|
||||
FROM tool_executions
|
||||
WHERE id IN (` + strings.Join(placeholders, ",") + `)
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var executions []*mcp.ToolExecution
|
||||
for rows.Next() {
|
||||
var exec mcp.ToolExecution
|
||||
var argsJSON string
|
||||
var resultJSON sql.NullString
|
||||
var errorText sql.NullString
|
||||
var endTime sql.NullTime
|
||||
var durationMs sql.NullInt64
|
||||
|
||||
err := rows.Scan(
|
||||
&exec.ID,
|
||||
&exec.ToolName,
|
||||
&argsJSON,
|
||||
&exec.Status,
|
||||
&resultJSON,
|
||||
&errorText,
|
||||
&exec.StartTime,
|
||||
&endTime,
|
||||
&durationMs,
|
||||
)
|
||||
if err != nil {
|
||||
db.logger.Warn("加载执行记录失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析参数
|
||||
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
|
||||
db.logger.Warn("解析执行参数失败", zap.Error(err))
|
||||
exec.Arguments = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// 解析结果
|
||||
if resultJSON.Valid && resultJSON.String != "" {
|
||||
var result mcp.ToolResult
|
||||
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
|
||||
db.logger.Warn("解析执行结果失败", zap.Error(err))
|
||||
} else {
|
||||
exec.Result = &result
|
||||
}
|
||||
}
|
||||
|
||||
// 设置错误
|
||||
if errorText.Valid {
|
||||
exec.Error = errorText.String
|
||||
}
|
||||
|
||||
// 设置结束时间
|
||||
if endTime.Valid {
|
||||
exec.EndTime = &endTime.Time
|
||||
}
|
||||
|
||||
// 设置持续时间
|
||||
if durationMs.Valid {
|
||||
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
||||
}
|
||||
|
||||
executions = append(executions, &exec)
|
||||
}
|
||||
|
||||
return executions, nil
|
||||
}
|
||||
|
||||
// SaveToolStats 保存工具统计信息
|
||||
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
|
||||
var lastCallTime sql.NullTime
|
||||
if stats.LastCallTime != nil {
|
||||
lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true}
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT OR REPLACE INTO tool_stats
|
||||
(tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err := db.Exec(query,
|
||||
toolName,
|
||||
stats.TotalCalls,
|
||||
stats.SuccessCalls,
|
||||
stats.FailedCalls,
|
||||
lastCallTime,
|
||||
time.Now(),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
db.logger.Error("保存工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadToolStats 加载所有工具统计信息
|
||||
func (db *DB) LoadToolStats() (map[string]*mcp.ToolStats, error) {
|
||||
query := `
|
||||
SELECT tool_name, total_calls, success_calls, failed_calls, last_call_time
|
||||
FROM tool_stats
|
||||
`
|
||||
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
stats := make(map[string]*mcp.ToolStats)
|
||||
for rows.Next() {
|
||||
var stat mcp.ToolStats
|
||||
var lastCallTime sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&stat.ToolName,
|
||||
&stat.TotalCalls,
|
||||
&stat.SuccessCalls,
|
||||
&stat.FailedCalls,
|
||||
&lastCallTime,
|
||||
)
|
||||
if err != nil {
|
||||
db.logger.Warn("加载统计信息失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
if lastCallTime.Valid {
|
||||
stat.LastCallTime = &lastCallTime.Time
|
||||
}
|
||||
|
||||
stats[stat.ToolName] = &stat
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// UpdateToolStats 更新工具统计信息(累加模式)
|
||||
func (db *DB) UpdateToolStats(toolName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
|
||||
var lastCallTimeSQL sql.NullTime
|
||||
if lastCallTime != nil {
|
||||
lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true}
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO tool_stats (tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(tool_name) DO UPDATE SET
|
||||
total_calls = total_calls + ?,
|
||||
success_calls = success_calls + ?,
|
||||
failed_calls = failed_calls + ?,
|
||||
last_call_time = COALESCE(?, last_call_time),
|
||||
updated_at = ?
|
||||
`
|
||||
|
||||
_, err := db.Exec(query,
|
||||
toolName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
|
||||
totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
db.logger.Error("更新工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecreaseToolStats 减少工具统计信息(用于删除执行记录时)
|
||||
// 如果统计信息变为0,则删除该统计记录
|
||||
func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error {
|
||||
// 先更新统计信息
|
||||
query := `
|
||||
UPDATE tool_stats SET
|
||||
total_calls = CASE WHEN total_calls - ? < 0 THEN 0 ELSE total_calls - ? END,
|
||||
success_calls = CASE WHEN success_calls - ? < 0 THEN 0 ELSE success_calls - ? END,
|
||||
failed_calls = CASE WHEN failed_calls - ? < 0 THEN 0 ELSE failed_calls - ? END,
|
||||
updated_at = ?
|
||||
WHERE tool_name = ?
|
||||
`
|
||||
|
||||
_, err := db.Exec(query, totalCalls, totalCalls, successCalls, successCalls, failedCalls, failedCalls, time.Now(), toolName)
|
||||
if err != nil {
|
||||
db.logger.Error("减少工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查更新后的 total_calls 是否为 0,如果是则删除该统计记录
|
||||
checkQuery := `SELECT total_calls FROM tool_stats WHERE tool_name = ?`
|
||||
var newTotalCalls int
|
||||
err = db.QueryRow(checkQuery, toolName).Scan(&newTotalCalls)
|
||||
if err != nil {
|
||||
// 如果查询失败(记录不存在),直接返回
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果 total_calls 为 0,删除该统计记录
|
||||
if newTotalCalls == 0 {
|
||||
deleteQuery := `DELETE FROM tool_stats WHERE tool_name = ?`
|
||||
_, err = db.Exec(deleteQuery, toolName)
|
||||
if err != nil {
|
||||
db.logger.Warn("删除零统计记录失败", zap.Error(err), zap.String("toolName", toolName))
|
||||
// 不返回错误,因为主要操作(更新统计)已成功
|
||||
} else {
|
||||
db.logger.Info("已删除零统计记录", zap.String("toolName", toolName))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SkillStats Skills统计信息
|
||||
type SkillStats struct {
|
||||
SkillName string
|
||||
TotalCalls int
|
||||
SuccessCalls int
|
||||
FailedCalls int
|
||||
LastCallTime *time.Time
|
||||
}
|
||||
|
||||
// SaveSkillStats 保存Skills统计信息
|
||||
func (db *DB) SaveSkillStats(skillName string, stats *SkillStats) error {
|
||||
var lastCallTime sql.NullTime
|
||||
if stats.LastCallTime != nil {
|
||||
lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true}
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT OR REPLACE INTO skill_stats
|
||||
(skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err := db.Exec(query,
|
||||
skillName,
|
||||
stats.TotalCalls,
|
||||
stats.SuccessCalls,
|
||||
stats.FailedCalls,
|
||||
lastCallTime,
|
||||
time.Now(),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
db.logger.Error("保存Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadSkillStats 加载所有Skills统计信息
|
||||
func (db *DB) LoadSkillStats() (map[string]*SkillStats, error) {
|
||||
query := `
|
||||
SELECT skill_name, total_calls, success_calls, failed_calls, last_call_time
|
||||
FROM skill_stats
|
||||
`
|
||||
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
stats := make(map[string]*SkillStats)
|
||||
for rows.Next() {
|
||||
var stat SkillStats
|
||||
var lastCallTime sql.NullTime
|
||||
|
||||
err := rows.Scan(
|
||||
&stat.SkillName,
|
||||
&stat.TotalCalls,
|
||||
&stat.SuccessCalls,
|
||||
&stat.FailedCalls,
|
||||
&lastCallTime,
|
||||
)
|
||||
if err != nil {
|
||||
db.logger.Warn("加载Skills统计信息失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
if lastCallTime.Valid {
|
||||
stat.LastCallTime = &lastCallTime.Time
|
||||
}
|
||||
|
||||
stats[stat.SkillName] = &stat
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// UpdateSkillStats 更新Skills统计信息(累加模式)
|
||||
func (db *DB) UpdateSkillStats(skillName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
|
||||
var lastCallTimeSQL sql.NullTime
|
||||
if lastCallTime != nil {
|
||||
lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true}
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO skill_stats (skill_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(skill_name) DO UPDATE SET
|
||||
total_calls = total_calls + ?,
|
||||
success_calls = success_calls + ?,
|
||||
failed_calls = failed_calls + ?,
|
||||
last_call_time = COALESCE(?, last_call_time),
|
||||
updated_at = ?
|
||||
`
|
||||
|
||||
_, err := db.Exec(query,
|
||||
skillName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
|
||||
totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
db.logger.Error("更新Skills统计信息失败", zap.Error(err), zap.String("skillName", skillName))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearSkillStats 清空所有Skills统计信息
|
||||
func (db *DB) ClearSkillStats() error {
|
||||
query := `DELETE FROM skill_stats`
|
||||
_, err := db.Exec(query)
|
||||
if err != nil {
|
||||
db.logger.Error("清空Skills统计信息失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
db.logger.Info("已清空所有Skills统计信息")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearSkillStatsByName 清空指定skill的统计信息
|
||||
func (db *DB) ClearSkillStatsByName(skillName string) error {
|
||||
query := `DELETE FROM skill_stats WHERE skill_name = ?`
|
||||
_, err := db.Exec(query, skillName)
|
||||
if err != nil {
|
||||
db.logger.Error("清空指定skill统计信息失败", zap.Error(err), zap.String("skillName", skillName))
|
||||
return err
|
||||
}
|
||||
db.logger.Info("已清空指定skill统计信息", zap.String("skillName", skillName))
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,281 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Vulnerability 漏洞
|
||||
type Vulnerability struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"` // critical, high, medium, low, info
|
||||
Status string `json:"status"` // open, confirmed, fixed, false_positive
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// CreateVulnerability 创建漏洞
|
||||
func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
|
||||
if vuln.ID == "" {
|
||||
vuln.ID = uuid.New().String()
|
||||
}
|
||||
if vuln.Status == "" {
|
||||
vuln.Status = "open"
|
||||
}
|
||||
now := time.Now()
|
||||
if vuln.CreatedAt.IsZero() {
|
||||
vuln.CreatedAt = now
|
||||
}
|
||||
vuln.UpdatedAt = now
|
||||
|
||||
query := `
|
||||
INSERT INTO vulnerabilities (
|
||||
id, conversation_id, title, description, severity, status,
|
||||
vulnerability_type, target, proof, impact, recommendation,
|
||||
created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err := db.Exec(
|
||||
query,
|
||||
vuln.ID, vuln.ConversationID, vuln.Title, vuln.Description,
|
||||
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
|
||||
vuln.Proof, vuln.Impact, vuln.Recommendation,
|
||||
vuln.CreatedAt, vuln.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建漏洞失败: %w", err)
|
||||
}
|
||||
|
||||
return vuln, nil
|
||||
}
|
||||
|
||||
// GetVulnerability 获取漏洞
|
||||
func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||
var vuln Vulnerability
|
||||
query := `
|
||||
SELECT id, conversation_id, title, description, severity, status,
|
||||
vulnerability_type, target, proof, impact, recommendation,
|
||||
created_at, updated_at
|
||||
FROM vulnerabilities
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
err := db.QueryRow(query, id).Scan(
|
||||
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
|
||||
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target,
|
||||
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
||||
&vuln.CreatedAt, &vuln.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("漏洞不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("获取漏洞失败: %w", err)
|
||||
}
|
||||
|
||||
return &vuln, nil
|
||||
}
|
||||
|
||||
// ListVulnerabilities 列出漏洞
|
||||
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status string) ([]*Vulnerability, error) {
|
||||
query := `
|
||||
SELECT id, conversation_id, title, description, severity, status,
|
||||
vulnerability_type, target, proof, impact, recommendation,
|
||||
created_at, updated_at
|
||||
FROM vulnerabilities
|
||||
WHERE 1=1
|
||||
`
|
||||
args := []interface{}{}
|
||||
|
||||
if id != "" {
|
||||
query += " AND id = ?"
|
||||
args = append(args, id)
|
||||
}
|
||||
if conversationID != "" {
|
||||
query += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if severity != "" {
|
||||
query += " AND severity = ?"
|
||||
args = append(args, severity)
|
||||
}
|
||||
if status != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
|
||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询漏洞列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var vulnerabilities []*Vulnerability
|
||||
for rows.Next() {
|
||||
var vuln Vulnerability
|
||||
err := rows.Scan(
|
||||
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
|
||||
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target,
|
||||
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
|
||||
&vuln.CreatedAt, &vuln.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
db.logger.Warn("扫描漏洞记录失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
vulnerabilities = append(vulnerabilities, &vuln)
|
||||
}
|
||||
|
||||
return vulnerabilities, nil
|
||||
}
|
||||
|
||||
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
|
||||
func (db *DB) CountVulnerabilities(id, conversationID, severity, status string) (int, error) {
|
||||
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
if id != "" {
|
||||
query += " AND id = ?"
|
||||
args = append(args, id)
|
||||
}
|
||||
if conversationID != "" {
|
||||
query += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if severity != "" {
|
||||
query += " AND severity = ?"
|
||||
args = append(args, severity)
|
||||
}
|
||||
if status != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
|
||||
var count int
|
||||
err := db.QueryRow(query, args...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("统计漏洞总数失败: %w", err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// UpdateVulnerability 更新漏洞
|
||||
func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
|
||||
vuln.UpdatedAt = time.Now()
|
||||
|
||||
query := `
|
||||
UPDATE vulnerabilities
|
||||
SET title = ?, description = ?, severity = ?, status = ?,
|
||||
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
|
||||
recommendation = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
_, err := db.Exec(
|
||||
query,
|
||||
vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
|
||||
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
|
||||
vuln.Recommendation, vuln.UpdatedAt, id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新漏洞失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteVulnerability 删除漏洞
|
||||
func (db *DB) DeleteVulnerability(id string) error {
|
||||
_, err := db.Exec("DELETE FROM vulnerabilities WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除漏洞失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetVulnerabilityStats 获取漏洞统计
|
||||
func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface{}, error) {
|
||||
stats := make(map[string]interface{})
|
||||
|
||||
// 总漏洞数
|
||||
var totalCount int
|
||||
query := "SELECT COUNT(*) FROM vulnerabilities"
|
||||
args := []interface{}{}
|
||||
if conversationID != "" {
|
||||
query += " WHERE conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
err := db.QueryRow(query, args...).Scan(&totalCount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取总漏洞数失败: %w", err)
|
||||
}
|
||||
stats["total"] = totalCount
|
||||
|
||||
// 按严重程度统计
|
||||
severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities"
|
||||
if conversationID != "" {
|
||||
severityQuery += " WHERE conversation_id = ?"
|
||||
}
|
||||
severityQuery += " GROUP BY severity"
|
||||
|
||||
rows, err := db.Query(severityQuery, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取严重程度统计失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
severityStats := make(map[string]int)
|
||||
for rows.Next() {
|
||||
var severity string
|
||||
var count int
|
||||
if err := rows.Scan(&severity, &count); err != nil {
|
||||
continue
|
||||
}
|
||||
severityStats[severity] = count
|
||||
}
|
||||
stats["by_severity"] = severityStats
|
||||
|
||||
// 按状态统计
|
||||
statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities"
|
||||
if conversationID != "" {
|
||||
statusQuery += " WHERE conversation_id = ?"
|
||||
}
|
||||
statusQuery += " GROUP BY status"
|
||||
|
||||
rows, err = db.Query(statusQuery, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取状态统计失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
statusStats := make(map[string]int)
|
||||
for rows.Next() {
|
||||
var status string
|
||||
var count int
|
||||
if err := rows.Scan(&status, &count); err != nil {
|
||||
continue
|
||||
}
|
||||
statusStats[status] = count
|
||||
}
|
||||
stats["by_status"] = statusStats
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// WebShellConnection WebShell 连接配置
|
||||
type WebShellConnection struct {
|
||||
ID string `json:"id"`
|
||||
URL string `json:"url"`
|
||||
Password string `json:"password"`
|
||||
Type string `json:"type"`
|
||||
Method string `json:"method"`
|
||||
CmdParam string `json:"cmdParam"`
|
||||
Remark string `json:"remark"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// GetWebshellConnectionState 获取连接关联的持久化状态 JSON,不存在时返回 "{}"
|
||||
func (db *DB) GetWebshellConnectionState(connectionID string) (string, error) {
|
||||
var stateJSON string
|
||||
err := db.QueryRow(`SELECT state_json FROM webshell_connection_states WHERE connection_id = ?`, connectionID).Scan(&stateJSON)
|
||||
if err == sql.ErrNoRows {
|
||||
return "{}", nil
|
||||
}
|
||||
if err != nil {
|
||||
db.logger.Error("查询 WebShell 连接状态失败", zap.Error(err), zap.String("connectionID", connectionID))
|
||||
return "", err
|
||||
}
|
||||
if stateJSON == "" {
|
||||
stateJSON = "{}"
|
||||
}
|
||||
return stateJSON, nil
|
||||
}
|
||||
|
||||
// UpsertWebshellConnectionState 保存连接关联的持久化状态 JSON
|
||||
func (db *DB) UpsertWebshellConnectionState(connectionID, stateJSON string) error {
|
||||
if stateJSON == "" {
|
||||
stateJSON = "{}"
|
||||
}
|
||||
query := `
|
||||
INSERT INTO webshell_connection_states (connection_id, state_json, updated_at)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(connection_id) DO UPDATE SET
|
||||
state_json = excluded.state_json,
|
||||
updated_at = excluded.updated_at
|
||||
`
|
||||
if _, err := db.Exec(query, connectionID, stateJSON, time.Now()); err != nil {
|
||||
db.logger.Error("保存 WebShell 连接状态失败", zap.Error(err), zap.String("connectionID", connectionID))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListWebshellConnections 列出所有 WebShell 连接,按创建时间倒序
|
||||
func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) {
|
||||
query := `
|
||||
SELECT id, url, password, type, method, cmd_param, remark, created_at
|
||||
FROM webshell_connections
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
db.logger.Error("查询 WebShell 连接列表失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var list []WebShellConnection
|
||||
for rows.Next() {
|
||||
var c WebShellConnection
|
||||
err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt)
|
||||
if err != nil {
|
||||
db.logger.Warn("扫描 WebShell 连接行失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
list = append(list, c)
|
||||
}
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
// GetWebshellConnection 根据 ID 获取一条连接
|
||||
func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) {
|
||||
query := `
|
||||
SELECT id, url, password, type, method, cmd_param, remark, created_at
|
||||
FROM webshell_connections WHERE id = ?
|
||||
`
|
||||
var c WebShellConnection
|
||||
err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
db.logger.Error("查询 WebShell 连接失败", zap.Error(err), zap.String("id", id))
|
||||
return nil, err
|
||||
}
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// CreateWebshellConnection 创建 WebShell 连接
|
||||
func (db *DB) CreateWebshellConnection(c *WebShellConnection) error {
|
||||
query := `
|
||||
INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
_, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.CreatedAt)
|
||||
if err != nil {
|
||||
db.logger.Error("创建 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateWebshellConnection 更新 WebShell 连接
|
||||
func (db *DB) UpdateWebshellConnection(c *WebShellConnection) error {
|
||||
query := `
|
||||
UPDATE webshell_connections
|
||||
SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.ID)
|
||||
if err != nil {
|
||||
db.logger.Error("更新 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID))
|
||||
return err
|
||||
}
|
||||
affected, _ := result.RowsAffected()
|
||||
if affected == 0 {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteWebshellConnection 删除 WebShell 连接
|
||||
func (db *DB) DeleteWebshellConnection(id string) error {
|
||||
result, err := db.Exec(`DELETE FROM webshell_connections WHERE id = ?`, id)
|
||||
if err != nil {
|
||||
db.logger.Error("删除 WebShell 连接失败", zap.Error(err), zap.String("id", id))
|
||||
return err
|
||||
}
|
||||
affected, _ := result.RowsAffected()
|
||||
if affected == 0 {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
*zap.Logger
|
||||
}
|
||||
|
||||
func New(level, output string) *Logger {
|
||||
var zapLevel zapcore.Level
|
||||
switch level {
|
||||
case "debug":
|
||||
zapLevel = zapcore.DebugLevel
|
||||
case "info":
|
||||
zapLevel = zapcore.InfoLevel
|
||||
case "warn":
|
||||
zapLevel = zapcore.WarnLevel
|
||||
case "error":
|
||||
zapLevel = zapcore.ErrorLevel
|
||||
default:
|
||||
zapLevel = zapcore.InfoLevel
|
||||
}
|
||||
|
||||
config := zap.NewProductionConfig()
|
||||
config.Level = zap.NewAtomicLevelAt(zapLevel)
|
||||
config.EncoderConfig.TimeKey = "timestamp"
|
||||
config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
|
||||
var writeSyncer zapcore.WriteSyncer
|
||||
if output == "stdout" {
|
||||
writeSyncer = zapcore.AddSync(os.Stdout)
|
||||
} else {
|
||||
file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
writeSyncer = zapcore.AddSync(os.Stdout)
|
||||
} else {
|
||||
writeSyncer = zapcore.AddSync(file)
|
||||
}
|
||||
}
|
||||
|
||||
core := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(config.EncoderConfig),
|
||||
writeSyncer,
|
||||
zapLevel,
|
||||
)
|
||||
|
||||
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
|
||||
|
||||
return &Logger{Logger: logger}
|
||||
}
|
||||
|
||||
func (l *Logger) Fatal(msg string, fields ...interface{}) {
|
||||
zapFields := make([]zap.Field, 0, len(fields))
|
||||
for _, f := range fields {
|
||||
switch v := f.(type) {
|
||||
case error:
|
||||
zapFields = append(zapFields, zap.Error(v))
|
||||
default:
|
||||
zapFields = append(zapFields, zap.Any("field", v))
|
||||
}
|
||||
}
|
||||
l.Logger.Fatal(msg, zapFields...)
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package builtin
|
||||
|
||||
// 内置工具名称常量
|
||||
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
|
||||
const (
|
||||
// 漏洞管理工具
|
||||
ToolRecordVulnerability = "record_vulnerability"
|
||||
|
||||
// 知识库工具
|
||||
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
|
||||
ToolSearchKnowledgeBase = "search_knowledge_base"
|
||||
|
||||
// WebShell 助手工具(AI 在 WebShell 管理 - AI 助手 中使用)
|
||||
ToolWebshellExec = "webshell_exec"
|
||||
ToolWebshellFileList = "webshell_file_list"
|
||||
ToolWebshellFileRead = "webshell_file_read"
|
||||
ToolWebshellFileWrite = "webshell_file_write"
|
||||
|
||||
// WebShell 连接管理工具(用于通过 MCP 管理 webshell 连接)
|
||||
ToolManageWebshellList = "manage_webshell_list"
|
||||
ToolManageWebshellAdd = "manage_webshell_add"
|
||||
ToolManageWebshellUpdate = "manage_webshell_update"
|
||||
ToolManageWebshellDelete = "manage_webshell_delete"
|
||||
ToolManageWebshellTest = "manage_webshell_test"
|
||||
|
||||
// 批量任务队列(与 Web 端批量任务一致,供模型创建/启停/查询队列)
|
||||
ToolBatchTaskList = "batch_task_list"
|
||||
ToolBatchTaskGet = "batch_task_get"
|
||||
ToolBatchTaskCreate = "batch_task_create"
|
||||
ToolBatchTaskStart = "batch_task_start"
|
||||
ToolBatchTaskRerun = "batch_task_rerun"
|
||||
ToolBatchTaskPause = "batch_task_pause"
|
||||
ToolBatchTaskDelete = "batch_task_delete"
|
||||
ToolBatchTaskUpdateMetadata = "batch_task_update_metadata"
|
||||
ToolBatchTaskUpdateSchedule = "batch_task_update_schedule"
|
||||
ToolBatchTaskScheduleEnabled = "batch_task_schedule_enabled"
|
||||
ToolBatchTaskAdd = "batch_task_add_task"
|
||||
ToolBatchTaskUpdate = "batch_task_update_task"
|
||||
ToolBatchTaskRemove = "batch_task_remove_task"
|
||||
)
|
||||
|
||||
// IsBuiltinTool 检查工具名称是否是内置工具
|
||||
func IsBuiltinTool(toolName string) bool {
|
||||
switch toolName {
|
||||
case ToolRecordVulnerability,
|
||||
ToolListKnowledgeRiskTypes,
|
||||
ToolSearchKnowledgeBase,
|
||||
ToolWebshellExec,
|
||||
ToolWebshellFileList,
|
||||
ToolWebshellFileRead,
|
||||
ToolWebshellFileWrite,
|
||||
ToolManageWebshellList,
|
||||
ToolManageWebshellAdd,
|
||||
ToolManageWebshellUpdate,
|
||||
ToolManageWebshellDelete,
|
||||
ToolManageWebshellTest,
|
||||
ToolBatchTaskList,
|
||||
ToolBatchTaskGet,
|
||||
ToolBatchTaskCreate,
|
||||
ToolBatchTaskStart,
|
||||
ToolBatchTaskRerun,
|
||||
ToolBatchTaskPause,
|
||||
ToolBatchTaskDelete,
|
||||
ToolBatchTaskUpdateMetadata,
|
||||
ToolBatchTaskUpdateSchedule,
|
||||
ToolBatchTaskScheduleEnabled,
|
||||
ToolBatchTaskAdd,
|
||||
ToolBatchTaskUpdate,
|
||||
ToolBatchTaskRemove:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllBuiltinTools 返回所有内置工具名称列表
|
||||
func GetAllBuiltinTools() []string {
|
||||
return []string{
|
||||
ToolRecordVulnerability,
|
||||
ToolListKnowledgeRiskTypes,
|
||||
ToolSearchKnowledgeBase,
|
||||
ToolWebshellExec,
|
||||
ToolWebshellFileList,
|
||||
ToolWebshellFileRead,
|
||||
ToolWebshellFileWrite,
|
||||
ToolManageWebshellList,
|
||||
ToolManageWebshellAdd,
|
||||
ToolManageWebshellUpdate,
|
||||
ToolManageWebshellDelete,
|
||||
ToolManageWebshellTest,
|
||||
ToolBatchTaskList,
|
||||
ToolBatchTaskGet,
|
||||
ToolBatchTaskCreate,
|
||||
ToolBatchTaskStart,
|
||||
ToolBatchTaskRerun,
|
||||
ToolBatchTaskPause,
|
||||
ToolBatchTaskDelete,
|
||||
ToolBatchTaskUpdateMetadata,
|
||||
ToolBatchTaskUpdateSchedule,
|
||||
ToolBatchTaskScheduleEnabled,
|
||||
ToolBatchTaskAdd,
|
||||
ToolBatchTaskUpdate,
|
||||
ToolBatchTaskRemove,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,551 @@
|
||||
// Package mcp 外部 MCP 客户端 - 基于官方 go-sdk 实现,保证协议兼容性
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
clientName = "CyberStrikeAI"
|
||||
clientVersion = "1.0.0"
|
||||
)
|
||||
|
||||
// sdkClient 基于官方 MCP Go SDK 的外部 MCP 客户端,实现 ExternalMCPClient 接口
|
||||
type sdkClient struct {
|
||||
session *mcp.ClientSession
|
||||
client *mcp.Client
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
status string // "disconnected", "connecting", "connected", "error"
|
||||
}
|
||||
|
||||
// newSDKClientFromSession 用已连接成功的 session 构造(供 createSDKClient 内部使用)
|
||||
func newSDKClientFromSession(session *mcp.ClientSession, client *mcp.Client, logger *zap.Logger) *sdkClient {
|
||||
return &sdkClient{
|
||||
session: session,
|
||||
client: client,
|
||||
logger: logger,
|
||||
status: "connected",
|
||||
}
|
||||
}
|
||||
|
||||
// lazySDKClient 延迟连接:Initialize() 时才调用官方 SDK 建立连接,对外实现 ExternalMCPClient
|
||||
type lazySDKClient struct {
|
||||
serverCfg config.ExternalMCPServerConfig
|
||||
logger *zap.Logger
|
||||
inner ExternalMCPClient // 连接成功后为 *sdkClient
|
||||
mu sync.RWMutex
|
||||
status string
|
||||
}
|
||||
|
||||
func newLazySDKClient(serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) *lazySDKClient {
|
||||
return &lazySDKClient{
|
||||
serverCfg: serverCfg,
|
||||
logger: logger,
|
||||
status: "connecting",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) setStatus(s string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.status = s
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) GetStatus() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.inner != nil {
|
||||
return c.inner.GetStatus()
|
||||
}
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) IsConnected() bool {
|
||||
c.mu.RLock()
|
||||
inner := c.inner
|
||||
c.mu.RUnlock()
|
||||
if inner != nil {
|
||||
return inner.IsConnected()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) Initialize(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
if c.inner != nil {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
inner, err := createSDKClient(ctx, c.serverCfg, c.logger)
|
||||
if err != nil {
|
||||
c.setStatus("error")
|
||||
return err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.inner = inner
|
||||
c.mu.Unlock()
|
||||
c.setStatus("connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) ListTools(ctx context.Context) ([]Tool, error) {
|
||||
c.mu.RLock()
|
||||
inner := c.inner
|
||||
c.mu.RUnlock()
|
||||
if inner == nil {
|
||||
return nil, fmt.Errorf("未连接")
|
||||
}
|
||||
return inner.ListTools(ctx)
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
|
||||
c.mu.RLock()
|
||||
inner := c.inner
|
||||
c.mu.RUnlock()
|
||||
if inner == nil {
|
||||
return nil, fmt.Errorf("未连接")
|
||||
}
|
||||
return inner.CallTool(ctx, name, args)
|
||||
}
|
||||
|
||||
func (c *lazySDKClient) Close() error {
|
||||
c.mu.Lock()
|
||||
inner := c.inner
|
||||
c.inner = nil
|
||||
c.mu.Unlock()
|
||||
c.setStatus("disconnected")
|
||||
if inner != nil {
|
||||
return inner.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sdkClient) setStatus(s string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.status = s
|
||||
}
|
||||
|
||||
func (c *sdkClient) GetStatus() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *sdkClient) IsConnected() bool {
|
||||
return c.GetStatus() == "connected"
|
||||
}
|
||||
|
||||
func (c *sdkClient) Initialize(ctx context.Context) error {
|
||||
// sdkClient 由 createSDKClient 在 Connect 成功后才创建,因此 Initialize 时已经连接
|
||||
// 此方法仅用于满足 ExternalMCPClient 接口,实际连接在 createSDKClient 中完成
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sdkClient) ListTools(ctx context.Context) ([]Tool, error) {
|
||||
if c.session == nil {
|
||||
return nil, fmt.Errorf("未连接")
|
||||
}
|
||||
res, err := c.session.ListTools(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return sdkToolsToOur(res.Tools), nil
|
||||
}
|
||||
|
||||
func (c *sdkClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
|
||||
if c.session == nil {
|
||||
return nil, fmt.Errorf("未连接")
|
||||
}
|
||||
params := &mcp.CallToolParams{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
}
|
||||
res, err := c.session.CallTool(ctx, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sdkCallToolResultToOurs(res), nil
|
||||
}
|
||||
|
||||
func (c *sdkClient) Close() error {
|
||||
c.setStatus("disconnected")
|
||||
if c.session != nil {
|
||||
err := c.session.Close()
|
||||
c.session = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sdkToolsToOur 将 SDK 的 []*mcp.Tool 转为我们的 []Tool
|
||||
func sdkToolsToOur(tools []*mcp.Tool) []Tool {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]Tool, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
schema := make(map[string]interface{})
|
||||
if t.InputSchema != nil {
|
||||
// SDK InputSchema 可能为 *jsonschema.Schema 或 map,统一转为 map
|
||||
if m, ok := t.InputSchema.(map[string]interface{}); ok {
|
||||
schema = m
|
||||
} else {
|
||||
_ = json.Unmarshal(mustJSON(t.InputSchema), &schema)
|
||||
}
|
||||
}
|
||||
desc := t.Description
|
||||
shortDesc := desc
|
||||
if t.Annotations != nil && t.Annotations.Title != "" {
|
||||
shortDesc = t.Annotations.Title
|
||||
}
|
||||
out = append(out, Tool{
|
||||
Name: t.Name,
|
||||
Description: desc,
|
||||
ShortDescription: shortDesc,
|
||||
InputSchema: schema,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// sdkCallToolResultToOurs 将 SDK 的 *mcp.CallToolResult 转为我们的 *ToolResult
|
||||
func sdkCallToolResultToOurs(res *mcp.CallToolResult) *ToolResult {
|
||||
if res == nil {
|
||||
return &ToolResult{Content: []Content{}}
|
||||
}
|
||||
content := sdkContentToOurs(res.Content)
|
||||
return &ToolResult{
|
||||
Content: content,
|
||||
IsError: res.IsError,
|
||||
}
|
||||
}
|
||||
|
||||
func sdkContentToOurs(list []mcp.Content) []Content {
|
||||
if len(list) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]Content, 0, len(list))
|
||||
for _, c := range list {
|
||||
switch v := c.(type) {
|
||||
case *mcp.TextContent:
|
||||
out = append(out, Content{Type: "text", Text: v.Text})
|
||||
default:
|
||||
out = append(out, Content{Type: "text", Text: fmt.Sprintf("%v", c)})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mustJSON(v interface{}) []byte {
|
||||
b, _ := json.Marshal(v)
|
||||
return b
|
||||
}
|
||||
|
||||
// simpleHTTPClient 简单 JSON-RPC over HTTP:每次请求一次 POST、响应在 body。实现 ExternalMCPClient。
|
||||
// 用于自建 MCP(如 http://127.0.0.1:8081/mcp)或其它仅支持简单 POST 的端点。
|
||||
type simpleHTTPClient struct {
|
||||
url string
|
||||
client *http.Client
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
status string
|
||||
}
|
||||
|
||||
func newSimpleHTTPClient(ctx context.Context, url string, timeout time.Duration, headers map[string]string, logger *zap.Logger) (ExternalMCPClient, error) {
|
||||
c := &simpleHTTPClient{
|
||||
url: url,
|
||||
client: httpClientWithTimeoutAndHeaders(timeout, headers),
|
||||
logger: logger,
|
||||
status: "connecting",
|
||||
}
|
||||
if err := c.initialize(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.status = "connected"
|
||||
c.mu.Unlock()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) setStatus(s string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.status = s
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) GetStatus() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) IsConnected() bool {
|
||||
return c.GetStatus() == "connected"
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) Initialize(context.Context) error {
|
||||
return nil // 已在 newSimpleHTTPClient 中完成
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) initialize(ctx context.Context) error {
|
||||
params := InitializeRequest{
|
||||
ProtocolVersion: ProtocolVersion,
|
||||
Capabilities: make(map[string]interface{}),
|
||||
ClientInfo: ClientInfo{Name: clientName, Version: clientVersion},
|
||||
}
|
||||
paramsJSON, _ := json.Marshal(params)
|
||||
req := &Message{
|
||||
ID: MessageID{value: "1"},
|
||||
Method: "initialize",
|
||||
Version: "2.0",
|
||||
Params: paramsJSON,
|
||||
}
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize: %w", err)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("initialize: %s (code %d)", resp.Error.Message, resp.Error.Code)
|
||||
}
|
||||
// 发送 notifications/initialized(协议要求)
|
||||
notify := &Message{
|
||||
ID: MessageID{value: nil},
|
||||
Method: "notifications/initialized",
|
||||
Version: "2.0",
|
||||
Params: json.RawMessage("{}"),
|
||||
}
|
||||
_ = c.sendNotification(notify)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) sendRequest(ctx context.Context, msg *Message) (*Message, error) {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(b))
|
||||
}
|
||||
var out Message
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) sendNotification(msg *Message) error {
|
||||
body, _ := json.Marshal(msg)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpReq, _ := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(body))
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) ListTools(ctx context.Context) ([]Tool, error) {
|
||||
req := &Message{
|
||||
ID: MessageID{value: uuid.New().String()},
|
||||
Method: "tools/list",
|
||||
Version: "2.0",
|
||||
Params: json.RawMessage("{}"),
|
||||
}
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("tools/list: %s (code %d)", resp.Error.Message, resp.Error.Code)
|
||||
}
|
||||
var listResp ListToolsResponse
|
||||
if err := json.Unmarshal(resp.Result, &listResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return listResp.Tools, nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error) {
|
||||
params := CallToolRequest{Name: name, Arguments: args}
|
||||
paramsJSON, _ := json.Marshal(params)
|
||||
req := &Message{
|
||||
ID: MessageID{value: uuid.New().String()},
|
||||
Method: "tools/call",
|
||||
Version: "2.0",
|
||||
Params: paramsJSON,
|
||||
}
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("tools/call: %s (code %d)", resp.Error.Message, resp.Error.Code)
|
||||
}
|
||||
var callResp CallToolResponse
|
||||
if err := json.Unmarshal(resp.Result, &callResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ToolResult{Content: callResp.Content, IsError: callResp.IsError}, nil
|
||||
}
|
||||
|
||||
func (c *simpleHTTPClient) Close() error {
|
||||
c.setStatus("disconnected")
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSDKClient 根据配置创建并连接外部 MCP 客户端(使用官方 SDK),返回实现 ExternalMCPClient 的 *sdkClient
|
||||
// 若连接失败返回 (nil, error)。ctx 用于连接超时与取消。
|
||||
func createSDKClient(ctx context.Context, serverCfg config.ExternalMCPServerConfig, logger *zap.Logger) (ExternalMCPClient, error) {
|
||||
timeout := time.Duration(serverCfg.Timeout) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
transport := serverCfg.Transport
|
||||
if transport == "" {
|
||||
if serverCfg.Command != "" {
|
||||
transport = "stdio"
|
||||
} else if serverCfg.URL != "" {
|
||||
transport = "http"
|
||||
} else {
|
||||
return nil, fmt.Errorf("配置缺少 command 或 url")
|
||||
}
|
||||
}
|
||||
|
||||
client := mcp.NewClient(&mcp.Implementation{
|
||||
Name: clientName,
|
||||
Version: clientVersion,
|
||||
}, nil)
|
||||
|
||||
var t mcp.Transport
|
||||
switch transport {
|
||||
case "stdio":
|
||||
if serverCfg.Command == "" {
|
||||
return nil, fmt.Errorf("stdio 模式需要配置 command")
|
||||
}
|
||||
// 必须用 exec.Command 而非 CommandContext:doConnect 返回后 ctx 会被 cancel,
|
||||
// 若用 CommandContext(ctx) 会立刻杀掉子进程,导致 ListTools 等后续请求失败、显示 0 工具
|
||||
cmd := exec.Command(serverCfg.Command, serverCfg.Args...)
|
||||
if len(serverCfg.Env) > 0 {
|
||||
cmd.Env = append(cmd.Env, envMapToSlice(serverCfg.Env)...)
|
||||
}
|
||||
t = &mcp.CommandTransport{Command: cmd}
|
||||
case "sse":
|
||||
if serverCfg.URL == "" {
|
||||
return nil, fmt.Errorf("sse 模式需要配置 url")
|
||||
}
|
||||
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
|
||||
t = &mcp.SSEClientTransport{
|
||||
Endpoint: serverCfg.URL,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
case "http":
|
||||
if serverCfg.URL == "" {
|
||||
return nil, fmt.Errorf("http 模式需要配置 url")
|
||||
}
|
||||
httpClient := httpClientWithTimeoutAndHeaders(timeout, serverCfg.Headers)
|
||||
t = &mcp.StreamableClientTransport{
|
||||
Endpoint: serverCfg.URL,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
case "simple_http":
|
||||
// 简单 JSON-RPC HTTP:每次请求一次 POST、响应在 body。用于自建 MCP 或兼容旧端点(如 http://127.0.0.1:8081/mcp)
|
||||
if serverCfg.URL == "" {
|
||||
return nil, fmt.Errorf("simple_http 模式需要配置 url")
|
||||
}
|
||||
return newSimpleHTTPClient(ctx, serverCfg.URL, timeout, serverCfg.Headers, logger)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的传输模式: %s", transport)
|
||||
}
|
||||
|
||||
session, err := client.Connect(ctx, t, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接失败: %w", err)
|
||||
}
|
||||
|
||||
return newSDKClientFromSession(session, client, logger), nil
|
||||
}
|
||||
|
||||
func envMapToSlice(env map[string]string) []string {
|
||||
m := make(map[string]string)
|
||||
for _, s := range os.Environ() {
|
||||
if i := strings.IndexByte(s, '='); i > 0 {
|
||||
m[s[:i]] = s[i+1:]
|
||||
}
|
||||
}
|
||||
for k, v := range env {
|
||||
m[k] = v
|
||||
}
|
||||
out := make([]string, 0, len(m))
|
||||
for k, v := range m {
|
||||
out = append(out, k+"="+v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func httpClientWithTimeoutAndHeaders(timeout time.Duration, headers map[string]string) *http.Client {
|
||||
transport := http.DefaultTransport
|
||||
if len(headers) > 0 {
|
||||
transport = &headerRoundTripper{
|
||||
headers: headers,
|
||||
base: http.DefaultTransport,
|
||||
}
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
type headerRoundTripper struct {
|
||||
headers map[string]string
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
func (h *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
for k, v := range h.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
return h.base.RoundTrip(req)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,239 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestExternalMCPManager_AddOrUpdateConfig(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
// 测试添加stdio配置
|
||||
stdioCfg := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Args: []string{"/path/to/script.py"},
|
||||
Transport: "stdio",
|
||||
Description: "Test stdio MCP",
|
||||
Timeout: 30,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
err := manager.AddOrUpdateConfig("test-stdio", stdioCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("添加stdio配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试添加HTTP配置
|
||||
httpCfg := config.ExternalMCPServerConfig{
|
||||
Transport: "http",
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Description: "Test HTTP MCP",
|
||||
Timeout: 30,
|
||||
Enabled: false,
|
||||
}
|
||||
|
||||
err = manager.AddOrUpdateConfig("test-http", httpCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("添加HTTP配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证配置已保存
|
||||
configs := manager.GetConfigs()
|
||||
if len(configs) != 2 {
|
||||
t.Fatalf("期望2个配置,实际%d个", len(configs))
|
||||
}
|
||||
|
||||
if configs["test-stdio"].Command != stdioCfg.Command {
|
||||
t.Errorf("stdio配置命令不匹配")
|
||||
}
|
||||
|
||||
if configs["test-http"].URL != httpCfg.URL {
|
||||
t.Errorf("HTTP配置URL不匹配")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_RemoveConfig(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Transport: "stdio",
|
||||
Enabled: false,
|
||||
}
|
||||
|
||||
manager.AddOrUpdateConfig("test-remove", cfg)
|
||||
|
||||
// 移除配置
|
||||
err := manager.RemoveConfig("test-remove")
|
||||
if err != nil {
|
||||
t.Fatalf("移除配置失败: %v", err)
|
||||
}
|
||||
|
||||
configs := manager.GetConfigs()
|
||||
if _, exists := configs["test-remove"]; exists {
|
||||
t.Error("配置应该已被移除")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_GetStats(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
// 添加多个配置
|
||||
manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
})
|
||||
|
||||
manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Enabled: true,
|
||||
})
|
||||
|
||||
manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Enabled: false,
|
||||
Disabled: true, // 明确设置为禁用
|
||||
})
|
||||
|
||||
stats := manager.GetStats()
|
||||
|
||||
if stats["total"].(int) != 3 {
|
||||
t.Errorf("期望总数3,实际%d", stats["total"])
|
||||
}
|
||||
|
||||
if stats["enabled"].(int) != 2 {
|
||||
t.Errorf("期望启用数2,实际%d", stats["enabled"])
|
||||
}
|
||||
|
||||
if stats["disabled"].(int) != 1 {
|
||||
t.Errorf("期望停用数1,实际%d", stats["disabled"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_LoadConfigs(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
externalMCPConfig := config.ExternalMCPConfig{
|
||||
Servers: map[string]config.ExternalMCPServerConfig{
|
||||
"loaded1": {
|
||||
Command: "python3",
|
||||
Enabled: true,
|
||||
},
|
||||
"loaded2": {
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager.LoadConfigs(&externalMCPConfig)
|
||||
|
||||
configs := manager.GetConfigs()
|
||||
if len(configs) != 2 {
|
||||
t.Fatalf("期望2个配置,实际%d个", len(configs))
|
||||
}
|
||||
|
||||
if configs["loaded1"].Command != "python3" {
|
||||
t.Error("配置1加载失败")
|
||||
}
|
||||
|
||||
if configs["loaded2"].URL != "http://127.0.0.1:8081/mcp" {
|
||||
t.Error("配置2加载失败")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLazySDKClient_InitializeFails 验证无效配置时 SDK 客户端 Initialize 失败并设置 error 状态
|
||||
func TestLazySDKClient_InitializeFails(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
// 使用不存在的 HTTP 地址,Initialize 应失败
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Transport: "http",
|
||||
URL: "http://127.0.0.1:19999/nonexistent",
|
||||
Timeout: 2,
|
||||
}
|
||||
c := newLazySDKClient(cfg, logger)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := c.Initialize(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when connecting to invalid server")
|
||||
}
|
||||
if c.GetStatus() != "error" {
|
||||
t.Errorf("expected status error, got %s", c.GetStatus())
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_StartStopClient(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
// 添加一个禁用的配置
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Transport: "stdio",
|
||||
Enabled: false,
|
||||
}
|
||||
|
||||
manager.AddOrUpdateConfig("test-start-stop", cfg)
|
||||
|
||||
// 尝试启动(可能会失败,因为没有真实的服务器)
|
||||
err := manager.StartClient("test-start-stop")
|
||||
if err != nil {
|
||||
t.Logf("启动失败(可能是没有服务器): %v", err)
|
||||
}
|
||||
|
||||
// 停止
|
||||
err = manager.StopClient("test-start-stop")
|
||||
if err != nil {
|
||||
t.Fatalf("停止失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证配置已更新为禁用
|
||||
configs := manager.GetConfigs()
|
||||
if configs["test-start-stop"].Enabled {
|
||||
t.Error("配置应该已被禁用")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_CallTool(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
// 测试调用不存在的工具
|
||||
_, _, err := manager.CallTool(context.Background(), "nonexistent::tool", map[string]interface{}{})
|
||||
if err == nil {
|
||||
t.Error("应该返回错误")
|
||||
}
|
||||
|
||||
// 测试无效的工具名称格式
|
||||
_, _, err = manager.CallTool(context.Background(), "invalid-tool-name", map[string]interface{}{})
|
||||
if err == nil {
|
||||
t.Error("应该返回错误(无效格式)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExternalMCPManager_GetAllTools(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
ctx := context.Background()
|
||||
tools, err := manager.GetAllTools(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("获取工具列表失败: %v", err)
|
||||
}
|
||||
|
||||
// 如果没有连接的客户端,应该返回空列表
|
||||
if len(tools) != 0 {
|
||||
t.Logf("获取到%d个工具", len(tools))
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,295 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ExternalMCPClient 外部 MCP 客户端接口(由 client_sdk.go 基于官方 SDK 实现)
|
||||
type ExternalMCPClient interface {
|
||||
Initialize(ctx context.Context) error
|
||||
ListTools(ctx context.Context) ([]Tool, error)
|
||||
CallTool(ctx context.Context, name string, args map[string]interface{}) (*ToolResult, error)
|
||||
Close() error
|
||||
IsConnected() bool
|
||||
GetStatus() string
|
||||
}
|
||||
|
||||
// MCP消息类型
|
||||
const (
|
||||
MessageTypeRequest = "request"
|
||||
MessageTypeResponse = "response"
|
||||
MessageTypeError = "error"
|
||||
MessageTypeNotify = "notify"
|
||||
)
|
||||
|
||||
// MCP协议版本
|
||||
const ProtocolVersion = "2024-11-05"
|
||||
|
||||
// MessageID 表示JSON-RPC 2.0的id字段,可以是字符串、数字或null
|
||||
type MessageID struct {
|
||||
value interface{}
|
||||
}
|
||||
|
||||
// UnmarshalJSON 自定义反序列化,支持字符串、数字和null
|
||||
func (m *MessageID) UnmarshalJSON(data []byte) error {
|
||||
// 尝试解析为null
|
||||
if string(data) == "null" {
|
||||
m.value = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// 尝试解析为字符串
|
||||
var str string
|
||||
if err := json.Unmarshal(data, &str); err == nil {
|
||||
m.value = str
|
||||
return nil
|
||||
}
|
||||
|
||||
// 尝试解析为数字
|
||||
var num json.Number
|
||||
if err := json.Unmarshal(data, &num); err == nil {
|
||||
m.value = num
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid id type")
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义序列化
|
||||
func (m MessageID) MarshalJSON() ([]byte, error) {
|
||||
if m.value == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(m.value)
|
||||
}
|
||||
|
||||
// String 返回字符串表示
|
||||
func (m MessageID) String() string {
|
||||
if m.value == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%v", m.value)
|
||||
}
|
||||
|
||||
// Value 返回原始值
|
||||
func (m MessageID) Value() interface{} {
|
||||
return m.value
|
||||
}
|
||||
|
||||
// Message 表示MCP消息(符合JSON-RPC 2.0规范)
|
||||
type Message struct {
|
||||
ID MessageID `json:"id,omitempty"`
|
||||
Type string `json:"-"` // 内部使用,不序列化到JSON
|
||||
Method string `json:"method,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
Version string `json:"jsonrpc,omitempty"` // JSON-RPC 2.0 版本标识
|
||||
}
|
||||
|
||||
// Error 表示MCP错误
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Tool 表示MCP工具定义
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"` // 详细描述
|
||||
ShortDescription string `json:"shortDescription,omitempty"` // 简短描述(用于工具列表,减少token消耗)
|
||||
InputSchema map[string]interface{} `json:"inputSchema"`
|
||||
}
|
||||
|
||||
// ToolCall 表示工具调用
|
||||
type ToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
// ToolResult 表示工具执行结果
|
||||
type ToolResult struct {
|
||||
Content []Content `json:"content"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
// Content 表示内容
|
||||
type Content struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// InitializeRequest 初始化请求
|
||||
type InitializeRequest struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities map[string]interface{} `json:"capabilities"`
|
||||
ClientInfo ClientInfo `json:"clientInfo"`
|
||||
}
|
||||
|
||||
// ClientInfo 客户端信息
|
||||
type ClientInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// InitializeResponse 初始化响应
|
||||
type InitializeResponse struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities ServerCapabilities `json:"capabilities"`
|
||||
ServerInfo ServerInfo `json:"serverInfo"`
|
||||
}
|
||||
|
||||
// ServerCapabilities 服务器能力
|
||||
type ServerCapabilities struct {
|
||||
Tools map[string]interface{} `json:"tools,omitempty"`
|
||||
Prompts map[string]interface{} `json:"prompts,omitempty"`
|
||||
Resources map[string]interface{} `json:"resources,omitempty"`
|
||||
Sampling map[string]interface{} `json:"sampling,omitempty"`
|
||||
}
|
||||
|
||||
// ServerInfo 服务器信息
|
||||
type ServerInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// ListToolsRequest 列出工具请求
|
||||
type ListToolsRequest struct{}
|
||||
|
||||
// ListToolsResponse 列出工具响应
|
||||
type ListToolsResponse struct {
|
||||
Tools []Tool `json:"tools"`
|
||||
}
|
||||
|
||||
// ListPromptsResponse 列出提示词响应
|
||||
type ListPromptsResponse struct {
|
||||
Prompts []Prompt `json:"prompts"`
|
||||
}
|
||||
|
||||
// ListResourcesResponse 列出资源响应
|
||||
type ListResourcesResponse struct {
|
||||
Resources []Resource `json:"resources"`
|
||||
}
|
||||
|
||||
// CallToolRequest 调用工具请求
|
||||
type CallToolRequest struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
// CallToolResponse 调用工具响应
|
||||
type CallToolResponse struct {
|
||||
Content []Content `json:"content"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
// ToolExecution 工具执行记录
|
||||
type ToolExecution struct {
|
||||
ID string `json:"id"`
|
||||
ToolName string `json:"toolName"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
Status string `json:"status"` // pending, running, completed, failed
|
||||
Result *ToolResult `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StartTime time.Time `json:"startTime"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Duration time.Duration `json:"duration,omitempty"`
|
||||
}
|
||||
|
||||
// ToolStats 工具统计信息
|
||||
type ToolStats struct {
|
||||
ToolName string `json:"toolName"`
|
||||
TotalCalls int `json:"totalCalls"`
|
||||
SuccessCalls int `json:"successCalls"`
|
||||
FailedCalls int `json:"failedCalls"`
|
||||
LastCallTime *time.Time `json:"lastCallTime,omitempty"`
|
||||
}
|
||||
|
||||
// Prompt 提示词模板
|
||||
type Prompt struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Arguments []PromptArgument `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// PromptArgument 提示词参数
|
||||
type PromptArgument struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Required bool `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
// GetPromptRequest 获取提示词请求
|
||||
type GetPromptRequest struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// GetPromptResponse 获取提示词响应
|
||||
type GetPromptResponse struct {
|
||||
Messages []PromptMessage `json:"messages"`
|
||||
}
|
||||
|
||||
// PromptMessage 提示词消息
|
||||
type PromptMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// Resource 资源
|
||||
type Resource struct {
|
||||
URI string `json:"uri"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
}
|
||||
|
||||
// ReadResourceRequest 读取资源请求
|
||||
type ReadResourceRequest struct {
|
||||
URI string `json:"uri"`
|
||||
}
|
||||
|
||||
// ReadResourceResponse 读取资源响应
|
||||
type ReadResourceResponse struct {
|
||||
Contents []ResourceContent `json:"contents"`
|
||||
}
|
||||
|
||||
// ResourceContent 资源内容
|
||||
type ResourceContent struct {
|
||||
URI string `json:"uri"`
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Blob string `json:"blob,omitempty"`
|
||||
}
|
||||
|
||||
// SamplingRequest 采样请求
|
||||
type SamplingRequest struct {
|
||||
Messages []SamplingMessage `json:"messages"`
|
||||
Model string `json:"model,omitempty"`
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
}
|
||||
|
||||
// SamplingMessage 采样消息
|
||||
type SamplingMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// SamplingResponse 采样响应
|
||||
type SamplingResponse struct {
|
||||
Content []SamplingContent `json:"content"`
|
||||
Model string `json:"model,omitempty"`
|
||||
StopReason string `json:"stopReason,omitempty"`
|
||||
}
|
||||
|
||||
// SamplingContent 采样内容
|
||||
type SamplingContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/adk/middlewares/filesystem"
|
||||
"github.com/cloudwego/eino/adk/middlewares/skill"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// prepareEinoSkills builds Eino official skill backend + middleware, and a shared local disk backend
|
||||
// for skill discovery and (optionally) filesystem/execute tools. Returns nils when disabled or dir missing.
|
||||
func prepareEinoSkills(
|
||||
ctx context.Context,
|
||||
skillsDir string,
|
||||
ma *config.MultiAgentConfig,
|
||||
logger *zap.Logger,
|
||||
) (loc *localbk.Local, skillMW adk.ChatModelAgentMiddleware, fsTools bool, err error) {
|
||||
if ma == nil || ma.EinoSkills.Disable {
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
root := strings.TrimSpace(skillsDir)
|
||||
if root == "" {
|
||||
if logger != nil {
|
||||
logger.Warn("eino skills: skills_dir empty, skip")
|
||||
}
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
abs, err := filepath.Abs(root)
|
||||
if err != nil {
|
||||
return nil, nil, false, fmt.Errorf("skills_dir abs: %w", err)
|
||||
}
|
||||
if st, err := os.Stat(abs); err != nil || !st.IsDir() {
|
||||
if logger != nil {
|
||||
logger.Warn("eino skills: directory missing, skip", zap.String("dir", abs), zap.Error(err))
|
||||
}
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
|
||||
loc, err = localbk.NewBackend(ctx, &localbk.Config{})
|
||||
if err != nil {
|
||||
return nil, nil, false, fmt.Errorf("eino local backend: %w", err)
|
||||
}
|
||||
|
||||
skillBE, err := skill.NewBackendFromFilesystem(ctx, &skill.BackendFromFilesystemConfig{
|
||||
Backend: loc,
|
||||
BaseDir: abs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, false, fmt.Errorf("eino skill filesystem backend: %w", err)
|
||||
}
|
||||
|
||||
sc := &skill.Config{Backend: skillBE}
|
||||
if name := strings.TrimSpace(ma.EinoSkills.SkillToolName); name != "" {
|
||||
sc.SkillToolName = &name
|
||||
}
|
||||
skillMW, err = skill.NewMiddleware(ctx, sc)
|
||||
if err != nil {
|
||||
return nil, nil, false, fmt.Errorf("eino skill middleware: %w", err)
|
||||
}
|
||||
|
||||
fsTools = ma.EinoSkills.EinoSkillFilesystemToolsEffective()
|
||||
return loc, skillMW, fsTools, nil
|
||||
}
|
||||
|
||||
// subAgentFilesystemMiddleware returns filesystem middleware for a sub-agent when Deep itself
|
||||
// does not set Backend (fsTools false on orchestrator) but we still want tools on subs — not used;
|
||||
// when orchestrator has Backend, builtin FS is only on outer agent; subs need explicit FS for parity.
|
||||
func subAgentFilesystemMiddleware(ctx context.Context, loc *localbk.Local) (adk.ChatModelAgentMiddleware, error) {
|
||||
if loc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return filesystem.New(ctx, &filesystem.MiddlewareConfig{
|
||||
Backend: loc,
|
||||
StreamingShell: loc,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/adk/middlewares/summarization"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// einoSummarizeUserInstruction 与单 Agent MemoryCompressor 目标一致:压缩时保留渗透关键信息。
|
||||
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。
|
||||
|
||||
必须保留:已确认漏洞与攻击路径、工具输出中的核心发现、凭证与认证细节、架构与薄弱点、当前进度、失败尝试与死路、策略决策。
|
||||
保留精确技术细节(URL、路径、参数、Payload、版本号、报错原文可摘要但要点不丢)。
|
||||
将冗长扫描输出概括为结论;重复发现合并表述。
|
||||
|
||||
输出须使后续代理能无缝继续同一授权测试任务。`
|
||||
|
||||
// newEinoSummarizationMiddleware 使用 Eino ADK Summarization 中间件(见 https://www.cloudwego.io/zh/docs/eino/core_modules/eino_adk/eino_adk_chatmodelagentmiddleware/middleware_summarization/)。
|
||||
// 触发阈值与单 Agent MemoryCompressor 一致:当估算 token 超过 openai.max_total_tokens 的 90% 时摘要。
|
||||
func newEinoSummarizationMiddleware(
|
||||
ctx context.Context,
|
||||
summaryModel model.BaseChatModel,
|
||||
appCfg *config.Config,
|
||||
logger *zap.Logger,
|
||||
) (adk.ChatModelAgentMiddleware, error) {
|
||||
if summaryModel == nil || appCfg == nil {
|
||||
return nil, fmt.Errorf("multiagent: summarization 需要 model 与配置")
|
||||
}
|
||||
maxTotal := appCfg.OpenAI.MaxTotalTokens
|
||||
if maxTotal <= 0 {
|
||||
maxTotal = 120000
|
||||
}
|
||||
trigger := int(float64(maxTotal) * 0.9)
|
||||
if trigger < 4096 {
|
||||
trigger = maxTotal
|
||||
if trigger < 4096 {
|
||||
trigger = 4096
|
||||
}
|
||||
}
|
||||
preserveMax := trigger / 3
|
||||
if preserveMax < 2048 {
|
||||
preserveMax = 2048
|
||||
}
|
||||
|
||||
modelName := strings.TrimSpace(appCfg.OpenAI.Model)
|
||||
if modelName == "" {
|
||||
modelName = "gpt-4o"
|
||||
}
|
||||
|
||||
mw, err := summarization.New(ctx, &summarization.Config{
|
||||
Model: summaryModel,
|
||||
Trigger: &summarization.TriggerCondition{
|
||||
ContextTokens: trigger,
|
||||
},
|
||||
TokenCounter: einoSummarizationTokenCounter(modelName),
|
||||
UserInstruction: einoSummarizeUserInstruction,
|
||||
EmitInternalEvents: false,
|
||||
PreserveUserMessages: &summarization.PreserveUserMessages{
|
||||
Enabled: true,
|
||||
MaxTokens: preserveMax,
|
||||
},
|
||||
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
|
||||
if logger == nil {
|
||||
return nil
|
||||
}
|
||||
logger.Info("eino summarization 已压缩上下文",
|
||||
zap.Int("messages_before", len(before.Messages)),
|
||||
zap.Int("messages_after", len(after.Messages)),
|
||||
zap.Int("max_total_tokens", maxTotal),
|
||||
zap.Int("trigger_context_tokens", trigger),
|
||||
)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("summarization.New: %w", err)
|
||||
}
|
||||
return mw, nil
|
||||
}
|
||||
|
||||
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
||||
tc := agent.NewTikTokenCounter()
|
||||
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
||||
var sb strings.Builder
|
||||
for _, msg := range input.Messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(string(msg.Role))
|
||||
sb.WriteByte('\n')
|
||||
if msg.Content != "" {
|
||||
sb.WriteString(msg.Content)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if msg.ReasoningContent != "" {
|
||||
sb.WriteString(msg.ReasoningContent)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
if b, err := sonic.Marshal(msg.ToolCalls); err == nil {
|
||||
sb.Write(b)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
for _, part := range msg.UserInputMultiContent {
|
||||
if part.Type == schema.ChatMessagePartTypeText && part.Text != "" {
|
||||
sb.WriteString(part.Text)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tl := range input.Tools {
|
||||
if tl == nil {
|
||||
continue
|
||||
}
|
||||
cp := *tl
|
||||
cp.Extra = nil
|
||||
if text, err := sonic.MarshalString(cp); err == nil {
|
||||
sb.WriteString(text)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
text := sb.String()
|
||||
n, err := tc.Count(openAIModel, text)
|
||||
if err != nil {
|
||||
return (len(text) + 3) / 4, nil
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
)
|
||||
|
||||
// noNestedTaskMiddleware 禁止在已经处于 task(sub-agent) 执行链中再次调用 task,
|
||||
// 避免子代理再次委派子代理造成的无限委派/递归。
|
||||
//
|
||||
// 通过在 ctx 中设置临时标记来实现嵌套检测:外层 task 调用会先标记 ctx,
|
||||
// 子代理内再调用 task 时会命中该标记并拒绝。
|
||||
type noNestedTaskMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
}
|
||||
|
||||
type nestedTaskCtxKey struct{}
|
||||
|
||||
func newNoNestedTaskMiddleware() adk.ChatModelAgentMiddleware {
|
||||
return &noNestedTaskMiddleware{}
|
||||
}
|
||||
|
||||
func (m *noNestedTaskMiddleware) WrapInvokableToolCall(
|
||||
ctx context.Context,
|
||||
endpoint adk.InvokableToolCallEndpoint,
|
||||
tCtx *adk.ToolContext,
|
||||
) (adk.InvokableToolCallEndpoint, error) {
|
||||
if tCtx == nil || strings.TrimSpace(tCtx.Name) == "" {
|
||||
return endpoint, nil
|
||||
}
|
||||
// Deep 内置 task 工具名固定为 "task";为兼容可能的大小写/空白,仅做不区分大小写匹配。
|
||||
if !strings.EqualFold(strings.TrimSpace(tCtx.Name), "task") {
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
// 已在 task 执行链中:拒绝继续委派,直接报错让上层快速终止。
|
||||
if ctx != nil {
|
||||
if v, ok := ctx.Value(nestedTaskCtxKey{}).(bool); ok && v {
|
||||
return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
||||
// Important: return a tool result text (not an error) to avoid hard-stopping the whole multi-agent run.
|
||||
// The nested task is still prevented from spawning another sub-agent, so recursion is avoided.
|
||||
_ = argumentsInJSON
|
||||
_ = opts
|
||||
return "Nested task delegation is forbidden (already inside a sub-agent delegation chain) to avoid infinite delegation. Please continue the work using the current agent's tools.", nil
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 标记当前 task 调用链,确保子代理内的再次 task 调用能检测到嵌套。
|
||||
return func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
||||
ctx2 := ctx
|
||||
if ctx2 == nil {
|
||||
ctx2 = context.Background()
|
||||
}
|
||||
ctx2 = context.WithValue(ctx2, nestedTaskCtxKey{}, true)
|
||||
return endpoint(ctx2, argumentsInJSON, opts...)
|
||||
}, nil
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,51 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// maxToolCallRecoveryAttempts 含首次运行:首次 + 自动重试次数。
|
||||
// 例如为 3 表示最多共 3 次完整 DeepAgent 运行(2 次失败后各追加一条纠错提示)。
|
||||
// 该常量同时用于 JSON 参数错误和工具执行错误(如子代理名称不存在)的恢复重试。
|
||||
const maxToolCallRecoveryAttempts = 5
|
||||
|
||||
// toolCallArgumentsJSONRetryHint 追加在用户消息后,提示模型输出合法 JSON 工具参数(部分云厂商会在流式阶段校验 arguments)。
|
||||
func toolCallArgumentsJSONRetryHint() *schema.Message {
|
||||
return schema.UserMessage(`[系统提示] 上一次输出中,工具调用的 function.arguments 不是合法 JSON,接口已拒绝。请重新生成:每个 tool call 的 arguments 必须是完整、可解析的 JSON 对象字符串(键名用双引号,无多余逗号,括号配对)。不要输出截断或不完整的 JSON。
|
||||
|
||||
[System] Your previous tool call used invalid JSON in function.arguments and was rejected by the API. Regenerate with strictly valid JSON objects only (double-quoted keys, matched braces, no trailing commas).`)
|
||||
}
|
||||
|
||||
// toolCallArgumentsJSONRecoveryTimelineMessage 供 eino_recovery 事件落库与前端时间线展示。
|
||||
func toolCallArgumentsJSONRecoveryTimelineMessage(attempt int) string {
|
||||
return fmt.Sprintf(
|
||||
"接口拒绝了无效的工具参数 JSON。已向对话追加系统提示并要求模型重新生成合法的 function.arguments。"+
|
||||
"当前为第 %d/%d 轮完整运行。\n\n"+
|
||||
"The API rejected invalid JSON in tool arguments. A system hint was appended. This is full run %d of %d.",
|
||||
attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts,
|
||||
)
|
||||
}
|
||||
|
||||
// isRecoverableToolCallArgumentsJSONError 判断是否为「工具参数非合法 JSON」类流式错误,可通过追加提示后重跑一轮。
|
||||
func isRecoverableToolCallArgumentsJSONError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s := strings.ToLower(err.Error())
|
||||
if !strings.Contains(s, "json") {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(s, "function.arguments") || strings.Contains(s, "function arguments") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(s, "invalidparameter") && strings.Contains(s, "json") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(s, "must be in json format") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsRecoverableToolCallArgumentsJSONError(t *testing.T) {
|
||||
yes := errors.New(`failed to receive stream chunk: error, <400> InternalError.Algo.InvalidParameter: The "function.arguments" parameter of the code model must be in JSON format.`)
|
||||
if !isRecoverableToolCallArgumentsJSONError(yes) {
|
||||
t.Fatal("expected recoverable for function.arguments + JSON")
|
||||
}
|
||||
no := errors.New("unrelated network failure")
|
||||
if isRecoverableToolCallArgumentsJSONError(no) {
|
||||
t.Fatal("expected not recoverable")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
// softRecoveryToolCallMiddleware returns an InvokableToolMiddleware that catches
|
||||
// specific recoverable errors from tool execution (JSON parse errors, tool-not-found,
|
||||
// etc.) and converts them into soft errors: nil error + descriptive error content
|
||||
// returned to the LLM. This allows the model to self-correct within the same
|
||||
// iteration rather than crashing the entire graph and requiring a full replay.
|
||||
//
|
||||
// Without this middleware, a JSON parse failure in any tool's InvokableRun propagates
|
||||
// as a hard error through the Eino ToolsNode → [NodeRunError] → ev.Err, which
|
||||
// either triggers the full-replay retry loop (expensive) or terminates the run
|
||||
// entirely once retries are exhausted. With it, the LLM simply sees an error message
|
||||
// in the tool result and can adjust its next tool call accordingly.
|
||||
func softRecoveryToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
return func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint {
|
||||
return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
output, err := next(ctx, input)
|
||||
if err == nil {
|
||||
return output, nil
|
||||
}
|
||||
if !isSoftRecoverableToolError(err) {
|
||||
return output, err
|
||||
}
|
||||
// Convert the hard error into a soft error: the LLM will see this
|
||||
// message as the tool's output and can self-correct.
|
||||
msg := buildSoftRecoveryMessage(input.Name, input.Arguments, err)
|
||||
return &compose.ToolOutput{Result: msg}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isSoftRecoverableToolError determines whether a tool execution error should be
|
||||
// silently converted to a tool-result message rather than crashing the graph.
|
||||
func isSoftRecoverableToolError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s := strings.ToLower(err.Error())
|
||||
|
||||
// JSON unmarshal/parse failures — the model generated truncated or malformed arguments.
|
||||
if isJSONRelatedError(s) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Sub-agent type not found (from deep/task_tool.go)
|
||||
if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Tool not found in ToolsNode indexes
|
||||
if strings.Contains(s, "tool") && strings.Contains(s, "not found") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isJSONRelatedError checks whether an error string indicates a JSON parsing problem.
|
||||
func isJSONRelatedError(lower string) bool {
|
||||
if !strings.Contains(lower, "json") {
|
||||
return false
|
||||
}
|
||||
jsonIndicators := []string{
|
||||
"unexpected end of json",
|
||||
"unmarshal",
|
||||
"invalid character",
|
||||
"cannot unmarshal",
|
||||
"invalid tool arguments",
|
||||
"failed to unmarshal",
|
||||
"must be in json format",
|
||||
"unexpected eof",
|
||||
}
|
||||
for _, ind := range jsonIndicators {
|
||||
if strings.Contains(lower, ind) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// buildSoftRecoveryMessage creates a bilingual error message that the LLM can act on.
|
||||
func buildSoftRecoveryMessage(toolName, arguments string, err error) string {
|
||||
// Truncate arguments preview to avoid flooding the context.
|
||||
argPreview := arguments
|
||||
if len(argPreview) > 300 {
|
||||
argPreview = argPreview[:300] + "... (truncated)"
|
||||
}
|
||||
|
||||
// Try to determine if it's specifically a JSON parse error for a friendlier message.
|
||||
errStr := err.Error()
|
||||
var jsonErr *json.SyntaxError
|
||||
isJSONErr := strings.Contains(strings.ToLower(errStr), "json") ||
|
||||
strings.Contains(strings.ToLower(errStr), "unmarshal")
|
||||
_ = jsonErr // suppress unused
|
||||
|
||||
if isJSONErr {
|
||||
return fmt.Sprintf(
|
||||
"[Tool Error] The arguments for tool '%s' are not valid JSON and could not be parsed.\n"+
|
||||
"Error: %s\n"+
|
||||
"Arguments received: %s\n\n"+
|
||||
"Please fix the JSON (ensure double-quoted keys, matched braces/brackets, no trailing commas, "+
|
||||
"no truncation) and call the tool again.\n\n"+
|
||||
"[工具错误] 工具 '%s' 的参数不是合法 JSON,无法解析。\n"+
|
||||
"错误:%s\n"+
|
||||
"收到的参数:%s\n\n"+
|
||||
"请修正 JSON(确保双引号键名、括号配对、无尾部逗号、无截断),然后重新调用工具。",
|
||||
toolName, errStr, argPreview,
|
||||
toolName, errStr, argPreview,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"[Tool Error] Tool '%s' execution failed: %s\n"+
|
||||
"Arguments: %s\n\n"+
|
||||
"Please review the available tools and their expected arguments, then retry.\n\n"+
|
||||
"[工具错误] 工具 '%s' 执行失败:%s\n"+
|
||||
"参数:%s\n\n"+
|
||||
"请检查可用工具及其参数要求,然后重试。",
|
||||
toolName, errStr, argPreview,
|
||||
toolName, errStr, argPreview,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
)
|
||||
|
||||
func TestIsSoftRecoverableToolError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "unexpected end of JSON input",
|
||||
err: errors.New("unexpected end of JSON input"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "failed to unmarshal task tool input json",
|
||||
err: errors.New("failed to unmarshal task tool input json: unexpected end of JSON input"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "invalid tool arguments JSON",
|
||||
err: errors.New("invalid tool arguments JSON: unexpected end of JSON input"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "json invalid character",
|
||||
err: errors.New(`invalid character '}' looking for beginning of value in JSON`),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "subagent type not found",
|
||||
err: errors.New("subagent type recon_agent not found"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "tool not found",
|
||||
err: errors.New("tool nmap_scan not found in toolsNode indexes"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "unrelated network error",
|
||||
err: errors.New("connection refused"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "context cancelled",
|
||||
err: context.Canceled,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "real json unmarshal error",
|
||||
err: func() error {
|
||||
var v map[string]interface{}
|
||||
return json.Unmarshal([]byte(`{"key": `), &v)
|
||||
}(),
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isSoftRecoverableToolError(tt.err)
|
||||
if got != tt.expected {
|
||||
t.Errorf("isSoftRecoverableToolError(%v) = %v, want %v", tt.err, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftRecoveryToolCallMiddleware_PassesThrough(t *testing.T) {
|
||||
mw := softRecoveryToolCallMiddleware()
|
||||
called := false
|
||||
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
called = true
|
||||
return &compose.ToolOutput{Result: "success"}, nil
|
||||
}
|
||||
wrapped := mw(next)
|
||||
out, err := wrapped(context.Background(), &compose.ToolInput{
|
||||
Name: "test_tool",
|
||||
Arguments: `{"key": "value"}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("next endpoint was not called")
|
||||
}
|
||||
if out.Result != "success" {
|
||||
t.Fatalf("expected 'success', got %q", out.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftRecoveryToolCallMiddleware_ConvertsJSONError(t *testing.T) {
|
||||
mw := softRecoveryToolCallMiddleware()
|
||||
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
return nil, errors.New("failed to unmarshal task tool input json: unexpected end of JSON input")
|
||||
}
|
||||
wrapped := mw(next)
|
||||
out, err := wrapped(context.Background(), &compose.ToolInput{
|
||||
Name: "task",
|
||||
Arguments: `{"subagent_type": "recon`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error (soft recovery), got: %v", err)
|
||||
}
|
||||
if out == nil || out.Result == "" {
|
||||
t.Fatal("expected non-empty recovery message")
|
||||
}
|
||||
if !containsAll(out.Result, "[Tool Error]", "task", "JSON") {
|
||||
t.Fatalf("recovery message missing expected content: %s", out.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftRecoveryToolCallMiddleware_PropagatesNonRecoverable(t *testing.T) {
|
||||
mw := softRecoveryToolCallMiddleware()
|
||||
origErr := errors.New("connection timeout to remote server")
|
||||
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
return nil, origErr
|
||||
}
|
||||
wrapped := mw(next)
|
||||
_, err := wrapped(context.Background(), &compose.ToolInput{
|
||||
Name: "test_tool",
|
||||
Arguments: `{}`,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error to propagate for non-recoverable errors")
|
||||
}
|
||||
if err != origErr {
|
||||
t.Fatalf("expected original error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func containsAll(s string, subs ...string) bool {
|
||||
for _, sub := range subs {
|
||||
if !contains(s, sub) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func contains(s, sub string) bool {
|
||||
return len(s) >= len(sub) && searchString(s, sub)
|
||||
}
|
||||
|
||||
func searchString(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// isRecoverableToolExecutionError detects tool-level execution errors that can be
|
||||
// recovered by retrying with a corrective hint. These errors originate from eino
|
||||
// framework internals (e.g. task_tool.go, tool_node.go) when the LLM produces
|
||||
// invalid tool calls such as non-existent sub-agent types, malformed JSON arguments,
|
||||
// or unregistered tool names.
|
||||
func isRecoverableToolExecutionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s := strings.ToLower(err.Error())
|
||||
|
||||
// Sub-agent type not found (from deep/task_tool.go)
|
||||
if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Tool not found in toolsNode indexes (from compose/tool_node.go, when UnknownToolsHandler is nil)
|
||||
if strings.Contains(s, "tool") && strings.Contains(s, "not found") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Invalid tool arguments JSON (from einomcp/mcp_tools.go or eino internals)
|
||||
if strings.Contains(s, "invalid tool arguments json") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Failed to unmarshal task tool input json (from deep/task_tool.go)
|
||||
if strings.Contains(s, "failed to unmarshal") && strings.Contains(s, "json") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Generic tool call stream/invoke failure wrapping the above
|
||||
if (strings.Contains(s, "failed to stream tool call") || strings.Contains(s, "failed to invoke tool")) &&
|
||||
(strings.Contains(s, "not found") || strings.Contains(s, "json") || strings.Contains(s, "unmarshal")) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// toolExecutionRetryHint returns a user message appended to the conversation to prompt
|
||||
// the LLM to correct its tool call after a tool execution error.
|
||||
func toolExecutionRetryHint() *schema.Message {
|
||||
return schema.UserMessage(`[System] Your previous tool call failed because:
|
||||
- The tool or sub-agent name you used does not exist, OR
|
||||
- The tool call arguments were not valid JSON.
|
||||
|
||||
Please carefully review the available tools and sub-agents listed in your context, use only exact registered names (case-sensitive), and ensure all arguments are well-formed JSON objects. Then retry your action.
|
||||
|
||||
[系统提示] 上一次工具调用失败,可能原因:
|
||||
- 你使用的工具名或子代理名称不存在;
|
||||
- 工具调用参数不是合法 JSON。
|
||||
|
||||
请仔细检查上下文中列出的可用工具和子代理名称(须完全匹配、区分大小写),确保所有参数均为合法的 JSON 对象,然后重新执行。`)
|
||||
}
|
||||
|
||||
// toolExecutionRecoveryTimelineMessage returns a message for the eino_recovery event
|
||||
// displayed in the UI timeline when a tool execution error triggers a retry.
|
||||
func toolExecutionRecoveryTimelineMessage(attempt int) string {
|
||||
return fmt.Sprintf(
|
||||
"工具调用执行失败(工具/子代理名称不存在或参数 JSON 无效)。已向对话追加纠错提示并要求模型重新生成。"+
|
||||
"当前为第 %d/%d 轮完整运行。\n\n"+
|
||||
"Tool call execution failed (unknown tool/sub-agent name or invalid JSON arguments). "+
|
||||
"A corrective hint was appended. This is full run %d of %d.",
|
||||
attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
package skillpackage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var reH2 = regexp.MustCompile(`(?m)^##\s+(.+)$`)
|
||||
|
||||
const summaryContentRunes = 6000
|
||||
|
||||
type markdownSection struct {
|
||||
Heading string
|
||||
Title string
|
||||
Content string
|
||||
}
|
||||
|
||||
func splitMarkdownSections(body string) []markdownSection {
|
||||
body = strings.TrimSpace(body)
|
||||
if body == "" {
|
||||
return nil
|
||||
}
|
||||
idxs := reH2.FindAllStringIndex(body, -1)
|
||||
titles := reH2.FindAllStringSubmatch(body, -1)
|
||||
if len(idxs) == 0 {
|
||||
return []markdownSection{{
|
||||
Heading: "",
|
||||
Title: "_body",
|
||||
Content: body,
|
||||
}}
|
||||
}
|
||||
var out []markdownSection
|
||||
for i := range idxs {
|
||||
title := strings.TrimSpace(titles[i][1])
|
||||
start := idxs[i][0]
|
||||
end := len(body)
|
||||
if i+1 < len(idxs) {
|
||||
end = idxs[i+1][0]
|
||||
}
|
||||
chunk := strings.TrimSpace(body[start:end])
|
||||
out = append(out, markdownSection{
|
||||
Heading: "## " + title,
|
||||
Title: title,
|
||||
Content: chunk,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func deriveSections(body string) []SkillSection {
|
||||
md := splitMarkdownSections(body)
|
||||
out := make([]SkillSection, 0, len(md))
|
||||
for _, ms := range md {
|
||||
if ms.Title == "_body" {
|
||||
continue
|
||||
}
|
||||
out = append(out, SkillSection{
|
||||
ID: slugifySectionID(ms.Title),
|
||||
Title: ms.Title,
|
||||
Heading: ms.Heading,
|
||||
Level: 2,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func slugifySectionID(title string) string {
|
||||
title = strings.TrimSpace(strings.ToLower(title))
|
||||
if title == "" {
|
||||
return "section"
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, r := range title {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z', r >= '0' && r <= '9':
|
||||
b.WriteRune(r)
|
||||
case r == ' ', r == '-', r == '_':
|
||||
b.WriteRune('-')
|
||||
}
|
||||
}
|
||||
s := strings.Trim(b.String(), "-")
|
||||
if s == "" {
|
||||
return "section"
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func findSectionContent(sections []markdownSection, sec string) string {
|
||||
sec = strings.TrimSpace(sec)
|
||||
if sec == "" {
|
||||
return ""
|
||||
}
|
||||
want := strings.ToLower(sec)
|
||||
for _, s := range sections {
|
||||
if strings.EqualFold(slugifySectionID(s.Title), want) || strings.EqualFold(s.Title, sec) {
|
||||
return s.Content
|
||||
}
|
||||
if strings.EqualFold(strings.ReplaceAll(s.Title, " ", "-"), want) {
|
||||
return s.Content
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func buildSummaryMarkdown(name, description string, tags []string, scripts []SkillScriptInfo, sections []SkillSection, body string) string {
|
||||
var b strings.Builder
|
||||
if description != "" {
|
||||
b.WriteString(description)
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
if len(tags) > 0 {
|
||||
b.WriteString("**Tags**: ")
|
||||
b.WriteString(strings.Join(tags, ", "))
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
if len(scripts) > 0 {
|
||||
b.WriteString("### Bundled scripts\n\n")
|
||||
for _, sc := range scripts {
|
||||
line := "- `" + sc.RelPath + "`"
|
||||
if sc.Description != "" {
|
||||
line += " — " + sc.Description
|
||||
}
|
||||
b.WriteString(line)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if len(sections) > 0 {
|
||||
b.WriteString("### Sections\n\n")
|
||||
for _, sec := range sections {
|
||||
line := "- **" + sec.ID + "**"
|
||||
if sec.Title != "" && sec.Title != sec.ID {
|
||||
line += ": " + sec.Title
|
||||
}
|
||||
b.WriteString(line)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
mdSecs := splitMarkdownSections(body)
|
||||
preview := body
|
||||
if len(mdSecs) > 0 && mdSecs[0].Title != "_body" {
|
||||
preview = mdSecs[0].Content
|
||||
}
|
||||
b.WriteString("### Preview (SKILL.md)\n\n")
|
||||
b.WriteString(truncateRunes(strings.TrimSpace(preview), summaryContentRunes))
|
||||
b.WriteString("\n\n---\n\n_(Summary for admin UI. Agents use Eino `skill` tool for full SKILL.md progressive loading.)_")
|
||||
if name != "" {
|
||||
b.WriteString(fmt.Sprintf("\n\n_Skill name: %s_", name))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func truncateRunes(s string, max int) string {
|
||||
if max <= 0 || s == "" {
|
||||
return s
|
||||
}
|
||||
r := []rune(s)
|
||||
if len(r) <= max {
|
||||
return s
|
||||
}
|
||||
return string(r[:max]) + "…"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package skillpackage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ExtractSkillMDFrontMatterYAML returns the YAML source inside the first --- ... --- block and the markdown body.
|
||||
func ExtractSkillMDFrontMatterYAML(raw []byte) (fmYAML string, body string, err error) {
|
||||
text := strings.TrimPrefix(string(raw), "\ufeff")
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return "", "", fmt.Errorf("SKILL.md is empty")
|
||||
}
|
||||
lines := strings.Split(text, "\n")
|
||||
if len(lines) < 2 || strings.TrimSpace(lines[0]) != "---" {
|
||||
return "", "", fmt.Errorf("SKILL.md must start with YAML front matter (---) per Agent Skills standard")
|
||||
}
|
||||
var fmLines []string
|
||||
i := 1
|
||||
for i < len(lines) {
|
||||
if strings.TrimSpace(lines[i]) == "---" {
|
||||
break
|
||||
}
|
||||
fmLines = append(fmLines, lines[i])
|
||||
i++
|
||||
}
|
||||
if i >= len(lines) {
|
||||
return "", "", fmt.Errorf("SKILL.md: front matter must end with a line containing only ---")
|
||||
}
|
||||
body = strings.Join(lines[i+1:], "\n")
|
||||
body = strings.TrimSpace(body)
|
||||
fmYAML = strings.Join(fmLines, "\n")
|
||||
return fmYAML, body, nil
|
||||
}
|
||||
|
||||
// ParseSkillMD parses SKILL.md YAML head + body.
|
||||
func ParseSkillMD(raw []byte) (*SkillManifest, string, error) {
|
||||
fmYAML, body, err := ExtractSkillMDFrontMatterYAML(raw)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
var m SkillManifest
|
||||
if err := yaml.Unmarshal([]byte(fmYAML), &m); err != nil {
|
||||
return nil, "", fmt.Errorf("SKILL.md front matter: %w", err)
|
||||
}
|
||||
return &m, body, nil
|
||||
}
|
||||
|
||||
type skillFrontMatterExport struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
License string `yaml:"license,omitempty"`
|
||||
Compatibility string `yaml:"compatibility,omitempty"`
|
||||
Metadata map[string]any `yaml:"metadata,omitempty"`
|
||||
AllowedTools string `yaml:"allowed-tools,omitempty"`
|
||||
}
|
||||
|
||||
// BuildSkillMD serializes SKILL.md per agentskills.io.
|
||||
func BuildSkillMD(m *SkillManifest, body string) ([]byte, error) {
|
||||
if m == nil {
|
||||
return nil, fmt.Errorf("nil manifest")
|
||||
}
|
||||
fm := skillFrontMatterExport{
|
||||
Name: strings.TrimSpace(m.Name),
|
||||
Description: strings.TrimSpace(m.Description),
|
||||
License: strings.TrimSpace(m.License),
|
||||
Compatibility: strings.TrimSpace(m.Compatibility),
|
||||
AllowedTools: strings.TrimSpace(m.AllowedTools),
|
||||
}
|
||||
if len(m.Metadata) > 0 {
|
||||
fm.Metadata = m.Metadata
|
||||
}
|
||||
head, err := yaml.Marshal(&fm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := strings.TrimSpace(string(head))
|
||||
out := "---\n" + s + "\n---\n\n" + strings.TrimSpace(body) + "\n"
|
||||
return []byte(out), nil
|
||||
}
|
||||
|
||||
func manifestTags(m *SkillManifest) []string {
|
||||
if m == nil || m.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
var out []string
|
||||
if raw, ok := m.Metadata["tags"]; ok {
|
||||
switch v := raw.(type) {
|
||||
case []any:
|
||||
for _, x := range v {
|
||||
if s, ok := x.(string); ok && s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
out = append(out, v...)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func versionFromMetadata(m *SkillManifest) string {
|
||||
if m == nil || m.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := m.Metadata["version"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
package skillpackage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
maxPackageFiles = 4000
|
||||
maxPackageDepth = 24
|
||||
maxScriptsDepth = 24
|
||||
defaultMaxRead = 10 << 20
|
||||
)
|
||||
|
||||
// SafeRelPath resolves rel inside root (no ..).
|
||||
func SafeRelPath(root, rel string) (string, error) {
|
||||
rel = strings.TrimSpace(rel)
|
||||
rel = filepath.ToSlash(rel)
|
||||
rel = strings.TrimPrefix(rel, "/")
|
||||
if rel == "" || rel == "." {
|
||||
return "", fmt.Errorf("empty resource path")
|
||||
}
|
||||
if strings.Contains(rel, "..") {
|
||||
return "", fmt.Errorf("invalid path %q", rel)
|
||||
}
|
||||
abs := filepath.Join(root, filepath.FromSlash(rel))
|
||||
cleanRoot := filepath.Clean(root)
|
||||
cleanAbs := filepath.Clean(abs)
|
||||
relOut, err := filepath.Rel(cleanRoot, cleanAbs)
|
||||
if err != nil || relOut == ".." || strings.HasPrefix(relOut, ".."+string(filepath.Separator)) {
|
||||
return "", fmt.Errorf("path escapes skill directory: %q", rel)
|
||||
}
|
||||
return cleanAbs, nil
|
||||
}
|
||||
|
||||
// ListPackageFiles lists files under a skill directory.
|
||||
func ListPackageFiles(skillsRoot, skillID string) ([]PackageFileInfo, error) {
|
||||
root := SkillDir(skillsRoot, skillID)
|
||||
if _, err := ResolveSKILLPath(root); err != nil {
|
||||
return nil, fmt.Errorf("skill %q: %w", skillID, err)
|
||||
}
|
||||
var out []PackageFileInfo
|
||||
err := filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rel, e := filepath.Rel(root, path)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
if rel == "." {
|
||||
return nil
|
||||
}
|
||||
depth := strings.Count(rel, string(os.PathSeparator))
|
||||
if depth > maxPackageDepth {
|
||||
if d.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if strings.HasPrefix(d.Name(), ".") {
|
||||
if d.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if len(out) >= maxPackageFiles {
|
||||
return fmt.Errorf("skill package exceeds %d files", maxPackageFiles)
|
||||
}
|
||||
fi, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out = append(out, PackageFileInfo{
|
||||
Path: filepath.ToSlash(rel),
|
||||
Size: fi.Size(),
|
||||
IsDir: d.IsDir(),
|
||||
})
|
||||
return nil
|
||||
})
|
||||
return out, err
|
||||
}
|
||||
|
||||
// ReadPackageFile reads a file relative to the skill package.
|
||||
func ReadPackageFile(skillsRoot, skillID, relPath string, maxBytes int64) ([]byte, error) {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = defaultMaxRead
|
||||
}
|
||||
root := SkillDir(skillsRoot, skillID)
|
||||
abs, err := SafeRelPath(root, relPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fi, err := os.Stat(abs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fi.IsDir() {
|
||||
return nil, fmt.Errorf("path is a directory")
|
||||
}
|
||||
if fi.Size() > maxBytes {
|
||||
return readFileHead(abs, maxBytes)
|
||||
}
|
||||
return os.ReadFile(abs)
|
||||
}
|
||||
|
||||
// WritePackageFile writes a file inside the skill package.
|
||||
func WritePackageFile(skillsRoot, skillID, relPath string, content []byte) error {
|
||||
root := SkillDir(skillsRoot, skillID)
|
||||
if _, err := ResolveSKILLPath(root); err != nil {
|
||||
return fmt.Errorf("skill %q: %w", skillID, err)
|
||||
}
|
||||
abs, err := SafeRelPath(root, relPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(abs), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(abs, content, 0644)
|
||||
}
|
||||
|
||||
func readFileHead(path string, max int64) ([]byte, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
buf := make([]byte, max)
|
||||
n, err := f.Read(buf)
|
||||
if err != nil && n == 0 {
|
||||
return nil, err
|
||||
}
|
||||
return buf[:n], nil
|
||||
}
|
||||
|
||||
func listScripts(skillsRoot, skillID string) ([]SkillScriptInfo, error) {
|
||||
root := filepath.Join(SkillDir(skillsRoot, skillID), "scripts")
|
||||
st, err := os.Stat(root)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if !st.IsDir() {
|
||||
return nil, nil
|
||||
}
|
||||
var out []SkillScriptInfo
|
||||
err = filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rel, e := filepath.Rel(root, path)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
if rel == "." {
|
||||
return nil
|
||||
}
|
||||
if d.IsDir() {
|
||||
if strings.HasPrefix(d.Name(), ".") {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if strings.Count(rel, string(os.PathSeparator)) >= maxScriptsDepth {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if strings.HasPrefix(d.Name(), ".") {
|
||||
return nil
|
||||
}
|
||||
relSkill := filepath.Join("scripts", rel)
|
||||
full := filepath.Join(root, rel)
|
||||
fi, err := os.Stat(full)
|
||||
if err != nil || fi.IsDir() {
|
||||
return nil
|
||||
}
|
||||
out = append(out, SkillScriptInfo{
|
||||
Name: filepath.Base(rel),
|
||||
RelPath: filepath.ToSlash(relSkill),
|
||||
Size: fi.Size(),
|
||||
})
|
||||
return nil
|
||||
})
|
||||
return out, err
|
||||
}
|
||||
|
||||
func countNonDirFiles(files []PackageFileInfo) int {
|
||||
n := 0
|
||||
for _, f := range files {
|
||||
if !f.IsDir && f.Path != "SKILL.md" {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package skillpackage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SkillDir returns the absolute path to a skill package directory.
|
||||
func SkillDir(skillsRoot, skillID string) string {
|
||||
return filepath.Join(skillsRoot, skillID)
|
||||
}
|
||||
|
||||
// ResolveSKILLPath returns SKILL.md path or error if missing.
|
||||
func ResolveSKILLPath(skillPath string) (string, error) {
|
||||
md := filepath.Join(skillPath, "SKILL.md")
|
||||
if st, err := os.Stat(md); err != nil || st.IsDir() {
|
||||
return "", fmt.Errorf("missing SKILL.md in %q (Agent Skills standard)", filepath.Base(skillPath))
|
||||
}
|
||||
return md, nil
|
||||
}
|
||||
|
||||
// SkillsRootFromConfig resolves cfg.SkillsDir relative to the config file directory.
|
||||
func SkillsRootFromConfig(skillsDir string, configPath string) string {
|
||||
if skillsDir == "" {
|
||||
skillsDir = "skills"
|
||||
}
|
||||
configDir := filepath.Dir(configPath)
|
||||
if !filepath.IsAbs(skillsDir) {
|
||||
skillsDir = filepath.Join(configDir, skillsDir)
|
||||
}
|
||||
return skillsDir
|
||||
}
|
||||
|
||||
// DirLister satisfies handler.SkillsManager for role UI (lists package directory names).
|
||||
type DirLister struct {
|
||||
SkillsRoot string
|
||||
}
|
||||
|
||||
// ListSkills implements the role handler dependency.
|
||||
func (d DirLister) ListSkills() ([]string, error) {
|
||||
return ListSkillDirNames(d.SkillsRoot)
|
||||
}
|
||||
|
||||
// ListSkillDirNames returns subdirectory names under skillsRoot that contain SKILL.md.
|
||||
func ListSkillDirNames(skillsRoot string) ([]string, error) {
|
||||
if _, err := os.Stat(skillsRoot); os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
entries, err := os.ReadDir(skillsRoot)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read skills directory: %w", err)
|
||||
}
|
||||
var names []string
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() || strings.HasPrefix(entry.Name(), ".") {
|
||||
continue
|
||||
}
|
||||
skillPath := filepath.Join(skillsRoot, entry.Name())
|
||||
if _, err := ResolveSKILLPath(skillPath); err == nil {
|
||||
names = append(names, entry.Name())
|
||||
}
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package skillpackage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ListSkillSummaries scans skillsRoot and returns index rows for the admin API.
|
||||
func ListSkillSummaries(skillsRoot string) ([]SkillSummary, error) {
|
||||
names, err := ListSkillDirNames(skillsRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Strings(names)
|
||||
out := make([]SkillSummary, 0, len(names))
|
||||
for _, dirName := range names {
|
||||
su, err := loadSummary(skillsRoot, dirName)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, su)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func loadSummary(skillsRoot, dirName string) (SkillSummary, error) {
|
||||
skillPath := SkillDir(skillsRoot, dirName)
|
||||
mdPath, err := ResolveSKILLPath(skillPath)
|
||||
if err != nil {
|
||||
return SkillSummary{}, err
|
||||
}
|
||||
raw, err := os.ReadFile(mdPath)
|
||||
if err != nil {
|
||||
return SkillSummary{}, err
|
||||
}
|
||||
man, _, err := ParseSkillMD(raw)
|
||||
if err != nil {
|
||||
return SkillSummary{}, err
|
||||
}
|
||||
if err := ValidateAgentSkillManifestInPackage(man, dirName); err != nil {
|
||||
return SkillSummary{}, err
|
||||
}
|
||||
fi, err := os.Stat(mdPath)
|
||||
if err != nil {
|
||||
return SkillSummary{}, err
|
||||
}
|
||||
pfiles, err := ListPackageFiles(skillsRoot, dirName)
|
||||
if err != nil {
|
||||
return SkillSummary{}, err
|
||||
}
|
||||
nFiles := 0
|
||||
for _, p := range pfiles {
|
||||
if !p.IsDir {
|
||||
nFiles++
|
||||
}
|
||||
}
|
||||
scripts, err := listScripts(skillsRoot, dirName)
|
||||
if err != nil {
|
||||
return SkillSummary{}, err
|
||||
}
|
||||
ver := versionFromMetadata(man)
|
||||
return SkillSummary{
|
||||
ID: dirName,
|
||||
DirName: dirName,
|
||||
Name: man.Name,
|
||||
Description: man.Description,
|
||||
Version: ver,
|
||||
Path: skillPath,
|
||||
Tags: manifestTags(man),
|
||||
ScriptCount: len(scripts),
|
||||
FileCount: nFiles,
|
||||
FileSize: fi.Size(),
|
||||
ModTime: fi.ModTime().Format("2006-01-02 15:04:05"),
|
||||
Progressive: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoadOptions mirrors legacy API query params for the web admin.
|
||||
type LoadOptions struct {
|
||||
Depth string // summary | full
|
||||
Section string
|
||||
}
|
||||
|
||||
// LoadSkill returns manifest + body + package listing for admin.
|
||||
func LoadSkill(skillsRoot, skillID string, opt LoadOptions) (*SkillView, error) {
|
||||
skillPath := SkillDir(skillsRoot, skillID)
|
||||
mdPath, err := ResolveSKILLPath(skillPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw, err := os.ReadFile(mdPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
man, body, err := ParseSkillMD(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ValidateAgentSkillManifestInPackage(man, skillID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pfiles, err := ListPackageFiles(skillsRoot, skillID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scripts, err := listScripts(skillsRoot, skillID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Slice(scripts, func(i, j int) bool { return scripts[i].RelPath < scripts[j].RelPath })
|
||||
sections := deriveSections(body)
|
||||
ver := versionFromMetadata(man)
|
||||
v := &SkillView{
|
||||
DirName: skillID,
|
||||
Name: man.Name,
|
||||
Description: man.Description,
|
||||
Content: body,
|
||||
Path: skillPath,
|
||||
Version: ver,
|
||||
Tags: manifestTags(man),
|
||||
Scripts: scripts,
|
||||
Sections: sections,
|
||||
PackageFiles: pfiles,
|
||||
}
|
||||
depth := strings.ToLower(strings.TrimSpace(opt.Depth))
|
||||
if depth == "" {
|
||||
depth = "full"
|
||||
}
|
||||
sec := strings.TrimSpace(opt.Section)
|
||||
if sec != "" {
|
||||
mds := splitMarkdownSections(body)
|
||||
chunk := findSectionContent(mds, sec)
|
||||
if chunk == "" {
|
||||
v.Content = fmt.Sprintf("_(section %q not found in SKILL.md for skill %s)_", sec, skillID)
|
||||
} else {
|
||||
v.Content = chunk
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
if depth == "summary" {
|
||||
v.Content = buildSummaryMarkdown(man.Name, man.Description, v.Tags, scripts, sections, body)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// ReadScriptText returns file content as string (for HTTP resource_path).
|
||||
func ReadScriptText(skillsRoot, skillID, relPath string, maxBytes int64) (string, error) {
|
||||
b, err := ReadPackageFile(skillsRoot, skillID, relPath, maxBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
// Package skillpackage provides filesystem-backed Agent Skills layout (SKILL.md + package files)
|
||||
// for HTTP admin APIs. Runtime discovery and progressive loading for agents use Eino ADK skill middleware.
|
||||
package skillpackage
|
||||
|
||||
// SkillManifest is parsed from SKILL.md front matter (https://agentskills.io/specification.md).
|
||||
type SkillManifest struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
License string `yaml:"license,omitempty"`
|
||||
Compatibility string `yaml:"compatibility,omitempty"`
|
||||
Metadata map[string]any `yaml:"metadata,omitempty"`
|
||||
AllowedTools string `yaml:"allowed-tools,omitempty"`
|
||||
}
|
||||
|
||||
// SkillSummary is API metadata for one skill directory.
|
||||
type SkillSummary struct {
|
||||
ID string `json:"id"`
|
||||
DirName string `json:"dir_name"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Version string `json:"version"`
|
||||
Path string `json:"path"`
|
||||
Tags []string `json:"tags"`
|
||||
Triggers []string `json:"triggers,omitempty"`
|
||||
ScriptCount int `json:"script_count"`
|
||||
FileCount int `json:"file_count"`
|
||||
FileSize int64 `json:"file_size"`
|
||||
ModTime string `json:"mod_time"`
|
||||
Progressive bool `json:"progressive"`
|
||||
}
|
||||
|
||||
// SkillScriptInfo describes a file under scripts/.
|
||||
type SkillScriptInfo struct {
|
||||
Name string `json:"name"`
|
||||
RelPath string `json:"rel_path"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// SkillSection is derived from ## headings in SKILL.md.
|
||||
type SkillSection struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Heading string `json:"heading"`
|
||||
Level int `json:"level"`
|
||||
}
|
||||
|
||||
// PackageFileInfo describes one file inside a package.
|
||||
type PackageFileInfo struct {
|
||||
Path string `json:"path"`
|
||||
Size int64 `json:"size"`
|
||||
IsDir bool `json:"is_dir,omitempty"`
|
||||
}
|
||||
|
||||
// SkillView is a loaded package for admin / API.
|
||||
type SkillView struct {
|
||||
DirName string `json:"dir_name"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content"`
|
||||
Path string `json:"path"`
|
||||
Version string `json:"version"`
|
||||
Tags []string `json:"tags"`
|
||||
Scripts []SkillScriptInfo `json:"scripts,omitempty"`
|
||||
Sections []SkillSection `json:"sections,omitempty"`
|
||||
PackageFiles []PackageFileInfo `json:"package_files,omitempty"`
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package skillpackage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var agentSkillsSpecFrontMatterKeys = map[string]struct{}{
|
||||
"name": {}, "description": {}, "license": {}, "compatibility": {},
|
||||
"metadata": {}, "allowed-tools": {},
|
||||
}
|
||||
|
||||
// ValidateAgentSkillManifest enforces Agent Skills rules for name and description.
|
||||
func ValidateAgentSkillManifest(m *SkillManifest) error {
|
||||
if m == nil {
|
||||
return fmt.Errorf("skill manifest is nil")
|
||||
}
|
||||
if strings.TrimSpace(m.Name) == "" {
|
||||
return fmt.Errorf("SKILL.md front matter: name is required")
|
||||
}
|
||||
if strings.TrimSpace(m.Description) == "" {
|
||||
return fmt.Errorf("SKILL.md front matter: description is required")
|
||||
}
|
||||
if utf8.RuneCountInString(m.Name) > 64 {
|
||||
return fmt.Errorf("name exceeds 64 characters (Agent Skills limit)")
|
||||
}
|
||||
if utf8.RuneCountInString(m.Description) > 1024 {
|
||||
return fmt.Errorf("description exceeds 1024 characters (Agent Skills limit)")
|
||||
}
|
||||
if m.Name != strings.ToLower(m.Name) {
|
||||
return fmt.Errorf("name must be lowercase (Agent Skills)")
|
||||
}
|
||||
for _, r := range m.Name {
|
||||
if !((r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-') {
|
||||
return fmt.Errorf("name must contain only lowercase letters, numbers, hyphens (Agent Skills)")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(m.Name, "-") || strings.HasSuffix(m.Name, "-") {
|
||||
return fmt.Errorf("name must not start or end with a hyphen (Agent Skills spec)")
|
||||
}
|
||||
if strings.Contains(m.Name, "--") {
|
||||
return fmt.Errorf("name must not contain consecutive hyphens (Agent Skills spec)")
|
||||
}
|
||||
lname := strings.ToLower(m.Name)
|
||||
if strings.Contains(lname, "anthropic") || strings.Contains(lname, "claude") {
|
||||
return fmt.Errorf("name must not contain reserved words anthropic or claude")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAgentSkillManifestInPackage checks manifest and that name matches package directory.
|
||||
func ValidateAgentSkillManifestInPackage(m *SkillManifest, packageDirName string) error {
|
||||
if err := ValidateAgentSkillManifest(m); err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.TrimSpace(packageDirName) == "" {
|
||||
return nil
|
||||
}
|
||||
if m.Name != packageDirName {
|
||||
return fmt.Errorf("SKILL.md name %q must match directory name %q (Agent Skills spec)", m.Name, packageDirName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateOfficialFrontMatterTopLevelKeys rejects keys not in the open spec.
|
||||
func ValidateOfficialFrontMatterTopLevelKeys(fmYAML string) error {
|
||||
var top map[string]interface{}
|
||||
if err := yaml.Unmarshal([]byte(fmYAML), &top); err != nil {
|
||||
return fmt.Errorf("SKILL.md front matter: %w", err)
|
||||
}
|
||||
for k := range top {
|
||||
if _, ok := agentSkillsSpecFrontMatterKeys[k]; !ok {
|
||||
return fmt.Errorf("SKILL.md front matter: unsupported key %q (allowed: name, description, license, compatibility, metadata, allowed-tools — see https://agentskills.io/specification.md)", k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSkillMDPackage validates SKILL.md bytes for writes.
|
||||
func ValidateSkillMDPackage(raw []byte, packageDirName string) error {
|
||||
fmYAML, body, err := ExtractSkillMDFrontMatterYAML(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateOfficialFrontMatterTopLevelKeys(fmYAML); err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.TrimSpace(body) == "" {
|
||||
return fmt.Errorf("SKILL.md: markdown body after front matter must not be empty")
|
||||
}
|
||||
var fm SkillManifest
|
||||
if err := yaml.Unmarshal([]byte(fmYAML), &fm); err != nil {
|
||||
return fmt.Errorf("SKILL.md front matter: %w", err)
|
||||
}
|
||||
if c := strings.TrimSpace(fm.Compatibility); c != "" && utf8.RuneCountInString(c) > 500 {
|
||||
return fmt.Errorf("compatibility exceeds 500 characters (Agent Skills spec)")
|
||||
}
|
||||
return ValidateAgentSkillManifestInPackage(&fm, packageDirName)
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ResultStorage 结果存储接口
|
||||
type ResultStorage interface {
|
||||
// SaveResult 保存工具执行结果
|
||||
SaveResult(executionID string, toolName string, result string) error
|
||||
|
||||
// GetResult 获取完整结果
|
||||
GetResult(executionID string) (string, error)
|
||||
|
||||
// GetResultPage 分页获取结果
|
||||
GetResultPage(executionID string, page int, limit int) (*ResultPage, error)
|
||||
|
||||
// SearchResult 搜索结果
|
||||
// useRegex: 如果为 true,将 keyword 作为正则表达式使用;如果为 false,使用简单的字符串包含匹配
|
||||
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
|
||||
|
||||
// FilterResult 过滤结果
|
||||
// useRegex: 如果为 true,将 filter 作为正则表达式使用;如果为 false,使用简单的字符串包含匹配
|
||||
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
|
||||
|
||||
// GetResultMetadata 获取结果元信息
|
||||
GetResultMetadata(executionID string) (*ResultMetadata, error)
|
||||
|
||||
// GetResultPath 获取结果文件路径
|
||||
GetResultPath(executionID string) string
|
||||
|
||||
// DeleteResult 删除结果
|
||||
DeleteResult(executionID string) error
|
||||
}
|
||||
|
||||
// ResultPage 分页结果
|
||||
type ResultPage struct {
|
||||
Lines []string `json:"lines"`
|
||||
Page int `json:"page"`
|
||||
Limit int `json:"limit"`
|
||||
TotalLines int `json:"total_lines"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// ResultMetadata 结果元信息
|
||||
type ResultMetadata struct {
|
||||
ExecutionID string `json:"execution_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
TotalSize int `json:"total_size"`
|
||||
TotalLines int `json:"total_lines"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// FileResultStorage 基于文件的结果存储实现
|
||||
type FileResultStorage struct {
|
||||
baseDir string
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewFileResultStorage 创建新的文件结果存储
|
||||
func NewFileResultStorage(baseDir string, logger *zap.Logger) (*FileResultStorage, error) {
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(baseDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建存储目录失败: %w", err)
|
||||
}
|
||||
|
||||
return &FileResultStorage{
|
||||
baseDir: baseDir,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getResultPath 获取结果文件路径
|
||||
func (s *FileResultStorage) getResultPath(executionID string) string {
|
||||
return filepath.Join(s.baseDir, executionID+".txt")
|
||||
}
|
||||
|
||||
// getMetadataPath 获取元数据文件路径
|
||||
func (s *FileResultStorage) getMetadataPath(executionID string) string {
|
||||
return filepath.Join(s.baseDir, executionID+".meta.json")
|
||||
}
|
||||
|
||||
// SaveResult 保存工具执行结果
|
||||
func (s *FileResultStorage) SaveResult(executionID string, toolName string, result string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 保存结果文件
|
||||
resultPath := s.getResultPath(executionID)
|
||||
if err := os.WriteFile(resultPath, []byte(result), 0644); err != nil {
|
||||
return fmt.Errorf("保存结果文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 计算统计信息
|
||||
lines := strings.Split(result, "\n")
|
||||
metadata := &ResultMetadata{
|
||||
ExecutionID: executionID,
|
||||
ToolName: toolName,
|
||||
TotalSize: len(result),
|
||||
TotalLines: len(lines),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 保存元数据
|
||||
metadataPath := s.getMetadataPath(executionID)
|
||||
metadataJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化元数据失败: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(metadataPath, metadataJSON, 0644); err != nil {
|
||||
return fmt.Errorf("保存元数据文件失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("保存工具执行结果",
|
||||
zap.String("executionID", executionID),
|
||||
zap.String("toolName", toolName),
|
||||
zap.Int("size", len(result)),
|
||||
zap.Int("lines", len(lines)),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetResult 获取完整结果
|
||||
func (s *FileResultStorage) GetResult(executionID string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
resultPath := s.getResultPath(executionID)
|
||||
data, err := os.ReadFile(resultPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("结果不存在: %s", executionID)
|
||||
}
|
||||
return "", fmt.Errorf("读取结果文件失败: %w", err)
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// GetResultMetadata 获取结果元信息
|
||||
func (s *FileResultStorage) GetResultMetadata(executionID string) (*ResultMetadata, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
metadataPath := s.getMetadataPath(executionID)
|
||||
data, err := os.ReadFile(metadataPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("结果不存在: %s", executionID)
|
||||
}
|
||||
return nil, fmt.Errorf("读取元数据文件失败: %w", err)
|
||||
}
|
||||
|
||||
var metadata ResultMetadata
|
||||
if err := json.Unmarshal(data, &metadata); err != nil {
|
||||
return nil, fmt.Errorf("解析元数据失败: %w", err)
|
||||
}
|
||||
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// GetResultPage 分页获取结果
|
||||
func (s *FileResultStorage) GetResultPage(executionID string, page int, limit int) (*ResultPage, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// 获取完整结果
|
||||
result, err := s.GetResult(executionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 分割为行
|
||||
lines := strings.Split(result, "\n")
|
||||
totalLines := len(lines)
|
||||
|
||||
// 计算分页
|
||||
totalPages := (totalLines + limit - 1) / limit
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if page > totalPages && totalPages > 0 {
|
||||
page = totalPages
|
||||
}
|
||||
|
||||
// 计算起始和结束索引
|
||||
start := (page - 1) * limit
|
||||
end := start + limit
|
||||
if end > totalLines {
|
||||
end = totalLines
|
||||
}
|
||||
|
||||
// 提取指定页的行
|
||||
var pageLines []string
|
||||
if start < totalLines {
|
||||
pageLines = lines[start:end]
|
||||
} else {
|
||||
pageLines = []string{}
|
||||
}
|
||||
|
||||
return &ResultPage{
|
||||
Lines: pageLines,
|
||||
Page: page,
|
||||
Limit: limit,
|
||||
TotalLines: totalLines,
|
||||
TotalPages: totalPages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SearchResult 搜索结果
|
||||
func (s *FileResultStorage) SearchResult(executionID string, keyword string, useRegex bool) ([]string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// 获取完整结果
|
||||
result, err := s.GetResult(executionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果使用正则表达式,先编译正则
|
||||
var regex *regexp.Regexp
|
||||
if useRegex {
|
||||
compiledRegex, err := regexp.Compile(keyword)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无效的正则表达式: %w", err)
|
||||
}
|
||||
regex = compiledRegex
|
||||
}
|
||||
|
||||
// 分割为行并搜索
|
||||
lines := strings.Split(result, "\n")
|
||||
var matchedLines []string
|
||||
|
||||
for _, line := range lines {
|
||||
var matched bool
|
||||
if useRegex {
|
||||
matched = regex.MatchString(line)
|
||||
} else {
|
||||
matched = strings.Contains(line, keyword)
|
||||
}
|
||||
|
||||
if matched {
|
||||
matchedLines = append(matchedLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
return matchedLines, nil
|
||||
}
|
||||
|
||||
// FilterResult 过滤结果
|
||||
func (s *FileResultStorage) FilterResult(executionID string, filter string, useRegex bool) ([]string, error) {
|
||||
// 过滤和搜索逻辑相同,都是查找包含关键词的行
|
||||
return s.SearchResult(executionID, filter, useRegex)
|
||||
}
|
||||
|
||||
// GetResultPath 获取结果文件路径
|
||||
func (s *FileResultStorage) GetResultPath(executionID string) string {
|
||||
return s.getResultPath(executionID)
|
||||
}
|
||||
|
||||
// DeleteResult 删除结果
|
||||
func (s *FileResultStorage) DeleteResult(executionID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
resultPath := s.getResultPath(executionID)
|
||||
metadataPath := s.getMetadataPath(executionID)
|
||||
|
||||
// 删除结果文件
|
||||
if err := os.Remove(resultPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("删除结果文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除元数据文件
|
||||
if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("删除元数据文件失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("删除工具执行结果",
|
||||
zap.String("executionID", executionID),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,453 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// setupTestStorage 创建测试用的存储实例
|
||||
func setupTestStorage(t *testing.T) (*FileResultStorage, string) {
|
||||
tmpDir := filepath.Join(os.TempDir(), "test_result_storage_"+time.Now().Format("20060102_150405"))
|
||||
logger := zap.NewNop()
|
||||
|
||||
storage, err := NewFileResultStorage(tmpDir, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("创建测试存储失败: %v", err)
|
||||
}
|
||||
|
||||
return storage, tmpDir
|
||||
}
|
||||
|
||||
// cleanupTestStorage 清理测试数据
|
||||
func cleanupTestStorage(t *testing.T, tmpDir string) {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Logf("清理测试目录失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileResultStorage(t *testing.T) {
|
||||
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
|
||||
defer cleanupTestStorage(t, tmpDir)
|
||||
|
||||
logger := zap.NewNop()
|
||||
storage, err := NewFileResultStorage(tmpDir, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("创建存储失败: %v", err)
|
||||
}
|
||||
|
||||
if storage == nil {
|
||||
t.Fatal("存储实例为nil")
|
||||
}
|
||||
|
||||
// 验证目录已创建
|
||||
if _, err := os.Stat(tmpDir); os.IsNotExist(err) {
|
||||
t.Fatal("存储目录未创建")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileResultStorage_SaveResult(t *testing.T) {
|
||||
storage, tmpDir := setupTestStorage(t)
|
||||
defer cleanupTestStorage(t, tmpDir)
|
||||
|
||||
executionID := "test_exec_001"
|
||||
toolName := "nmap_scan"
|
||||
result := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
||||
|
||||
err := storage.SaveResult(executionID, toolName, result)
|
||||
if err != nil {
|
||||
t.Fatalf("保存结果失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证结果文件存在
|
||||
resultPath := filepath.Join(tmpDir, executionID+".txt")
|
||||
if _, err := os.Stat(resultPath); os.IsNotExist(err) {
|
||||
t.Fatal("结果文件未创建")
|
||||
}
|
||||
|
||||
// 验证元数据文件存在
|
||||
metadataPath := filepath.Join(tmpDir, executionID+".meta.json")
|
||||
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
|
||||
t.Fatal("元数据文件未创建")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileResultStorage_GetResult(t *testing.T) {
|
||||
storage, tmpDir := setupTestStorage(t)
|
||||
defer cleanupTestStorage(t, tmpDir)
|
||||
|
||||
executionID := "test_exec_002"
|
||||
toolName := "test_tool"
|
||||
expectedResult := "Test result content\nLine 2\nLine 3"
|
||||
|
||||
// 先保存结果
|
||||
err := storage.SaveResult(executionID, toolName, expectedResult)
|
||||
if err != nil {
|
||||
t.Fatalf("保存结果失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取结果
|
||||
result, err := storage.GetResult(executionID)
|
||||
if err != nil {
|
||||
t.Fatalf("获取结果失败: %v", err)
|
||||
}
|
||||
|
||||
if result != expectedResult {
|
||||
t.Errorf("结果不匹配。期望: %q, 实际: %q", expectedResult, result)
|
||||
}
|
||||
|
||||
// 测试不存在的执行ID
|
||||
_, err = storage.GetResult("nonexistent_id")
|
||||
if err == nil {
|
||||
t.Fatal("应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileResultStorage_GetResultMetadata(t *testing.T) {
|
||||
storage, tmpDir := setupTestStorage(t)
|
||||
defer cleanupTestStorage(t, tmpDir)
|
||||
|
||||
executionID := "test_exec_003"
|
||||
toolName := "test_tool"
|
||||
result := "Line 1\nLine 2\nLine 3"
|
||||
|
||||
// 保存结果
|
||||
err := storage.SaveResult(executionID, toolName, result)
|
||||
if err != nil {
|
||||
t.Fatalf("保存结果失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取元数据
|
||||
metadata, err := storage.GetResultMetadata(executionID)
|
||||
if err != nil {
|
||||
t.Fatalf("获取元数据失败: %v", err)
|
||||
}
|
||||
|
||||
if metadata.ExecutionID != executionID {
|
||||
t.Errorf("执行ID不匹配。期望: %s, 实际: %s", executionID, metadata.ExecutionID)
|
||||
}
|
||||
|
||||
if metadata.ToolName != toolName {
|
||||
t.Errorf("工具名称不匹配。期望: %s, 实际: %s", toolName, metadata.ToolName)
|
||||
}
|
||||
|
||||
if metadata.TotalSize != len(result) {
|
||||
t.Errorf("总大小不匹配。期望: %d, 实际: %d", len(result), metadata.TotalSize)
|
||||
}
|
||||
|
||||
expectedLines := len(strings.Split(result, "\n"))
|
||||
if metadata.TotalLines != expectedLines {
|
||||
t.Errorf("总行数不匹配。期望: %d, 实际: %d", expectedLines, metadata.TotalLines)
|
||||
}
|
||||
|
||||
// 验证创建时间在合理范围内
|
||||
now := time.Now()
|
||||
if metadata.CreatedAt.After(now) || metadata.CreatedAt.Before(now.Add(-time.Second)) {
|
||||
t.Errorf("创建时间不在合理范围内: %v", metadata.CreatedAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileResultStorage_GetResultPage(t *testing.T) {
|
||||
storage, tmpDir := setupTestStorage(t)
|
||||
defer cleanupTestStorage(t, tmpDir)
|
||||
|
||||
executionID := "test_exec_004"
|
||||
toolName := "test_tool"
|
||||
// 创建包含10行的结果
|
||||
lines := make([]string, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
lines[i] = fmt.Sprintf("Line %d", i+1)
|
||||
}
|
||||
result := strings.Join(lines, "\n")
|
||||
|
||||
// 保存结果
|
||||
err := storage.SaveResult(executionID, toolName, result)
|
||||
if err != nil {
|
||||
t.Fatalf("保存结果失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试第一页(每页3行)
|
||||
page, err := storage.GetResultPage(executionID, 1, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("获取第一页失败: %v", err)
|
||||
}
|
||||
|
||||
if page.Page != 1 {
|
||||
t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page)
|
||||
}
|
||||
|
||||
if page.Limit != 3 {
|
||||
t.Errorf("每页行数不匹配。期望: 3, 实际: %d", page.Limit)
|
||||
}
|
||||
|
||||
if page.TotalLines != 10 {
|
||||
t.Errorf("总行数不匹配。期望: 10, 实际: %d", page.TotalLines)
|
||||
}
|
||||
|
||||
if page.TotalPages != 4 {
|
||||
t.Errorf("总页数不匹配。期望: 4, 实际: %d", page.TotalPages)
|
||||
}
|
||||
|
||||
if len(page.Lines) != 3 {
|
||||
t.Errorf("第一页行数不匹配。期望: 3, 实际: %d", len(page.Lines))
|
||||
}
|
||||
|
||||
if page.Lines[0] != "Line 1" {
|
||||
t.Errorf("第一行内容不匹配。期望: Line 1, 实际: %s", page.Lines[0])
|
||||
}
|
||||
|
||||
// 测试第二页
|
||||
page2, err := storage.GetResultPage(executionID, 2, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("获取第二页失败: %v", err)
|
||||
}
|
||||
|
||||
if len(page2.Lines) != 3 {
|
||||
t.Errorf("第二页行数不匹配。期望: 3, 实际: %d", len(page2.Lines))
|
||||
}
|
||||
|
||||
if page2.Lines[0] != "Line 4" {
|
||||
t.Errorf("第二页第一行内容不匹配。期望: Line 4, 实际: %s", page2.Lines[0])
|
||||
}
|
||||
|
||||
// 测试最后一页(可能不满一页)
|
||||
page4, err := storage.GetResultPage(executionID, 4, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("获取第四页失败: %v", err)
|
||||
}
|
||||
|
||||
if len(page4.Lines) != 1 {
|
||||
t.Errorf("第四页行数不匹配。期望: 1, 实际: %d", len(page4.Lines))
|
||||
}
|
||||
|
||||
// 测试超出范围的页码(应该返回最后一页)
|
||||
page5, err := storage.GetResultPage(executionID, 5, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("获取第五页失败: %v", err)
|
||||
}
|
||||
|
||||
// 超出范围的页码会被修正为最后一页,所以应该返回最后一页的内容
|
||||
if page5.Page != 4 {
|
||||
t.Errorf("超出范围的页码应该被修正为最后一页。期望: 4, 实际: %d", page5.Page)
|
||||
}
|
||||
|
||||
// 最后一页应该只有1行
|
||||
if len(page5.Lines) != 1 {
|
||||
t.Errorf("最后一页应该只有1行。实际: %d行", len(page5.Lines))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileResultStorage_SearchResult(t *testing.T) {
|
||||
storage, tmpDir := setupTestStorage(t)
|
||||
defer cleanupTestStorage(t, tmpDir)
|
||||
|
||||
executionID := "test_exec_005"
|
||||
toolName := "test_tool"
|
||||
result := "Line 1: error occurred\nLine 2: success\nLine 3: error again\nLine 4: ok"
|
||||
|
||||
// 保存结果
|
||||
err := storage.SaveResult(executionID, toolName, result)
|
||||
if err != nil {
|
||||
t.Fatalf("保存结果失败: %v", err)
|
||||
}
|
||||
|
||||
// 搜索包含"error"的行(简单字符串匹配)
|
||||
matchedLines, err := storage.SearchResult(executionID, "error", false)
|
||||
if err != nil {
|
||||
t.Fatalf("搜索失败: %v", err)
|
||||
}
|
||||
|
||||
if len(matchedLines) != 2 {
|
||||
t.Errorf("搜索结果数量不匹配。期望: 2, 实际: %d", len(matchedLines))
|
||||
}
|
||||
|
||||
// 验证搜索结果内容
|
||||
for i, line := range matchedLines {
|
||||
if !strings.Contains(line, "error") {
|
||||
t.Errorf("搜索结果第%d行不包含关键词: %s", i+1, line)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试搜索不存在的关键词
|
||||
noMatch, err := storage.SearchResult(executionID, "nonexistent", false)
|
||||
if err != nil {
|
||||
t.Fatalf("搜索失败: %v", err)
|
||||
}
|
||||
|
||||
if len(noMatch) != 0 {
|
||||
t.Errorf("搜索不存在的关键词应该返回空结果。实际: %d行", len(noMatch))
|
||||
}
|
||||
|
||||
// 测试正则表达式搜索
|
||||
regexMatched, err := storage.SearchResult(executionID, "error.*again", true)
|
||||
if err != nil {
|
||||
t.Fatalf("正则搜索失败: %v", err)
|
||||
}
|
||||
|
||||
if len(regexMatched) != 1 {
|
||||
t.Errorf("正则搜索结果数量不匹配。期望: 1, 实际: %d", len(regexMatched))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileResultStorage_FilterResult(t *testing.T) {
|
||||
storage, tmpDir := setupTestStorage(t)
|
||||
defer cleanupTestStorage(t, tmpDir)
|
||||
|
||||
executionID := "test_exec_006"
|
||||
toolName := "test_tool"
|
||||
result := "Line 1: warning message\nLine 2: info message\nLine 3: warning again\nLine 4: debug message"
|
||||
|
||||
// 保存结果
|
||||
err := storage.SaveResult(executionID, toolName, result)
|
||||
if err != nil {
|
||||
t.Fatalf("保存结果失败: %v", err)
|
||||
}
|
||||
|
||||
// 过滤包含"warning"的行(简单字符串匹配)
|
||||
filteredLines, err := storage.FilterResult(executionID, "warning", false)
|
||||
if err != nil {
|
||||
t.Fatalf("过滤失败: %v", err)
|
||||
}
|
||||
|
||||
if len(filteredLines) != 2 {
|
||||
t.Errorf("过滤结果数量不匹配。期望: 2, 实际: %d", len(filteredLines))
|
||||
}
|
||||
|
||||
// 验证过滤结果内容
|
||||
for i, line := range filteredLines {
|
||||
if !strings.Contains(line, "warning") {
|
||||
t.Errorf("过滤结果第%d行不包含关键词: %s", i+1, line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileResultStorage_DeleteResult(t *testing.T) {
|
||||
storage, tmpDir := setupTestStorage(t)
|
||||
defer cleanupTestStorage(t, tmpDir)
|
||||
|
||||
executionID := "test_exec_007"
|
||||
toolName := "test_tool"
|
||||
result := "Test result"
|
||||
|
||||
// 保存结果
|
||||
err := storage.SaveResult(executionID, toolName, result)
|
||||
if err != nil {
|
||||
t.Fatalf("保存结果失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证文件存在
|
||||
resultPath := filepath.Join(tmpDir, executionID+".txt")
|
||||
metadataPath := filepath.Join(tmpDir, executionID+".meta.json")
|
||||
|
||||
if _, err := os.Stat(resultPath); os.IsNotExist(err) {
|
||||
t.Fatal("结果文件不存在")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
|
||||
t.Fatal("元数据文件不存在")
|
||||
}
|
||||
|
||||
// 删除结果
|
||||
err = storage.DeleteResult(executionID)
|
||||
if err != nil {
|
||||
t.Fatalf("删除结果失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证文件已删除
|
||||
if _, err := os.Stat(resultPath); !os.IsNotExist(err) {
|
||||
t.Fatal("结果文件未被删除")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(metadataPath); !os.IsNotExist(err) {
|
||||
t.Fatal("元数据文件未被删除")
|
||||
}
|
||||
|
||||
// 测试删除不存在的执行ID(应该不报错)
|
||||
err = storage.DeleteResult("nonexistent_id")
|
||||
if err != nil {
|
||||
t.Errorf("删除不存在的执行ID不应该报错: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileResultStorage_ConcurrentAccess(t *testing.T) {
|
||||
storage, tmpDir := setupTestStorage(t)
|
||||
defer cleanupTestStorage(t, tmpDir)
|
||||
|
||||
// 并发保存多个结果
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
executionID := fmt.Sprintf("test_exec_%d", id)
|
||||
toolName := "test_tool"
|
||||
result := fmt.Sprintf("Result %d\nLine 2\nLine 3", id)
|
||||
|
||||
err := storage.SaveResult(executionID, toolName, result)
|
||||
if err != nil {
|
||||
t.Errorf("并发保存失败 (ID: %s): %v", executionID, err)
|
||||
}
|
||||
|
||||
// 并发读取
|
||||
_, err = storage.GetResult(executionID)
|
||||
if err != nil {
|
||||
t.Errorf("并发读取失败 (ID: %s): %v", executionID, err)
|
||||
}
|
||||
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 等待所有goroutine完成
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileResultStorage_LargeResult(t *testing.T) {
|
||||
storage, tmpDir := setupTestStorage(t)
|
||||
defer cleanupTestStorage(t, tmpDir)
|
||||
|
||||
executionID := "test_exec_large"
|
||||
toolName := "test_tool"
|
||||
|
||||
// 创建大结果(1000行)
|
||||
lines := make([]string, 1000)
|
||||
for i := 0; i < 1000; i++ {
|
||||
lines[i] = fmt.Sprintf("Line %d: This is a test line with some content", i+1)
|
||||
}
|
||||
result := strings.Join(lines, "\n")
|
||||
|
||||
// 保存大结果
|
||||
err := storage.SaveResult(executionID, toolName, result)
|
||||
if err != nil {
|
||||
t.Fatalf("保存大结果失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证元数据
|
||||
metadata, err := storage.GetResultMetadata(executionID)
|
||||
if err != nil {
|
||||
t.Fatalf("获取元数据失败: %v", err)
|
||||
}
|
||||
|
||||
if metadata.TotalLines != 1000 {
|
||||
t.Errorf("总行数不匹配。期望: 1000, 实际: %d", metadata.TotalLines)
|
||||
}
|
||||
|
||||
// 测试分页查询大结果
|
||||
page, err := storage.GetResultPage(executionID, 1, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("获取第一页失败: %v", err)
|
||||
}
|
||||
|
||||
if page.TotalPages != 10 {
|
||||
t.Errorf("总页数不匹配。期望: 10, 实际: %d", page.TotalPages)
|
||||
}
|
||||
|
||||
if len(page.Lines) != 100 {
|
||||
t.Errorf("第一页行数不匹配。期望: 100, 实际: %d", len(page.Lines))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user