mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-21 10:16:32 +02:00
427 lines
11 KiB
Go
427 lines
11 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"strings"
|
|
"time"
|
|
|
|
"cyberstrike-ai/internal/mcp"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// SaveToolExecution 保存工具执行记录
|
|
func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
|
|
argsJSON, err := json.Marshal(exec.Arguments)
|
|
if err != nil {
|
|
db.logger.Warn("序列化执行参数失败", zap.Error(err))
|
|
argsJSON = []byte("{}")
|
|
}
|
|
|
|
var resultJSON sql.NullString
|
|
if exec.Result != nil {
|
|
resultBytes, err := json.Marshal(exec.Result)
|
|
if err != nil {
|
|
db.logger.Warn("序列化执行结果失败", zap.Error(err))
|
|
} else {
|
|
resultJSON = sql.NullString{String: string(resultBytes), Valid: true}
|
|
}
|
|
}
|
|
|
|
var errorText sql.NullString
|
|
if exec.Error != "" {
|
|
errorText = sql.NullString{String: exec.Error, Valid: true}
|
|
}
|
|
|
|
var endTime sql.NullTime
|
|
if exec.EndTime != nil {
|
|
endTime = sql.NullTime{Time: *exec.EndTime, Valid: true}
|
|
}
|
|
|
|
var durationMs sql.NullInt64
|
|
if exec.Duration > 0 {
|
|
durationMs = sql.NullInt64{Int64: exec.Duration.Milliseconds(), Valid: true}
|
|
}
|
|
|
|
query := `
|
|
INSERT OR REPLACE INTO tool_executions
|
|
(id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
`
|
|
|
|
_, err = db.Exec(query,
|
|
exec.ID,
|
|
exec.ToolName,
|
|
string(argsJSON),
|
|
exec.Status,
|
|
resultJSON,
|
|
errorText,
|
|
exec.StartTime,
|
|
endTime,
|
|
durationMs,
|
|
time.Now(),
|
|
)
|
|
|
|
if err != nil {
|
|
db.logger.Error("保存工具执行记录失败", zap.Error(err), zap.String("executionId", exec.ID))
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CountToolExecutions 统计工具执行记录总数
|
|
func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
|
|
query := `SELECT COUNT(*) FROM tool_executions`
|
|
args := []interface{}{}
|
|
conditions := []string{}
|
|
if status != "" {
|
|
conditions = append(conditions, "status = ?")
|
|
args = append(args, status)
|
|
}
|
|
if toolName != "" {
|
|
// 支持部分匹配(模糊搜索),不区分大小写
|
|
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
|
|
args = append(args, "%"+strings.ToLower(toolName)+"%")
|
|
}
|
|
if len(conditions) > 0 {
|
|
query += ` WHERE ` + conditions[0]
|
|
for i := 1; i < len(conditions); i++ {
|
|
query += ` AND ` + conditions[i]
|
|
}
|
|
}
|
|
var count int
|
|
err := db.QueryRow(query, args...).Scan(&count)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
// LoadToolExecutions 加载所有工具执行记录(支持分页)
|
|
func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) {
|
|
return db.LoadToolExecutionsWithPagination(0, 1000, "", "")
|
|
}
|
|
|
|
// LoadToolExecutionsWithPagination 分页加载工具执行记录
|
|
// limit: 最大返回记录数,0 表示使用默认值 1000
|
|
// offset: 跳过的记录数,用于分页
|
|
// status: 状态筛选,空字符串表示不过滤
|
|
// toolName: 工具名称筛选,空字符串表示不过滤
|
|
func (db *DB) LoadToolExecutionsWithPagination(offset, limit int, status, toolName string) ([]*mcp.ToolExecution, error) {
|
|
if limit <= 0 {
|
|
limit = 1000 // 默认限制
|
|
}
|
|
if limit > 10000 {
|
|
limit = 10000 // 最大限制,防止一次性加载过多数据
|
|
}
|
|
|
|
query := `
|
|
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
|
|
FROM tool_executions
|
|
`
|
|
args := []interface{}{}
|
|
conditions := []string{}
|
|
if status != "" {
|
|
conditions = append(conditions, "status = ?")
|
|
args = append(args, status)
|
|
}
|
|
if toolName != "" {
|
|
// 支持部分匹配(模糊搜索),不区分大小写
|
|
conditions = append(conditions, "LOWER(tool_name) LIKE ?")
|
|
args = append(args, "%"+strings.ToLower(toolName)+"%")
|
|
}
|
|
if len(conditions) > 0 {
|
|
query += ` WHERE ` + conditions[0]
|
|
for i := 1; i < len(conditions); i++ {
|
|
query += ` AND ` + conditions[i]
|
|
}
|
|
}
|
|
query += ` ORDER BY start_time DESC LIMIT ? OFFSET ?`
|
|
args = append(args, limit, offset)
|
|
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var executions []*mcp.ToolExecution
|
|
for rows.Next() {
|
|
var exec mcp.ToolExecution
|
|
var argsJSON string
|
|
var resultJSON sql.NullString
|
|
var errorText sql.NullString
|
|
var endTime sql.NullTime
|
|
var durationMs sql.NullInt64
|
|
|
|
err := rows.Scan(
|
|
&exec.ID,
|
|
&exec.ToolName,
|
|
&argsJSON,
|
|
&exec.Status,
|
|
&resultJSON,
|
|
&errorText,
|
|
&exec.StartTime,
|
|
&endTime,
|
|
&durationMs,
|
|
)
|
|
if err != nil {
|
|
db.logger.Warn("加载执行记录失败", zap.Error(err))
|
|
continue
|
|
}
|
|
|
|
// 解析参数
|
|
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
|
|
db.logger.Warn("解析执行参数失败", zap.Error(err))
|
|
exec.Arguments = make(map[string]interface{})
|
|
}
|
|
|
|
// 解析结果
|
|
if resultJSON.Valid && resultJSON.String != "" {
|
|
var result mcp.ToolResult
|
|
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
|
|
db.logger.Warn("解析执行结果失败", zap.Error(err))
|
|
} else {
|
|
exec.Result = &result
|
|
}
|
|
}
|
|
|
|
// 设置错误
|
|
if errorText.Valid {
|
|
exec.Error = errorText.String
|
|
}
|
|
|
|
// 设置结束时间
|
|
if endTime.Valid {
|
|
exec.EndTime = &endTime.Time
|
|
}
|
|
|
|
// 设置持续时间
|
|
if durationMs.Valid {
|
|
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
|
}
|
|
|
|
executions = append(executions, &exec)
|
|
}
|
|
|
|
return executions, nil
|
|
}
|
|
|
|
// GetToolExecution 根据ID获取单条工具执行记录
|
|
func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
|
|
query := `
|
|
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
|
|
FROM tool_executions
|
|
WHERE id = ?
|
|
`
|
|
|
|
row := db.QueryRow(query, id)
|
|
|
|
var exec mcp.ToolExecution
|
|
var argsJSON string
|
|
var resultJSON sql.NullString
|
|
var errorText sql.NullString
|
|
var endTime sql.NullTime
|
|
var durationMs sql.NullInt64
|
|
|
|
err := row.Scan(
|
|
&exec.ID,
|
|
&exec.ToolName,
|
|
&argsJSON,
|
|
&exec.Status,
|
|
&resultJSON,
|
|
&errorText,
|
|
&exec.StartTime,
|
|
&endTime,
|
|
&durationMs,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
|
|
db.logger.Warn("解析执行参数失败", zap.Error(err))
|
|
exec.Arguments = make(map[string]interface{})
|
|
}
|
|
|
|
if resultJSON.Valid && resultJSON.String != "" {
|
|
var result mcp.ToolResult
|
|
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
|
|
db.logger.Warn("解析执行结果失败", zap.Error(err))
|
|
} else {
|
|
exec.Result = &result
|
|
}
|
|
}
|
|
|
|
if errorText.Valid {
|
|
exec.Error = errorText.String
|
|
}
|
|
|
|
if endTime.Valid {
|
|
exec.EndTime = &endTime.Time
|
|
}
|
|
|
|
if durationMs.Valid {
|
|
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
|
}
|
|
|
|
return &exec, nil
|
|
}
|
|
|
|
// DeleteToolExecution 删除工具执行记录
|
|
func (db *DB) DeleteToolExecution(id string) error {
|
|
query := `DELETE FROM tool_executions WHERE id = ?`
|
|
_, err := db.Exec(query, id)
|
|
if err != nil {
|
|
db.logger.Error("删除工具执行记录失败", zap.Error(err), zap.String("executionId", id))
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SaveToolStats 保存工具统计信息
|
|
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
|
|
var lastCallTime sql.NullTime
|
|
if stats.LastCallTime != nil {
|
|
lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true}
|
|
}
|
|
|
|
query := `
|
|
INSERT OR REPLACE INTO tool_stats
|
|
(tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
`
|
|
|
|
_, err := db.Exec(query,
|
|
toolName,
|
|
stats.TotalCalls,
|
|
stats.SuccessCalls,
|
|
stats.FailedCalls,
|
|
lastCallTime,
|
|
time.Now(),
|
|
)
|
|
|
|
if err != nil {
|
|
db.logger.Error("保存工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadToolStats 加载所有工具统计信息
|
|
func (db *DB) LoadToolStats() (map[string]*mcp.ToolStats, error) {
|
|
query := `
|
|
SELECT tool_name, total_calls, success_calls, failed_calls, last_call_time
|
|
FROM tool_stats
|
|
`
|
|
|
|
rows, err := db.Query(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
stats := make(map[string]*mcp.ToolStats)
|
|
for rows.Next() {
|
|
var stat mcp.ToolStats
|
|
var lastCallTime sql.NullTime
|
|
|
|
err := rows.Scan(
|
|
&stat.ToolName,
|
|
&stat.TotalCalls,
|
|
&stat.SuccessCalls,
|
|
&stat.FailedCalls,
|
|
&lastCallTime,
|
|
)
|
|
if err != nil {
|
|
db.logger.Warn("加载统计信息失败", zap.Error(err))
|
|
continue
|
|
}
|
|
|
|
if lastCallTime.Valid {
|
|
stat.LastCallTime = &lastCallTime.Time
|
|
}
|
|
|
|
stats[stat.ToolName] = &stat
|
|
}
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
// UpdateToolStats 更新工具统计信息(累加模式)
|
|
func (db *DB) UpdateToolStats(toolName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
|
|
var lastCallTimeSQL sql.NullTime
|
|
if lastCallTime != nil {
|
|
lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true}
|
|
}
|
|
|
|
query := `
|
|
INSERT INTO tool_stats (tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(tool_name) DO UPDATE SET
|
|
total_calls = total_calls + ?,
|
|
success_calls = success_calls + ?,
|
|
failed_calls = failed_calls + ?,
|
|
last_call_time = COALESCE(?, last_call_time),
|
|
updated_at = ?
|
|
`
|
|
|
|
_, err := db.Exec(query,
|
|
toolName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
|
|
totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
|
|
)
|
|
|
|
if err != nil {
|
|
db.logger.Error("更新工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DecreaseToolStats 减少工具统计信息(用于删除执行记录时)
|
|
// 如果统计信息变为0,则删除该统计记录
|
|
func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error {
|
|
// 先更新统计信息
|
|
query := `
|
|
UPDATE tool_stats SET
|
|
total_calls = CASE WHEN total_calls - ? < 0 THEN 0 ELSE total_calls - ? END,
|
|
success_calls = CASE WHEN success_calls - ? < 0 THEN 0 ELSE success_calls - ? END,
|
|
failed_calls = CASE WHEN failed_calls - ? < 0 THEN 0 ELSE failed_calls - ? END,
|
|
updated_at = ?
|
|
WHERE tool_name = ?
|
|
`
|
|
|
|
_, err := db.Exec(query, totalCalls, totalCalls, successCalls, successCalls, failedCalls, failedCalls, time.Now(), toolName)
|
|
if err != nil {
|
|
db.logger.Error("减少工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
|
return err
|
|
}
|
|
|
|
// 检查更新后的 total_calls 是否为 0,如果是则删除该统计记录
|
|
checkQuery := `SELECT total_calls FROM tool_stats WHERE tool_name = ?`
|
|
var newTotalCalls int
|
|
err = db.QueryRow(checkQuery, toolName).Scan(&newTotalCalls)
|
|
if err != nil {
|
|
// 如果查询失败(记录不存在),直接返回
|
|
return nil
|
|
}
|
|
|
|
// 如果 total_calls 为 0,删除该统计记录
|
|
if newTotalCalls == 0 {
|
|
deleteQuery := `DELETE FROM tool_stats WHERE tool_name = ?`
|
|
_, err = db.Exec(deleteQuery, toolName)
|
|
if err != nil {
|
|
db.logger.Warn("删除零统计记录失败", zap.Error(err), zap.String("toolName", toolName))
|
|
// 不返回错误,因为主要操作(更新统计)已成功
|
|
} else {
|
|
db.logger.Info("已删除零统计记录", zap.String("toolName", toolName))
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|