Add files via upload

This commit is contained in:
公明
2026-01-11 02:03:33 +08:00
committed by GitHub
parent 4ca1aa9aa8
commit 3aee7022c4
30 changed files with 3759 additions and 86 deletions
+63 -12
View File
@@ -12,7 +12,9 @@ import (
"unicode/utf8"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp/builtin"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
@@ -66,13 +68,14 @@ type AgentHandler struct {
logger *zap.Logger
tasks *AgentTaskManager
batchTaskManager *BatchTaskManager
knowledgeManager interface { // 知识库管理器接口
config *config.Config // 配置引用,用于获取角色信息
knowledgeManager interface { // 知识库管理器接口
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
}
}
// NewAgentHandler 创建新的Agent处理器
func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler {
func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, logger *zap.Logger) *AgentHandler {
batchTaskManager := NewBatchTaskManager()
batchTaskManager.SetDB(db)
@@ -87,6 +90,7 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *A
logger: logger,
tasks: NewAgentTaskManager(),
batchTaskManager: batchTaskManager,
config: cfg,
}
}
@@ -101,6 +105,7 @@ func (h *AgentHandler) SetKnowledgeManager(manager interface {
type ChatRequest struct {
Message string `json:"message" binding:"required"`
ConversationID string `json:"conversationId,omitempty"`
Role string `json:"role,omitempty"` // 角色名称
}
// ChatResponse 聊天响应
@@ -161,14 +166,34 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 保存用户消息
// 应用角色用户提示词和工具配置
finalMessage := req.Message
var roleTools []string // 角色配置的工具列表
if req.Role != "" && req.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
// 应用用户提示词
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + req.Message
h.logger.Info("应用角色用户提示词", zap.String("role", req.Role))
}
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
}
}
}
}
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
}
// 执行Agent Loop,传入历史消息和对话ID
result, err := h.agent.AgentLoopWithConversationID(c.Request.Context(), req.Message, agentHistoryMessages, conversationID)
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
@@ -231,7 +256,7 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
if eventType == "tool_call" {
if dataMap, ok := data.(map[string]interface{}); ok {
toolName, _ := dataMap["toolName"].(string)
if toolName == "search_knowledge_base" {
if toolName == builtin.ToolSearchKnowledgeBase {
if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" {
if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok {
toolCallCache[toolCallId] = argumentsObj
@@ -245,7 +270,7 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
if eventType == "tool_result" && h.knowledgeManager != nil {
if dataMap, ok := data.(map[string]interface{}); ok {
toolName, _ := dataMap["toolName"].(string)
if toolName == "search_knowledge_base" {
if toolName == builtin.ToolSearchKnowledgeBase {
// 提取检索信息
query := ""
riskType := ""
@@ -470,7 +495,32 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 保存用户消息
// 应用角色用户提示词和工具配置
finalMessage := req.Message
var roleTools []string // 角色配置的工具列表
if req.Role != "" && req.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
// 应用用户提示词
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + req.Message
h.logger.Info("应用角色用户提示词", zap.String("role", req.Role))
}
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
} else if len(role.MCPs) > 0 {
// 向后兼容:如果只有mcps字段,暂时使用空列表(表示使用所有工具)
// 因为mcps是MCP服务器名称,不是工具列表
h.logger.Info("角色配置使用旧的mcps字段,将使用所有工具", zap.String("role", req.Role))
}
}
}
}
// 如果roleTools为空,表示使用所有工具(默认角色或未配置工具的角色)
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
@@ -547,9 +597,9 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
taskStatus := "completed"
defer h.tasks.FinishTask(conversationID, taskStatus)
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断(使用包含角色提示词的finalMessage和角色工具列表)
sendEvent("progress", "正在分析您的请求...", nil)
result, err := h.agent.AgentLoopWithProgress(taskCtx, req.Message, agentHistoryMessages, conversationID, progressCallback)
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
cause := context.Cause(baseCtx)
@@ -759,7 +809,7 @@ func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
// BatchTaskRequest 批量任务请求
type BatchTaskRequest struct {
Title string `json:"title"` // 任务标题(可选)
Title string `json:"title"` // 任务标题(可选)
Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务
}
@@ -1072,7 +1122,8 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
// 存储取消函数,以便在取消队列时能够取消当前任务
h.batchTaskManager.SetTaskCancel(queueID, cancel)
result, err := h.agent.AgentLoopWithProgress(ctx, task.Message, []agent.ChatMessage{}, conversationID, progressCallback)
// 批量任务暂时不支持角色工具过滤,使用所有工具(传入nil)
result, err := h.agent.AgentLoopWithProgress(ctx, task.Message, []agent.ChatMessage{}, conversationID, progressCallback, nil)
// 任务执行完成,清理取消函数
h.batchTaskManager.SetTaskCancel(queueID, nil)
cancel()
+114 -12
View File
@@ -147,6 +147,7 @@ type ToolConfigInfo struct {
Enabled bool `json:"enabled"`
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具)
}
// GetConfig 获取当前配置
@@ -272,11 +273,12 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
// GetToolsResponse 获取工具列表响应(分页)
type GetToolsResponse struct {
Tools []ToolConfigInfo `json:"tools"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
Tools []ToolConfigInfo `json:"tools"`
Total int `json:"total"`
TotalEnabled int `json:"total_enabled"` // 已启用的工具总数
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
}
// GetTools 获取工具列表(支持分页和搜索)
@@ -305,6 +307,23 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
searchTermLower = strings.ToLower(searchTerm)
}
// 解析角色参数,用于过滤工具并标注启用状态
roleName := c.Query("role")
var roleToolsSet map[string]bool // 角色配置的工具集合
var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色)
if roleName != "" && roleName != "默认" && h.config.Roles != nil {
if role, exists := h.config.Roles[roleName]; exists && role.Enabled {
if len(role.Tools) > 0 {
// 角色配置了工具列表,只使用这些工具
roleToolsSet = make(map[string]bool)
for _, toolKey := range role.Tools {
roleToolsSet[toolKey] = true
}
roleUsesAllTools = false
}
}
}
// 获取所有内部工具并应用搜索过滤
configToolMap := make(map[string]bool)
allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
@@ -325,6 +344,31 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
toolInfo.Description = desc
}
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,标注启用的工具为role_enabled=true
if tool.Enabled {
roleEnabled := true
toolInfo.RoleEnabled = &roleEnabled
} else {
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 内部工具使用工具名称作为key
if roleToolsSet[tool.Name] {
roleEnabled := tool.Enabled // 工具必须在角色列表中且本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
// 如果有关键词,进行搜索过滤
if searchTermLower != "" {
nameLower := strings.ToLower(toolInfo.Name)
@@ -361,6 +405,26 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
IsExternal: false,
}
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,直接注册的工具默认启用
roleEnabled := true
toolInfo.RoleEnabled = &roleEnabled
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 内部工具使用工具名称作为key
if roleToolsSet[mcpTool.Name] {
roleEnabled := true // 在角色列表中且工具本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
// 如果有关键词,进行搜索过滤
if searchTermLower != "" {
nameLower := strings.ToLower(toolInfo.Name)
@@ -439,18 +503,55 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
}
}
allTools = append(allTools, ToolConfigInfo{
toolInfo := ToolConfigInfo{
Name: actualToolName, // 显示实际工具名称,不带前缀
Description: description,
Enabled: enabled,
IsExternal: true,
ExternalMCP: mcpName,
})
}
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,标注启用的工具为role_enabled=true
toolInfo.RoleEnabled = &enabled
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 外部工具使用 "mcpName::toolName" 格式作为key
externalToolKey := externalTool.Name // 这是 "mcpName::toolName" 格式
if roleToolsSet[externalToolKey] {
roleEnabled := enabled // 工具必须在角色列表中且本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
allTools = append(allTools, toolInfo)
}
}
}
// 如果角色配置了工具列表,过滤工具(只保留列表中的工具,但保留其他工具并标记为禁用)
// 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态
// 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用
total := len(allTools)
// 统计已启用的工具数(在角色中的启用工具数)
totalEnabled := 0
for _, tool := range allTools {
if tool.RoleEnabled != nil && *tool.RoleEnabled {
totalEnabled++
} else if tool.RoleEnabled == nil && tool.Enabled {
// 如果未指定角色,统计所有启用的工具
totalEnabled++
}
}
totalPages := (total + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
@@ -471,11 +572,12 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
}
c.JSON(http.StatusOK, GetToolsResponse{
Tools: tools,
Total: total,
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
Tools: tools,
Total: total,
TotalEnabled: totalEnabled,
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
})
}
+453
View File
@@ -0,0 +1,453 @@
package handler
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"cyberstrike-ai/internal/config"
"gopkg.in/yaml.v3"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// RoleHandler 角色处理器
type RoleHandler struct {
config *config.Config
configPath string
logger *zap.Logger
}
// NewRoleHandler 创建新的角色处理器
func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *RoleHandler {
return &RoleHandler{
config: cfg,
configPath: configPath,
logger: logger,
}
}
// GetRoles 获取所有角色
func (h *RoleHandler) GetRoles(c *gin.Context) {
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
roles := make([]config.RoleConfig, 0, len(h.config.Roles))
for key, role := range h.config.Roles {
// 确保角色的key与name一致
if role.Name == "" {
role.Name = key
}
roles = append(roles, role)
}
c.JSON(http.StatusOK, gin.H{
"roles": roles,
})
}
// GetRole 获取单个角色
func (h *RoleHandler) GetRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
if h.config.Roles == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
role, exists := h.config.Roles[roleName]
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
// 确保角色的name与key一致
if role.Name == "" {
role.Name = roleName
}
c.JSON(http.StatusOK, gin.H{
"role": role,
})
}
// UpdateRole 更新角色
func (h *RoleHandler) UpdateRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
var req config.RoleConfig
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
// 确保角色名称与请求中的name一致
if req.Name == "" {
req.Name = roleName
}
// 初始化Roles map
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
// 删除所有与角色name相同但key不同的旧角色(避免重复)
// 使用角色name作为key,确保唯一性
finalKey := req.Name
keysToDelete := make([]string, 0)
for key := range h.config.Roles {
// 如果key与最终的key不同,但name相同,则标记为删除
if key != finalKey {
role := h.config.Roles[key]
// 确保角色的name字段正确设置
if role.Name == "" {
role.Name = key
}
if role.Name == req.Name {
keysToDelete = append(keysToDelete, key)
}
}
}
// 删除旧的角色
for _, key := range keysToDelete {
delete(h.config.Roles, key)
h.logger.Info("删除重复的角色", zap.String("oldKey", key), zap.String("name", req.Name))
}
// 如果当前更新的key与最终key不同,也需要删除旧的
if roleName != finalKey {
delete(h.config.Roles, roleName)
}
// 如果角色名称改变,需要删除旧文件
if roleName != finalKey {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 删除旧的角色文件
oldSafeFileName := sanitizeFileName(roleName)
oldRoleFileYaml := filepath.Join(rolesDir, oldSafeFileName+".yaml")
oldRoleFileYml := filepath.Join(rolesDir, oldSafeFileName+".yml")
if _, err := os.Stat(oldRoleFileYaml); err == nil {
if err := os.Remove(oldRoleFileYaml); err != nil {
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYaml), zap.Error(err))
}
}
if _, err := os.Stat(oldRoleFileYml); err == nil {
if err := os.Remove(oldRoleFileYml); err != nil {
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYml), zap.Error(err))
}
}
}
// 使用角色name作为key来保存(确保唯一性)
h.config.Roles[finalKey] = req
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
c.JSON(http.StatusOK, gin.H{
"message": "角色已更新",
"role": req,
})
}
// CreateRole 创建新角色
func (h *RoleHandler) CreateRole(c *gin.Context) {
var req config.RoleConfig
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
if req.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
// 初始化Roles map
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
// 检查角色是否已存在
if _, exists := h.config.Roles[req.Name]; exists {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色已存在"})
return
}
// 创建角色(默认启用)
if !req.Enabled {
req.Enabled = true
}
h.config.Roles[req.Name] = req
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("创建角色", zap.String("roleName", req.Name))
c.JSON(http.StatusOK, gin.H{
"message": "角色已创建",
"role": req,
})
}
// DeleteRole 删除角色
func (h *RoleHandler) DeleteRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
if h.config.Roles == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
if _, exists := h.config.Roles[roleName]; !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
// 不允许删除"默认"角色
if roleName == "默认" {
c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除默认角色"})
return
}
delete(h.config.Roles, roleName)
// 删除对应的角色文件
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 尝试删除角色文件(.yaml 和 .yml)
safeFileName := sanitizeFileName(roleName)
roleFileYaml := filepath.Join(rolesDir, safeFileName+".yaml")
roleFileYml := filepath.Join(rolesDir, safeFileName+".yml")
// 删除 .yaml 文件(如果存在)
if _, err := os.Stat(roleFileYaml); err == nil {
if err := os.Remove(roleFileYaml); err != nil {
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYaml), zap.Error(err))
} else {
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYaml))
}
}
// 删除 .yml 文件(如果存在)
if _, err := os.Stat(roleFileYml); err == nil {
if err := os.Remove(roleFileYml); err != nil {
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYml), zap.Error(err))
} else {
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYml))
}
}
h.logger.Info("删除角色", zap.String("roleName", roleName))
c.JSON(http.StatusOK, gin.H{
"message": "角色已删除",
})
}
// saveConfig 保存配置到目录中的文件
func (h *RoleHandler) saveConfig() error {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 确保目录存在
if err := os.MkdirAll(rolesDir, 0755); err != nil {
return fmt.Errorf("创建角色目录失败: %w", err)
}
// 保存每个角色到独立的文件
if h.config.Roles != nil {
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
safeFileName := sanitizeFileName(role.Name)
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
// 将角色配置序列化为YAML
roleData, err := yaml.Marshal(&role)
if err != nil {
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
continue
}
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
roleDataStr := string(roleData)
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
// 使用负向前瞻确保后面没有引号,或者直接匹配没有引号的情况
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
roleData = []byte(roleDataStr)
}
// 写入文件
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
continue
}
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
}
}
return nil
}
// sanitizeFileName 将角色名称转换为安全的文件名
func sanitizeFileName(name string) string {
// 替换可能不安全的字符
replacer := map[rune]string{
'/': "_",
'\\': "_",
':': "_",
'*': "_",
'?': "_",
'"': "_",
'<': "_",
'>': "_",
'|': "_",
' ': "_",
}
var result []rune
for _, r := range name {
if replacement, ok := replacer[r]; ok {
result = append(result, []rune(replacement)...)
} else {
result = append(result, r)
}
}
fileName := string(result)
// 如果文件名为空,使用默认名称
if fileName == "" {
fileName = "role"
}
return fileName
}
// updateRolesConfig 更新角色配置
func updateRolesConfig(doc *yaml.Node, cfg config.RolesConfig) {
root := doc.Content[0]
rolesNode := ensureMap(root, "roles")
// 清空现有角色
if rolesNode.Kind == yaml.MappingNode {
rolesNode.Content = nil
}
// 添加新角色(使用name作为key,确保唯一性)
if cfg.Roles != nil {
// 先建立一个以name为key的map,去重(保留最后一个)
rolesByName := make(map[string]config.RoleConfig)
for roleKey, role := range cfg.Roles {
// 确保角色的name字段正确设置
if role.Name == "" {
role.Name = roleKey
}
// 使用name作为最终key,如果有多个key对应相同的name,只保留最后一个
rolesByName[role.Name] = role
}
// 将去重后的角色写入YAML
for roleName, role := range rolesByName {
roleNode := ensureMap(rolesNode, roleName)
setStringInMap(roleNode, "name", role.Name)
setStringInMap(roleNode, "description", role.Description)
setStringInMap(roleNode, "user_prompt", role.UserPrompt)
if role.Icon != "" {
setStringInMap(roleNode, "icon", role.Icon)
}
setBoolInMap(roleNode, "enabled", role.Enabled)
// 添加工具列表(优先使用tools字段)
if len(role.Tools) > 0 {
toolsNode := ensureArray(roleNode, "tools")
toolsNode.Content = nil
for _, toolKey := range role.Tools {
toolNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: toolKey}
toolsNode.Content = append(toolsNode.Content, toolNode)
}
} else if len(role.MCPs) > 0 {
// 向后兼容:如果没有tools但有mcps,保存mcps
mcpsNode := ensureArray(roleNode, "mcps")
mcpsNode.Content = nil
for _, mcpName := range role.MCPs {
mcpNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: mcpName}
mcpsNode.Content = append(mcpsNode.Content, mcpNode)
}
}
}
}
}
// ensureArray 确保数组中存在指定key的数组节点
func ensureArray(parent *yaml.Node, key string) *yaml.Node {
_, valueNode := ensureKeyValue(parent, key)
if valueNode.Kind != yaml.SequenceNode {
valueNode.Kind = yaml.SequenceNode
valueNode.Tag = "!!seq"
valueNode.Content = nil
}
return valueNode
}