mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-04-01 00:30:33 +02:00
310 lines
7.0 KiB
Go
310 lines
7.0 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"time"
|
|
|
|
"cyberstrike-ai/internal/mcp"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// SaveToolExecution 保存工具执行记录
|
|
func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
|
|
argsJSON, err := json.Marshal(exec.Arguments)
|
|
if err != nil {
|
|
db.logger.Warn("序列化执行参数失败", zap.Error(err))
|
|
argsJSON = []byte("{}")
|
|
}
|
|
|
|
var resultJSON sql.NullString
|
|
if exec.Result != nil {
|
|
resultBytes, err := json.Marshal(exec.Result)
|
|
if err != nil {
|
|
db.logger.Warn("序列化执行结果失败", zap.Error(err))
|
|
} else {
|
|
resultJSON = sql.NullString{String: string(resultBytes), Valid: true}
|
|
}
|
|
}
|
|
|
|
var errorText sql.NullString
|
|
if exec.Error != "" {
|
|
errorText = sql.NullString{String: exec.Error, Valid: true}
|
|
}
|
|
|
|
var endTime sql.NullTime
|
|
if exec.EndTime != nil {
|
|
endTime = sql.NullTime{Time: *exec.EndTime, Valid: true}
|
|
}
|
|
|
|
var durationMs sql.NullInt64
|
|
if exec.Duration > 0 {
|
|
durationMs = sql.NullInt64{Int64: exec.Duration.Milliseconds(), Valid: true}
|
|
}
|
|
|
|
query := `
|
|
INSERT OR REPLACE INTO tool_executions
|
|
(id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
`
|
|
|
|
_, err = db.Exec(query,
|
|
exec.ID,
|
|
exec.ToolName,
|
|
string(argsJSON),
|
|
exec.Status,
|
|
resultJSON,
|
|
errorText,
|
|
exec.StartTime,
|
|
endTime,
|
|
durationMs,
|
|
time.Now(),
|
|
)
|
|
|
|
if err != nil {
|
|
db.logger.Error("保存工具执行记录失败", zap.Error(err), zap.String("executionId", exec.ID))
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadToolExecutions 加载所有工具执行记录
|
|
func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) {
|
|
query := `
|
|
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
|
|
FROM tool_executions
|
|
ORDER BY start_time DESC
|
|
LIMIT 1000
|
|
`
|
|
|
|
rows, err := db.Query(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var executions []*mcp.ToolExecution
|
|
for rows.Next() {
|
|
var exec mcp.ToolExecution
|
|
var argsJSON string
|
|
var resultJSON sql.NullString
|
|
var errorText sql.NullString
|
|
var endTime sql.NullTime
|
|
var durationMs sql.NullInt64
|
|
|
|
err := rows.Scan(
|
|
&exec.ID,
|
|
&exec.ToolName,
|
|
&argsJSON,
|
|
&exec.Status,
|
|
&resultJSON,
|
|
&errorText,
|
|
&exec.StartTime,
|
|
&endTime,
|
|
&durationMs,
|
|
)
|
|
if err != nil {
|
|
db.logger.Warn("加载执行记录失败", zap.Error(err))
|
|
continue
|
|
}
|
|
|
|
// 解析参数
|
|
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
|
|
db.logger.Warn("解析执行参数失败", zap.Error(err))
|
|
exec.Arguments = make(map[string]interface{})
|
|
}
|
|
|
|
// 解析结果
|
|
if resultJSON.Valid && resultJSON.String != "" {
|
|
var result mcp.ToolResult
|
|
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
|
|
db.logger.Warn("解析执行结果失败", zap.Error(err))
|
|
} else {
|
|
exec.Result = &result
|
|
}
|
|
}
|
|
|
|
// 设置错误
|
|
if errorText.Valid {
|
|
exec.Error = errorText.String
|
|
}
|
|
|
|
// 设置结束时间
|
|
if endTime.Valid {
|
|
exec.EndTime = &endTime.Time
|
|
}
|
|
|
|
// 设置持续时间
|
|
if durationMs.Valid {
|
|
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
|
}
|
|
|
|
executions = append(executions, &exec)
|
|
}
|
|
|
|
return executions, nil
|
|
}
|
|
|
|
// GetToolExecution 根据ID获取单条工具执行记录
|
|
func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) {
|
|
query := `
|
|
SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms
|
|
FROM tool_executions
|
|
WHERE id = ?
|
|
`
|
|
|
|
row := db.QueryRow(query, id)
|
|
|
|
var exec mcp.ToolExecution
|
|
var argsJSON string
|
|
var resultJSON sql.NullString
|
|
var errorText sql.NullString
|
|
var endTime sql.NullTime
|
|
var durationMs sql.NullInt64
|
|
|
|
err := row.Scan(
|
|
&exec.ID,
|
|
&exec.ToolName,
|
|
&argsJSON,
|
|
&exec.Status,
|
|
&resultJSON,
|
|
&errorText,
|
|
&exec.StartTime,
|
|
&endTime,
|
|
&durationMs,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil {
|
|
db.logger.Warn("解析执行参数失败", zap.Error(err))
|
|
exec.Arguments = make(map[string]interface{})
|
|
}
|
|
|
|
if resultJSON.Valid && resultJSON.String != "" {
|
|
var result mcp.ToolResult
|
|
if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil {
|
|
db.logger.Warn("解析执行结果失败", zap.Error(err))
|
|
} else {
|
|
exec.Result = &result
|
|
}
|
|
}
|
|
|
|
if errorText.Valid {
|
|
exec.Error = errorText.String
|
|
}
|
|
|
|
if endTime.Valid {
|
|
exec.EndTime = &endTime.Time
|
|
}
|
|
|
|
if durationMs.Valid {
|
|
exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond
|
|
}
|
|
|
|
return &exec, nil
|
|
}
|
|
|
|
// SaveToolStats 保存工具统计信息
|
|
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
|
|
var lastCallTime sql.NullTime
|
|
if stats.LastCallTime != nil {
|
|
lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true}
|
|
}
|
|
|
|
query := `
|
|
INSERT OR REPLACE INTO tool_stats
|
|
(tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
`
|
|
|
|
_, err := db.Exec(query,
|
|
toolName,
|
|
stats.TotalCalls,
|
|
stats.SuccessCalls,
|
|
stats.FailedCalls,
|
|
lastCallTime,
|
|
time.Now(),
|
|
)
|
|
|
|
if err != nil {
|
|
db.logger.Error("保存工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadToolStats 加载所有工具统计信息
|
|
func (db *DB) LoadToolStats() (map[string]*mcp.ToolStats, error) {
|
|
query := `
|
|
SELECT tool_name, total_calls, success_calls, failed_calls, last_call_time
|
|
FROM tool_stats
|
|
`
|
|
|
|
rows, err := db.Query(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
stats := make(map[string]*mcp.ToolStats)
|
|
for rows.Next() {
|
|
var stat mcp.ToolStats
|
|
var lastCallTime sql.NullTime
|
|
|
|
err := rows.Scan(
|
|
&stat.ToolName,
|
|
&stat.TotalCalls,
|
|
&stat.SuccessCalls,
|
|
&stat.FailedCalls,
|
|
&lastCallTime,
|
|
)
|
|
if err != nil {
|
|
db.logger.Warn("加载统计信息失败", zap.Error(err))
|
|
continue
|
|
}
|
|
|
|
if lastCallTime.Valid {
|
|
stat.LastCallTime = &lastCallTime.Time
|
|
}
|
|
|
|
stats[stat.ToolName] = &stat
|
|
}
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
// UpdateToolStats 更新工具统计信息(累加模式)
|
|
func (db *DB) UpdateToolStats(toolName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error {
|
|
var lastCallTimeSQL sql.NullTime
|
|
if lastCallTime != nil {
|
|
lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true}
|
|
}
|
|
|
|
query := `
|
|
INSERT INTO tool_stats (tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(tool_name) DO UPDATE SET
|
|
total_calls = total_calls + ?,
|
|
success_calls = success_calls + ?,
|
|
failed_calls = failed_calls + ?,
|
|
last_call_time = COALESCE(?, last_call_time),
|
|
updated_at = ?
|
|
`
|
|
|
|
_, err := db.Exec(query,
|
|
toolName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
|
|
totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(),
|
|
)
|
|
|
|
if err != nil {
|
|
db.logger.Error("更新工具统计信息失败", zap.Error(err), zap.String("toolName", toolName))
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|