mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-21 15:16:55 +02:00
Add files via upload
This commit is contained in:
+63
-12
@@ -12,7 +12,9 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -66,13 +68,14 @@ type AgentHandler struct {
|
||||
logger *zap.Logger
|
||||
tasks *AgentTaskManager
|
||||
batchTaskManager *BatchTaskManager
|
||||
knowledgeManager interface { // 知识库管理器接口
|
||||
config *config.Config // 配置引用,用于获取角色信息
|
||||
knowledgeManager interface { // 知识库管理器接口
|
||||
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
|
||||
}
|
||||
}
|
||||
|
||||
// NewAgentHandler 创建新的Agent处理器
|
||||
func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler {
|
||||
func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, logger *zap.Logger) *AgentHandler {
|
||||
batchTaskManager := NewBatchTaskManager()
|
||||
batchTaskManager.SetDB(db)
|
||||
|
||||
@@ -87,6 +90,7 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *A
|
||||
logger: logger,
|
||||
tasks: NewAgentTaskManager(),
|
||||
batchTaskManager: batchTaskManager,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,6 +105,7 @@ func (h *AgentHandler) SetKnowledgeManager(manager interface {
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
Role string `json:"role,omitempty"` // 角色名称
|
||||
}
|
||||
|
||||
// ChatResponse 聊天响应
|
||||
@@ -161,14 +166,34 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
|
||||
// 保存用户消息
|
||||
// 应用角色用户提示词和工具配置
|
||||
finalMessage := req.Message
|
||||
var roleTools []string // 角色配置的工具列表
|
||||
if req.Role != "" && req.Role != "默认" {
|
||||
if h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
|
||||
// 应用用户提示词
|
||||
if role.UserPrompt != "" {
|
||||
finalMessage = role.UserPrompt + "\n\n" + req.Message
|
||||
h.logger.Info("应用角色用户提示词", zap.String("role", req.Role))
|
||||
}
|
||||
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
|
||||
if len(role.Tools) > 0 {
|
||||
roleTools = role.Tools
|
||||
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存用户消息(保存原始消息,不包含角色提示词)
|
||||
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 执行Agent Loop,传入历史消息和对话ID
|
||||
result, err := h.agent.AgentLoopWithConversationID(c.Request.Context(), req.Message, agentHistoryMessages, conversationID)
|
||||
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools)
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
|
||||
@@ -231,7 +256,7 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
|
||||
if eventType == "tool_call" {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
toolName, _ := dataMap["toolName"].(string)
|
||||
if toolName == "search_knowledge_base" {
|
||||
if toolName == builtin.ToolSearchKnowledgeBase {
|
||||
if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" {
|
||||
if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok {
|
||||
toolCallCache[toolCallId] = argumentsObj
|
||||
@@ -245,7 +270,7 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
|
||||
if eventType == "tool_result" && h.knowledgeManager != nil {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
toolName, _ := dataMap["toolName"].(string)
|
||||
if toolName == "search_knowledge_base" {
|
||||
if toolName == builtin.ToolSearchKnowledgeBase {
|
||||
// 提取检索信息
|
||||
query := ""
|
||||
riskType := ""
|
||||
@@ -470,7 +495,32 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
|
||||
// 保存用户消息
|
||||
// 应用角色用户提示词和工具配置
|
||||
finalMessage := req.Message
|
||||
var roleTools []string // 角色配置的工具列表
|
||||
if req.Role != "" && req.Role != "默认" {
|
||||
if h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
|
||||
// 应用用户提示词
|
||||
if role.UserPrompt != "" {
|
||||
finalMessage = role.UserPrompt + "\n\n" + req.Message
|
||||
h.logger.Info("应用角色用户提示词", zap.String("role", req.Role))
|
||||
}
|
||||
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
|
||||
if len(role.Tools) > 0 {
|
||||
roleTools = role.Tools
|
||||
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
|
||||
} else if len(role.MCPs) > 0 {
|
||||
// 向后兼容:如果只有mcps字段,暂时使用空列表(表示使用所有工具)
|
||||
// 因为mcps是MCP服务器名称,不是工具列表
|
||||
h.logger.Info("角色配置使用旧的mcps字段,将使用所有工具", zap.String("role", req.Role))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 如果roleTools为空,表示使用所有工具(默认角色或未配置工具的角色)
|
||||
|
||||
// 保存用户消息(保存原始消息,不包含角色提示词)
|
||||
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.Error(err))
|
||||
@@ -547,9 +597,9 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
taskStatus := "completed"
|
||||
defer h.tasks.FinishTask(conversationID, taskStatus)
|
||||
|
||||
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断
|
||||
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
sendEvent("progress", "正在分析您的请求...", nil)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, req.Message, agentHistoryMessages, conversationID, progressCallback)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
cause := context.Cause(baseCtx)
|
||||
@@ -759,7 +809,7 @@ func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
|
||||
|
||||
// BatchTaskRequest 批量任务请求
|
||||
type BatchTaskRequest struct {
|
||||
Title string `json:"title"` // 任务标题(可选)
|
||||
Title string `json:"title"` // 任务标题(可选)
|
||||
Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务
|
||||
}
|
||||
|
||||
@@ -1072,7 +1122,8 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
// 存储取消函数,以便在取消队列时能够取消当前任务
|
||||
h.batchTaskManager.SetTaskCancel(queueID, cancel)
|
||||
result, err := h.agent.AgentLoopWithProgress(ctx, task.Message, []agent.ChatMessage{}, conversationID, progressCallback)
|
||||
// 批量任务暂时不支持角色工具过滤,使用所有工具(传入nil)
|
||||
result, err := h.agent.AgentLoopWithProgress(ctx, task.Message, []agent.ChatMessage{}, conversationID, progressCallback, nil)
|
||||
// 任务执行完成,清理取消函数
|
||||
h.batchTaskManager.SetTaskCancel(queueID, nil)
|
||||
cancel()
|
||||
|
||||
+114
-12
@@ -147,6 +147,7 @@ type ToolConfigInfo struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
|
||||
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
|
||||
RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具)
|
||||
}
|
||||
|
||||
// GetConfig 获取当前配置
|
||||
@@ -272,11 +273,12 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
|
||||
// GetToolsResponse 获取工具列表响应(分页)
|
||||
type GetToolsResponse struct {
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
Tools []ToolConfigInfo `json:"tools"`
|
||||
Total int `json:"total"`
|
||||
TotalEnabled int `json:"total_enabled"` // 已启用的工具总数
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// GetTools 获取工具列表(支持分页和搜索)
|
||||
@@ -305,6 +307,23 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
searchTermLower = strings.ToLower(searchTerm)
|
||||
}
|
||||
|
||||
// 解析角色参数,用于过滤工具并标注启用状态
|
||||
roleName := c.Query("role")
|
||||
var roleToolsSet map[string]bool // 角色配置的工具集合
|
||||
var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色)
|
||||
if roleName != "" && roleName != "默认" && h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[roleName]; exists && role.Enabled {
|
||||
if len(role.Tools) > 0 {
|
||||
// 角色配置了工具列表,只使用这些工具
|
||||
roleToolsSet = make(map[string]bool)
|
||||
for _, toolKey := range role.Tools {
|
||||
roleToolsSet[toolKey] = true
|
||||
}
|
||||
roleUsesAllTools = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有内部工具并应用搜索过滤
|
||||
configToolMap := make(map[string]bool)
|
||||
allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
|
||||
@@ -325,6 +344,31 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
toolInfo.Description = desc
|
||||
}
|
||||
|
||||
// 根据角色配置标注工具状态
|
||||
if roleName != "" {
|
||||
if roleUsesAllTools {
|
||||
// 角色使用所有工具,标注启用的工具为role_enabled=true
|
||||
if tool.Enabled {
|
||||
roleEnabled := true
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
roleEnabled := false
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
}
|
||||
} else {
|
||||
// 角色配置了工具列表,检查工具是否在列表中
|
||||
// 内部工具使用工具名称作为key
|
||||
if roleToolsSet[tool.Name] {
|
||||
roleEnabled := tool.Enabled // 工具必须在角色列表中且本身启用
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 不在角色列表中,标记为false
|
||||
roleEnabled := false
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有关键词,进行搜索过滤
|
||||
if searchTermLower != "" {
|
||||
nameLower := strings.ToLower(toolInfo.Name)
|
||||
@@ -361,6 +405,26 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
IsExternal: false,
|
||||
}
|
||||
|
||||
// 根据角色配置标注工具状态
|
||||
if roleName != "" {
|
||||
if roleUsesAllTools {
|
||||
// 角色使用所有工具,直接注册的工具默认启用
|
||||
roleEnabled := true
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 角色配置了工具列表,检查工具是否在列表中
|
||||
// 内部工具使用工具名称作为key
|
||||
if roleToolsSet[mcpTool.Name] {
|
||||
roleEnabled := true // 在角色列表中且工具本身启用
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 不在角色列表中,标记为false
|
||||
roleEnabled := false
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有关键词,进行搜索过滤
|
||||
if searchTermLower != "" {
|
||||
nameLower := strings.ToLower(toolInfo.Name)
|
||||
@@ -439,18 +503,55 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
allTools = append(allTools, ToolConfigInfo{
|
||||
toolInfo := ToolConfigInfo{
|
||||
Name: actualToolName, // 显示实际工具名称,不带前缀
|
||||
Description: description,
|
||||
Enabled: enabled,
|
||||
IsExternal: true,
|
||||
ExternalMCP: mcpName,
|
||||
})
|
||||
}
|
||||
|
||||
// 根据角色配置标注工具状态
|
||||
if roleName != "" {
|
||||
if roleUsesAllTools {
|
||||
// 角色使用所有工具,标注启用的工具为role_enabled=true
|
||||
toolInfo.RoleEnabled = &enabled
|
||||
} else {
|
||||
// 角色配置了工具列表,检查工具是否在列表中
|
||||
// 外部工具使用 "mcpName::toolName" 格式作为key
|
||||
externalToolKey := externalTool.Name // 这是 "mcpName::toolName" 格式
|
||||
if roleToolsSet[externalToolKey] {
|
||||
roleEnabled := enabled // 工具必须在角色列表中且本身启用
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
} else {
|
||||
// 不在角色列表中,标记为false
|
||||
roleEnabled := false
|
||||
toolInfo.RoleEnabled = &roleEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allTools = append(allTools, toolInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果角色配置了工具列表,过滤工具(只保留列表中的工具,但保留其他工具并标记为禁用)
|
||||
// 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态
|
||||
// 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用
|
||||
|
||||
total := len(allTools)
|
||||
// 统计已启用的工具数(在角色中的启用工具数)
|
||||
totalEnabled := 0
|
||||
for _, tool := range allTools {
|
||||
if tool.RoleEnabled != nil && *tool.RoleEnabled {
|
||||
totalEnabled++
|
||||
} else if tool.RoleEnabled == nil && tool.Enabled {
|
||||
// 如果未指定角色,统计所有启用的工具
|
||||
totalEnabled++
|
||||
}
|
||||
}
|
||||
|
||||
totalPages := (total + pageSize - 1) / pageSize
|
||||
if totalPages == 0 {
|
||||
totalPages = 1
|
||||
@@ -471,11 +572,12 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, GetToolsResponse{
|
||||
Tools: tools,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
Tools: tools,
|
||||
Total: total,
|
||||
TotalEnabled: totalEnabled,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,453 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RoleHandler 角色处理器
|
||||
type RoleHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRoleHandler 创建新的角色处理器
|
||||
func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *RoleHandler {
|
||||
return &RoleHandler{
|
||||
config: cfg,
|
||||
configPath: configPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetRoles 获取所有角色
|
||||
func (h *RoleHandler) GetRoles(c *gin.Context) {
|
||||
if h.config.Roles == nil {
|
||||
h.config.Roles = make(map[string]config.RoleConfig)
|
||||
}
|
||||
|
||||
roles := make([]config.RoleConfig, 0, len(h.config.Roles))
|
||||
for key, role := range h.config.Roles {
|
||||
// 确保角色的key与name一致
|
||||
if role.Name == "" {
|
||||
role.Name = key
|
||||
}
|
||||
roles = append(roles, role)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"roles": roles,
|
||||
})
|
||||
}
|
||||
|
||||
// GetRole 获取单个角色
|
||||
func (h *RoleHandler) GetRole(c *gin.Context) {
|
||||
roleName := c.Param("name")
|
||||
if roleName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.config.Roles == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
role, exists := h.config.Roles[roleName]
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 确保角色的name与key一致
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"role": role,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRole 更新角色
|
||||
func (h *RoleHandler) UpdateRole(c *gin.Context) {
|
||||
roleName := c.Param("name")
|
||||
if roleName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
var req config.RoleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 确保角色名称与请求中的name一致
|
||||
if req.Name == "" {
|
||||
req.Name = roleName
|
||||
}
|
||||
|
||||
// 初始化Roles map
|
||||
if h.config.Roles == nil {
|
||||
h.config.Roles = make(map[string]config.RoleConfig)
|
||||
}
|
||||
|
||||
// 删除所有与角色name相同但key不同的旧角色(避免重复)
|
||||
// 使用角色name作为key,确保唯一性
|
||||
finalKey := req.Name
|
||||
keysToDelete := make([]string, 0)
|
||||
for key := range h.config.Roles {
|
||||
// 如果key与最终的key不同,但name相同,则标记为删除
|
||||
if key != finalKey {
|
||||
role := h.config.Roles[key]
|
||||
// 确保角色的name字段正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = key
|
||||
}
|
||||
if role.Name == req.Name {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
// 删除旧的角色
|
||||
for _, key := range keysToDelete {
|
||||
delete(h.config.Roles, key)
|
||||
h.logger.Info("删除重复的角色", zap.String("oldKey", key), zap.String("name", req.Name))
|
||||
}
|
||||
|
||||
// 如果当前更新的key与最终key不同,也需要删除旧的
|
||||
if roleName != finalKey {
|
||||
delete(h.config.Roles, roleName)
|
||||
}
|
||||
|
||||
// 如果角色名称改变,需要删除旧文件
|
||||
if roleName != finalKey {
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 删除旧的角色文件
|
||||
oldSafeFileName := sanitizeFileName(roleName)
|
||||
oldRoleFileYaml := filepath.Join(rolesDir, oldSafeFileName+".yaml")
|
||||
oldRoleFileYml := filepath.Join(rolesDir, oldSafeFileName+".yml")
|
||||
|
||||
if _, err := os.Stat(oldRoleFileYaml); err == nil {
|
||||
if err := os.Remove(oldRoleFileYaml); err != nil {
|
||||
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYaml), zap.Error(err))
|
||||
}
|
||||
}
|
||||
if _, err := os.Stat(oldRoleFileYml); err == nil {
|
||||
if err := os.Remove(oldRoleFileYml); err != nil {
|
||||
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYml), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 使用角色name作为key来保存(确保唯一性)
|
||||
h.config.Roles[finalKey] = req
|
||||
|
||||
// 保存配置到文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已更新",
|
||||
"role": req,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateRole 创建新角色
|
||||
func (h *RoleHandler) CreateRole(c *gin.Context) {
|
||||
var req config.RoleConfig
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化Roles map
|
||||
if h.config.Roles == nil {
|
||||
h.config.Roles = make(map[string]config.RoleConfig)
|
||||
}
|
||||
|
||||
// 检查角色是否已存在
|
||||
if _, exists := h.config.Roles[req.Name]; exists {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色已存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建角色(默认启用)
|
||||
if !req.Enabled {
|
||||
req.Enabled = true
|
||||
}
|
||||
|
||||
h.config.Roles[req.Name] = req
|
||||
|
||||
// 保存配置到文件
|
||||
if err := h.saveConfig(); err != nil {
|
||||
h.logger.Error("保存配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("创建角色", zap.String("roleName", req.Name))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已创建",
|
||||
"role": req,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteRole 删除角色
|
||||
func (h *RoleHandler) DeleteRole(c *gin.Context) {
|
||||
roleName := c.Param("name")
|
||||
if roleName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.config.Roles == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
if _, exists := h.config.Roles[roleName]; !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 不允许删除"默认"角色
|
||||
if roleName == "默认" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除默认角色"})
|
||||
return
|
||||
}
|
||||
|
||||
delete(h.config.Roles, roleName)
|
||||
|
||||
// 删除对应的角色文件
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 尝试删除角色文件(.yaml 和 .yml)
|
||||
safeFileName := sanitizeFileName(roleName)
|
||||
roleFileYaml := filepath.Join(rolesDir, safeFileName+".yaml")
|
||||
roleFileYml := filepath.Join(rolesDir, safeFileName+".yml")
|
||||
|
||||
// 删除 .yaml 文件(如果存在)
|
||||
if _, err := os.Stat(roleFileYaml); err == nil {
|
||||
if err := os.Remove(roleFileYaml); err != nil {
|
||||
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYaml), zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYaml))
|
||||
}
|
||||
}
|
||||
|
||||
// 删除 .yml 文件(如果存在)
|
||||
if _, err := os.Stat(roleFileYml); err == nil {
|
||||
if err := os.Remove(roleFileYml); err != nil {
|
||||
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYml), zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYml))
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("删除角色", zap.String("roleName", roleName))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已删除",
|
||||
})
|
||||
}
|
||||
|
||||
// saveConfig 保存配置到目录中的文件
|
||||
func (h *RoleHandler) saveConfig() error {
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(rolesDir, 0755); err != nil {
|
||||
return fmt.Errorf("创建角色目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存每个角色到独立的文件
|
||||
if h.config.Roles != nil {
|
||||
for roleName, role := range h.config.Roles {
|
||||
// 确保角色名称正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
|
||||
safeFileName := sanitizeFileName(role.Name)
|
||||
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
|
||||
|
||||
// 将角色配置序列化为YAML
|
||||
roleData, err := yaml.Marshal(&role)
|
||||
if err != nil {
|
||||
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
|
||||
roleDataStr := string(roleData)
|
||||
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
|
||||
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
|
||||
// 使用负向前瞻确保后面没有引号,或者直接匹配没有引号的情况
|
||||
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
|
||||
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
|
||||
roleData = []byte(roleDataStr)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
|
||||
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeFileName 将角色名称转换为安全的文件名
|
||||
func sanitizeFileName(name string) string {
|
||||
// 替换可能不安全的字符
|
||||
replacer := map[rune]string{
|
||||
'/': "_",
|
||||
'\\': "_",
|
||||
':': "_",
|
||||
'*': "_",
|
||||
'?': "_",
|
||||
'"': "_",
|
||||
'<': "_",
|
||||
'>': "_",
|
||||
'|': "_",
|
||||
' ': "_",
|
||||
}
|
||||
|
||||
var result []rune
|
||||
for _, r := range name {
|
||||
if replacement, ok := replacer[r]; ok {
|
||||
result = append(result, []rune(replacement)...)
|
||||
} else {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
|
||||
fileName := string(result)
|
||||
// 如果文件名为空,使用默认名称
|
||||
if fileName == "" {
|
||||
fileName = "role"
|
||||
}
|
||||
|
||||
return fileName
|
||||
}
|
||||
|
||||
// updateRolesConfig 更新角色配置
|
||||
func updateRolesConfig(doc *yaml.Node, cfg config.RolesConfig) {
|
||||
root := doc.Content[0]
|
||||
rolesNode := ensureMap(root, "roles")
|
||||
|
||||
// 清空现有角色
|
||||
if rolesNode.Kind == yaml.MappingNode {
|
||||
rolesNode.Content = nil
|
||||
}
|
||||
|
||||
// 添加新角色(使用name作为key,确保唯一性)
|
||||
if cfg.Roles != nil {
|
||||
// 先建立一个以name为key的map,去重(保留最后一个)
|
||||
rolesByName := make(map[string]config.RoleConfig)
|
||||
for roleKey, role := range cfg.Roles {
|
||||
// 确保角色的name字段正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleKey
|
||||
}
|
||||
// 使用name作为最终key,如果有多个key对应相同的name,只保留最后一个
|
||||
rolesByName[role.Name] = role
|
||||
}
|
||||
|
||||
// 将去重后的角色写入YAML
|
||||
for roleName, role := range rolesByName {
|
||||
roleNode := ensureMap(rolesNode, roleName)
|
||||
setStringInMap(roleNode, "name", role.Name)
|
||||
setStringInMap(roleNode, "description", role.Description)
|
||||
setStringInMap(roleNode, "user_prompt", role.UserPrompt)
|
||||
if role.Icon != "" {
|
||||
setStringInMap(roleNode, "icon", role.Icon)
|
||||
}
|
||||
setBoolInMap(roleNode, "enabled", role.Enabled)
|
||||
|
||||
// 添加工具列表(优先使用tools字段)
|
||||
if len(role.Tools) > 0 {
|
||||
toolsNode := ensureArray(roleNode, "tools")
|
||||
toolsNode.Content = nil
|
||||
for _, toolKey := range role.Tools {
|
||||
toolNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: toolKey}
|
||||
toolsNode.Content = append(toolsNode.Content, toolNode)
|
||||
}
|
||||
} else if len(role.MCPs) > 0 {
|
||||
// 向后兼容:如果没有tools但有mcps,保存mcps
|
||||
mcpsNode := ensureArray(roleNode, "mcps")
|
||||
mcpsNode.Content = nil
|
||||
for _, mcpName := range role.MCPs {
|
||||
mcpNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: mcpName}
|
||||
mcpsNode.Content = append(mcpsNode.Content, mcpNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensureArray 确保数组中存在指定key的数组节点
|
||||
func ensureArray(parent *yaml.Node, key string) *yaml.Node {
|
||||
_, valueNode := ensureKeyValue(parent, key)
|
||||
if valueNode.Kind != yaml.SequenceNode {
|
||||
valueNode.Kind = yaml.SequenceNode
|
||||
valueNode.Tag = "!!seq"
|
||||
valueNode.Content = nil
|
||||
}
|
||||
return valueNode
|
||||
}
|
||||
Reference in New Issue
Block a user