mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-02 04:21:41 +02:00
Add files via upload
This commit is contained in:
+35
-8
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
|
||||
@@ -302,16 +303,16 @@ type ProgressCallback func(eventType, message string, data interface{})
|
||||
|
||||
// AgentLoop 执行Agent循环
|
||||
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil)
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil)
|
||||
}
|
||||
|
||||
// AgentLoopWithConversationID 执行Agent循环(带对话ID)
|
||||
func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) {
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil)
|
||||
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil)
|
||||
}
|
||||
|
||||
// AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID)
|
||||
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback) (*AgentLoopResult, error) {
|
||||
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string) (*AgentLoopResult, error) {
|
||||
// 设置当前对话ID
|
||||
a.mu.Lock()
|
||||
a.currentConversationID = conversationID
|
||||
@@ -401,8 +402,8 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
|
||||
|
||||
漏洞记录要求:
|
||||
- 当你发现有效漏洞时,必须使用 record_vulnerability 工具记录漏洞详情
|
||||
- 漏洞记录应包含:标题、描述、严重程度、类型、目标、证明(POC)、影响和修复建议
|
||||
- 当你发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 工具记录漏洞详情
|
||||
` + `- 漏洞记录应包含:标题、描述、严重程度、类型、目标、证明(POC)、影响和修复建议
|
||||
- 严重程度评估标准:
|
||||
* critical(严重):可导致系统完全被控制、数据泄露、服务中断等
|
||||
* high(高):可导致敏感信息泄露、权限提升、重要功能被绕过等
|
||||
@@ -512,7 +513,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
}
|
||||
|
||||
// 获取可用工具
|
||||
tools := a.getAvailableTools()
|
||||
tools := a.getAvailableTools(roleTools)
|
||||
|
||||
// 记录当前上下文的Token用量,展示压缩器运行状态
|
||||
if a.memoryCompressor != nil {
|
||||
@@ -837,13 +838,29 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
|
||||
// getAvailableTools 获取可用工具
|
||||
// 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗
|
||||
func (a *Agent) getAvailableTools() []Tool {
|
||||
// roleTools: 角色配置的工具列表(toolKey格式),如果为空或nil,则使用所有工具(默认角色)
|
||||
func (a *Agent) getAvailableTools(roleTools []string) []Tool {
|
||||
// 构建角色工具集合(用于快速查找)
|
||||
roleToolSet := make(map[string]bool)
|
||||
if len(roleTools) > 0 {
|
||||
for _, toolKey := range roleTools {
|
||||
roleToolSet[toolKey] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 从MCP服务器获取所有已注册的内部工具
|
||||
mcpTools := a.mcpServer.GetAllTools()
|
||||
|
||||
// 转换为OpenAI格式的工具定义
|
||||
tools := make([]Tool, 0, len(mcpTools))
|
||||
for _, mcpTool := range mcpTools {
|
||||
// 如果指定了角色工具列表,只添加在列表中的工具
|
||||
if len(roleToolSet) > 0 {
|
||||
toolKey := mcpTool.Name // 内置工具使用工具名称作为key
|
||||
if !roleToolSet[toolKey] {
|
||||
continue // 不在角色工具列表中,跳过
|
||||
}
|
||||
}
|
||||
// 使用简短描述(如果存在),否则使用详细描述
|
||||
description := mcpTool.ShortDescription
|
||||
if description == "" {
|
||||
@@ -883,6 +900,16 @@ func (a *Agent) getAvailableTools() []Tool {
|
||||
|
||||
// 将外部MCP工具添加到工具列表(只添加启用的工具)
|
||||
for _, externalTool := range externalTools {
|
||||
// 外部工具使用 "mcpName::toolName" 作为toolKey
|
||||
externalToolKey := externalTool.Name
|
||||
|
||||
// 如果指定了角色工具列表,只添加在列表中的工具
|
||||
if len(roleToolSet) > 0 {
|
||||
if !roleToolSet[externalToolKey] {
|
||||
continue // 不在角色工具列表中,跳过
|
||||
}
|
||||
}
|
||||
|
||||
// 解析工具名称:mcpName::toolName
|
||||
var mcpName, actualToolName string
|
||||
if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
|
||||
@@ -1136,7 +1163,7 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
|
||||
)
|
||||
|
||||
// 如果是record_vulnerability工具,自动添加conversation_id
|
||||
if toolName == "record_vulnerability" {
|
||||
if toolName == builtin.ToolRecordVulnerability {
|
||||
a.mu.RLock()
|
||||
conversationID := a.currentConversationID
|
||||
a.mu.RUnlock()
|
||||
|
||||
+13
-2
@@ -16,6 +16,7 @@ import (
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/logger"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
@@ -278,7 +279,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
}
|
||||
|
||||
// 创建处理器
|
||||
agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
|
||||
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
|
||||
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
|
||||
if knowledgeManager != nil {
|
||||
agentHandler.SetKnowledgeManager(knowledgeManager)
|
||||
@@ -292,6 +293,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
|
||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
||||
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
||||
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
|
||||
|
||||
// 创建 App 实例(部分字段稍后填充)
|
||||
app := &App{
|
||||
@@ -368,6 +370,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
attackChainHandler,
|
||||
app, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||
vulnerabilityHandler,
|
||||
roleHandler,
|
||||
mcpServer,
|
||||
authManager,
|
||||
)
|
||||
@@ -428,6 +431,7 @@ func setupRoutes(
|
||||
attackChainHandler *handler.AttackChainHandler,
|
||||
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
|
||||
vulnerabilityHandler *handler.VulnerabilityHandler,
|
||||
roleHandler *handler.RoleHandler,
|
||||
mcpServer *mcp.Server,
|
||||
authManager *security.AuthManager,
|
||||
) {
|
||||
@@ -653,6 +657,13 @@ func setupRoutes(
|
||||
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
|
||||
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
|
||||
|
||||
// 角色管理
|
||||
protected.GET("/roles", roleHandler.GetRoles)
|
||||
protected.GET("/roles/:name", roleHandler.GetRole)
|
||||
protected.POST("/roles", roleHandler.CreateRole)
|
||||
protected.PUT("/roles/:name", roleHandler.UpdateRole)
|
||||
protected.DELETE("/roles/:name", roleHandler.DeleteRole)
|
||||
|
||||
// MCP端点
|
||||
protected.POST("/mcp", func(c *gin.Context) {
|
||||
mcpServer.HandleHTTP(c.Writer, c.Request)
|
||||
@@ -672,7 +683,7 @@ func setupRoutes(
|
||||
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
|
||||
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
|
||||
tool := mcp.Tool{
|
||||
Name: "record_vulnerability",
|
||||
Name: builtin.ToolRecordVulnerability,
|
||||
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
|
||||
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
|
||||
InputSchema: map[string]interface{}{
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -22,6 +23,8 @@ type Config struct {
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
|
||||
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@@ -207,6 +210,29 @@ func Load(path string) (*Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 从角色目录加载角色配置
|
||||
if cfg.RolesDir != "" {
|
||||
configDir := filepath.Dir(path)
|
||||
rolesDir := cfg.RolesDir
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
roles, err := LoadRolesFromDir(rolesDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从角色目录加载角色配置失败: %w", err)
|
||||
}
|
||||
|
||||
cfg.Roles = roles
|
||||
} else {
|
||||
// 如果未配置 roles_dir,初始化为空 map
|
||||
if cfg.Roles == nil {
|
||||
cfg.Roles = make(map[string]RoleConfig)
|
||||
}
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
@@ -375,6 +401,98 @@ func LoadToolFromFile(path string) (*ToolConfig, error) {
|
||||
return &tool, nil
|
||||
}
|
||||
|
||||
// LoadRolesFromDir 从目录加载所有角色配置文件
|
||||
func LoadRolesFromDir(dir string) (map[string]RoleConfig, error) {
|
||||
roles := make(map[string]RoleConfig)
|
||||
|
||||
// 检查目录是否存在
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
return roles, nil // 目录不存在时返回空map,不报错
|
||||
}
|
||||
|
||||
// 读取目录中的所有 .yaml 和 .yml 文件
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取角色目录失败: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(dir, name)
|
||||
role, err := LoadRoleFromFile(filePath)
|
||||
if err != nil {
|
||||
// 记录错误但继续加载其他文件
|
||||
fmt.Printf("警告: 加载角色配置文件 %s 失败: %v\n", filePath, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 使用角色名称作为key
|
||||
roleName := role.Name
|
||||
if roleName == "" {
|
||||
// 如果角色名称为空,使用文件名(去掉扩展名)作为名称
|
||||
roleName = strings.TrimSuffix(strings.TrimSuffix(name, ".yaml"), ".yml")
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
roles[roleName] = *role
|
||||
}
|
||||
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// LoadRoleFromFile 从单个文件加载角色配置
|
||||
func LoadRoleFromFile(path string) (*RoleConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取文件失败: %w", err)
|
||||
}
|
||||
|
||||
var role RoleConfig
|
||||
if err := yaml.Unmarshal(data, &role); err != nil {
|
||||
return nil, fmt.Errorf("解析角色配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 处理 icon 字段:如果包含 Unicode 转义格式(\U0001F3C6),转换为实际的 Unicode 字符
|
||||
// Go 的 yaml 库可能不会自动解析 \U 转义序列,需要手动转换
|
||||
if role.Icon != "" {
|
||||
icon := role.Icon
|
||||
// 去除可能的引号
|
||||
icon = strings.Trim(icon, `"`)
|
||||
|
||||
// 检查是否是 Unicode 转义格式 \U0001F3C6(8位十六进制)或 \uXXXX(4位十六进制)
|
||||
if len(icon) >= 3 && icon[0] == '\\' {
|
||||
if icon[1] == 'U' && len(icon) >= 10 {
|
||||
// \U0001F3C6 格式(8位十六进制)
|
||||
if codePoint, err := strconv.ParseInt(icon[2:10], 16, 32); err == nil {
|
||||
role.Icon = string(rune(codePoint))
|
||||
}
|
||||
} else if icon[1] == 'u' && len(icon) >= 6 {
|
||||
// \uXXXX 格式(4位十六进制)
|
||||
if codePoint, err := strconv.ParseInt(icon[2:6], 16, 32); err == nil {
|
||||
role.Icon = string(rune(codePoint))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证必需字段
|
||||
if role.Name == "" {
|
||||
// 如果名称为空,尝试从文件名获取
|
||||
baseName := filepath.Base(path)
|
||||
role.Name = strings.TrimSuffix(strings.TrimSuffix(baseName, ".yaml"), ".yml")
|
||||
}
|
||||
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
func Default() *Config {
|
||||
return &Config{
|
||||
Server: ServerConfig{
|
||||
@@ -448,3 +566,20 @@ type RetrievalConfig struct {
|
||||
SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 相似度阈值
|
||||
HybridWeight float64 `yaml:"hybrid_weight" json:"hybrid_weight"` // 向量检索权重(0-1)
|
||||
}
|
||||
|
||||
// RolesConfig 角色配置(已废弃,使用 map[string]RoleConfig 替代)
|
||||
// 保留此类型以兼容旧代码,但建议直接使用 map[string]RoleConfig
|
||||
type RolesConfig struct {
|
||||
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"`
|
||||
}
|
||||
|
||||
// RoleConfig 单个角色配置
|
||||
type RoleConfig struct {
|
||||
Name string `yaml:"name" json:"name"` // 角色名称
|
||||
Description string `yaml:"description" json:"description"` // 角色描述
|
||||
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
|
||||
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
|
||||
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName")
|
||||
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代)
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
|
||||
}
|
||||
|
||||
+63
-12
@@ -12,7 +12,9 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -66,13 +68,14 @@ type AgentHandler struct {
|
||||
logger *zap.Logger
|
||||
tasks *AgentTaskManager
|
||||
batchTaskManager *BatchTaskManager
|
||||
knowledgeManager interface { // 知识库管理器接口
|
||||
config *config.Config // 配置引用,用于获取角色信息
|
||||
knowledgeManager interface { // 知识库管理器接口
|
||||
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
|
||||
}
|
||||
}
|
||||
|
||||
// NewAgentHandler 创建新的Agent处理器
|
||||
func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler {
|
||||
func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, logger *zap.Logger) *AgentHandler {
|
||||
batchTaskManager := NewBatchTaskManager()
|
||||
batchTaskManager.SetDB(db)
|
||||
|
||||
@@ -87,6 +90,7 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *A
|
||||
logger: logger,
|
||||
tasks: NewAgentTaskManager(),
|
||||
batchTaskManager: batchTaskManager,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,6 +105,7 @@ func (h *AgentHandler) SetKnowledgeManager(manager interface {
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
Role string `json:"role,omitempty"` // 角色名称
|
||||
}
|
||||
|
||||
// ChatResponse 聊天响应
|
||||
@@ -161,14 +166,34 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
|
||||
// 保存用户消息
|
||||
// 应用角色用户提示词和工具配置
|
||||
finalMessage := req.Message
|
||||
var roleTools []string // 角色配置的工具列表
|
||||
if req.Role != "" && req.Role != "默认" {
|
||||
if h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
|
||||
// 应用用户提示词
|
||||
if role.UserPrompt != "" {
|
||||
finalMessage = role.UserPrompt + "\n\n" + req.Message
|
||||
h.logger.Info("应用角色用户提示词", zap.String("role", req.Role))
|
||||
}
|
||||
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
|
||||
if len(role.Tools) > 0 {
|
||||
roleTools = role.Tools
|
||||
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存用户消息(保存原始消息,不包含角色提示词)
|
||||
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 执行Agent Loop,传入历史消息和对话ID
|
||||
result, err := h.agent.AgentLoopWithConversationID(c.Request.Context(), req.Message, agentHistoryMessages, conversationID)
|
||||
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools)
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
|
||||
@@ -231,7 +256,7 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
|
||||
if eventType == "tool_call" {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
toolName, _ := dataMap["toolName"].(string)
|
||||
if toolName == "search_knowledge_base" {
|
||||
if toolName == builtin.ToolSearchKnowledgeBase {
|
||||
if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" {
|
||||
if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok {
|
||||
toolCallCache[toolCallId] = argumentsObj
|
||||
@@ -245,7 +270,7 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
|
||||
if eventType == "tool_result" && h.knowledgeManager != nil {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
toolName, _ := dataMap["toolName"].(string)
|
||||
if toolName == "search_knowledge_base" {
|
||||
if toolName == builtin.ToolSearchKnowledgeBase {
|
||||
// 提取检索信息
|
||||
query := ""
|
||||
riskType := ""
|
||||
@@ -470,7 +495,32 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
|
||||
// 保存用户消息
|
||||
// 应用角色用户提示词和工具配置
|
||||
finalMessage := req.Message
|
||||
var roleTools []string // 角色配置的工具列表
|
||||
if req.Role != "" && req.Role != "默认" {
|
||||
if h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
|
||||
// 应用用户提示词
|
||||
if role.UserPrompt != "" {
|
||||
finalMessage = role.UserPrompt + "\n\n" + req.Message
|
||||
h.logger.Info("应用角色用户提示词", zap.String("role", req.Role))
|
||||
}
|
||||
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
|
||||
if len(role.Tools) > 0 {
|
||||
roleTools = role.Tools
|
||||
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
|
||||
} else if len(role.MCPs) > 0 {
|
||||
// 向后兼容:如果只有mcps字段,暂时使用空列表(表示使用所有工具)
|
||||
// 因为mcps是MCP服务器名称,不是工具列表
|
||||
h.logger.Info("角色配置使用旧的mcps字段,将使用所有工具", zap.String("role", req.Role))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 如果roleTools为空,表示使用所有工具(默认角色或未配置工具的角色)
|
||||
|
||||
// 保存用户消息(保存原始消息,不包含角色提示词)
|
||||
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.Error(err))
|
||||
@@ -547,9 +597,9 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
taskStatus := "completed"
|
||||
defer h.tasks.FinishTask(conversationID, taskStatus)
|
||||
|
||||
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断
|
||||
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
sendEvent("progress", "正在分析您的请求...", nil)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, req.Message, agentHistoryMessages, conversationID, progressCallback)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
cause := context.Cause(baseCtx)
|
||||
@@ -759,7 +809,7 @@ func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
|
||||
|
||||
// BatchTaskRequest 批量任务请求
|
||||
type BatchTaskRequest struct {
|
||||
Title string `json:"title"` // 任务标题(可选)
|
||||
Title string `json:"title"` // 任务标题(可选)
|
||||
Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务
|
||||
}
|
||||
|
||||
@@ -1072,7 +1122,8 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
// 存储取消函数,以便在取消队列时能够取消当前任务
|
||||
h.batchTaskManager.SetTaskCancel(queueID, cancel)
|
||||
result, err := h.agent.AgentLoopWithProgress(ctx, task.Message, []agent.ChatMessage{}, conversationID, progressCallback)
|
||||
// 批量任务暂时不支持角色工具过滤,使用所有工具(传入nil)
|
||||
result, err := h.agent.AgentLoopWithProgress(ctx, task.Message, []agent.ChatMessage{}, conversationID, progressCallback, nil)
|
||||
// 任务执行完成,清理取消函数
|
||||
h.batchTaskManager.SetTaskCancel(queueID, nil)
|
||||
cancel()
|
||||
|
||||
+114
-12
@@ -147,6 +147,7 @@ type ToolConfigInfo struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
|
||||
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
|
||||
RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具)
|
||||
}
|
||||
|
||||
// GetConfig 获取当前配置
|
||||
@@ -272,11 +273,12 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
|
||||
// GetToolsResponse 获取工具列表响应(分页)
|
||||
type GetToolsResponse struct {
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Total int `json:"total"`
|
||||
TotalEnabled int `json:"total_enabled"` // 已启用的工具总数
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// GetTools 获取工具列表(支持分页和搜索)
|
||||
@@ -305,6 +307,23 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
searchTermLower = strings.ToLower(searchTerm)
|
||||
}
|
||||
|
||||
// 解析角色参数,用于过滤工具并标注启用状态
|
||||
roleName := c.Query("role")
|
||||
var roleToolsSet map[string]bool // 角色配置的工具集合
|
||||
var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色)
|
||||
if roleName != "" && roleName != "默认" && h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[roleName]; exists && role.Enabled {
|
||||
if len(role.Tools) > 0 {
|
||||
// 角色配置了工具列表,只使用这些工具
|
||||
roleToolsSet = make(map[string]bool)
|
||||
for _, toolKey := range role.Tools {
|
||||
roleToolsSet[toolKey] = true
|
||||
}
|
||||
roleUsesAllTools = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有内部工具并应用搜索过滤
|
||||
configToolMap := make(map[string]bool)
|
||||
allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
|
||||
@@ -325,6 +344,31 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
toolInfo.Description = desc
|
||||
}
|
||||
|
||||
// 根据角色配置标注工具状态
|
||||
if roleName != "" {
|
||||
if roleUsesAllTools {
|
||||
// 角色使用所有工具,标注启用的工具为role_enabled=true
|
||||
if tool.Enabled {
|
||||
roleEnabled := true
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
roleEnabled := false
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
}
|
||||
} else {
|
||||
// 角色配置了工具列表,检查工具是否在列表中
|
||||
// 内部工具使用工具名称作为key
|
||||
if roleToolsSet[tool.Name] {
|
||||
roleEnabled := tool.Enabled // 工具必须在角色列表中且本身启用
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 不在角色列表中,标记为false
|
||||
roleEnabled := false
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有关键词,进行搜索过滤
|
||||
if searchTermLower != "" {
|
||||
nameLower := strings.ToLower(toolInfo.Name)
|
||||
@@ -361,6 +405,26 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
IsExternal: false,
|
||||
}
|
||||
|
||||
// 根据角色配置标注工具状态
|
||||
if roleName != "" {
|
||||
if roleUsesAllTools {
|
||||
// 角色使用所有工具,直接注册的工具默认启用
|
||||
roleEnabled := true
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 角色配置了工具列表,检查工具是否在列表中
|
||||
// 内部工具使用工具名称作为key
|
||||
if roleToolsSet[mcpTool.Name] {
|
||||
roleEnabled := true // 在角色列表中且工具本身启用
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 不在角色列表中,标记为false
|
||||
roleEnabled := false
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有关键词,进行搜索过滤
|
||||
if searchTermLower != "" {
|
||||
nameLower := strings.ToLower(toolInfo.Name)
|
||||
@@ -439,18 +503,55 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
allTools = append(allTools, ToolConfigInfo{
|
||||
toolInfo := ToolConfigInfo{
|
||||
Name: actualToolName, // 显示实际工具名称,不带前缀
|
||||
Description: description,
|
||||
Enabled: enabled,
|
||||
IsExternal: true,
|
||||
ExternalMCP: mcpName,
|
||||
})
|
||||
}
|
||||
|
||||
// 根据角色配置标注工具状态
|
||||
if roleName != "" {
|
||||
if roleUsesAllTools {
|
||||
// 角色使用所有工具,标注启用的工具为role_enabled=true
|
||||
toolInfo.RoleEnabled = &enabled
|
||||
} else {
|
||||
// 角色配置了工具列表,检查工具是否在列表中
|
||||
// 外部工具使用 "mcpName::toolName" 格式作为key
|
||||
externalToolKey := externalTool.Name // 这是 "mcpName::toolName" 格式
|
||||
if roleToolsSet[externalToolKey] {
|
||||
roleEnabled := enabled // 工具必须在角色列表中且本身启用
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 不在角色列表中,标记为false
|
||||
roleEnabled := false
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allTools = append(allTools, toolInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果角色配置了工具列表,过滤工具(只保留列表中的工具,但保留其他工具并标记为禁用)
|
||||
// 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态
|
||||
// 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用
|
||||
|
||||
total := len(allTools)
|
||||
// 统计已启用的工具数(在角色中的启用工具数)
|
||||
totalEnabled := 0
|
||||
for _, tool := range allTools {
|
||||
if tool.RoleEnabled != nil && *tool.RoleEnabled {
|
||||
totalEnabled++
|
||||
} else if tool.RoleEnabled == nil && tool.Enabled {
|
||||
// 如果未指定角色,统计所有启用的工具
|
||||
totalEnabled++
|
||||
}
|
||||
}
|
||||
|
||||
totalPages := (total + pageSize - 1) / pageSize
|
||||
if totalPages == 0 {
|
||||
totalPages = 1
|
||||
@@ -471,11 +572,12 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, GetToolsResponse{
|
||||
Tools: tools,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
Tools: tools,
|
||||
Total: total,
|
||||
TotalEnabled: totalEnabled,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,453 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RoleHandler 角色处理器
|
||||
type RoleHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRoleHandler 创建新的角色处理器
|
||||
func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *RoleHandler {
|
||||
return &RoleHandler{
|
||||
config: cfg,
|
||||
configPath: configPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetRoles 获取所有角色
|
||||
func (h *RoleHandler) GetRoles(c *gin.Context) {
|
||||
if h.config.Roles == nil {
|
||||
h.config.Roles = make(map[string]config.RoleConfig)
|
||||
}
|
||||
|
||||
roles := make([]config.RoleConfig, 0, len(h.config.Roles))
|
||||
for key, role := range h.config.Roles {
|
||||
// 确保角色的key与name一致
|
||||
if role.Name == "" {
|
||||
role.Name = key
|
||||
}
|
||||
roles = append(roles, role)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"roles": roles,
|
||||
})
|
||||
}
|
||||
|
||||
// GetRole 获取单个角色
|
||||
func (h *RoleHandler) GetRole(c *gin.Context) {
|
||||
roleName := c.Param("name")
|
||||
if roleName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.config.Roles == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
role, exists := h.config.Roles[roleName]
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 确保角色的name与key一致
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"role": role,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRole 更新角色
|
||||
func (h *RoleHandler) UpdateRole(c *gin.Context) {
|
||||
roleName := c.Param("name")
|
||||
if roleName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
var req config.RoleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 确保角色名称与请求中的name一致
|
||||
if req.Name == "" {
|
||||
req.Name = roleName
|
||||
}
|
||||
|
||||
// 初始化Roles map
|
||||
if h.config.Roles == nil {
|
||||
h.config.Roles = make(map[string]config.RoleConfig)
|
||||
}
|
||||
|
||||
// 删除所有与角色name相同但key不同的旧角色(避免重复)
|
||||
// 使用角色name作为key,确保唯一性
|
||||
finalKey := req.Name
|
||||
keysToDelete := make([]string, 0)
|
||||
for key := range h.config.Roles {
|
||||
// 如果key与最终的key不同,但name相同,则标记为删除
|
||||
if key != finalKey {
|
||||
role := h.config.Roles[key]
|
||||
// 确保角色的name字段正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = key
|
||||
}
|
||||
if role.Name == req.Name {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 删除旧的角色
|
||||
for _, key := range keysToDelete {
|
||||
delete(h.config.Roles, key)
|
||||
h.logger.Info("删除重复的角色", zap.String("oldKey", key), zap.String("name", req.Name))
|
||||
}
|
||||
|
||||
// 如果当前更新的key与最终key不同,也需要删除旧的
|
||||
if roleName != finalKey {
|
||||
delete(h.config.Roles, roleName)
|
||||
}
|
||||
|
||||
// 如果角色名称改变,需要删除旧文件
|
||||
if roleName != finalKey {
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 删除旧的角色文件
|
||||
oldSafeFileName := sanitizeFileName(roleName)
|
||||
oldRoleFileYaml := filepath.Join(rolesDir, oldSafeFileName+".yaml")
|
||||
oldRoleFileYml := filepath.Join(rolesDir, oldSafeFileName+".yml")
|
||||
|
||||
if _, err := os.Stat(oldRoleFileYaml); err == nil {
|
||||
if err := os.Remove(oldRoleFileYaml); err != nil {
|
||||
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYaml), zap.Error(err))
|
||||
}
|
||||
}
|
||||
if _, err := os.Stat(oldRoleFileYml); err == nil {
|
||||
if err := os.Remove(oldRoleFileYml); err != nil {
|
||||
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYml), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 使用角色name作为key来保存(确保唯一性)
|
||||
h.config.Roles[finalKey] = req
|
||||
|
||||
// 保存配置到文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已更新",
|
||||
"role": req,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateRole 创建新角色
|
||||
func (h *RoleHandler) CreateRole(c *gin.Context) {
|
||||
var req config.RoleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化Roles map
|
||||
if h.config.Roles == nil {
|
||||
h.config.Roles = make(map[string]config.RoleConfig)
|
||||
}
|
||||
|
||||
// 检查角色是否已存在
|
||||
if _, exists := h.config.Roles[req.Name]; exists {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色已存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建角色(默认启用)
|
||||
if !req.Enabled {
|
||||
req.Enabled = true
|
||||
}
|
||||
|
||||
h.config.Roles[req.Name] = req
|
||||
|
||||
// 保存配置到文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("创建角色", zap.String("roleName", req.Name))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已创建",
|
||||
"role": req,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteRole 删除角色
|
||||
func (h *RoleHandler) DeleteRole(c *gin.Context) {
|
||||
roleName := c.Param("name")
|
||||
if roleName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.config.Roles == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
if _, exists := h.config.Roles[roleName]; !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 不允许删除"默认"角色
|
||||
if roleName == "默认" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除默认角色"})
|
||||
return
|
||||
}
|
||||
|
||||
delete(h.config.Roles, roleName)
|
||||
|
||||
// 删除对应的角色文件
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 尝试删除角色文件(.yaml 和 .yml)
|
||||
safeFileName := sanitizeFileName(roleName)
|
||||
roleFileYaml := filepath.Join(rolesDir, safeFileName+".yaml")
|
||||
roleFileYml := filepath.Join(rolesDir, safeFileName+".yml")
|
||||
|
||||
// 删除 .yaml 文件(如果存在)
|
||||
if _, err := os.Stat(roleFileYaml); err == nil {
|
||||
if err := os.Remove(roleFileYaml); err != nil {
|
||||
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYaml), zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYaml))
|
||||
}
|
||||
}
|
||||
|
||||
// 删除 .yml 文件(如果存在)
|
||||
if _, err := os.Stat(roleFileYml); err == nil {
|
||||
if err := os.Remove(roleFileYml); err != nil {
|
||||
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYml), zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYml))
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("删除角色", zap.String("roleName", roleName))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已删除",
|
||||
})
|
||||
}
|
||||
|
||||
// saveConfig 保存配置到目录中的文件
|
||||
func (h *RoleHandler) saveConfig() error {
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(rolesDir, 0755); err != nil {
|
||||
return fmt.Errorf("创建角色目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存每个角色到独立的文件
|
||||
if h.config.Roles != nil {
|
||||
for roleName, role := range h.config.Roles {
|
||||
// 确保角色名称正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
|
||||
safeFileName := sanitizeFileName(role.Name)
|
||||
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
|
||||
|
||||
// 将角色配置序列化为YAML
|
||||
roleData, err := yaml.Marshal(&role)
|
||||
if err != nil {
|
||||
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
|
||||
roleDataStr := string(roleData)
|
||||
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
|
||||
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
|
||||
// 使用负向前瞻确保后面没有引号,或者直接匹配没有引号的情况
|
||||
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
|
||||
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
|
||||
roleData = []byte(roleDataStr)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
|
||||
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeFileName 将角色名称转换为安全的文件名
|
||||
func sanitizeFileName(name string) string {
|
||||
// 替换可能不安全的字符
|
||||
replacer := map[rune]string{
|
||||
'/': "_",
|
||||
'\\': "_",
|
||||
':': "_",
|
||||
'*': "_",
|
||||
'?': "_",
|
||||
'"': "_",
|
||||
'<': "_",
|
||||
'>': "_",
|
||||
'|': "_",
|
||||
' ': "_",
|
||||
}
|
||||
|
||||
var result []rune
|
||||
for _, r := range name {
|
||||
if replacement, ok := replacer[r]; ok {
|
||||
result = append(result, []rune(replacement)...)
|
||||
} else {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
|
||||
fileName := string(result)
|
||||
// 如果文件名为空,使用默认名称
|
||||
if fileName == "" {
|
||||
fileName = "role"
|
||||
}
|
||||
|
||||
return fileName
|
||||
}
|
||||
|
||||
// updateRolesConfig 更新角色配置
|
||||
func updateRolesConfig(doc *yaml.Node, cfg config.RolesConfig) {
|
||||
root := doc.Content[0]
|
||||
rolesNode := ensureMap(root, "roles")
|
||||
|
||||
// 清空现有角色
|
||||
if rolesNode.Kind == yaml.MappingNode {
|
||||
rolesNode.Content = nil
|
||||
}
|
||||
|
||||
// 添加新角色(使用name作为key,确保唯一性)
|
||||
if cfg.Roles != nil {
|
||||
// 先建立一个以name为key的map,去重(保留最后一个)
|
||||
rolesByName := make(map[string]config.RoleConfig)
|
||||
for roleKey, role := range cfg.Roles {
|
||||
// 确保角色的name字段正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleKey
|
||||
}
|
||||
// 使用name作为最终key,如果有多个key对应相同的name,只保留最后一个
|
||||
rolesByName[role.Name] = role
|
||||
}
|
||||
|
||||
// 将去重后的角色写入YAML
|
||||
for roleName, role := range rolesByName {
|
||||
roleNode := ensureMap(rolesNode, roleName)
|
||||
setStringInMap(roleNode, "name", role.Name)
|
||||
setStringInMap(roleNode, "description", role.Description)
|
||||
setStringInMap(roleNode, "user_prompt", role.UserPrompt)
|
||||
if role.Icon != "" {
|
||||
setStringInMap(roleNode, "icon", role.Icon)
|
||||
}
|
||||
setBoolInMap(roleNode, "enabled", role.Enabled)
|
||||
|
||||
// 添加工具列表(优先使用tools字段)
|
||||
if len(role.Tools) > 0 {
|
||||
toolsNode := ensureArray(roleNode, "tools")
|
||||
toolsNode.Content = nil
|
||||
for _, toolKey := range role.Tools {
|
||||
toolNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: toolKey}
|
||||
toolsNode.Content = append(toolsNode.Content, toolNode)
|
||||
}
|
||||
} else if len(role.MCPs) > 0 {
|
||||
// 向后兼容:如果没有tools但有mcps,保存mcps
|
||||
mcpsNode := ensureArray(roleNode, "mcps")
|
||||
mcpsNode.Content = nil
|
||||
for _, mcpName := range role.MCPs {
|
||||
mcpNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: mcpName}
|
||||
mcpsNode.Content = append(mcpsNode.Content, mcpNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensureArray 确保数组中存在指定key的数组节点
|
||||
func ensureArray(parent *yaml.Node, key string) *yaml.Node {
|
||||
_, valueNode := ensureKeyValue(parent, key)
|
||||
if valueNode.Kind != yaml.SequenceNode {
|
||||
valueNode.Kind = yaml.SequenceNode
|
||||
valueNode.Tag = "!!seq"
|
||||
valueNode.Content = nil
|
||||
}
|
||||
return valueNode
|
||||
}
|
||||
@@ -161,14 +161,14 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
|
||||
// 查询所有向量(或按风险类型过滤)
|
||||
// 使用精确匹配(=)以提高性能和准确性
|
||||
// 由于系统提供了 list_knowledge_risk_types 工具,用户应该使用准确的category名称
|
||||
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
|
||||
var rows *sql.Rows
|
||||
if req.RiskType != "" {
|
||||
// 使用精确匹配(=),性能更好且更准确
|
||||
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
|
||||
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
|
||||
// 建议用户先调用 list_knowledge_risk_types 获取准确的category名称
|
||||
// 由于系统提供了内置工具来获取风险类型列表,用户应该使用准确的category名称
|
||||
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
|
||||
var rows *sql.Rows
|
||||
if req.RiskType != "" {
|
||||
// 使用精确匹配(=),性能更好且更准确
|
||||
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
|
||||
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
|
||||
// 建议用户先调用相应的内置工具获取准确的category名称
|
||||
rows, err = r.db.Query(`
|
||||
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
|
||||
FROM knowledge_embeddings e
|
||||
|
||||
+12
-11
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -21,7 +22,7 @@ func RegisterKnowledgeTool(
|
||||
) {
|
||||
// 注册第一个工具:获取所有可用的风险类型列表
|
||||
listRiskTypesTool := mcp.Tool{
|
||||
Name: "list_knowledge_risk_types",
|
||||
Name: builtin.ToolListKnowledgeRiskTypes,
|
||||
Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。",
|
||||
ShortDescription: "获取知识库中所有可用的风险类型列表",
|
||||
InputSchema: map[string]interface{}{
|
||||
@@ -62,7 +63,7 @@ func RegisterKnowledgeTool(
|
||||
for i, category := range categories {
|
||||
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
|
||||
}
|
||||
resultText.WriteString("\n提示:在调用 search_knowledge_base 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
|
||||
resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
@@ -79,8 +80,8 @@ func RegisterKnowledgeTool(
|
||||
|
||||
// 注册第二个工具:搜索知识库(保持原有功能)
|
||||
searchTool := mcp.Tool{
|
||||
Name: "search_knowledge_base",
|
||||
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 list_knowledge_risk_types 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
|
||||
Name: builtin.ToolSearchKnowledgeBase,
|
||||
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
|
||||
ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
@@ -91,7 +92,7 @@ func RegisterKnowledgeTool(
|
||||
},
|
||||
"risk_type": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 list_knowledge_risk_types 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
|
||||
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
@@ -165,9 +166,9 @@ func RegisterKnowledgeTool(
|
||||
// 按文档分组结果,以便更好地展示上下文
|
||||
// 使用有序的slice来保持文档顺序(按最高混合分数)
|
||||
type itemGroup struct {
|
||||
itemID string
|
||||
results []*RetrievalResult
|
||||
maxScore float64 // 该文档的最高混合分数
|
||||
itemID string
|
||||
results []*RetrievalResult
|
||||
maxScore float64 // 该文档的最高混合分数
|
||||
}
|
||||
itemGroups := make([]*itemGroup, 0)
|
||||
itemMap := make(map[string]*itemGroup)
|
||||
@@ -177,8 +178,8 @@ func RegisterKnowledgeTool(
|
||||
group, exists := itemMap[itemID]
|
||||
if !exists {
|
||||
group = &itemGroup{
|
||||
itemID: itemID,
|
||||
results: make([]*RetrievalResult, 0),
|
||||
itemID: itemID,
|
||||
results: make([]*RetrievalResult, 0),
|
||||
maxScore: result.Score,
|
||||
}
|
||||
itemMap[itemID] = group
|
||||
@@ -219,7 +220,7 @@ func RegisterKnowledgeTool(
|
||||
})
|
||||
|
||||
// 显示主结果(混合分数最高的,同时显示相似度和混合分数)
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
|
||||
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
|
||||
resultIndex, mainResult.Similarity*100, mainResult.Score*100))
|
||||
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
|
||||
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
package builtin
|
||||
|
||||
// 内置工具名称常量
|
||||
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
|
||||
const (
|
||||
// 漏洞管理工具
|
||||
ToolRecordVulnerability = "record_vulnerability"
|
||||
|
||||
// 知识库工具
|
||||
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
|
||||
ToolSearchKnowledgeBase = "search_knowledge_base"
|
||||
)
|
||||
|
||||
// IsBuiltinTool 检查工具名称是否是内置工具
|
||||
func IsBuiltinTool(toolName string) bool {
|
||||
switch toolName {
|
||||
case ToolRecordVulnerability,
|
||||
ToolListKnowledgeRiskTypes,
|
||||
ToolSearchKnowledgeBase:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllBuiltinTools 返回所有内置工具名称列表
|
||||
func GetAllBuiltinTools() []string {
|
||||
return []string{
|
||||
ToolRecordVulnerability,
|
||||
ToolListKnowledgeRiskTypes,
|
||||
ToolSearchKnowledgeBase,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user