Add files via upload

This commit is contained in:
公明
2026-01-11 02:03:33 +08:00
committed by GitHub
parent 4ca1aa9aa8
commit 3aee7022c4
30 changed files with 3759 additions and 86 deletions
+35 -8
View File
@@ -12,6 +12,7 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/storage"
@@ -302,16 +303,16 @@ type ProgressCallback func(eventType, message string, data interface{})
// AgentLoop 执行Agent循环
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil)
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, "", nil, nil)
}
// AgentLoopWithConversationID 执行Agent循环(带对话ID
func (a *Agent) AgentLoopWithConversationID(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string) (*AgentLoopResult, error) {
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil)
return a.AgentLoopWithProgress(ctx, userInput, historyMessages, conversationID, nil, nil)
}
// AgentLoopWithProgress 执行Agent循环(带进度回调和对话ID)
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback) (*AgentLoopResult, error) {
func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, historyMessages []ChatMessage, conversationID string, callback ProgressCallback, roleTools []string) (*AgentLoopResult, error) {
// 设置当前对话ID
a.mu.Lock()
a.currentConversationID = conversationID
@@ -401,8 +402,8 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
漏洞记录要求:
- 当你发现有效漏洞时,必须使用 record_vulnerability 工具记录漏洞详情
- 漏洞记录应包含:标题、描述、严重程度、类型、目标、证明(POC)、影响和修复建议
- 当你发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 工具记录漏洞详情
` + `- 漏洞记录应包含:标题、描述、严重程度、类型、目标、证明(POC)、影响和修复建议
- 严重程度评估标准:
* critical(严重):可导致系统完全被控制、数据泄露、服务中断等
* high(高):可导致敏感信息泄露、权限提升、重要功能被绕过等
@@ -512,7 +513,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}
// 获取可用工具
tools := a.getAvailableTools()
tools := a.getAvailableTools(roleTools)
// 记录当前上下文的Token用量,展示压缩器运行状态
if a.memoryCompressor != nil {
@@ -837,13 +838,29 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
// getAvailableTools 获取可用工具
// 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗
func (a *Agent) getAvailableTools() []Tool {
// roleTools: 角色配置的工具列表(toolKey格式),如果为空或nil,则使用所有工具(默认角色)
func (a *Agent) getAvailableTools(roleTools []string) []Tool {
// 构建角色工具集合(用于快速查找)
roleToolSet := make(map[string]bool)
if len(roleTools) > 0 {
for _, toolKey := range roleTools {
roleToolSet[toolKey] = true
}
}
// 从MCP服务器获取所有已注册的内部工具
mcpTools := a.mcpServer.GetAllTools()
// 转换为OpenAI格式的工具定义
tools := make([]Tool, 0, len(mcpTools))
for _, mcpTool := range mcpTools {
// 如果指定了角色工具列表,只添加在列表中的工具
if len(roleToolSet) > 0 {
toolKey := mcpTool.Name // 内置工具使用工具名称作为key
if !roleToolSet[toolKey] {
continue // 不在角色工具列表中,跳过
}
}
// 使用简短描述(如果存在),否则使用详细描述
description := mcpTool.ShortDescription
if description == "" {
@@ -883,6 +900,16 @@ func (a *Agent) getAvailableTools() []Tool {
// 将外部MCP工具添加到工具列表(只添加启用的工具)
for _, externalTool := range externalTools {
// 外部工具使用 "mcpName::toolName" 作为toolKey
externalToolKey := externalTool.Name
// 如果指定了角色工具列表,只添加在列表中的工具
if len(roleToolSet) > 0 {
if !roleToolSet[externalToolKey] {
continue // 不在角色工具列表中,跳过
}
}
// 解析工具名称:mcpName::toolName
var mcpName, actualToolName string
if idx := strings.Index(externalTool.Name, "::"); idx > 0 {
@@ -1136,7 +1163,7 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
)
// 如果是record_vulnerability工具,自动添加conversation_id
if toolName == "record_vulnerability" {
if toolName == builtin.ToolRecordVulnerability {
a.mu.RLock()
conversationID := a.currentConversationID
a.mu.RUnlock()
+13 -2
View File
@@ -16,6 +16,7 @@ import (
"cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/storage"
@@ -278,7 +279,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
}
// 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
if knowledgeManager != nil {
agentHandler.SetKnowledgeManager(knowledgeManager)
@@ -292,6 +293,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
// 创建 App 实例(部分字段稍后填充)
app := &App{
@@ -368,6 +370,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
attackChainHandler,
app, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler,
roleHandler,
mcpServer,
authManager,
)
@@ -428,6 +431,7 @@ func setupRoutes(
attackChainHandler *handler.AttackChainHandler,
app *App, // 传递 App 实例以便动态获取 knowledgeHandler
vulnerabilityHandler *handler.VulnerabilityHandler,
roleHandler *handler.RoleHandler,
mcpServer *mcp.Server,
authManager *security.AuthManager,
) {
@@ -653,6 +657,13 @@ func setupRoutes(
protected.PUT("/vulnerabilities/:id", vulnerabilityHandler.UpdateVulnerability)
protected.DELETE("/vulnerabilities/:id", vulnerabilityHandler.DeleteVulnerability)
// 角色管理
protected.GET("/roles", roleHandler.GetRoles)
protected.GET("/roles/:name", roleHandler.GetRole)
protected.POST("/roles", roleHandler.CreateRole)
protected.PUT("/roles/:name", roleHandler.UpdateRole)
protected.DELETE("/roles/:name", roleHandler.DeleteRole)
// MCP端点
protected.POST("/mcp", func(c *gin.Context) {
mcpServer.HandleHTTP(c.Writer, c.Request)
@@ -672,7 +683,7 @@ func setupRoutes(
// registerVulnerabilityTool 注册漏洞记录工具到MCP服务器
func registerVulnerabilityTool(mcpServer *mcp.Server, db *database.DB, logger *zap.Logger) {
tool := mcp.Tool{
Name: "record_vulnerability",
Name: builtin.ToolRecordVulnerability,
Description: "记录发现的漏洞详情到漏洞管理系统。当发现有效漏洞时,使用此工具记录漏洞信息,包括标题、描述、严重程度、类型、目标、证明、影响和建议等。",
ShortDescription: "记录发现的漏洞详情到漏洞管理系统",
InputSchema: map[string]interface{}{
+135
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"gopkg.in/yaml.v3"
@@ -22,6 +23,8 @@ type Config struct {
Auth AuthConfig `yaml:"auth"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
}
type ServerConfig struct {
@@ -207,6 +210,29 @@ func Load(path string) (*Config, error) {
}
}
// 从角色目录加载角色配置
if cfg.RolesDir != "" {
configDir := filepath.Dir(path)
rolesDir := cfg.RolesDir
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
roles, err := LoadRolesFromDir(rolesDir)
if err != nil {
return nil, fmt.Errorf("从角色目录加载角色配置失败: %w", err)
}
cfg.Roles = roles
} else {
// 如果未配置 roles_dir,初始化为空 map
if cfg.Roles == nil {
cfg.Roles = make(map[string]RoleConfig)
}
}
return &cfg, nil
}
@@ -375,6 +401,98 @@ func LoadToolFromFile(path string) (*ToolConfig, error) {
return &tool, nil
}
// LoadRolesFromDir 从目录加载所有角色配置文件
func LoadRolesFromDir(dir string) (map[string]RoleConfig, error) {
roles := make(map[string]RoleConfig)
// 检查目录是否存在
if _, err := os.Stat(dir); os.IsNotExist(err) {
return roles, nil // 目录不存在时返回空map,不报错
}
// 读取目录中的所有 .yaml 和 .yml 文件
entries, err := os.ReadDir(dir)
if err != nil {
return nil, fmt.Errorf("读取角色目录失败: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") {
continue
}
filePath := filepath.Join(dir, name)
role, err := LoadRoleFromFile(filePath)
if err != nil {
// 记录错误但继续加载其他文件
fmt.Printf("警告: 加载角色配置文件 %s 失败: %v\n", filePath, err)
continue
}
// 使用角色名称作为key
roleName := role.Name
if roleName == "" {
// 如果角色名称为空,使用文件名(去掉扩展名)作为名称
roleName = strings.TrimSuffix(strings.TrimSuffix(name, ".yaml"), ".yml")
role.Name = roleName
}
roles[roleName] = *role
}
return roles, nil
}
// LoadRoleFromFile 从单个文件加载角色配置
func LoadRoleFromFile(path string) (*RoleConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取文件失败: %w", err)
}
var role RoleConfig
if err := yaml.Unmarshal(data, &role); err != nil {
return nil, fmt.Errorf("解析角色配置失败: %w", err)
}
// 处理 icon 字段:如果包含 Unicode 转义格式(\U0001F3C6),转换为实际的 Unicode 字符
// Go 的 yaml 库可能不会自动解析 \U 转义序列,需要手动转换
if role.Icon != "" {
icon := role.Icon
// 去除可能的引号
icon = strings.Trim(icon, `"`)
// 检查是否是 Unicode 转义格式 \U0001F3C68位十六进制)或 \uXXXX(4位十六进制)
if len(icon) >= 3 && icon[0] == '\\' {
if icon[1] == 'U' && len(icon) >= 10 {
// \U0001F3C6 格式(8位十六进制)
if codePoint, err := strconv.ParseInt(icon[2:10], 16, 32); err == nil {
role.Icon = string(rune(codePoint))
}
} else if icon[1] == 'u' && len(icon) >= 6 {
// \uXXXX 格式(4位十六进制)
if codePoint, err := strconv.ParseInt(icon[2:6], 16, 32); err == nil {
role.Icon = string(rune(codePoint))
}
}
}
}
// 验证必需字段
if role.Name == "" {
// 如果名称为空,尝试从文件名获取
baseName := filepath.Base(path)
role.Name = strings.TrimSuffix(strings.TrimSuffix(baseName, ".yaml"), ".yml")
}
return &role, nil
}
func Default() *Config {
return &Config{
Server: ServerConfig{
@@ -448,3 +566,20 @@ type RetrievalConfig struct {
SimilarityThreshold float64 `yaml:"similarity_threshold" json:"similarity_threshold"` // 相似度阈值
HybridWeight float64 `yaml:"hybrid_weight" json:"hybrid_weight"` // 向量检索权重(0-1
}
// RolesConfig 角色配置(已废弃,使用 map[string]RoleConfig 替代)
// 保留此类型以兼容旧代码,但建议直接使用 map[string]RoleConfig
type RolesConfig struct {
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"`
}
// RoleConfig 单个角色配置
type RoleConfig struct {
Name string `yaml:"name" json:"name"` // 角色名称
Description string `yaml:"description" json:"description"` // 角色描述
UserPrompt string `yaml:"user_prompt" json:"user_prompt"` // 用户提示词(追加到用户消息前)
Icon string `yaml:"icon,omitempty" json:"icon,omitempty"` // 角色图标(可选)
Tools []string `yaml:"tools,omitempty" json:"tools,omitempty"` // 关联的工具列表(toolKey格式,如 "toolName" 或 "mcpName::toolName"
MCPs []string `yaml:"mcps,omitempty" json:"mcps,omitempty"` // 向后兼容:关联的MCP服务器列表(已废弃,使用tools替代)
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用
}
+63 -12
View File
@@ -12,7 +12,9 @@ import (
"unicode/utf8"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp/builtin"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
@@ -66,13 +68,14 @@ type AgentHandler struct {
logger *zap.Logger
tasks *AgentTaskManager
batchTaskManager *BatchTaskManager
knowledgeManager interface { // 知识库管理器接口
config *config.Config // 配置引用,用于获取角色信息
knowledgeManager interface { // 知识库管理器接口
LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error
}
}
// NewAgentHandler 创建新的Agent处理器
func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler {
func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, logger *zap.Logger) *AgentHandler {
batchTaskManager := NewBatchTaskManager()
batchTaskManager.SetDB(db)
@@ -87,6 +90,7 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *A
logger: logger,
tasks: NewAgentTaskManager(),
batchTaskManager: batchTaskManager,
config: cfg,
}
}
@@ -101,6 +105,7 @@ func (h *AgentHandler) SetKnowledgeManager(manager interface {
type ChatRequest struct {
Message string `json:"message" binding:"required"`
ConversationID string `json:"conversationId,omitempty"`
Role string `json:"role,omitempty"` // 角色名称
}
// ChatResponse 聊天响应
@@ -161,14 +166,34 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 保存用户消息
// 应用角色用户提示词和工具配置
finalMessage := req.Message
var roleTools []string // 角色配置的工具列表
if req.Role != "" && req.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
// 应用用户提示词
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + req.Message
h.logger.Info("应用角色用户提示词", zap.String("role", req.Role))
}
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
}
}
}
}
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
}
// 执行Agent Loop,传入历史消息和对话ID
result, err := h.agent.AgentLoopWithConversationID(c.Request.Context(), req.Message, agentHistoryMessages, conversationID)
// 执行Agent Loop,传入历史消息和对话ID(使用包含角色提示词的finalMessage和角色工具列表)
result, err := h.agent.AgentLoopWithProgress(c.Request.Context(), finalMessage, agentHistoryMessages, conversationID, nil, roleTools)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
@@ -231,7 +256,7 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
if eventType == "tool_call" {
if dataMap, ok := data.(map[string]interface{}); ok {
toolName, _ := dataMap["toolName"].(string)
if toolName == "search_knowledge_base" {
if toolName == builtin.ToolSearchKnowledgeBase {
if toolCallId, ok := dataMap["toolCallId"].(string); ok && toolCallId != "" {
if argumentsObj, ok := dataMap["argumentsObj"].(map[string]interface{}); ok {
toolCallCache[toolCallId] = argumentsObj
@@ -245,7 +270,7 @@ func (h *AgentHandler) createProgressCallback(conversationID, assistantMessageID
if eventType == "tool_result" && h.knowledgeManager != nil {
if dataMap, ok := data.(map[string]interface{}); ok {
toolName, _ := dataMap["toolName"].(string)
if toolName == "search_knowledge_base" {
if toolName == builtin.ToolSearchKnowledgeBase {
// 提取检索信息
query := ""
riskType := ""
@@ -470,7 +495,32 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 保存用户消息
// 应用角色用户提示词和工具配置
finalMessage := req.Message
var roleTools []string // 角色配置的工具列表
if req.Role != "" && req.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled {
// 应用用户提示词
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + req.Message
h.logger.Info("应用角色用户提示词", zap.String("role", req.Role))
}
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("role", req.Role), zap.Int("toolCount", len(roleTools)))
} else if len(role.MCPs) > 0 {
// 向后兼容:如果只有mcps字段,暂时使用空列表(表示使用所有工具)
// 因为mcps是MCP服务器名称,不是工具列表
h.logger.Info("角色配置使用旧的mcps字段,将使用所有工具", zap.String("role", req.Role))
}
}
}
}
// 如果roleTools为空,表示使用所有工具(默认角色或未配置工具的角色)
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
@@ -547,9 +597,9 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
taskStatus := "completed"
defer h.tasks.FinishTask(conversationID, taskStatus)
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断
// 执行Agent Loop,传入独立的上下文,确保任务不会因客户端断开而中断(使用包含角色提示词的finalMessage和角色工具列表)
sendEvent("progress", "正在分析您的请求...", nil)
result, err := h.agent.AgentLoopWithProgress(taskCtx, req.Message, agentHistoryMessages, conversationID, progressCallback)
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
cause := context.Cause(baseCtx)
@@ -759,7 +809,7 @@ func (h *AgentHandler) ListCompletedTasks(c *gin.Context) {
// BatchTaskRequest 批量任务请求
type BatchTaskRequest struct {
Title string `json:"title"` // 任务标题(可选)
Title string `json:"title"` // 任务标题(可选)
Tasks []string `json:"tasks" binding:"required"` // 任务列表,每行一个任务
}
@@ -1072,7 +1122,8 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
// 存储取消函数,以便在取消队列时能够取消当前任务
h.batchTaskManager.SetTaskCancel(queueID, cancel)
result, err := h.agent.AgentLoopWithProgress(ctx, task.Message, []agent.ChatMessage{}, conversationID, progressCallback)
// 批量任务暂时不支持角色工具过滤,使用所有工具(传入nil)
result, err := h.agent.AgentLoopWithProgress(ctx, task.Message, []agent.ChatMessage{}, conversationID, progressCallback, nil)
// 任务执行完成,清理取消函数
h.batchTaskManager.SetTaskCancel(queueID, nil)
cancel()
+114 -12
View File
@@ -147,6 +147,7 @@ type ToolConfigInfo struct {
Enabled bool `json:"enabled"`
IsExternal bool `json:"is_external,omitempty"` // 是否为外部MCP工具
ExternalMCP string `json:"external_mcp,omitempty"` // 外部MCP名称(如果是外部工具)
RoleEnabled *bool `json:"role_enabled,omitempty"` // 该工具在当前角色中是否启用(nil表示未指定角色或使用所有工具)
}
// GetConfig 获取当前配置
@@ -272,11 +273,12 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
// GetToolsResponse 获取工具列表响应(分页)
type GetToolsResponse struct {
Tools []ToolConfigInfo `json:"tools"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
Tools []ToolConfigInfo `json:"tools"`
Total int `json:"total"`
TotalEnabled int `json:"total_enabled"` // 已启用的工具总数
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
}
// GetTools 获取工具列表(支持分页和搜索)
@@ -305,6 +307,23 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
searchTermLower = strings.ToLower(searchTerm)
}
// 解析角色参数,用于过滤工具并标注启用状态
roleName := c.Query("role")
var roleToolsSet map[string]bool // 角色配置的工具集合
var roleUsesAllTools bool = true // 角色是否使用所有工具(默认角色)
if roleName != "" && roleName != "默认" && h.config.Roles != nil {
if role, exists := h.config.Roles[roleName]; exists && role.Enabled {
if len(role.Tools) > 0 {
// 角色配置了工具列表,只使用这些工具
roleToolsSet = make(map[string]bool)
for _, toolKey := range role.Tools {
roleToolsSet[toolKey] = true
}
roleUsesAllTools = false
}
}
}
// 获取所有内部工具并应用搜索过滤
configToolMap := make(map[string]bool)
allTools := make([]ToolConfigInfo, 0, len(h.config.Security.Tools))
@@ -325,6 +344,31 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
toolInfo.Description = desc
}
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,标注启用的工具为role_enabled=true
if tool.Enabled {
roleEnabled := true
toolInfo.RoleEnabled = &roleEnabled
} else {
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 内部工具使用工具名称作为key
if roleToolsSet[tool.Name] {
roleEnabled := tool.Enabled // 工具必须在角色列表中且本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
// 如果有关键词,进行搜索过滤
if searchTermLower != "" {
nameLower := strings.ToLower(toolInfo.Name)
@@ -361,6 +405,26 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
IsExternal: false,
}
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,直接注册的工具默认启用
roleEnabled := true
toolInfo.RoleEnabled = &roleEnabled
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 内部工具使用工具名称作为key
if roleToolsSet[mcpTool.Name] {
roleEnabled := true // 在角色列表中且工具本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
// 如果有关键词,进行搜索过滤
if searchTermLower != "" {
nameLower := strings.ToLower(toolInfo.Name)
@@ -439,18 +503,55 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
}
}
allTools = append(allTools, ToolConfigInfo{
toolInfo := ToolConfigInfo{
Name: actualToolName, // 显示实际工具名称,不带前缀
Description: description,
Enabled: enabled,
IsExternal: true,
ExternalMCP: mcpName,
})
}
// 根据角色配置标注工具状态
if roleName != "" {
if roleUsesAllTools {
// 角色使用所有工具,标注启用的工具为role_enabled=true
toolInfo.RoleEnabled = &enabled
} else {
// 角色配置了工具列表,检查工具是否在列表中
// 外部工具使用 "mcpName::toolName" 格式作为key
externalToolKey := externalTool.Name // 这是 "mcpName::toolName" 格式
if roleToolsSet[externalToolKey] {
roleEnabled := enabled // 工具必须在角色列表中且本身启用
toolInfo.RoleEnabled = &roleEnabled
} else {
// 不在角色列表中,标记为false
roleEnabled := false
toolInfo.RoleEnabled = &roleEnabled
}
}
}
allTools = append(allTools, toolInfo)
}
}
}
// 如果角色配置了工具列表,过滤工具(只保留列表中的工具,但保留其他工具并标记为禁用)
// 注意:这里我们不直接过滤掉工具,而是保留所有工具,但通过 role_enabled 字段标注状态
// 这样前端可以显示所有工具,并标注哪些工具在当前角色中可用
total := len(allTools)
// 统计已启用的工具数(在角色中的启用工具数)
totalEnabled := 0
for _, tool := range allTools {
if tool.RoleEnabled != nil && *tool.RoleEnabled {
totalEnabled++
} else if tool.RoleEnabled == nil && tool.Enabled {
// 如果未指定角色,统计所有启用的工具
totalEnabled++
}
}
totalPages := (total + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
@@ -471,11 +572,12 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
}
c.JSON(http.StatusOK, GetToolsResponse{
Tools: tools,
Total: total,
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
Tools: tools,
Total: total,
TotalEnabled: totalEnabled,
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
})
}
+453
View File
@@ -0,0 +1,453 @@
package handler
import (
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"cyberstrike-ai/internal/config"
"gopkg.in/yaml.v3"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// RoleHandler 角色处理器
type RoleHandler struct {
config *config.Config
configPath string
logger *zap.Logger
}
// NewRoleHandler 创建新的角色处理器
func NewRoleHandler(cfg *config.Config, configPath string, logger *zap.Logger) *RoleHandler {
return &RoleHandler{
config: cfg,
configPath: configPath,
logger: logger,
}
}
// GetRoles 获取所有角色
func (h *RoleHandler) GetRoles(c *gin.Context) {
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
roles := make([]config.RoleConfig, 0, len(h.config.Roles))
for key, role := range h.config.Roles {
// 确保角色的key与name一致
if role.Name == "" {
role.Name = key
}
roles = append(roles, role)
}
c.JSON(http.StatusOK, gin.H{
"roles": roles,
})
}
// GetRole 获取单个角色
func (h *RoleHandler) GetRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
if h.config.Roles == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
role, exists := h.config.Roles[roleName]
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
// 确保角色的name与key一致
if role.Name == "" {
role.Name = roleName
}
c.JSON(http.StatusOK, gin.H{
"role": role,
})
}
// UpdateRole 更新角色
func (h *RoleHandler) UpdateRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
var req config.RoleConfig
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
// 确保角色名称与请求中的name一致
if req.Name == "" {
req.Name = roleName
}
// 初始化Roles map
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
// 删除所有与角色name相同但key不同的旧角色(避免重复)
// 使用角色name作为key,确保唯一性
finalKey := req.Name
keysToDelete := make([]string, 0)
for key := range h.config.Roles {
// 如果key与最终的key不同,但name相同,则标记为删除
if key != finalKey {
role := h.config.Roles[key]
// 确保角色的name字段正确设置
if role.Name == "" {
role.Name = key
}
if role.Name == req.Name {
keysToDelete = append(keysToDelete, key)
}
}
}
// 删除旧的角色
for _, key := range keysToDelete {
delete(h.config.Roles, key)
h.logger.Info("删除重复的角色", zap.String("oldKey", key), zap.String("name", req.Name))
}
// 如果当前更新的key与最终key不同,也需要删除旧的
if roleName != finalKey {
delete(h.config.Roles, roleName)
}
// 如果角色名称改变,需要删除旧文件
if roleName != finalKey {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 删除旧的角色文件
oldSafeFileName := sanitizeFileName(roleName)
oldRoleFileYaml := filepath.Join(rolesDir, oldSafeFileName+".yaml")
oldRoleFileYml := filepath.Join(rolesDir, oldSafeFileName+".yml")
if _, err := os.Stat(oldRoleFileYaml); err == nil {
if err := os.Remove(oldRoleFileYaml); err != nil {
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYaml), zap.Error(err))
}
}
if _, err := os.Stat(oldRoleFileYml); err == nil {
if err := os.Remove(oldRoleFileYml); err != nil {
h.logger.Warn("删除旧角色配置文件失败", zap.String("file", oldRoleFileYml), zap.Error(err))
}
}
}
// 使用角色name作为key来保存(确保唯一性)
h.config.Roles[finalKey] = req
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
c.JSON(http.StatusOK, gin.H{
"message": "角色已更新",
"role": req,
})
}
// CreateRole 创建新角色
func (h *RoleHandler) CreateRole(c *gin.Context) {
var req config.RoleConfig
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
return
}
if req.Name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
// 初始化Roles map
if h.config.Roles == nil {
h.config.Roles = make(map[string]config.RoleConfig)
}
// 检查角色是否已存在
if _, exists := h.config.Roles[req.Name]; exists {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色已存在"})
return
}
// 创建角色(默认启用)
if !req.Enabled {
req.Enabled = true
}
h.config.Roles[req.Name] = req
// 保存配置到文件
if err := h.saveConfig(); err != nil {
h.logger.Error("保存配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
h.logger.Info("创建角色", zap.String("roleName", req.Name))
c.JSON(http.StatusOK, gin.H{
"message": "角色已创建",
"role": req,
})
}
// DeleteRole 删除角色
func (h *RoleHandler) DeleteRole(c *gin.Context) {
roleName := c.Param("name")
if roleName == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "角色名称不能为空"})
return
}
if h.config.Roles == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
if _, exists := h.config.Roles[roleName]; !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "角色不存在"})
return
}
// 不允许删除"默认"角色
if roleName == "默认" {
c.JSON(http.StatusBadRequest, gin.H{"error": "不能删除默认角色"})
return
}
delete(h.config.Roles, roleName)
// 删除对应的角色文件
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 尝试删除角色文件(.yaml 和 .yml)
safeFileName := sanitizeFileName(roleName)
roleFileYaml := filepath.Join(rolesDir, safeFileName+".yaml")
roleFileYml := filepath.Join(rolesDir, safeFileName+".yml")
// 删除 .yaml 文件(如果存在)
if _, err := os.Stat(roleFileYaml); err == nil {
if err := os.Remove(roleFileYaml); err != nil {
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYaml), zap.Error(err))
} else {
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYaml))
}
}
// 删除 .yml 文件(如果存在)
if _, err := os.Stat(roleFileYml); err == nil {
if err := os.Remove(roleFileYml); err != nil {
h.logger.Warn("删除角色配置文件失败", zap.String("file", roleFileYml), zap.Error(err))
} else {
h.logger.Info("已删除角色配置文件", zap.String("file", roleFileYml))
}
}
h.logger.Info("删除角色", zap.String("roleName", roleName))
c.JSON(http.StatusOK, gin.H{
"message": "角色已删除",
})
}
// saveConfig 保存配置到目录中的文件
func (h *RoleHandler) saveConfig() error {
configDir := filepath.Dir(h.configPath)
rolesDir := h.config.RolesDir
if rolesDir == "" {
rolesDir = "roles" // 默认目录
}
// 如果是相对路径,相对于配置文件所在目录
if !filepath.IsAbs(rolesDir) {
rolesDir = filepath.Join(configDir, rolesDir)
}
// 确保目录存在
if err := os.MkdirAll(rolesDir, 0755); err != nil {
return fmt.Errorf("创建角色目录失败: %w", err)
}
// 保存每个角色到独立的文件
if h.config.Roles != nil {
for roleName, role := range h.config.Roles {
// 确保角色名称正确设置
if role.Name == "" {
role.Name = roleName
}
// 使用角色名称作为文件名(安全化文件名,避免特殊字符)
safeFileName := sanitizeFileName(role.Name)
roleFile := filepath.Join(rolesDir, safeFileName+".yaml")
// 将角色配置序列化为YAML
roleData, err := yaml.Marshal(&role)
if err != nil {
h.logger.Error("序列化角色配置失败", zap.String("role", roleName), zap.Error(err))
continue
}
// 处理icon字段:确保包含\U的icon值被引号包围(YAML需要引号才能正确解析Unicode转义)
roleDataStr := string(roleData)
if role.Icon != "" && strings.HasPrefix(role.Icon, "\\U") {
// 匹配 icon: \UXXXXXXXX 格式(没有引号),排除已经有引号的情况
// 使用负向前瞻确保后面没有引号,或者直接匹配没有引号的情况
re := regexp.MustCompile(`(?m)^(icon:\s+)(\\U[0-9A-F]{8})(\s*)$`)
roleDataStr = re.ReplaceAllString(roleDataStr, `${1}"${2}"${3}`)
roleData = []byte(roleDataStr)
}
// 写入文件
if err := os.WriteFile(roleFile, roleData, 0644); err != nil {
h.logger.Error("保存角色配置文件失败", zap.String("role", roleName), zap.String("file", roleFile), zap.Error(err))
continue
}
h.logger.Info("角色配置已保存到文件", zap.String("role", roleName), zap.String("file", roleFile))
}
}
return nil
}
// sanitizeFileName 将角色名称转换为安全的文件名
func sanitizeFileName(name string) string {
// 替换可能不安全的字符
replacer := map[rune]string{
'/': "_",
'\\': "_",
':': "_",
'*': "_",
'?': "_",
'"': "_",
'<': "_",
'>': "_",
'|': "_",
' ': "_",
}
var result []rune
for _, r := range name {
if replacement, ok := replacer[r]; ok {
result = append(result, []rune(replacement)...)
} else {
result = append(result, r)
}
}
fileName := string(result)
// 如果文件名为空,使用默认名称
if fileName == "" {
fileName = "role"
}
return fileName
}
// updateRolesConfig 更新角色配置
func updateRolesConfig(doc *yaml.Node, cfg config.RolesConfig) {
root := doc.Content[0]
rolesNode := ensureMap(root, "roles")
// 清空现有角色
if rolesNode.Kind == yaml.MappingNode {
rolesNode.Content = nil
}
// 添加新角色(使用name作为key,确保唯一性)
if cfg.Roles != nil {
// 先建立一个以name为key的map,去重(保留最后一个)
rolesByName := make(map[string]config.RoleConfig)
for roleKey, role := range cfg.Roles {
// 确保角色的name字段正确设置
if role.Name == "" {
role.Name = roleKey
}
// 使用name作为最终key,如果有多个key对应相同的name,只保留最后一个
rolesByName[role.Name] = role
}
// 将去重后的角色写入YAML
for roleName, role := range rolesByName {
roleNode := ensureMap(rolesNode, roleName)
setStringInMap(roleNode, "name", role.Name)
setStringInMap(roleNode, "description", role.Description)
setStringInMap(roleNode, "user_prompt", role.UserPrompt)
if role.Icon != "" {
setStringInMap(roleNode, "icon", role.Icon)
}
setBoolInMap(roleNode, "enabled", role.Enabled)
// 添加工具列表(优先使用tools字段)
if len(role.Tools) > 0 {
toolsNode := ensureArray(roleNode, "tools")
toolsNode.Content = nil
for _, toolKey := range role.Tools {
toolNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: toolKey}
toolsNode.Content = append(toolsNode.Content, toolNode)
}
} else if len(role.MCPs) > 0 {
// 向后兼容:如果没有tools但有mcps,保存mcps
mcpsNode := ensureArray(roleNode, "mcps")
mcpsNode.Content = nil
for _, mcpName := range role.MCPs {
mcpNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: mcpName}
mcpsNode.Content = append(mcpsNode.Content, mcpNode)
}
}
}
}
}
// ensureArray 确保数组中存在指定key的数组节点
func ensureArray(parent *yaml.Node, key string) *yaml.Node {
_, valueNode := ensureKeyValue(parent, key)
if valueNode.Kind != yaml.SequenceNode {
valueNode.Kind = yaml.SequenceNode
valueNode.Tag = "!!seq"
valueNode.Content = nil
}
return valueNode
}
+8 -8
View File
@@ -161,14 +161,14 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
// 查询所有向量(或按风险类型过滤)
// 使用精确匹配(=)以提高性能和准确性
// 由于系统提供了 list_knowledge_risk_types 工具,用户应该使用准确的category名称
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
var rows *sql.Rows
if req.RiskType != "" {
// 使用精确匹配(=),性能更好且更准确
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
// 建议用户先调用 list_knowledge_risk_types 获取准确的category名称
// 由于系统提供了内置工具来获取风险类型列表,用户应该使用准确的category名称
// 同时,向量嵌入中已包含category信息,即使SQL过滤不完全匹配,向量相似度也能帮助匹配
var rows *sql.Rows
if req.RiskType != "" {
// 使用精确匹配(=),性能更好且更准确
// 使用 COLLATE NOCASE 实现大小写不敏感匹配,提高容错性
// 注意:如果用户输入的risk_type与category不完全一致,可能匹配不到
// 建议用户先调用相应的内置工具获取准确的category名称
rows, err = r.db.Query(`
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
FROM knowledge_embeddings e
+12 -11
View File
@@ -8,6 +8,7 @@ import (
"strings"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"go.uber.org/zap"
)
@@ -21,7 +22,7 @@ func RegisterKnowledgeTool(
) {
// 注册第一个工具:获取所有可用的风险类型列表
listRiskTypesTool := mcp.Tool{
Name: "list_knowledge_risk_types",
Name: builtin.ToolListKnowledgeRiskTypes,
Description: "获取知识库中所有可用的风险类型(risk_type)列表。在搜索知识库之前,可以先调用此工具获取可用的风险类型,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间并提高检索准确性。",
ShortDescription: "获取知识库中所有可用的风险类型列表",
InputSchema: map[string]interface{}{
@@ -62,7 +63,7 @@ func RegisterKnowledgeTool(
for i, category := range categories {
resultText.WriteString(fmt.Sprintf("%d. %s\n", i+1, category))
}
resultText.WriteString("\n提示:在调用 search_knowledge_base 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
resultText.WriteString("\n提示:在调用 " + builtin.ToolSearchKnowledgeBase + " 工具时,可以使用上述风险类型之一作为 risk_type 参数,以缩小搜索范围并提高检索效率。")
return &mcp.ToolResult{
Content: []mcp.Content{
@@ -79,8 +80,8 @@ func RegisterKnowledgeTool(
// 注册第二个工具:搜索知识库(保持原有功能)
searchTool := mcp.Tool{
Name: "search_knowledge_base",
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 list_knowledge_risk_types 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
Name: builtin.ToolSearchKnowledgeBase,
Description: "在知识库中搜索相关的安全知识。当你需要了解特定漏洞类型、攻击技术、检测方法等安全知识时,可以使用此工具进行检索。工具使用向量检索和混合搜索技术,能够根据查询内容的语义相似度和关键词匹配,自动找到最相关的知识片段。建议:在搜索前可以先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型,然后使用正确的 risk_type 参数进行精确搜索,这样可以大幅减少检索时间。",
ShortDescription: "搜索知识库中的安全知识(支持向量检索和混合搜索)",
InputSchema: map[string]interface{}{
"type": "object",
@@ -91,7 +92,7 @@ func RegisterKnowledgeTool(
},
"risk_type": map[string]interface{}{
"type": "string",
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 list_knowledge_risk_types 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
"description": "可选:指定风险类型(如:SQL注入、XSS、文件上传等)。建议先调用 " + builtin.ToolListKnowledgeRiskTypes + " 工具获取可用的风险类型列表,然后使用正确的风险类型进行精确搜索,这样可以大幅减少检索时间。如果不指定则搜索所有类型。",
},
},
"required": []string{"query"},
@@ -165,9 +166,9 @@ func RegisterKnowledgeTool(
// 按文档分组结果,以便更好地展示上下文
// 使用有序的slice来保持文档顺序(按最高混合分数)
type itemGroup struct {
itemID string
results []*RetrievalResult
maxScore float64 // 该文档的最高混合分数
itemID string
results []*RetrievalResult
maxScore float64 // 该文档的最高混合分数
}
itemGroups := make([]*itemGroup, 0)
itemMap := make(map[string]*itemGroup)
@@ -177,8 +178,8 @@ func RegisterKnowledgeTool(
group, exists := itemMap[itemID]
if !exists {
group = &itemGroup{
itemID: itemID,
results: make([]*RetrievalResult, 0),
itemID: itemID,
results: make([]*RetrievalResult, 0),
maxScore: result.Score,
}
itemMap[itemID] = group
@@ -219,7 +220,7 @@ func RegisterKnowledgeTool(
})
// 显示主结果(混合分数最高的,同时显示相似度和混合分数)
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
resultText.WriteString(fmt.Sprintf("--- 结果 %d (相似度: %.2f%%, 混合分数: %.2f%%) ---\n",
resultIndex, mainResult.Similarity*100, mainResult.Score*100))
resultText.WriteString(fmt.Sprintf("来源: [%s] %s (ID: %s)\n", mainResult.Item.Category, mainResult.Item.Title, mainResult.Item.ID))
+33
View File
@@ -0,0 +1,33 @@
package builtin
// 内置工具名称常量
// 所有代码中使用内置工具名称的地方都应该使用这些常量,而不是硬编码字符串
const (
// 漏洞管理工具
ToolRecordVulnerability = "record_vulnerability"
// 知识库工具
ToolListKnowledgeRiskTypes = "list_knowledge_risk_types"
ToolSearchKnowledgeBase = "search_knowledge_base"
)
// IsBuiltinTool 检查工具名称是否是内置工具
func IsBuiltinTool(toolName string) bool {
switch toolName {
case ToolRecordVulnerability,
ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase:
return true
default:
return false
}
}
// GetAllBuiltinTools 返回所有内置工具名称列表
func GetAllBuiltinTools() []string {
return []string{
ToolRecordVulnerability,
ToolListKnowledgeRiskTypes,
ToolSearchKnowledgeBase,
}
}