mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-03 21:08:08 +02:00
Add files via upload
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/skills"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -72,6 +73,7 @@ type AgentHandler struct {
|
||||
knowledgeManager interface { // 知识库管理器接口
|
||||
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
|
||||
}
|
||||
skillsManager *skills.Manager // Skills管理器
|
||||
}
|
||||
|
||||
// NewAgentHandler 创建新的Agent处理器
|
||||
@@ -101,6 +103,11 @@ func (h *AgentHandler) SetKnowledgeManager(manager interface {
|
||||
h.knowledgeManager = manager
|
||||
}
|
||||
|
||||
// SetSkillsManager 设置Skills管理器
|
||||
func (h *AgentHandler) SetSkillsManager(manager *skills.Manager) {
|
||||
h.skillsManager = manager
|
||||
}
|
||||
|
||||
// ChatRequest 聊天请求
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
@@ -169,6 +176,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
// 应用角色用户提示词和工具配置
|
||||
finalMessage := req.Message
|
||||
var roleTools []string // 角色配置的工具列表
|
||||
var roleSkills []string // 角色配置的skills列表(用于提示AI,但不硬编码内容)
|
||||
if req.Role != "" && req.Role != "默认" {
|
||||
if h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
|
||||
@@ -182,6 +190,11 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
roleTools = role.Tools
|
||||
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
|
||||
}
|
||||
// 获取角色配置的skills列表(用于在系统提示词中提示AI,但不硬编码内容)
|
||||
if len(role.Skills) > 0 {
|
||||
roleSkills = role.Skills
|
||||
h.logger.Info("角色配置了skills,将在系统提示词中提示AI", zap.String("role", req.Role), zap.Int("skillCount", len(roleSkills)), zap.Strings("skills", roleSkills))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -193,7 +206,8 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools)
|
||||
// 注意:skills不会硬编码注入,但会在系统提示词中提示AI这个角色推荐使用哪些skills
|
||||
result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools, roleSkills)
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
|
||||
@@ -515,6 +529,10 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
// 因为mcps是MCP服务器名称,不是工具列表
|
||||
h.logger.Info("角色配置使用旧的mcps字段,将使用所有工具", zap.String("role", req.Role))
|
||||
}
|
||||
// 注意:角色配置的skills不再硬编码注入,AI可以通过list_skills和read_skill工具按需调用
|
||||
if len(role.Skills) > 0 {
|
||||
h.logger.Info("角色配置了skills,AI可通过工具按需调用", zap.String("role", req.Role), zap.Int("skillCount", len(role.Skills)), zap.Strings("skills", role.Skills))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -599,7 +617,18 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
|
||||
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
sendEvent("progress", "正在分析您的请求...", nil)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
// 注意:skills不会硬编码注入,但会在系统提示词中提示AI这个角色推荐使用哪些skills
|
||||
var roleSkills []string // 角色配置的skills列表(用于提示AI,但不硬编码内容)
|
||||
if req.Role != "" && req.Role != "默认" {
|
||||
if h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
|
||||
if len(role.Skills) > 0 {
|
||||
roleSkills = role.Skills
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools, roleSkills)
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
cause := context.Cause(baseCtx)
|
||||
@@ -1099,6 +1128,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
// 应用角色用户提示词和工具配置
|
||||
finalMessage := task.Message
|
||||
var roleTools []string // 角色配置的工具列表
|
||||
var roleSkills []string // 角色配置的skills列表(用于提示AI,但不硬编码内容)
|
||||
if queue.Role != "" && queue.Role != "默认" {
|
||||
if h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
|
||||
@@ -1112,6 +1142,11 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
roleTools = role.Tools
|
||||
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
|
||||
}
|
||||
// 获取角色配置的skills列表(用于在系统提示词中提示AI,但不硬编码内容)
|
||||
if len(role.Skills) > 0 {
|
||||
roleSkills = role.Skills
|
||||
h.logger.Info("角色配置了skills,将在系统提示词中提示AI", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("skillCount", len(roleSkills)), zap.Strings("skills", roleSkills))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1144,7 +1179,8 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
// 存储取消函数,以便在取消队列时能够取消当前任务
|
||||
h.batchTaskManager.SetTaskCancel(queueID, cancel)
|
||||
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
|
||||
result, err := h.agent.AgentLoopWithProgress(ctx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools)
|
||||
// 注意:skills不会硬编码注入,但会在系统提示词中提示AI这个角色推荐使用哪些skills
|
||||
result, err := h.agent.AgentLoopWithProgress(ctx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools, roleSkills)
|
||||
// 任务执行完成,清理取消函数
|
||||
h.batchTaskManager.SetTaskCancel(queueID, nil)
|
||||
cancel()
|
||||
|
||||
@@ -18,9 +18,15 @@ import (
|
||||
|
||||
// RoleHandler 角色处理器
|
||||
type RoleHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
skillsManager SkillsManager // Skills管理器接口(可选)
|
||||
}
|
||||
|
||||
// SkillsManager Skills管理器接口
|
||||
type SkillsManager interface {
|
||||
ListSkills() ([]string, error)
|
||||
}
|
||||
|
||||
// NewRoleHandler 创建新的角色处理器
|
||||
@@ -32,6 +38,34 @@ func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *
|
||||
}
|
||||
}
|
||||
|
||||
// SetSkillsManager 设置Skills管理器
|
||||
func (h *RoleHandler) SetSkillsManager(manager SkillsManager) {
|
||||
h.skillsManager = manager
|
||||
}
|
||||
|
||||
// GetSkills 获取所有可用的skills列表
|
||||
func (h *RoleHandler) GetSkills(c *gin.Context) {
|
||||
if h.skillsManager == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skills": []string{},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
skills, err := h.skillsManager.ListSkills()
|
||||
if err != nil {
|
||||
h.logger.Warn("获取skills列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skills": []string{},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skills": skills,
|
||||
})
|
||||
}
|
||||
|
||||
// GetRoles 获取所有角色
|
||||
func (h *RoleHandler) GetRoles(c *gin.Context) {
|
||||
if h.config.Roles == nil {
|
||||
|
||||
@@ -0,0 +1,778 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/skills"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// SkillsHandler Skills处理器
|
||||
type SkillsHandler struct {
|
||||
manager *skills.Manager
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
db *database.DB // 数据库连接(用于获取调用统计)
|
||||
}
|
||||
|
||||
// NewSkillsHandler 创建新的Skills处理器
|
||||
func NewSkillsHandler(manager *skills.Manager, cfg *config.Config, configPath string, logger *zap.Logger) *SkillsHandler {
|
||||
return &SkillsHandler{
|
||||
manager: manager,
|
||||
config: cfg,
|
||||
configPath: configPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SetDB 设置数据库连接(用于获取调用统计)
|
||||
func (h *SkillsHandler) SetDB(db *database.DB) {
|
||||
h.db = db
|
||||
}
|
||||
|
||||
// GetSkills 获取所有skills列表(支持分页和搜索)
|
||||
func (h *SkillsHandler) GetSkills(c *gin.Context) {
|
||||
skillList, err := h.manager.ListSkills()
|
||||
if err != nil {
|
||||
h.logger.Error("获取skills列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 搜索参数
|
||||
searchKeyword := strings.TrimSpace(c.Query("search"))
|
||||
|
||||
// 先加载所有skills的详细信息用于搜索过滤
|
||||
allSkillsInfo := make([]map[string]interface{}, 0, len(skillList))
|
||||
for _, skillName := range skillList {
|
||||
skill, err := h.manager.LoadSkill(skillName)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载skill失败", zap.String("skill", skillName), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取文件信息
|
||||
skillPath := skill.Path
|
||||
skillFile := filepath.Join(skillPath, "SKILL.md")
|
||||
// 尝试其他可能的文件名
|
||||
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
|
||||
alternatives := []string{
|
||||
filepath.Join(skillPath, "skill.md"),
|
||||
filepath.Join(skillPath, "README.md"),
|
||||
filepath.Join(skillPath, "readme.md"),
|
||||
}
|
||||
for _, alt := range alternatives {
|
||||
if _, err := os.Stat(alt); err == nil {
|
||||
skillFile = alt
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fileInfo, _ := os.Stat(skillFile)
|
||||
var fileSize int64
|
||||
var modTime string
|
||||
if fileInfo != nil {
|
||||
fileSize = fileInfo.Size()
|
||||
modTime = fileInfo.ModTime().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
skillInfo := map[string]interface{}{
|
||||
"name": skill.Name,
|
||||
"description": skill.Description,
|
||||
"path": skill.Path,
|
||||
"file_size": fileSize,
|
||||
"mod_time": modTime,
|
||||
}
|
||||
allSkillsInfo = append(allSkillsInfo, skillInfo)
|
||||
}
|
||||
|
||||
// 如果有搜索关键词,进行过滤
|
||||
filteredSkillsInfo := allSkillsInfo
|
||||
if searchKeyword != "" {
|
||||
keywordLower := strings.ToLower(searchKeyword)
|
||||
filteredSkillsInfo = make([]map[string]interface{}, 0)
|
||||
for _, skillInfo := range allSkillsInfo {
|
||||
name := strings.ToLower(fmt.Sprintf("%v", skillInfo["name"]))
|
||||
description := strings.ToLower(fmt.Sprintf("%v", skillInfo["description"]))
|
||||
path := strings.ToLower(fmt.Sprintf("%v", skillInfo["path"]))
|
||||
|
||||
if strings.Contains(name, keywordLower) ||
|
||||
strings.Contains(description, keywordLower) ||
|
||||
strings.Contains(path, keywordLower) {
|
||||
filteredSkillsInfo = append(filteredSkillsInfo, skillInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 分页参数
|
||||
limit := 20 // 默认每页20条
|
||||
offset := 0
|
||||
if limitStr := c.Query("limit"); limitStr != "" {
|
||||
if parsed, err := parseInt(limitStr); err == nil && parsed > 0 {
|
||||
// 允许更大的limit用于搜索场景,但设置一个合理的上限(10000)
|
||||
if parsed <= 10000 {
|
||||
limit = parsed
|
||||
} else {
|
||||
limit = 10000
|
||||
}
|
||||
}
|
||||
}
|
||||
if offsetStr := c.Query("offset"); offsetStr != "" {
|
||||
if parsed, err := parseInt(offsetStr); err == nil && parsed >= 0 {
|
||||
offset = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// 计算分页范围
|
||||
total := len(filteredSkillsInfo)
|
||||
start := offset
|
||||
end := offset + limit
|
||||
if start > total {
|
||||
start = total
|
||||
}
|
||||
if end > total {
|
||||
end = total
|
||||
}
|
||||
|
||||
// 获取当前页的skill列表
|
||||
var paginatedSkillsInfo []map[string]interface{}
|
||||
if start < end {
|
||||
paginatedSkillsInfo = filteredSkillsInfo[start:end]
|
||||
} else {
|
||||
paginatedSkillsInfo = []map[string]interface{}{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skills": paginatedSkillsInfo,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// GetSkill 获取单个skill的详细信息
|
||||
func (h *SkillsHandler) GetSkill(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
skill, err := h.manager.LoadSkill(skillName)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载skill失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取文件信息
|
||||
skillPath := skill.Path
|
||||
skillFile := filepath.Join(skillPath, "SKILL.md")
|
||||
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
|
||||
alternatives := []string{
|
||||
filepath.Join(skillPath, "skill.md"),
|
||||
filepath.Join(skillPath, "README.md"),
|
||||
filepath.Join(skillPath, "readme.md"),
|
||||
}
|
||||
for _, alt := range alternatives {
|
||||
if _, err := os.Stat(alt); err == nil {
|
||||
skillFile = alt
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fileInfo, _ := os.Stat(skillFile)
|
||||
var fileSize int64
|
||||
var modTime string
|
||||
if fileInfo != nil {
|
||||
fileSize = fileInfo.Size()
|
||||
modTime = fileInfo.ModTime().Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skill": map[string]interface{}{
|
||||
"name": skill.Name,
|
||||
"description": skill.Description,
|
||||
"content": skill.Content,
|
||||
"path": skill.Path,
|
||||
"file_size": fileSize,
|
||||
"mod_time": modTime,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// GetSkillBoundRoles 获取绑定指定skill的角色列表
|
||||
func (h *SkillsHandler) GetSkillBoundRoles(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
boundRoles := h.getRolesBoundToSkill(skillName)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"skill": skillName,
|
||||
"bound_roles": boundRoles,
|
||||
"bound_count": len(boundRoles),
|
||||
})
|
||||
}
|
||||
|
||||
// getRolesBoundToSkill 获取绑定指定skill的角色列表(不修改配置)
|
||||
func (h *SkillsHandler) getRolesBoundToSkill(skillName string) []string {
|
||||
if h.config.Roles == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
boundRoles := make([]string, 0)
|
||||
for roleName, role := range h.config.Roles {
|
||||
// 确保角色名称正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
// 检查角色的Skills列表中是否包含该skill
|
||||
if len(role.Skills) > 0 {
|
||||
for _, skill := range role.Skills {
|
||||
if skill == skillName {
|
||||
boundRoles = append(boundRoles, roleName)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return boundRoles
|
||||
}
|
||||
|
||||
// CreateSkill 创建新skill
|
||||
func (h *SkillsHandler) CreateSkill(c *gin.Context) {
|
||||
var req struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证skill名称(只允许字母、数字、连字符和下划线)
|
||||
if !isValidSkillName(req.Name) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称只能包含字母、数字、连字符和下划线"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取skills目录
|
||||
skillsDir := h.config.SkillsDir
|
||||
if skillsDir == "" {
|
||||
skillsDir = "skills"
|
||||
}
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
if !filepath.IsAbs(skillsDir) {
|
||||
skillsDir = filepath.Join(configDir, skillsDir)
|
||||
}
|
||||
|
||||
// 创建skill目录
|
||||
skillDir := filepath.Join(skillsDir, req.Name)
|
||||
if err := os.MkdirAll(skillDir, 0755); err != nil {
|
||||
h.logger.Error("创建skill目录失败", zap.String("skill", req.Name), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建skill目录失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否已存在
|
||||
skillFile := filepath.Join(skillDir, "SKILL.md")
|
||||
if _, err := os.Stat(skillFile); err == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill已存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建SKILL.md内容
|
||||
var content strings.Builder
|
||||
content.WriteString("---\n")
|
||||
content.WriteString(fmt.Sprintf("name: %s\n", req.Name))
|
||||
if req.Description != "" {
|
||||
// 如果描述包含特殊字符,需要加引号
|
||||
desc := req.Description
|
||||
if strings.Contains(desc, ":") || strings.Contains(desc, "\n") {
|
||||
desc = fmt.Sprintf(`"%s"`, strings.ReplaceAll(desc, `"`, `\"`))
|
||||
}
|
||||
content.WriteString(fmt.Sprintf("description: %s\n", desc))
|
||||
}
|
||||
content.WriteString("version: 1.0.0\n")
|
||||
content.WriteString("---\n\n")
|
||||
content.WriteString(req.Content)
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(skillFile, []byte(content.String()), 0644); err != nil {
|
||||
h.logger.Error("创建skill文件失败", zap.String("skill", req.Name), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建skill文件失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("创建skill成功", zap.String("skill", req.Name))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "skill已创建",
|
||||
"skill": map[string]interface{}{
|
||||
"name": req.Name,
|
||||
"path": skillDir,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSkill 更新skill
|
||||
func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取skills目录
|
||||
skillsDir := h.config.SkillsDir
|
||||
if skillsDir == "" {
|
||||
skillsDir = "skills"
|
||||
}
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
if !filepath.IsAbs(skillsDir) {
|
||||
skillsDir = filepath.Join(configDir, skillsDir)
|
||||
}
|
||||
|
||||
// 查找skill文件
|
||||
skillDir := filepath.Join(skillsDir, skillName)
|
||||
skillFile := filepath.Join(skillDir, "SKILL.md")
|
||||
if _, err := os.Stat(skillFile); os.IsNotExist(err) {
|
||||
alternatives := []string{
|
||||
filepath.Join(skillDir, "skill.md"),
|
||||
filepath.Join(skillDir, "README.md"),
|
||||
filepath.Join(skillDir, "readme.md"),
|
||||
}
|
||||
found := false
|
||||
for _, alt := range alternatives {
|
||||
if _, err := os.Stat(alt); err == nil {
|
||||
skillFile = alt
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "skill不存在"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 读取现有文件以保留front matter中的name
|
||||
existingContent, err := os.ReadFile(skillFile)
|
||||
if err != nil {
|
||||
h.logger.Error("读取skill文件失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "读取skill文件失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析现有内容,提取name
|
||||
existingName := skillName
|
||||
contentStr := string(existingContent)
|
||||
if strings.HasPrefix(contentStr, "---") {
|
||||
parts := strings.SplitN(contentStr, "---", 3)
|
||||
if len(parts) >= 2 {
|
||||
frontMatter := parts[1]
|
||||
lines := strings.Split(frontMatter, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "name:") {
|
||||
name := strings.TrimSpace(strings.TrimPrefix(line, "name:"))
|
||||
name = strings.Trim(name, `"'`)
|
||||
if name != "" {
|
||||
existingName = name
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 构建新的SKILL.md内容
|
||||
var newContent strings.Builder
|
||||
newContent.WriteString("---\n")
|
||||
newContent.WriteString(fmt.Sprintf("name: %s\n", existingName))
|
||||
if req.Description != "" {
|
||||
// 如果描述包含特殊字符,需要加引号
|
||||
desc := req.Description
|
||||
if strings.Contains(desc, ":") || strings.Contains(desc, "\n") {
|
||||
desc = fmt.Sprintf(`"%s"`, strings.ReplaceAll(desc, `"`, `\"`))
|
||||
}
|
||||
newContent.WriteString(fmt.Sprintf("description: %s\n", desc))
|
||||
}
|
||||
newContent.WriteString("version: 1.0.0\n")
|
||||
newContent.WriteString("---\n\n")
|
||||
newContent.WriteString(req.Content)
|
||||
|
||||
// 写入文件(统一使用SKILL.md)
|
||||
targetFile := filepath.Join(skillDir, "SKILL.md")
|
||||
if err := os.WriteFile(targetFile, []byte(newContent.String()), 0644); err != nil {
|
||||
h.logger.Error("更新skill文件失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新skill文件失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果原文件不是SKILL.md,删除旧文件
|
||||
if skillFile != targetFile {
|
||||
os.Remove(skillFile)
|
||||
}
|
||||
|
||||
h.logger.Info("更新skill成功", zap.String("skill", skillName))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "skill已更新",
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteSkill 删除skill
|
||||
func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否有角色绑定了该skill,如果有则自动移除绑定
|
||||
affectedRoles := h.removeSkillFromRoles(skillName)
|
||||
if len(affectedRoles) > 0 {
|
||||
h.logger.Info("从角色中移除skill绑定",
|
||||
zap.String("skill", skillName),
|
||||
zap.Strings("roles", affectedRoles))
|
||||
}
|
||||
|
||||
// 获取skills目录
|
||||
skillsDir := h.config.SkillsDir
|
||||
if skillsDir == "" {
|
||||
skillsDir = "skills"
|
||||
}
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
if !filepath.IsAbs(skillsDir) {
|
||||
skillsDir = filepath.Join(configDir, skillsDir)
|
||||
}
|
||||
|
||||
// 删除skill目录
|
||||
skillDir := filepath.Join(skillsDir, skillName)
|
||||
if err := os.RemoveAll(skillDir); err != nil {
|
||||
h.logger.Error("删除skill失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除skill失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
responseMsg := "skill已删除"
|
||||
if len(affectedRoles) > 0 {
|
||||
responseMsg = fmt.Sprintf("skill已删除,已自动从 %d 个角色中移除绑定: %s",
|
||||
len(affectedRoles), strings.Join(affectedRoles, ", "))
|
||||
}
|
||||
|
||||
h.logger.Info("删除skill成功", zap.String("skill", skillName))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": responseMsg,
|
||||
"affected_roles": affectedRoles,
|
||||
})
|
||||
}
|
||||
|
||||
// GetSkillStats 获取skills调用统计信息
|
||||
func (h *SkillsHandler) GetSkillStats(c *gin.Context) {
|
||||
skillList, err := h.manager.ListSkills()
|
||||
if err != nil {
|
||||
h.logger.Error("获取skills列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取skills目录
|
||||
skillsDir := h.config.SkillsDir
|
||||
if skillsDir == "" {
|
||||
skillsDir = "skills"
|
||||
}
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
if !filepath.IsAbs(skillsDir) {
|
||||
skillsDir = filepath.Join(configDir, skillsDir)
|
||||
}
|
||||
|
||||
// 从数据库加载调用统计
|
||||
var skillStatsMap map[string]*database.SkillStats
|
||||
if h.db != nil {
|
||||
dbStats, err := h.db.LoadSkillStats()
|
||||
if err != nil {
|
||||
h.logger.Warn("从数据库加载Skills统计信息失败", zap.Error(err))
|
||||
skillStatsMap = make(map[string]*database.SkillStats)
|
||||
} else {
|
||||
skillStatsMap = dbStats
|
||||
}
|
||||
} else {
|
||||
skillStatsMap = make(map[string]*database.SkillStats)
|
||||
}
|
||||
|
||||
// 构建统计信息(包含所有skills,即使没有调用记录)
|
||||
statsList := make([]map[string]interface{}, 0, len(skillList))
|
||||
totalCalls := 0
|
||||
totalSuccess := 0
|
||||
totalFailed := 0
|
||||
|
||||
for _, skillName := range skillList {
|
||||
stat, exists := skillStatsMap[skillName]
|
||||
if !exists {
|
||||
stat = &database.SkillStats{
|
||||
SkillName: skillName,
|
||||
TotalCalls: 0,
|
||||
SuccessCalls: 0,
|
||||
FailedCalls: 0,
|
||||
}
|
||||
}
|
||||
|
||||
totalCalls += stat.TotalCalls
|
||||
totalSuccess += stat.SuccessCalls
|
||||
totalFailed += stat.FailedCalls
|
||||
|
||||
lastCallTimeStr := ""
|
||||
if stat.LastCallTime != nil {
|
||||
lastCallTimeStr = stat.LastCallTime.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
statsList = append(statsList, map[string]interface{}{
|
||||
"skill_name": stat.SkillName,
|
||||
"total_calls": stat.TotalCalls,
|
||||
"success_calls": stat.SuccessCalls,
|
||||
"failed_calls": stat.FailedCalls,
|
||||
"last_call_time": lastCallTimeStr,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"total_skills": len(skillList),
|
||||
"total_calls": totalCalls,
|
||||
"total_success": totalSuccess,
|
||||
"total_failed": totalFailed,
|
||||
"skills_dir": skillsDir,
|
||||
"stats": statsList,
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSkillStats 清空所有Skills统计信息
|
||||
func (h *SkillsHandler) ClearSkillStats(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.ClearSkillStats(); err != nil {
|
||||
h.logger.Error("清空Skills统计信息失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("已清空所有Skills统计信息")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "已清空所有Skills统计信息",
|
||||
})
|
||||
}
|
||||
|
||||
// ClearSkillStatsByName 清空指定skill的统计信息
|
||||
func (h *SkillsHandler) ClearSkillStatsByName(c *gin.Context) {
|
||||
skillName := c.Param("name")
|
||||
if skillName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "skill名称不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "数据库连接未配置"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.db.ClearSkillStatsByName(skillName); err != nil {
|
||||
h.logger.Error("清空指定skill统计信息失败", zap.String("skill", skillName), zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空统计信息失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("已清空指定skill统计信息", zap.String("skill", skillName))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": fmt.Sprintf("已清空skill '%s' 的统计信息", skillName),
|
||||
})
|
||||
}
|
||||
|
||||
// removeSkillFromRoles 从所有角色中移除指定的skill绑定
|
||||
// 返回受影响角色名称列表
|
||||
func (h *SkillsHandler) removeSkillFromRoles(skillName string) []string {
|
||||
if h.config.Roles == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
affectedRoles := make([]string, 0)
|
||||
rolesToUpdate := make(map[string]config.RoleConfig)
|
||||
|
||||
// 遍历所有角色,查找并移除skill绑定
|
||||
for roleName, role := range h.config.Roles {
|
||||
// 确保角色名称正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
// 检查角色的Skills列表中是否包含要删除的skill
|
||||
if len(role.Skills) > 0 {
|
||||
updated := false
|
||||
newSkills := make([]string, 0, len(role.Skills))
|
||||
for _, skill := range role.Skills {
|
||||
if skill != skillName {
|
||||
newSkills = append(newSkills, skill)
|
||||
} else {
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
if updated {
|
||||
role.Skills = newSkills
|
||||
rolesToUpdate[roleName] = role
|
||||
affectedRoles = append(affectedRoles, roleName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有角色需要更新,保存到文件
|
||||
if len(rolesToUpdate) > 0 {
|
||||
// 更新内存中的配置
|
||||
for roleName, role := range rolesToUpdate {
|
||||
h.config.Roles[roleName] = role
|
||||
}
|
||||
// 保存更新后的角色配置到文件
|
||||
if err := h.saveRolesConfig(); err != nil {
|
||||
h.logger.Error("保存角色配置失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return affectedRoles
|
||||
}
|
||||
|
||||
// saveRolesConfig 保存角色配置到文件(从SkillsHandler调用)
|
||||
func (h *SkillsHandler) saveRolesConfig() error {
|
||||
configDir := filepath.Dir(h.configPath)
|
||||
rolesDir := h.config.RolesDir
|
||||
if rolesDir == "" {
|
||||
rolesDir = "roles" // 默认目录
|
||||
}
|
||||
|
||||
// 如果是相对路径,相对于配置文件所在目录
|
||||
if !filepath.IsAbs(rolesDir) {
|
||||
rolesDir = filepath.Join(configDir, rolesDir)
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(rolesDir, 0755); err != nil {
|
||||
return fmt.Errorf("创建角色目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存每个角色到独立的文件
|
||||
if h.config.Roles != nil {
|
||||
for roleName, role := range h.config.Roles {
|
||||
// 确保角色名称正确设置
|
||||
if role.Name == "" {
|
||||
role.Name = roleName
|
||||
}
|
||||
|
||||
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
|
||||
safeFileName := sanitizeRoleFileName(role.Name)
|
||||
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
|
||||
|
||||
// 将角色配置序列化为YAML
|
||||
roleData, err := yaml.Marshal(&role)
|
||||
if err != nil {
|
||||
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
|
||||
roleDataStr := string(roleData)
|
||||
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
|
||||
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
|
||||
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
|
||||
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
|
||||
roleData = []byte(roleDataStr)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
|
||||
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeRoleFileName 将角色名称转换为安全的文件名
|
||||
func sanitizeRoleFileName(name string) string {
|
||||
// 替换可能不安全的字符
|
||||
replacer := map[rune]string{
|
||||
'/': "_",
|
||||
'\\': "_",
|
||||
':': "_",
|
||||
'*': "_",
|
||||
'?': "_",
|
||||
'"': "_",
|
||||
'<': "_",
|
||||
'>': "_",
|
||||
'|': "_",
|
||||
' ': "_",
|
||||
}
|
||||
|
||||
var result []rune
|
||||
for _, r := range name {
|
||||
if replacement, ok := replacer[r]; ok {
|
||||
result = append(result, []rune(replacement)...)
|
||||
} else {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
|
||||
fileName := string(result)
|
||||
// 如果文件名为空,使用默认名称
|
||||
if fileName == "" {
|
||||
fileName = "role"
|
||||
}
|
||||
|
||||
return fileName
|
||||
}
|
||||
|
||||
// isValidSkillName 验证skill名称是否有效
|
||||
func isValidSkillName(name string) bool {
|
||||
if name == "" || len(name) > 100 {
|
||||
return false
|
||||
}
|
||||
// 只允许字母、数字、连字符和下划线
|
||||
for _, r := range name {
|
||||
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' || r == '_') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user