diff --git a/internal/database/database.go b/internal/database/database.go index 30cba35b..d82b23f9 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -269,6 +269,8 @@ func (db *DB) initTables() error { method TEXT NOT NULL DEFAULT 'post', cmd_param TEXT NOT NULL DEFAULT '', remark TEXT NOT NULL DEFAULT '', + encoding TEXT NOT NULL DEFAULT '', + os TEXT NOT NULL DEFAULT '', created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP );` @@ -402,6 +404,11 @@ func (db *DB) initTables() error { // 不返回错误,允许继续运行 } + if err := db.migrateWebshellConnectionsTable(); err != nil { + db.logger.Warn("迁移webshell_connections表失败", zap.Error(err)) + // 不返回错误,允许继续运行 + } + if _, err := db.Exec(createIndexes); err != nil { return fmt.Errorf("创建索引失败: %w", err) } @@ -732,6 +739,37 @@ func (db *DB) migrateVulnerabilitiesTable() error { return nil } +// migrateWebshellConnectionsTable 迁移 webshell_connections 表,补充新字段 +func (db *DB) migrateWebshellConnectionsTable() error { + columns := []struct { + name string + stmt string + }{ + {name: "encoding", stmt: "ALTER TABLE webshell_connections ADD COLUMN encoding TEXT NOT NULL DEFAULT ''"}, + {name: "os", stmt: "ALTER TABLE webshell_connections ADD COLUMN os TEXT NOT NULL DEFAULT ''"}, + } + + for _, col := range columns { + var count int + err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('webshell_connections') WHERE name=?", col.name).Scan(&count) + if err != nil { + if _, addErr := db.Exec(col.stmt); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加webshell_connections字段失败", zap.String("field", col.name), zap.Error(addErr)) + } + } + continue + } + if count == 0 { + if _, addErr := db.Exec(col.stmt); addErr != nil { + db.logger.Warn("添加webshell_connections字段失败", zap.String("field", col.name), zap.Error(addErr)) + } + } + } + return nil +} + // NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表) func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) { sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL") diff --git a/internal/database/webshell.go b/internal/database/webshell.go index 2ea25da7..db4e912f 100644 --- a/internal/database/webshell.go +++ b/internal/database/webshell.go @@ -16,6 +16,8 @@ type WebShellConnection struct { Method string `json:"method"` CmdParam string `json:"cmdParam"` Remark string `json:"remark"` + Encoding string `json:"encoding"` // 目标响应编码:auto / utf-8 / gbk / gb18030,空值视为 auto + OS string `json:"os"` // 目标操作系统:auto / linux / windows,空值/未知视为 auto CreatedAt time.Time `json:"createdAt"` } @@ -58,7 +60,8 @@ func (db *DB) UpsertWebshellConnectionState(connectionID, stateJSON string) erro // ListWebshellConnections 列出所有 WebShell 连接,按创建时间倒序 func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) { query := ` - SELECT id, url, password, type, method, cmd_param, remark, created_at + SELECT id, url, password, type, method, cmd_param, remark, + COALESCE(encoding, '') AS encoding, COALESCE(os, '') AS os, created_at FROM webshell_connections ORDER BY created_at DESC ` @@ -72,7 +75,7 @@ func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) { var list []WebShellConnection for rows.Next() { var c WebShellConnection - err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt) + err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.Encoding, &c.OS, &c.CreatedAt) if err != nil { db.logger.Warn("扫描 WebShell 连接行失败", zap.Error(err)) continue @@ -85,11 +88,12 @@ func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) { // GetWebshellConnection 根据 ID 获取一条连接 func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) { query := ` - SELECT id, url, password, type, method, cmd_param, remark, created_at + SELECT id, url, password, type, method, cmd_param, remark, + COALESCE(encoding, '') AS encoding, COALESCE(os, '') AS os, created_at FROM webshell_connections WHERE id = ? ` var c WebShellConnection - err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt) + err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.Encoding, &c.OS, &c.CreatedAt) if err == sql.ErrNoRows { return nil, nil } @@ -103,10 +107,10 @@ func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) { // CreateWebshellConnection 创建 WebShell 连接 func (db *DB) CreateWebshellConnection(c *WebShellConnection) error { query := ` - INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, encoding, os, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` - _, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.CreatedAt) + _, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.Encoding, c.OS, c.CreatedAt) if err != nil { db.logger.Error("创建 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID)) return err @@ -118,10 +122,10 @@ func (db *DB) CreateWebshellConnection(c *WebShellConnection) error { func (db *DB) UpdateWebshellConnection(c *WebShellConnection) error { query := ` UPDATE webshell_connections - SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ? + SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ?, encoding = ?, os = ? WHERE id = ? ` - result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.ID) + result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.Encoding, c.OS, c.ID) if err != nil { db.logger.Error("更新 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID)) return err