package database import ( "database/sql" "encoding/json" "time" "cyberstrike-ai/internal/mcp" "go.uber.org/zap" ) // SaveToolExecution 保存工具执行记录 func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error { argsJSON, err := json.Marshal(exec.Arguments) if err != nil { db.logger.Warn("序列化执行参数失败", zap.Error(err)) argsJSON = []byte("{}") } var resultJSON sql.NullString if exec.Result != nil { resultBytes, err := json.Marshal(exec.Result) if err != nil { db.logger.Warn("序列化执行结果失败", zap.Error(err)) } else { resultJSON = sql.NullString{String: string(resultBytes), Valid: true} } } var errorText sql.NullString if exec.Error != "" { errorText = sql.NullString{String: exec.Error, Valid: true} } var endTime sql.NullTime if exec.EndTime != nil { endTime = sql.NullTime{Time: *exec.EndTime, Valid: true} } var durationMs sql.NullInt64 if exec.Duration > 0 { durationMs = sql.NullInt64{Int64: exec.Duration.Milliseconds(), Valid: true} } query := ` INSERT OR REPLACE INTO tool_executions (id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` _, err = db.Exec(query, exec.ID, exec.ToolName, string(argsJSON), exec.Status, resultJSON, errorText, exec.StartTime, endTime, durationMs, time.Now(), ) if err != nil { db.logger.Error("保存工具执行记录失败", zap.Error(err), zap.String("executionId", exec.ID)) return err } return nil } // CountToolExecutions 统计工具执行记录总数 func (db *DB) CountToolExecutions() (int, error) { query := `SELECT COUNT(*) FROM tool_executions` var count int err := db.QueryRow(query).Scan(&count) if err != nil { return 0, err } return count, nil } // LoadToolExecutions 加载所有工具执行记录(支持分页) func (db *DB) LoadToolExecutions() ([]*mcp.ToolExecution, error) { return db.LoadToolExecutionsWithPagination(0, 1000) } // LoadToolExecutionsWithPagination 分页加载工具执行记录 // limit: 最大返回记录数,0 表示使用默认值 1000 // offset: 跳过的记录数,用于分页 func (db *DB) LoadToolExecutionsWithPagination(offset, limit int) ([]*mcp.ToolExecution, error) { if limit <= 0 { limit = 1000 // 默认限制 } if limit > 10000 { limit = 10000 // 最大限制,防止一次性加载过多数据 } query := ` SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms FROM tool_executions ORDER BY start_time DESC LIMIT ? OFFSET ? ` rows, err := db.Query(query, limit, offset) if err != nil { return nil, err } defer rows.Close() var executions []*mcp.ToolExecution for rows.Next() { var exec mcp.ToolExecution var argsJSON string var resultJSON sql.NullString var errorText sql.NullString var endTime sql.NullTime var durationMs sql.NullInt64 err := rows.Scan( &exec.ID, &exec.ToolName, &argsJSON, &exec.Status, &resultJSON, &errorText, &exec.StartTime, &endTime, &durationMs, ) if err != nil { db.logger.Warn("加载执行记录失败", zap.Error(err)) continue } // 解析参数 if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil { db.logger.Warn("解析执行参数失败", zap.Error(err)) exec.Arguments = make(map[string]interface{}) } // 解析结果 if resultJSON.Valid && resultJSON.String != "" { var result mcp.ToolResult if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil { db.logger.Warn("解析执行结果失败", zap.Error(err)) } else { exec.Result = &result } } // 设置错误 if errorText.Valid { exec.Error = errorText.String } // 设置结束时间 if endTime.Valid { exec.EndTime = &endTime.Time } // 设置持续时间 if durationMs.Valid { exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond } executions = append(executions, &exec) } return executions, nil } // GetToolExecution 根据ID获取单条工具执行记录 func (db *DB) GetToolExecution(id string) (*mcp.ToolExecution, error) { query := ` SELECT id, tool_name, arguments, status, result, error, start_time, end_time, duration_ms FROM tool_executions WHERE id = ? ` row := db.QueryRow(query, id) var exec mcp.ToolExecution var argsJSON string var resultJSON sql.NullString var errorText sql.NullString var endTime sql.NullTime var durationMs sql.NullInt64 err := row.Scan( &exec.ID, &exec.ToolName, &argsJSON, &exec.Status, &resultJSON, &errorText, &exec.StartTime, &endTime, &durationMs, ) if err != nil { return nil, err } if err := json.Unmarshal([]byte(argsJSON), &exec.Arguments); err != nil { db.logger.Warn("解析执行参数失败", zap.Error(err)) exec.Arguments = make(map[string]interface{}) } if resultJSON.Valid && resultJSON.String != "" { var result mcp.ToolResult if err := json.Unmarshal([]byte(resultJSON.String), &result); err != nil { db.logger.Warn("解析执行结果失败", zap.Error(err)) } else { exec.Result = &result } } if errorText.Valid { exec.Error = errorText.String } if endTime.Valid { exec.EndTime = &endTime.Time } if durationMs.Valid { exec.Duration = time.Duration(durationMs.Int64) * time.Millisecond } return &exec, nil } // DeleteToolExecution 删除工具执行记录 func (db *DB) DeleteToolExecution(id string) error { query := `DELETE FROM tool_executions WHERE id = ?` _, err := db.Exec(query, id) if err != nil { db.logger.Error("删除工具执行记录失败", zap.Error(err), zap.String("executionId", id)) return err } return nil } // SaveToolStats 保存工具统计信息 func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error { var lastCallTime sql.NullTime if stats.LastCallTime != nil { lastCallTime = sql.NullTime{Time: *stats.LastCallTime, Valid: true} } query := ` INSERT OR REPLACE INTO tool_stats (tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at) VALUES (?, ?, ?, ?, ?, ?) ` _, err := db.Exec(query, toolName, stats.TotalCalls, stats.SuccessCalls, stats.FailedCalls, lastCallTime, time.Now(), ) if err != nil { db.logger.Error("保存工具统计信息失败", zap.Error(err), zap.String("toolName", toolName)) return err } return nil } // LoadToolStats 加载所有工具统计信息 func (db *DB) LoadToolStats() (map[string]*mcp.ToolStats, error) { query := ` SELECT tool_name, total_calls, success_calls, failed_calls, last_call_time FROM tool_stats ` rows, err := db.Query(query) if err != nil { return nil, err } defer rows.Close() stats := make(map[string]*mcp.ToolStats) for rows.Next() { var stat mcp.ToolStats var lastCallTime sql.NullTime err := rows.Scan( &stat.ToolName, &stat.TotalCalls, &stat.SuccessCalls, &stat.FailedCalls, &lastCallTime, ) if err != nil { db.logger.Warn("加载统计信息失败", zap.Error(err)) continue } if lastCallTime.Valid { stat.LastCallTime = &lastCallTime.Time } stats[stat.ToolName] = &stat } return stats, nil } // UpdateToolStats 更新工具统计信息(累加模式) func (db *DB) UpdateToolStats(toolName string, totalCalls, successCalls, failedCalls int, lastCallTime *time.Time) error { var lastCallTimeSQL sql.NullTime if lastCallTime != nil { lastCallTimeSQL = sql.NullTime{Time: *lastCallTime, Valid: true} } query := ` INSERT INTO tool_stats (tool_name, total_calls, success_calls, failed_calls, last_call_time, updated_at) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(tool_name) DO UPDATE SET total_calls = total_calls + ?, success_calls = success_calls + ?, failed_calls = failed_calls + ?, last_call_time = COALESCE(?, last_call_time), updated_at = ? ` _, err := db.Exec(query, toolName, totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(), totalCalls, successCalls, failedCalls, lastCallTimeSQL, time.Now(), ) if err != nil { db.logger.Error("更新工具统计信息失败", zap.Error(err), zap.String("toolName", toolName)) return err } return nil } // DecreaseToolStats 减少工具统计信息(用于删除执行记录时) // 如果统计信息变为0,则删除该统计记录 func (db *DB) DecreaseToolStats(toolName string, totalCalls, successCalls, failedCalls int) error { // 先更新统计信息 query := ` UPDATE tool_stats SET total_calls = CASE WHEN total_calls - ? < 0 THEN 0 ELSE total_calls - ? END, success_calls = CASE WHEN success_calls - ? < 0 THEN 0 ELSE success_calls - ? END, failed_calls = CASE WHEN failed_calls - ? < 0 THEN 0 ELSE failed_calls - ? END, updated_at = ? WHERE tool_name = ? ` _, err := db.Exec(query, totalCalls, totalCalls, successCalls, successCalls, failedCalls, failedCalls, time.Now(), toolName) if err != nil { db.logger.Error("减少工具统计信息失败", zap.Error(err), zap.String("toolName", toolName)) return err } // 检查更新后的 total_calls 是否为 0,如果是则删除该统计记录 checkQuery := `SELECT total_calls FROM tool_stats WHERE tool_name = ?` var newTotalCalls int err = db.QueryRow(checkQuery, toolName).Scan(&newTotalCalls) if err != nil { // 如果查询失败(记录不存在),直接返回 return nil } // 如果 total_calls 为 0,删除该统计记录 if newTotalCalls == 0 { deleteQuery := `DELETE FROM tool_stats WHERE tool_name = ?` _, err = db.Exec(deleteQuery, toolName) if err != nil { db.logger.Warn("删除零统计记录失败", zap.Error(err), zap.String("toolName", toolName)) // 不返回错误,因为主要操作(更新统计)已成功 } else { db.logger.Info("已删除零统计记录", zap.String("toolName", toolName)) } } return nil }