mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-07-04 11:37:57 +02:00
Add files via upload
This commit is contained in:
@@ -611,6 +611,8 @@ func (db *DB) initTables() error {
|
||||
input_json TEXT,
|
||||
output_json TEXT,
|
||||
error TEXT,
|
||||
pending_hitl_node_id TEXT,
|
||||
pending_hitl_json TEXT,
|
||||
started_at DATETIME NOT NULL,
|
||||
finished_at DATETIME,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
@@ -848,6 +850,9 @@ func (db *DB) initTables() error {
|
||||
db.logger.Warn("迁移webshell_connections表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
if err := db.migrateWorkflowRunsTable(); err != nil {
|
||||
db.logger.Warn("迁移workflow_runs表失败", zap.Error(err))
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createIndexes); err != nil {
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
|
||||
@@ -31,6 +31,8 @@ type WorkflowRun struct {
|
||||
InputJSON string `json:"input_json,omitempty"`
|
||||
OutputJSON string `json:"output_json,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
PendingHITLNodeID string `json:"pending_hitl_node_id,omitempty"`
|
||||
PendingHITLJSON string `json:"pending_hitl_json,omitempty"`
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
||||
}
|
||||
@@ -245,6 +247,124 @@ func (db *DB) FinishWorkflowNodeRun(nodeRunID, status, outputJSON, errText strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func scanWorkflowRun(scanner interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}) (*WorkflowRun, error) {
|
||||
var row WorkflowRun
|
||||
var convID, projectID, roleID, inputJSON, outputJSON, errText, pendingNode, pendingJSON sql.NullString
|
||||
var finishedAt sql.NullTime
|
||||
if err := scanner.Scan(
|
||||
&row.ID, &row.WorkflowID, &row.WorkflowVersion,
|
||||
&convID, &projectID, &roleID, &row.Status,
|
||||
&inputJSON, &outputJSON, &errText,
|
||||
&pendingNode, &pendingJSON,
|
||||
&row.StartedAt, &finishedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
row.ConversationID = convID.String
|
||||
row.ProjectID = projectID.String
|
||||
row.RoleID = roleID.String
|
||||
row.InputJSON = inputJSON.String
|
||||
row.OutputJSON = outputJSON.String
|
||||
row.Error = errText.String
|
||||
row.PendingHITLNodeID = pendingNode.String
|
||||
row.PendingHITLJSON = pendingJSON.String
|
||||
if finishedAt.Valid {
|
||||
t := finishedAt.Time
|
||||
row.FinishedAt = &t
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
const workflowRunColumns = `id, workflow_id, workflow_version, conversation_id, project_id, role_id, status, input_json, output_json, error, pending_hitl_node_id, pending_hitl_json, started_at, finished_at`
|
||||
|
||||
func (db *DB) GetWorkflowRun(runID string) (*WorkflowRun, error) {
|
||||
runID = strings.TrimSpace(runID)
|
||||
if runID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
row, err := scanWorkflowRun(db.QueryRow("SELECT "+workflowRunColumns+" FROM workflow_runs WHERE id = ?", runID))
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询工作流运行失败: %w", err)
|
||||
}
|
||||
return row, nil
|
||||
}
|
||||
|
||||
func (db *DB) SetWorkflowRunStatus(runID, status string) error {
|
||||
runID = strings.TrimSpace(runID)
|
||||
if runID == "" {
|
||||
return fmt.Errorf("工作流运行 id 不能为空")
|
||||
}
|
||||
_, err := db.Exec(`UPDATE workflow_runs SET status = ? WHERE id = ?`, strings.TrimSpace(status), runID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新工作流运行状态失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) SetWorkflowRunAwaitingHITL(runID, nodeID, pendingJSON string) error {
|
||||
runID = strings.TrimSpace(runID)
|
||||
if runID == "" {
|
||||
return fmt.Errorf("工作流运行 id 不能为空")
|
||||
}
|
||||
_, err := db.Exec(
|
||||
`UPDATE workflow_runs SET status = 'awaiting_hitl', pending_hitl_node_id = ?, pending_hitl_json = ?, finished_at = NULL WHERE id = ?`,
|
||||
strings.TrimSpace(nodeID), pendingJSON, runID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新工作流 HITL 等待状态失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) ListWorkflowRunsAwaitingHITL(limit int) ([]*WorkflowRun, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
rows, err := db.Query(
|
||||
`SELECT `+workflowRunColumns+` FROM workflow_runs WHERE status = 'awaiting_hitl' ORDER BY started_at DESC LIMIT ?`,
|
||||
limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询等待审批的工作流运行失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*WorkflowRun
|
||||
for rows.Next() {
|
||||
row, err := scanWorkflowRun(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (db *DB) migrateWorkflowRunsTable() error {
|
||||
cols := []struct{ name, ddl string }{
|
||||
{"pending_hitl_node_id", "ALTER TABLE workflow_runs ADD COLUMN pending_hitl_node_id TEXT"},
|
||||
{"pending_hitl_json", "ALTER TABLE workflow_runs ADD COLUMN pending_hitl_json TEXT"},
|
||||
}
|
||||
for _, col := range cols {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('workflow_runs') WHERE name=?", col.name).Scan(&count)
|
||||
if err != nil || count > 0 {
|
||||
continue
|
||||
}
|
||||
if _, err := db.Exec(col.ddl); err != nil {
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nullString(v string) interface{} {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
|
||||
Reference in New Issue
Block a user