mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 08:19:54 +02:00
1、修复删除知识项后总分类数统计错误:将 updateKnowledgeStats 中的 || 改为 != null 检查,并移除会错误更新统计的 updateKnowledgeStatsAfterDelete 调用。 2、为 MCP 状态监控页面添加了批量删除功能(复选框、全选、批量删除按钮)和每页显示数量配置(选择器位于分页控件左侧,设置保存到 localStorage)。
538 lines
14 KiB
Go
538 lines
14 KiB
Go
package database
|
||
|
||
import (
|
||
"database/sql"
|
||
"encoding/json"
|
||
"strings"
|
||
"time"
|
||
|
||
"cyberstrike-ai/internal/mcp"
|
||
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// SaveToolExecution 保存工具执行记录
|
||
func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
|
||
argsJSON, err := json.Marshal(exec.Arguments)
|
||
if err != nil {
|
||
db.logger.Warn("序列化执行参数失败", zap.Error(err))
|
||
argsJSON = []byte("{}")
|
||
}
|
||
|
||
var resultJSON sql.NullString
|
||
if exec.Result != nil {
|
||
resultBytes, err := json.Marshal(exec.Result)
|
||
if err != nil {
|
||
db.logger.Warn("序列化执行结果失败", zap.Error(err))
|
||
} else {
|
||
resultJSON = sql.NullString{String: string(resultBytes), Valid: true}
|
||
}
|
||
}
|
||
|
||
var errorText sql.NullString
|
||
if exec.Error != "" {
|
||
errorText = sql.NullString{String: exec.Error, Valid: true}
|
||
}
|
||
|
||
var endTime sql.NullTime
|
||
if exec.EndTime != nil {
|
||
endTime = sql.NullTime{Time: *exec.EndTime, Valid: true}
|
||
}
|
||
|
||
var durationMs sql.NullInt64
|
||
if exec.Duration > 0 {
|
||
durationMs = sql.NullInt64{Int64: exec.Duration.Milliseconds(), Valid: true}
|
||
}
|
||
|
||
query := `
|
||
INSERT OR REPLACE INTO tool_executions
|
||
(id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
`
|
||
|
||
_, err = db.Exec(query,
|
||
exec.ID,
|
||
exec.ToolName,
|
||
string(argsJSON),
|
||
exec.Status,
|
||
resultJSON,
|
||
errorText,
|
||
exec.StartTime,
|
||
endTime,
|
||
durationMs,
|
||
time.Now(),
|
||
)
|
||
|
||
if err != nil {
|
||
db.logger.Error("保存工具执行记录失败", zap.Error(err), zap.String("executionId", exec.ID))
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// CountToolExecutions 统计工具执行记录总数
|
||
func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
|
||
query := `SELECT COUNT(*) FROM tool_executions`
|
||
args := []interface{}{}
|
||
conditions := []string{}
|
||
if status != "" {
|
||
conditions = append(conditions, "status = ?")
|
||
args = append(args, status)
|
||
}
|
||
if toolName != "" {
|
||
// 支持部分匹配(模糊搜索),不区分大小写
|
||
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
|
||
args = append(args, "%"+strings.ToLower(toolName)+"%")
|
||
}
|
||
if len(conditions) > 0 {
|
||
query += ` WHERE ` + conditions[0]
|
||
for i := 1; i < len(conditions); i++ {
|
||
query += ` AND ` + conditions[i]
|
||
}
|
||
}
|
||
var count int
|
||
err := db.QueryRow(query, args...).Scan(&count)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return count, nil
|
||
}
|
||
|
||
// LoadToolExecutions 加载所有工具执行记录(支持分页)
|
||
func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) {
|
||
return db.LoadToolExecutionsWithPagination(0, 1000, "", "")
|
||
}
|
||
|
||
// LoadToolExecutionsWithPagination 分页加载工具执行记录
|
||
// limit: 最大返回记录数,0 表示使用默认值 1000
|
||
// offset: 跳过的记录数,用于分页
|
||
// status: 状态筛选,空字符串表示不过滤
|
||
// toolName: 工具名称筛选,空字符串表示不过滤
|
||
func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status, toolName string) ([]*mcp.ToolExecution, error) {
|
||
if limit <= 0 {
|
||
limit = 1000 // 默认限制
|
||
}
|
||
if limit > 10000 {
|
||
limit = 10000 // 最大限制,防止一次性加载过多数据
|
||
}
|
||
|
||
query := `
|
||
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
|
||
FROM tool_executions
|
||
`
|
||
args := []interface{}{}
|
||
conditions := []string{}
|
||
if status != "" {
|
||
conditions = append(conditions, "status = ?")
|
||
args = append(args, status)
|
||
}
|
||
if toolName != "" {
|
||
// 支持部分匹配(模糊搜索),不区分大小写
|
||
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
|
||
args = append(args, "%"+strings.ToLower(toolName)+"%")
|
||
}
|
||
if len(conditions) > 0 {
|
||
query += ` WHERE ` + conditions[0]
|
||
for i := 1; i < len(conditions); i++ {
|
||
query += ` AND ` + conditions[i]
|
||
}
|
||
}
|
||
query += ` ORDER BY start_time DESC LIMIT ? OFFSET ?`
|
||
args = append(args, limit, offset)
|
||
|
||
rows, err := db.Query(query, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var executions []*mcp.ToolExecution
|
||
for rows.Next() {
|
||
var exec mcp.ToolExecution
|
||
var argsJSON string
|
||
var resultJSON sql.NullString
|
||
var errorText sql.NullString
|
||
var endTime sql.NullTime
|
||
var durationMs sql.NullInt64
|
||
|
||
err := rows.Scan(
|
||
&exec.ID,
|
||
&exec.ToolName,
|
||
&argsJSON,
|
||
&exec.Status,
|
||
&resultJSON,
|
||
&errorText,
|
||
&exec.StartTime,
|
||
&endTime,
|
||
&durationMs,
|
||
)
|
||
if err != nil {
|
||
db.logger.Warn("加载执行记录失败", zap.Error(err))
|
||
continue
|
||
}
|
||
|
||
// 解析参数
|
||
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
|
||
db.logger.Warn("解析执行参数失败", zap.Error(err))
|
||
exec.Arguments = make(map[string]interface{})
|
||
}
|
||
|
||
// 解析结果
|
||
if resultJSON.Valid && resultJSON.String != "" {
|
||
var result mcp.ToolResult
|
||
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
|
||
db.logger.Warn("解析执行结果失败", zap.Error(err))
|
||
} else {
|
||
exec.Result = &result
|
||
}
|
||
}
|
||
|
||
// 设置错误
|
||
if errorText.Valid {
|
||
exec.Error = errorText.String
|
||
}
|
||
|
||
// 设置结束时间
|
||
if endTime.Valid {
|
||
exec.EndTime = &endTime.Time
|
||
}
|
||
|
||
// 设置持续时间
|
||
if durationMs.Valid {
|
||
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
||
}
|
||
|
||
executions = append(executions, &exec)
|
||
}
|
||
|
||
return executions, nil
|
||
}
|
||
|
||
// GetToolExecution 根据ID获取单条工具执行记录
|
||
func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
|
||
query := `
|
||
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
|
||
FROM tool_executions
|
||
WHERE id = ?
|
||
`
|
||
|
||
row := db.QueryRow(query, id)
|
||
|
||
var exec mcp.ToolExecution
|
||
var argsJSON string
|
||
var resultJSON sql.NullString
|
||
var errorText sql.NullString
|
||
var endTime sql.NullTime
|
||
var durationMs sql.NullInt64
|
||
|
||
err := row.Scan(
|
||
&exec.ID,
|
||
&exec.ToolName,
|
||
&argsJSON,
|
||
&exec.Status,
|
||
&resultJSON,
|
||
&errorText,
|
||
&exec.StartTime,
|
||
&endTime,
|
||
&durationMs,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
|
||
db.logger.Warn("解析执行参数失败", zap.Error(err))
|
||
exec.Arguments = make(map[string]interface{})
|
||
}
|
||
|
||
if resultJSON.Valid && resultJSON.String != "" {
|
||
var result mcp.ToolResult
|
||
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
|
||
db.logger.Warn("解析执行结果失败", zap.Error(err))
|
||
} else {
|
||
exec.Result = &result
|
||
}
|
||
}
|
||
|
||
if errorText.Valid {
|
||
exec.Error = errorText.String
|
||
}
|
||
|
||
if endTime.Valid {
|
||
exec.EndTime = &endTime.Time
|
||
}
|
||
|
||
if durationMs.Valid {
|
||
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
||
}
|
||
|
||
return &exec, nil
|
||
}
|
||
|
||
// DeleteToolExecution 删除工具执行记录
|
||
func (db *DB) DeleteToolExecution(id string) error {
|
||
query := `DELETE FROM tool_executions WHERE id = ?`
|
||
_, err := db.Exec(query, id)
|
||
if err != nil {
|
||
db.logger.Error("删除工具执行记录失败", zap.Error(err), zap.String("executionId", id))
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// DeleteToolExecutions 批量删除工具执行记录
|
||
func (db *DB) DeleteToolExecutions(ids []string) error {
|
||
if len(ids) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 构建 IN 查询的占位符
|
||
placeholders := make([]string, len(ids))
|
||
args := make([]interface{}, len(ids))
|
||
for i, id := range ids {
|
||
placeholders[i] = "?"
|
||
args[i] = id
|
||
}
|
||
|
||
query := `DELETE FROM tool_executions WHERE id IN (` + strings.Join(placeholders, ",") + `)`
|
||
_, err := db.Exec(query, args...)
|
||
if err != nil {
|
||
db.logger.Error("批量删除工具执行记录失败", zap.Error(err), zap.Int("count", len(ids)))
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetToolExecutionsByIds 根据ID列表获取工具执行记录(用于批量删除前获取统计信息)
|
||
func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error) {
|
||
if len(ids) == 0 {
|
||
return []*mcp.ToolExecution{}, nil
|
||
}
|
||
|
||
// 构建 IN 查询的占位符
|
||
placeholders := make([]string, len(ids))
|
||
args := make([]interface{}, len(ids))
|
||
for i, id := range ids {
|
||
placeholders[i] = "?"
|
||
args[i] = id
|
||
}
|
||
|
||
query := `
|
||
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
|
||
FROM tool_executions
|
||
WHERE id IN (` + strings.Join(placeholders, ",") + `)
|
||
`
|
||
|
||
rows, err := db.Query(query, args...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
var executions []*mcp.ToolExecution
|
||
for rows.Next() {
|
||
var exec mcp.ToolExecution
|
||
var argsJSON string
|
||
var resultJSON sql.NullString
|
||
var errorText sql.NullString
|
||
var endTime sql.NullTime
|
||
var durationMs sql.NullInt64
|
||
|
||
err := rows.Scan(
|
||
&exec.ID,
|
||
&exec.ToolName,
|
||
&argsJSON,
|
||
&exec.Status,
|
||
&resultJSON,
|
||
&errorText,
|
||
&exec.StartTime,
|
||
&endTime,
|
||
&durationMs,
|
||
)
|
||
if err != nil {
|
||
db.logger.Warn("加载执行记录失败", zap.Error(err))
|
||
continue
|
||
}
|
||
|
||
// 解析参数
|
||
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
|
||
db.logger.Warn("解析执行参数失败", zap.Error(err))
|
||
exec.Arguments = make(map[string]interface{})
|
||
}
|
||
|
||
// 解析结果
|
||
if resultJSON.Valid && resultJSON.String != "" {
|
||
var result mcp.ToolResult
|
||
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
|
||
db.logger.Warn("解析执行结果失败", zap.Error(err))
|
||
} else {
|
||
exec.Result = &result
|
||
}
|
||
}
|
||
|
||
// 设置错误
|
||
if errorText.Valid {
|
||
exec.Error = errorText.String
|
||
}
|
||
|
||
// 设置结束时间
|
||
if endTime.Valid {
|
||
exec.EndTime = &endTime.Time
|
||
}
|
||
|
||
// 设置持续时间
|
||
if durationMs.Valid {
|
||
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
||
}
|
||
|
||
executions = append(executions, &exec)
|
||
}
|
||
|
||
return executions, nil
|
||
}
|
||
|
||
// SaveToolStats 保存工具统计信息
|
||
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
|
||
var lastCallTime sql.NullTime
|
||
if stats.LastCallTime != nil {
|
||
lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true}
|
||
}
|
||
|
||
query := `
|
||
INSERT OR REPLACE INTO tool_stats
|
||
(tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
`
|
||
|
||
_, err := db.Exec(query,
|
||
toolName,
|
||
stats.TotalCalls,
|
||
stats.SuccessCalls,
|
||
stats.FailedCalls,
|
||
lastCallTime,
|
||
time.Now(),
|
||
)
|
||
|
||
if err != nil {
|
||
db.logger.Error("保存工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// LoadToolStats 加载所有工具统计信息
|
||
func (db *DB) LoadToolStats() (map[string]*mcp.ToolStats, error) {
|
||
query := `
|
||
SELECT tool_name, total_calls, success_calls, failed_calls, last_call_time
|
||
FROM tool_stats
|
||
`
|
||
|
||
rows, err := db.Query(query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer rows.Close()
|
||
|
||
stats := make(map[string]*mcp.ToolStats)
|
||
for rows.Next() {
|
||
var stat mcp.ToolStats
|
||
var lastCallTime sql.NullTime
|
||
|
||
err := rows.Scan(
|
||
&stat.ToolName,
|
||
&stat.TotalCalls,
|
||
&stat.SuccessCalls,
|
||
&stat.FailedCalls,
|
||
&lastCallTime,
|
||
)
|
||
if err != nil {
|
||
db.logger.Warn("加载统计信息失败", zap.Error(err))
|
||
continue
|
||
}
|
||
|
||
if lastCallTime.Valid {
|
||
stat.LastCallTime = &lastCallTime.Time
|
||
}
|
||
|
||
stats[stat.ToolName] = &stat
|
||
}
|
||
|
||
return stats, nil
|
||
}
|
||
|
||
// UpdateToolStats 更新工具统计信息(累加模式)
|
||
func (db *DB) UpdateToolStats(toolName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
|
||
var lastCallTimeSQL sql.NullTime
|
||
if lastCallTime != nil {
|
||
lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true}
|
||
}
|
||
|
||
query := `
|
||
INSERT INTO tool_stats (tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
ON CONFLICT(tool_name) DO UPDATE SET
|
||
total_calls = total_calls + ?,
|
||
success_calls = success_calls + ?,
|
||
failed_calls = failed_calls + ?,
|
||
last_call_time = COALESCE(?, last_call_time),
|
||
updated_at = ?
|
||
`
|
||
|
||
_, err := db.Exec(query,
|
||
toolName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
|
||
totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
|
||
)
|
||
|
||
if err != nil {
|
||
db.logger.Error("更新工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// DecreaseToolStats 减少工具统计信息(用于删除执行记录时)
|
||
// 如果统计信息变为0,则删除该统计记录
|
||
func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error {
|
||
// 先更新统计信息
|
||
query := `
|
||
UPDATE tool_stats SET
|
||
total_calls = CASE WHEN total_calls - ? < 0 THEN 0 ELSE total_calls - ? END,
|
||
success_calls = CASE WHEN success_calls - ? < 0 THEN 0 ELSE success_calls - ? END,
|
||
failed_calls = CASE WHEN failed_calls - ? < 0 THEN 0 ELSE failed_calls - ? END,
|
||
updated_at = ?
|
||
WHERE tool_name = ?
|
||
`
|
||
|
||
_, err := db.Exec(query, totalCalls, totalCalls, successCalls, successCalls, failedCalls, failedCalls, time.Now(), toolName)
|
||
if err != nil {
|
||
db.logger.Error("减少工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
||
return err
|
||
}
|
||
|
||
// 检查更新后的 total_calls 是否为 0,如果是则删除该统计记录
|
||
checkQuery := `SELECT total_calls FROM tool_stats WHERE tool_name = ?`
|
||
var newTotalCalls int
|
||
err = db.QueryRow(checkQuery, toolName).Scan(&newTotalCalls)
|
||
if err != nil {
|
||
// 如果查询失败(记录不存在),直接返回
|
||
return nil
|
||
}
|
||
|
||
// 如果 total_calls 为 0,删除该统计记录
|
||
if newTotalCalls == 0 {
|
||
deleteQuery := `DELETE FROM tool_stats WHERE tool_name = ?`
|
||
_, err = db.Exec(deleteQuery, toolName)
|
||
if err != nil {
|
||
db.logger.Warn("删除零统计记录失败", zap.Error(err), zap.String("toolName", toolName))
|
||
// 不返回错误,因为主要操作(更新统计)已成功
|
||
} else {
|
||
db.logger.Info("已删除零统计记录", zap.String("toolName", toolName))
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|