diff --git a/internal/database/database.go b/internal/database/database.go index ce675001..abe1f786 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -611,6 +611,8 @@ func (db *DB) initTables() error { input_json TEXT, output_json TEXT, error TEXT, + pending_hitl_node_id TEXT, + pending_hitl_json TEXT, started_at DATETIME NOT NULL, finished_at DATETIME, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -848,6 +850,9 @@ func (db *DB) initTables() error { db.logger.Warn("迁移webshell_connections表失败", zap.Error(err)) // 不返回错误,允许继续运行 } + if err := db.migrateWorkflowRunsTable(); err != nil { + db.logger.Warn("迁移workflow_runs表失败", zap.Error(err)) + } if _, err := db.Exec(createIndexes); err != nil { return fmt.Errorf("创建索引失败: %w", err) diff --git a/internal/database/workflow.go b/internal/database/workflow.go index 26abb20a..d883ddf5 100644 --- a/internal/database/workflow.go +++ b/internal/database/workflow.go @@ -31,6 +31,8 @@ type WorkflowRun struct { InputJSON string `json:"input_json,omitempty"` OutputJSON string `json:"output_json,omitempty"` Error string `json:"error,omitempty"` + PendingHITLNodeID string `json:"pending_hitl_node_id,omitempty"` + PendingHITLJSON string `json:"pending_hitl_json,omitempty"` StartedAt time.Time `json:"started_at"` FinishedAt *time.Time `json:"finished_at,omitempty"` } @@ -245,6 +247,124 @@ func (db *DB) FinishWorkflowNodeRun(nodeRunID, status, outputJSON, errText strin return nil } +func scanWorkflowRun(scanner interface { + Scan(dest ...interface{}) error +}) (*WorkflowRun, error) { + var row WorkflowRun + var convID, projectID, roleID, inputJSON, outputJSON, errText, pendingNode, pendingJSON sql.NullString + var finishedAt sql.NullTime + if err := scanner.Scan( + &row.ID, &row.WorkflowID, &row.WorkflowVersion, + &convID, &projectID, &roleID, &row.Status, + &inputJSON, &outputJSON, &errText, + &pendingNode, &pendingJSON, + &row.StartedAt, &finishedAt, + ); err != nil { + return nil, err + } + row.ConversationID = convID.String + row.ProjectID = projectID.String + row.RoleID = roleID.String + row.InputJSON = inputJSON.String + row.OutputJSON = outputJSON.String + row.Error = errText.String + row.PendingHITLNodeID = pendingNode.String + row.PendingHITLJSON = pendingJSON.String + if finishedAt.Valid { + t := finishedAt.Time + row.FinishedAt = &t + } + return &row, nil +} + +const workflowRunColumns = `id, workflow_id, workflow_version, conversation_id, project_id, role_id, status, input_json, output_json, error, pending_hitl_node_id, pending_hitl_json, started_at, finished_at` + +func (db *DB) GetWorkflowRun(runID string) (*WorkflowRun, error) { + runID = strings.TrimSpace(runID) + if runID == "" { + return nil, nil + } + row, err := scanWorkflowRun(db.QueryRow("SELECT "+workflowRunColumns+" FROM workflow_runs WHERE id = ?", runID)) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("查询工作流运行失败: %w", err) + } + return row, nil +} + +func (db *DB) SetWorkflowRunStatus(runID, status string) error { + runID = strings.TrimSpace(runID) + if runID == "" { + return fmt.Errorf("工作流运行 id 不能为空") + } + _, err := db.Exec(`UPDATE workflow_runs SET status = ? WHERE id = ?`, strings.TrimSpace(status), runID) + if err != nil { + return fmt.Errorf("更新工作流运行状态失败: %w", err) + } + return nil +} + +func (db *DB) SetWorkflowRunAwaitingHITL(runID, nodeID, pendingJSON string) error { + runID = strings.TrimSpace(runID) + if runID == "" { + return fmt.Errorf("工作流运行 id 不能为空") + } + _, err := db.Exec( + `UPDATE workflow_runs SET status = 'awaiting_hitl', pending_hitl_node_id = ?, pending_hitl_json = ?, finished_at = NULL WHERE id = ?`, + strings.TrimSpace(nodeID), pendingJSON, runID, + ) + if err != nil { + return fmt.Errorf("更新工作流 HITL 等待状态失败: %w", err) + } + return nil +} + +func (db *DB) ListWorkflowRunsAwaitingHITL(limit int) ([]*WorkflowRun, error) { + if limit <= 0 { + limit = 50 + } + rows, err := db.Query( + `SELECT `+workflowRunColumns+` FROM workflow_runs WHERE status = 'awaiting_hitl' ORDER BY started_at DESC LIMIT ?`, + limit, + ) + if err != nil { + return nil, fmt.Errorf("查询等待审批的工作流运行失败: %w", err) + } + defer rows.Close() + var out []*WorkflowRun + for rows.Next() { + row, err := scanWorkflowRun(rows) + if err != nil { + return nil, err + } + out = append(out, row) + } + return out, rows.Err() +} + +func (db *DB) migrateWorkflowRunsTable() error { + cols := []struct{ name, ddl string }{ + {"pending_hitl_node_id", "ALTER TABLE workflow_runs ADD COLUMN pending_hitl_node_id TEXT"}, + {"pending_hitl_json", "ALTER TABLE workflow_runs ADD COLUMN pending_hitl_json TEXT"}, + } + for _, col := range cols { + var count int + err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('workflow_runs') WHERE name=?", col.name).Scan(&count) + if err != nil || count > 0 { + continue + } + if _, err := db.Exec(col.ddl); err != nil { + errMsg := strings.ToLower(err.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + return err + } + } + } + return nil +} + func nullString(v string) interface{} { v = strings.TrimSpace(v) if v == "" {