From 77d212098d979b2e215565d4b8a58e3118d3cb2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Fri, 1 May 2026 01:03:28 +0800 Subject: [PATCH] Add files via upload --- internal/handler/agent.go | 14 +- internal/handler/multi_agent_prepare.go | 7 +- internal/handler/webshell.go | 507 +++++++++++++++------ internal/handler/webshell_context.go | 106 +++++ internal/handler/webshell_context_test.go | 170 +++++++ internal/handler/webshell_encoding_test.go | 103 +++++ internal/handler/webshell_os_test.go | 348 ++++++++++++++ internal/handler/webshell_probe.go | 127 ++++++ internal/handler/webshell_probe_test.go | 68 +++ 9 files changed, 1294 insertions(+), 156 deletions(-) create mode 100644 internal/handler/webshell_context.go create mode 100644 internal/handler/webshell_context_test.go create mode 100644 internal/handler/webshell_encoding_test.go create mode 100644 internal/handler/webshell_os_test.go create mode 100644 internal/handler/webshell_probe.go create mode 100644 internal/handler/webshell_probe_test.go diff --git a/internal/handler/agent.go b/internal/handler/agent.go index 9af3b67d..a2adb8bb 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -539,12 +539,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "未找到该 WebShell 连接"}) return } - remark := conn.Remark - if remark == "" { - remark = conn.URL - } - webshellContext := fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。\n\n用户请求:%s", - conn.ID, remark, conn.ID, req.Message) + webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, req.Message) // WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具) if req.Role != "" && req.Role != "默认" && h.config.Roles != nil { if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" { @@ -1400,12 +1395,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { sendEvent("error", "未找到该 WebShell 连接", nil) return } - remark := conn.Remark - if remark == "" { - remark = conn.URL - } - webshellContext := fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。\n\n用户请求:%s", - conn.ID, remark, conn.ID, req.Message) + webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, req.Message) // WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具) if req.Role != "" && req.Role != "默认" && h.config.Roles != nil { if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" { diff --git a/internal/handler/multi_agent_prepare.go b/internal/handler/multi_agent_prepare.go index 0e9cd43b..51703e86 100644 --- a/internal/handler/multi_agent_prepare.go +++ b/internal/handler/multi_agent_prepare.go @@ -73,12 +73,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr h.logger.Warn("WebShell AI 助手:未找到连接", zap.String("id", req.WebShellConnectionID), zap.Error(errConn)) return nil, fmt.Errorf("未找到该 WebShell 连接") } - remark := conn.Remark - if remark == "" { - remark = conn.URL - } - webshellContext := fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用 Eino 多代理内置 `skill` 工具。\n\n用户请求:%s", - conn.ID, remark, conn.ID, req.Message) + webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, req.Message) // WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具) if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil { if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" { diff --git a/internal/handler/webshell.go b/internal/handler/webshell.go index 5afa44c6..f94a564e 100644 --- a/internal/handler/webshell.go +++ b/internal/handler/webshell.go @@ -3,20 +3,302 @@ package handler import ( "bytes" "database/sql" + "encoding/base64" "encoding/json" "io" "net/http" "net/url" "strings" "time" + "unicode/utf8" "cyberstrike-ai/internal/database" "github.com/gin-gonic/gin" "github.com/google/uuid" "go.uber.org/zap" + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/transform" ) +// webshellSupportedEncodings 允许的 WebShell 响应编码取值(小写,含空串代表 auto) +// 仅暴露目前最常见的几种,其他需求可后续扩展(如 Big5、Shift_JIS 等)。 +var webshellSupportedEncodings = map[string]struct{}{ + "": {}, // 未配置,按 auto 处理 + "auto": {}, + "utf-8": {}, + "utf8": {}, + "gbk": {}, + "gb18030": {}, +} + +// normalizeWebshellEncoding 归一化编码标识:统一为小写,未知值回退为 auto,供持久化使用 +func normalizeWebshellEncoding(enc string) string { + enc = strings.ToLower(strings.TrimSpace(enc)) + if _, ok := webshellSupportedEncodings[enc]; !ok { + return "auto" + } + if enc == "" { + return "auto" + } + if enc == "utf8" { + return "utf-8" + } + return enc +} + +// decodeWebshellOutput 把 WebShell 返回的字节按指定编码转换为合法 UTF-8 字符串。 +// 约定: +// - "" / "auto":若已是合法 UTF-8 原样返回,否则依次尝试 GB18030(GBK 超集)解码。 +// - "utf-8" / "utf8":原样返回,非法字节交由 JSON 层按 U+FFFD 处理(保持原有行为)。 +// - "gbk" / "gb18030":强制按对应编码解码;失败则回退原始字节。 +// +// 该函数对空输入直接返回空串,避免不必要的转换。 +func decodeWebshellOutput(raw []byte, encoding string) string { + if len(raw) == 0 { + return "" + } + enc := normalizeWebshellEncoding(encoding) + switch enc { + case "utf-8": + return string(raw) + case "gbk": + if out, _, err := transform.Bytes(simplifiedchinese.GBK.NewDecoder(), raw); err == nil { + return string(out) + } + return string(raw) + case "gb18030": + if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil { + return string(out) + } + return string(raw) + default: // auto + if utf8.Valid(raw) { + return string(raw) + } + // GB18030 是 GBK 的超集,覆盖范围最广,auto 模式统一用它兜底 + if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil { + return string(out) + } + return string(raw) + } +} + +// webshellSupportedOS 允许的 WebShell 目标操作系统(小写,空串代表 auto) +var webshellSupportedOS = map[string]struct{}{ + "": {}, + "auto": {}, + "linux": {}, + "windows": {}, +} + +// normalizeWebshellOS 归一化 OS 标识,未知值回退为 auto,供持久化使用 +func normalizeWebshellOS(osTag string) string { + osTag = strings.ToLower(strings.TrimSpace(osTag)) + if _, ok := webshellSupportedOS[osTag]; !ok { + return "auto" + } + if osTag == "" { + return "auto" + } + return osTag +} + +// resolveWebshellOS 根据连接的 os 与 shellType 推断最终目标 OS(仅返回 "linux" 或 "windows")。 +// 规则: +// - 显式 linux / windows:按用户选择。 +// - auto 或未知:asp/aspx → windows,其他 → linux。保持历史行为,平滑向后兼容。 +func resolveWebshellOS(osTag, shellType string) string { + osTag = strings.ToLower(strings.TrimSpace(osTag)) + switch osTag { + case "linux": + return "linux" + case "windows": + return "windows" + } + t := strings.ToLower(strings.TrimSpace(shellType)) + if t == "asp" || t == "aspx" { + return "windows" + } + return "linux" +} + +// quoteCmdPath 把路径按 Windows cmd.exe 规则转义。 +// 使用双引号包裹,内部双引号转义为 ""(cmd 接受的写法)。 +func quoteCmdPath(p string) string { + if p == "" { + return "\".\"" + } + return "\"" + strings.ReplaceAll(p, "\"", "\"\"") + "\"" +} + +// quotePsSingle 把字符串按 PowerShell 单引号字符串规则转义(内部 ' → '')。 +// 供 PowerShell 脚本参数使用,全脚本只用单引号,外层 cmd 再用双引号包裹即可安全传递。 +func quotePsSingle(s string) string { + return "'" + strings.ReplaceAll(s, "'", "''") + "'" +} + +// quoteShellSinglePosix 把路径按 POSIX sh 单引号规则转义(内部 ' → '\'') +func quoteShellSinglePosix(p string) string { + if p == "" { + return "." + } + return "'" + strings.ReplaceAll(p, "'", "'\\''") + "'" +} + +// quoteWebshellPath 按目标 OS 选择转义方案:Linux 用 POSIX 单引号,Windows 用 cmd 双引号 +func quoteWebshellPath(path, osTag string) string { + if resolveWebshellOS(osTag, "") == "windows" { + return quoteCmdPath(path) + } + return quoteShellSinglePosix(path) +} + +// buildWindowsPowerShellWrite 构造 Windows 端把 base64 内容一次性写入目标路径的 cmd 命令。 +// 外层走 cmd.exe 的 powershell 调用,PowerShell 脚本里只用单引号字符串,避免嵌套引号陷阱。 +func buildWindowsPowerShellWrite(path, b64 string) string { + script := "$b=[Convert]::FromBase64String(" + quotePsSingle(b64) + ");" + + "[IO.File]::WriteAllBytes(" + quotePsSingle(path) + ",$b)" + return "powershell -NoProfile -NonInteractive -Command \"" + script + "\"" +} + +// buildWindowsPowerShellAppend 构造 Windows 端把 base64 内容追加写入目标路径的 cmd 命令(用于分块上传) +func buildWindowsPowerShellAppend(path, b64 string) string { + script := "$b=[Convert]::FromBase64String(" + quotePsSingle(b64) + ");" + + "$f=[IO.File]::Open(" + quotePsSingle(path) + ",[IO.FileMode]::Append,[IO.FileAccess]::Write,[IO.FileShare]::None);" + + "try{$f.Write($b,0,$b.Length)}finally{$f.Close()}" + return "powershell -NoProfile -NonInteractive -Command \"" + script + "\"" +} + +// fileCommandInput 封装 buildFileCommand 的输入,避免长参数列表 +type fileCommandInput struct { + Action string + Path string + TargetPath string + Content string + ChunkIndex int + OS string + ShellType string +} + +// buildFileCommand 根据目标 OS 与文件操作类型生成具体的远端命令字符串。 +// 同一份实现供 HTTP 入口(FileOp)与 MCP 入口(FileOpWithConnection)共用,避免双份维护。 +// 返回值第二位是用户可见的业务错误(如 "path is required")。 +func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error) { + targetOS := resolveWebshellOS(in.OS, in.ShellType) + action := strings.ToLower(strings.TrimSpace(in.Action)) + path := strings.TrimSpace(in.Path) + + switch action { + case "list": + p := path + if p == "" { + p = "." + } + if targetOS == "windows" { + return "dir /a " + quoteCmdPath(p), nil + } + return "ls -la " + quoteShellSinglePosix(p), nil + + case "read": + if path == "" { + return "", errFileOpPathRequired + } + if targetOS == "windows" { + return "type " + quoteCmdPath(path), nil + } + return "cat " + quoteShellSinglePosix(path), nil + + case "delete": + if path == "" { + return "", errFileOpPathRequired + } + if targetOS == "windows" { + return "del /q /f " + quoteCmdPath(path), nil + } + return "rm -f " + quoteShellSinglePosix(path), nil + + case "mkdir": + if path == "" { + return "", errFileOpPathRequired + } + if targetOS == "windows" { + // cmd 的 md 默认会自动创建中间目录(等价于 Linux 的 mkdir -p) + return "md " + quoteCmdPath(path), nil + } + return "mkdir -p " + quoteShellSinglePosix(path), nil + + case "rename": + oldPath := path + newPath := strings.TrimSpace(in.TargetPath) + if oldPath == "" || newPath == "" { + return "", errFileOpRenameNeedsBothPaths + } + if targetOS == "windows" { + return "move /y " + quoteCmdPath(oldPath) + " " + quoteCmdPath(newPath), nil + } + return "mv -f " + quoteShellSinglePosix(oldPath) + " " + quoteShellSinglePosix(newPath), nil + + case "write": + if path == "" { + return "", errFileOpPathRequired + } + // 统一策略:先把内容 base64 编码,再用目标平台对应方式解码写回, + // 这样既能写入任意二进制/含引号的文本,又避免各家 shell 的转义地狱。 + b64 := base64.StdEncoding.EncodeToString([]byte(in.Content)) + if targetOS == "windows" { + return buildWindowsPowerShellWrite(path, b64), nil + } + return "echo '" + b64 + "' | base64 -d > " + quoteShellSinglePosix(path), nil + + case "upload": + if path == "" { + return "", errFileOpPathRequired + } + if len(in.Content) > 512*1024 { + return "", errFileOpUploadTooLarge + } + if targetOS == "windows" { + return buildWindowsPowerShellWrite(path, in.Content), nil + } + return "echo '" + in.Content + "' | base64 -d > " + quoteShellSinglePosix(path), nil + + case "upload_chunk": + if path == "" { + return "", errFileOpPathRequired + } + if targetOS == "windows" { + if in.ChunkIndex == 0 { + return buildWindowsPowerShellWrite(path, in.Content), nil + } + return buildWindowsPowerShellAppend(path, in.Content), nil + } + redir := ">>" + if in.ChunkIndex == 0 { + redir = ">" + } + return "echo '" + in.Content + "' | base64 -d " + redir + " " + quoteShellSinglePosix(path), nil + } + + return "", errFileOpUnsupportedAction(action) +} + +// 业务错误常量,便于上层统一返回用户可见提示 +var ( + errFileOpPathRequired = simpleError("path is required") + errFileOpRenameNeedsBothPaths = simpleError("path and target_path are required for rename") + errFileOpUploadTooLarge = simpleError("upload content too large (max 512KB base64)") +) + +func errFileOpUnsupportedAction(action string) error { + return simpleError("unsupported action: " + action) +} + +// simpleError 是不带堆栈的轻量错误类型,供 buildFileCommand 报可预期的参数校验错误 +type simpleError string + +func (e simpleError) Error() string { return string(e) } + // WebShellHandler 代理执行 WebShell 命令(类似冰蝎/蚁剑),避免前端跨域并统一构建请求 type WebShellHandler struct { logger *zap.Logger @@ -44,6 +326,8 @@ type CreateConnectionRequest struct { Method string `json:"method"` CmdParam string `json:"cmd_param"` Remark string `json:"remark"` + Encoding string `json:"encoding"` + OS string `json:"os"` } // UpdateConnectionRequest 更新连接请求 @@ -54,6 +338,8 @@ type UpdateConnectionRequest struct { Method string `json:"method"` CmdParam string `json:"cmd_param"` Remark string `json:"remark"` + Encoding string `json:"encoding"` + OS string `json:"os"` } // ListConnections 列出所有 WebShell 连接(GET /api/webshell/connections) @@ -109,6 +395,8 @@ func (h *WebShellHandler) CreateConnection(c *gin.Context) { Method: method, CmdParam: strings.TrimSpace(req.CmdParam), Remark: strings.TrimSpace(req.Remark), + Encoding: normalizeWebshellEncoding(req.Encoding), + OS: normalizeWebshellOS(req.OS), CreatedAt: time.Now(), } if err := h.db.CreateWebshellConnection(conn); err != nil { @@ -159,6 +447,8 @@ func (h *WebShellHandler) UpdateConnection(c *gin.Context) { Method: method, CmdParam: strings.TrimSpace(req.CmdParam), Remark: strings.TrimSpace(req.Remark), + Encoding: normalizeWebshellEncoding(req.Encoding), + OS: normalizeWebshellOS(req.OS), } if err := h.db.UpdateWebshellConnection(conn); err != nil { if err == sql.ErrNoRows { @@ -331,6 +621,8 @@ type ExecRequest struct { Type string `json:"type"` // php, asp, aspx, jsp, custom Method string `json:"method"` // GET 或 POST,空则默认 POST CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd + Encoding string `json:"encoding"` // 响应编码:auto / utf-8 / gbk / gb18030,空则 auto + OS string `json:"os"` // 目标操作系统:auto / linux / windows,当前 exec 不用它,保留字段便于未来扩展 Command string `json:"command" binding:"required"` } @@ -344,23 +636,27 @@ type ExecResponse struct { // FileOpRequest 文件操作请求 type FileOpRequest struct { - URL string `json:"url" binding:"required"` - Password string `json:"password"` - Type string `json:"type"` - Method string `json:"method"` // GET 或 POST,空则默认 POST - CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd - Action string `json:"action" binding:"required"` // list, read, delete, write, mkdir, rename, upload, upload_chunk - Path string `json:"path"` - TargetPath string `json:"target_path"` // rename 时目标路径 - Content string `json:"content"` // write/upload 时使用 - ChunkIndex int `json:"chunk_index"` // upload_chunk 时,0 表示首块 + URL string `json:"url" binding:"required"` + Password string `json:"password"` + Type string `json:"type"` + Method string `json:"method"` // GET 或 POST,空则默认 POST + CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd + Encoding string `json:"encoding"` // 响应编码:auto / utf-8 / gbk / gb18030,空则 auto + OS string `json:"os"` // 目标操作系统:auto / linux / windows,空则按 shellType 推断 + ConnectionID string `json:"connection_id,omitempty"` // 可选:连接 ID;服务端探活出 OS 后会回写到此连接 + Action string `json:"action" binding:"required"` // list, read, delete, write, mkdir, rename, upload, upload_chunk + Path string `json:"path"` + TargetPath string `json:"target_path"` // rename 时目标路径 + Content string `json:"content"` // write/upload 时使用 + ChunkIndex int `json:"chunk_index"` // upload_chunk 时,0 表示首块 } // FileOpResponse 文件操作响应 type FileOpResponse struct { - OK bool `json:"ok"` - Output string `json:"output"` - Error string `json:"error,omitempty"` + OK bool `json:"ok"` + Output string `json:"output"` + Error string `json:"error,omitempty"` + DetectedOS string `json:"detected_os,omitempty"` // 仅在 auto 模式且探活成功时返回,前端应更新本地缓存 } func (h *WebShellHandler) Exec(c *gin.Context) { @@ -415,7 +711,7 @@ func (h *WebShellHandler) Exec(c *gin.Context) { if readErr != nil { h.logger.Warn("webshell exec read body", zap.Error(readErr)) } - output := string(out) + output := decodeWebshellOutput(out, req.Encoding) httpCode := resp.StatusCode c.JSON(http.StatusOK, ExecResponse{ @@ -474,83 +770,32 @@ func (h *WebShellHandler) FileOp(c *gin.Context) { return } - // 通过执行系统命令实现文件操作(与通用一句话兼容) - var command string - shellType := strings.ToLower(strings.TrimSpace(req.Type)) - switch req.Action { - case "list": - path := strings.TrimSpace(req.Path) - if path == "" { - path = "." + // 若 OS 未显式配置,先发一次探活命令,识别出真实 OS 再构造文件操作命令。 + // 这解决了 "Windows + PHP + OS=auto" 场景下旧 fallback 错发 `ls -la` 导致目录列不出来的问题。 + osTag := req.OS + detectedOS := "" + if normalizeWebshellOS(osTag) == "auto" { + if probed := probeWebshellOSViaExec(h.newHTTPExecFn(req.URL, req.Password, req.Type, req.Method, req.CmdParam, req.Encoding)); probed != "" { + osTag = probed + detectedOS = probed + // 若前端带了 connection_id,顺带把探活结果持久化到该连接,后续刷新零成本 + if cid := strings.TrimSpace(req.ConnectionID); cid != "" { + h.persistDetectedOS(cid, probed) + } } - if shellType == "asp" || shellType == "aspx" { - command = "dir " + h.escapePath(path) - } else { - command = "ls -la " + h.escapePath(path) - } - case "read": - if shellType == "asp" || shellType == "aspx" { - command = "type " + h.escapePath(strings.TrimSpace(req.Path)) - } else { - command = "cat " + h.escapePath(strings.TrimSpace(req.Path)) - } - case "delete": - if shellType == "asp" || shellType == "aspx" { - command = "del " + h.escapePath(strings.TrimSpace(req.Path)) - } else { - command = "rm -f " + h.escapePath(strings.TrimSpace(req.Path)) - } - case "write": - path := h.escapePath(strings.TrimSpace(req.Path)) - command = "echo " + h.escapeForEcho(req.Content) + " > " + path - case "mkdir": - path := strings.TrimSpace(req.Path) - if path == "" { - c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for mkdir"}) - return - } - if shellType == "asp" || shellType == "aspx" { - command = "md " + h.escapePath(path) - } else { - command = "mkdir -p " + h.escapePath(path) - } - case "rename": - oldPath := strings.TrimSpace(req.Path) - newPath := strings.TrimSpace(req.TargetPath) - if oldPath == "" || newPath == "" { - c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path and target_path are required for rename"}) - return - } - if shellType == "asp" || shellType == "aspx" { - command = "move /y " + h.escapePath(oldPath) + " " + h.escapePath(newPath) - } else { - command = "mv " + h.escapePath(oldPath) + " " + h.escapePath(newPath) - } - case "upload": - path := strings.TrimSpace(req.Path) - if path == "" { - c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for upload"}) - return - } - if len(req.Content) > 512*1024 { - c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "upload content too large (max 512KB base64)"}) - return - } - // base64 仅含 A-Za-z0-9+/=,用单引号包裹安全 - command = "echo " + "'" + req.Content + "'" + " | base64 -d > " + h.escapePath(path) - case "upload_chunk": - path := strings.TrimSpace(req.Path) - if path == "" { - c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for upload_chunk"}) - return - } - redir := ">>" - if req.ChunkIndex == 0 { - redir = ">" - } - command = "echo " + "'" + req.Content + "'" + " | base64 -d " + redir + " " + h.escapePath(path) - default: - c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "unsupported action: " + req.Action}) + } + + command, cmdErr := h.buildFileCommand(fileCommandInput{ + Action: req.Action, + Path: req.Path, + TargetPath: req.TargetPath, + Content: req.Content, + ChunkIndex: req.ChunkIndex, + OS: osTag, + ShellType: req.Type, + }) + if cmdErr != nil { + c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: cmdErr.Error()}) return } @@ -585,27 +830,15 @@ func (h *WebShellHandler) FileOp(c *gin.Context) { if readErr != nil { h.logger.Warn("webshell fileop read body", zap.Error(readErr)) } - output := string(out) + output := decodeWebshellOutput(out, req.Encoding) c.JSON(http.StatusOK, FileOpResponse{ - OK: resp.StatusCode == http.StatusOK, - Output: output, + OK: resp.StatusCode == http.StatusOK, + Output: output, + DetectedOS: detectedOS, }) } -func (h *WebShellHandler) escapePath(p string) string { - if p == "" { - return "." - } - // 简单转义空格与敏感字符,避免命令注入 - return "'" + strings.ReplaceAll(p, "'", "'\\''") + "'" -} - -func (h *WebShellHandler) escapeForEcho(s string) string { - // 仅用于 write:base64 写入更安全,这里简单用单引号包裹 - return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'" -} - // ExecWithConnection 在指定 WebShell 连接上执行命令(供 MCP/Agent 等非 HTTP 调用) func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection, command string) (output string, ok bool, errMsg string) { if conn == nil { @@ -643,7 +876,7 @@ func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection, if readErr != nil { h.logger.Warn("webshell ExecWithConnection read body", zap.Error(readErr)) } - return string(out), resp.StatusCode == http.StatusOK, "" + return decodeWebshellOutput(out, conn.Encoding), resp.StatusCode == http.StatusOK, "" } // FileOpWithConnection 在指定 WebShell 连接上执行文件操作(供 MCP/Agent 调用),支持 list / read / write @@ -652,40 +885,38 @@ func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection return "", false, "connection is nil" } action = strings.ToLower(strings.TrimSpace(action)) - shellType := strings.ToLower(strings.TrimSpace(conn.Type)) - if shellType == "" { - shellType = "php" - } - var command string + // MCP 入口仅开放 list / read / write 三种动作,与工具文档的承诺保持一致 switch action { - case "list": - if path == "" { - path = "." - } - if shellType == "asp" || shellType == "aspx" { - command = "dir " + h.escapePath(strings.TrimSpace(path)) - } else { - command = "ls -la " + h.escapePath(strings.TrimSpace(path)) - } - case "read": - path = strings.TrimSpace(path) - if path == "" { - return "", false, "path is required for read" - } - if shellType == "asp" || shellType == "aspx" { - command = "type " + h.escapePath(path) - } else { - command = "cat " + h.escapePath(path) - } - case "write": - path = strings.TrimSpace(path) - if path == "" { - return "", false, "path is required for write" - } - command = "echo " + h.escapeForEcho(content) + " > " + h.escapePath(path) + case "list", "read", "write": + // 支持的动作 default: return "", false, "unsupported action: " + action + " (supported: list, read, write)" } + + // 若连接的 OS 为 auto,先探活并持久化,避免 AI/MCP 每次都对 Windows 发 `ls -la` + osTag := conn.OS + if normalizeWebshellOS(osTag) == "auto" { + if probed := probeWebshellOSViaExec(func(cmd string) (string, bool) { + out, exOk, _ := h.ExecWithConnection(conn, cmd) + return out, exOk + }); probed != "" { + osTag = probed + conn.OS = probed // 本次请求内使用探活结果 + h.persistDetectedOS(conn.ID, probed) + } + } + + command, cmdErr := h.buildFileCommand(fileCommandInput{ + Action: action, + Path: path, + TargetPath: targetPath, + Content: content, + OS: osTag, + ShellType: conn.Type, + }) + if cmdErr != nil { + return "", false, cmdErr.Error() + } useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET" cmdParam := strings.TrimSpace(conn.CmdParam) if cmdParam == "" { @@ -714,5 +945,5 @@ func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection if readErr != nil { h.logger.Warn("webshell FileOpWithConnection read body", zap.Error(readErr)) } - return string(out), resp.StatusCode == http.StatusOK, "" + return decodeWebshellOutput(out, conn.Encoding), resp.StatusCode == http.StatusOK, "" } diff --git a/internal/handler/webshell_context.go b/internal/handler/webshell_context.go new file mode 100644 index 00000000..17541f5a --- /dev/null +++ b/internal/handler/webshell_context.go @@ -0,0 +1,106 @@ +package handler + +import ( + "strings" + + "cyberstrike-ai/internal/database" +) + +// WebshellSkillHintDefault 对话页 / Eino 单代理共用的 Skills 说明,放在 webshell 上下文末尾, +// 供 AI 选择 skill 加载入口时参考。 +const WebshellSkillHintDefault = "Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。" + +// WebshellSkillHintMultiAgent 多代理 / Eino 多代理准备阶段使用的 Skills 说明 +const WebshellSkillHintMultiAgent = "Skills 包请使用 Eino 多代理内置 `skill` 工具。" + +// webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。 +// 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。 +const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base" + +// BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。 +// 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、 +// 以及最终的用户请求。调用方只需要决定 skillHint 的文案(默认使用 WebshellSkillHintDefault)。 +// +// 之所以把这段逻辑抽到共享函数里,是为了避免 agent.go / multi_agent_prepare.go 等多处复制粘贴, +// 并确保当我们升级 OS / Encoding 文案时只需要改一处、测一处、同步生效。 +func BuildWebshellAssistantContext(conn *database.WebShellConnection, skillHint, userMsg string) string { + if conn == nil { + // 兜底:调用方已保证 conn 非 nil,这里只是防御性返回原消息 + return userMsg + } + remark := conn.Remark + if remark == "" { + remark = conn.URL + } + + targetOS := resolveWebshellOS(conn.OS, conn.Type) // 归一为 "linux" / "windows" + encoding := normalizeWebshellEncoding(conn.Encoding) + if skillHint == "" { + skillHint = WebshellSkillHintDefault + } + + var b strings.Builder + b.Grow(512 + len(userMsg)) + + b.WriteString("[WebShell 助手上下文] 连接 ID:") + b.WriteString(conn.ID) + b.WriteString(",备注:") + b.WriteString(remark) + b.WriteByte('\n') + + // 目标系统:明确告诉 AI 能用/不能用的命令集,避免它对着 Windows 发 ls/cat/rm + b.WriteString("- 目标系统:") + b.WriteString(describeTargetOSForPrompt(targetOS)) + b.WriteByte('\n') + + // 响应编码:仅在非 auto 时显式告知,auto 模式由后端自适应,不打扰模型 + if encHint := describeEncodingForPrompt(encoding); encHint != "" { + b.WriteString("- 响应编码:") + b.WriteString(encHint) + b.WriteByte('\n') + } + + // 工具清单 & connection_id 约束:保持旧有表达,AI 已熟悉 + b.WriteString("可用工具(仅在该连接上操作时使用,connection_id 填 \"") + b.WriteString(conn.ID) + b.WriteString("\"):") + b.WriteString(webshellAssistantToolList) + b.WriteString("。") + b.WriteString(skillHint) + b.WriteString("\n\n用户请求:") + b.WriteString(userMsg) + + return b.String() +} + +// describeTargetOSForPrompt 返回某个 OS 对应的中文描述 + 推荐命令集 + 反例, +// 命令列表覆盖文件管理最常用的 6 类动作(查看/读/删/改名/建目录/查找),让 AI 能直接照抄。 +func describeTargetOSForPrompt(targetOS string) string { + switch targetOS { + case "windows": + return "Windows(推荐 cmd/PowerShell:dir /a、type、del /q /f、move /y、md、ren;" + + "查找文件用 `dir /s /b 过滤词` 或 PowerShell `Get-ChildItem -Recurse`;" + + "避免 ls / cat / rm / mv / find 等 Unix 命令,否则将返回 `不是内部或外部命令`)" + case "linux": + return "Linux/Unix(推荐 sh/bash:ls -la、cat、rm -f、mv、mkdir -p;" + + "查找文件用 `find /path -name '*pattern*'`;" + + "避免 dir、type、del、move 等 Windows 命令)" + default: + // 理论上不会走到这里,resolveWebshellOS 已经兜底 + return "未知(请先执行 `uname || ver` 探测再决定命令集)" + } +} + +// describeEncodingForPrompt 返回响应编码的人类可读描述;auto 返回空串以减少 token。 +func describeEncodingForPrompt(encoding string) string { + switch encoding { + case "utf-8": + return "UTF-8(目标原生 UTF-8,无需额外解码)" + case "gbk": + return "GBK(中文 Windows;后端已自动转码为 UTF-8 返回,若仍出现大量 \\uFFFD 替换字符说明命令失败或编码识别错误)" + case "gb18030": + return "GB18030(后端已自动转码为 UTF-8 返回)" + default: + return "" + } +} diff --git a/internal/handler/webshell_context_test.go b/internal/handler/webshell_context_test.go new file mode 100644 index 00000000..743c1a9e --- /dev/null +++ b/internal/handler/webshell_context_test.go @@ -0,0 +1,170 @@ +package handler + +import ( + "strings" + "testing" + + "cyberstrike-ai/internal/database" +) + +func TestBuildWebshellAssistantContext_WindowsExplicit(t *testing.T) { + conn := &database.WebShellConnection{ + ID: "ws_win01", + Remark: "IIS Windows 靶机", + URL: "http://example.com/shell.php", + Type: "php", + OS: "windows", + Encoding: "gbk", + } + got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "列出当前目录并告诉我 flag 在哪") + + mustContain(t, got, + "[WebShell 助手上下文]", + "ws_win01", + "IIS Windows 靶机", + "目标系统:Windows", + "dir /a", + "move /y", + "避免 ls / cat / rm", + "响应编码:GBK", + "后端已自动转码为 UTF-8", + "connection_id 填 \"ws_win01\"", + "webshell_exec、webshell_file_list", + WebshellSkillHintDefault, + "用户请求:列出当前目录并告诉我 flag 在哪", + ) + // Windows 场景下不应出现 Linux 命令推荐 + mustNotContain(t, got, "推荐 sh/bash") +} + +func TestBuildWebshellAssistantContext_LinuxAutoFromPHP(t *testing.T) { + conn := &database.WebShellConnection{ + ID: "ws_lnx01", + Remark: "", // 测试备注为空时 fallback URL + URL: "http://example.com/a.php", + Type: "php", + OS: "auto", // auto + php → linux + Encoding: "", // auto 编码不显式提示 + } + got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "看看 /etc/passwd") + + mustContain(t, got, + "连接 ID:ws_lnx01", + "备注:http://example.com/a.php", // 备注空时 fallback URL + "目标系统:Linux/Unix", + "ls -la", + "mkdir -p", + "避免 dir、type、del、move", + "用户请求:看看 /etc/passwd", + ) + // encoding=auto 不应出现"响应编码:"这一行 + mustNotContain(t, got, "响应编码:") + // Linux 场景不应出现 Windows 命令 + mustNotContain(t, got, "推荐 cmd/PowerShell") +} + +func TestBuildWebshellAssistantContext_AutoFromASPDefaultsToWindows(t *testing.T) { + // 保留向后兼容:旧连接没配 os,shellType=asp 时应视为 Windows + conn := &database.WebShellConnection{ + ID: "ws_asp01", + Remark: "老 ASP 靶机", + Type: "asp", + OS: "", // 空串等同 auto + Encoding: "gb18030", + } + got := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, "查当前用户") + + mustContain(t, got, + "目标系统:Windows", + "响应编码:GB18030", + "后端已自动转码为 UTF-8 返回", + WebshellSkillHintMultiAgent, + ) + // 多代理 skill 文案里没有 DeepAgent,不应混入 default 文案 + mustNotContain(t, got, "DeepAgent") +} + +func TestBuildWebshellAssistantContext_MultiAgentSkillHint(t *testing.T) { + conn := &database.WebShellConnection{ID: "ws_m1", Remark: "x", Type: "php", OS: "linux"} + got := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, "hi") + mustContain(t, got, WebshellSkillHintMultiAgent) + mustNotContain(t, got, "DeepAgent") +} + +func TestBuildWebshellAssistantContext_DefaultSkillHintFallback(t *testing.T) { + conn := &database.WebShellConnection{ID: "ws_d1", Remark: "x", Type: "php", OS: "linux"} + // skillHint 传空字符串时应回退到 default + got := BuildWebshellAssistantContext(conn, "", "hi") + mustContain(t, got, WebshellSkillHintDefault) +} + +func TestBuildWebshellAssistantContext_UTF8EncodingIsAnnotated(t *testing.T) { + conn := &database.WebShellConnection{ + ID: "ws_u1", Remark: "u", Type: "jsp", OS: "linux", Encoding: "utf-8", + } + got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "hi") + mustContain(t, got, "响应编码:UTF-8", "目标原生 UTF-8") +} + +func TestBuildWebshellAssistantContext_NilConnReturnsUserMsg(t *testing.T) { + // 防御性:conn == nil 时不 panic,直接返回原消息 + got := BuildWebshellAssistantContext(nil, WebshellSkillHintDefault, "just the message") + if got != "just the message" { + t.Errorf("nil conn should return userMsg as-is, got %q", got) + } +} + +func TestDescribeTargetOSForPrompt(t *testing.T) { + cases := map[string][]string{ + "windows": {"Windows", "dir /a", "move /y", "PowerShell"}, + "linux": {"Linux/Unix", "ls -la", "mkdir -p"}, + "": {"未知", "uname"}, // 防御性分支 + } + for in, wants := range cases { + got := describeTargetOSForPrompt(in) + for _, w := range wants { + if !strings.Contains(got, w) { + t.Errorf("describeTargetOSForPrompt(%q) should contain %q, got: %s", in, w, got) + } + } + } +} + +func TestDescribeEncodingForPrompt(t *testing.T) { + cases := map[string]string{ + "utf-8": "UTF-8", + "gbk": "GBK", + "gb18030": "GB18030", + "auto": "", + "": "", + } + for in, want := range cases { + got := describeEncodingForPrompt(in) + if want == "" && got != "" { + t.Errorf("describeEncodingForPrompt(%q) should return empty string, got: %s", in, got) + } + if want != "" && !strings.Contains(got, want) { + t.Errorf("describeEncodingForPrompt(%q) should contain %q, got: %s", in, want, got) + } + } +} + +// ---- 小工具 ---- + +func mustContain(t *testing.T, text string, substrings ...string) { + t.Helper() + for _, s := range substrings { + if !strings.Contains(text, s) { + t.Errorf("expected text to contain %q\n--- text ---\n%s", s, text) + } + } +} + +func mustNotContain(t *testing.T, text string, substrings ...string) { + t.Helper() + for _, s := range substrings { + if strings.Contains(text, s) { + t.Errorf("text should not contain %q\n--- text ---\n%s", s, text) + } + } +} diff --git a/internal/handler/webshell_encoding_test.go b/internal/handler/webshell_encoding_test.go new file mode 100644 index 00000000..f246008a --- /dev/null +++ b/internal/handler/webshell_encoding_test.go @@ -0,0 +1,103 @@ +package handler + +import ( + "testing" + + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/transform" +) + +// mustEncode 使用指定编码对 UTF-8 字符串做编码,得到原始字节,用于构造测试输入 +func mustEncode(t *testing.T, s string, enc string) []byte { + t.Helper() + var tr transform.Transformer + switch enc { + case "gbk": + tr = simplifiedchinese.GBK.NewEncoder() + case "gb18030": + tr = simplifiedchinese.GB18030.NewEncoder() + default: + t.Fatalf("unsupported test encoding: %s", enc) + } + out, _, err := transform.Bytes(tr, []byte(s)) + if err != nil { + t.Fatalf("mustEncode(%s) failed: %v", enc, err) + } + return out +} + +func TestNormalizeWebshellEncoding(t *testing.T) { + cases := map[string]string{ + "": "auto", + " ": "auto", + "auto": "auto", + "AUTO": "auto", + "utf-8": "utf-8", + "UTF-8": "utf-8", + "utf8": "utf-8", + "gbk": "gbk", + "GBK": "gbk", + "gb18030": "gb18030", + "big5": "auto", // 未支持的回退到 auto + "anything": "auto", + } + for in, want := range cases { + if got := normalizeWebshellEncoding(in); got != want { + t.Errorf("normalizeWebshellEncoding(%q) = %q, want %q", in, got, want) + } + } +} + +func TestDecodeWebshellOutput_AutoDetectsGBK(t *testing.T) { + // 模拟 Windows 中文 cmd 输出的 GBK 字节流 + want := "用户名 SID 类型" + raw := mustEncode(t, want, "gbk") + + // auto 模式:UTF-8 校验失败后应当回退 GB18030 解码,得到原始中文 + got := decodeWebshellOutput(raw, "auto") + if got != want { + t.Errorf("decodeWebshellOutput(auto) = %q, want %q", got, want) + } + + // 显式 GBK 模式:同样应当正确解码 + got = decodeWebshellOutput(raw, "gbk") + if got != want { + t.Errorf("decodeWebshellOutput(gbk) = %q, want %q", got, want) + } + + // 显式 GB18030 模式:GBK 是 GB18030 子集,也应正确解码 + got = decodeWebshellOutput(raw, "gb18030") + if got != want { + t.Errorf("decodeWebshellOutput(gb18030) = %q, want %q", got, want) + } +} + +func TestDecodeWebshellOutput_PassthroughUTF8(t *testing.T) { + // 已经是 UTF-8 的中文字符串,各模式都应返回原串(不破坏) + want := "hello 世界" + for _, enc := range []string{"", "auto", "utf-8"} { + if got := decodeWebshellOutput([]byte(want), enc); got != want { + t.Errorf("decodeWebshellOutput(%q) passthrough = %q, want %q", enc, got, want) + } + } +} + +func TestDecodeWebshellOutput_ASCIIStable(t *testing.T) { + // 纯 ASCII 在任何模式下都必须保持原样 + want := "whoami\nAdministrator\n" + for _, enc := range []string{"", "auto", "utf-8", "gbk", "gb18030"} { + if got := decodeWebshellOutput([]byte(want), enc); got != want { + t.Errorf("decodeWebshellOutput(%q) ASCII = %q, want %q", enc, got, want) + } + } +} + +func TestDecodeWebshellOutput_EmptyInput(t *testing.T) { + // 空输入直接返回空串,不做额外分配 + if got := decodeWebshellOutput(nil, "gbk"); got != "" { + t.Errorf("decodeWebshellOutput(nil) = %q, want empty", got) + } + if got := decodeWebshellOutput([]byte{}, "auto"); got != "" { + t.Errorf("decodeWebshellOutput([]) = %q, want empty", got) + } +} diff --git a/internal/handler/webshell_os_test.go b/internal/handler/webshell_os_test.go new file mode 100644 index 00000000..5cf47b6b --- /dev/null +++ b/internal/handler/webshell_os_test.go @@ -0,0 +1,348 @@ +package handler + +import ( + "encoding/base64" + "strings" + "testing" + + "go.uber.org/zap" +) + +func newTestWebShellHandler() *WebShellHandler { + return NewWebShellHandler(zap.NewNop(), nil) +} + +func TestNormalizeWebshellOS(t *testing.T) { + cases := map[string]string{ + "": "auto", + " ": "auto", + "auto": "auto", + "AUTO": "auto", + "linux": "linux", + "Linux": "linux", + "windows": "windows", + "WINDOWS": "windows", + "macos": "auto", // 未支持的回退 auto + "solaris": "auto", + } + for in, want := range cases { + if got := normalizeWebshellOS(in); got != want { + t.Errorf("normalizeWebshellOS(%q) = %q, want %q", in, got, want) + } + } +} + +func TestResolveWebshellOS(t *testing.T) { + type testCase struct { + osTag string + shellType string + want string + } + cases := []testCase{ + // 显式 OS:按用户选择,忽略 shellType + {"linux", "asp", "linux"}, + {"windows", "php", "windows"}, + {"LINUX", "jsp", "linux"}, + + // auto + 各种 shellType:asp/aspx → windows,其他 → linux + {"auto", "asp", "windows"}, + {"auto", "aspx", "windows"}, + {"auto", "ASP", "windows"}, + {"auto", "php", "linux"}, + {"auto", "jsp", "linux"}, + {"auto", "custom", "linux"}, + {"auto", "", "linux"}, + + // 空/未知 OS 等价 auto + {"", "asp", "windows"}, + {"", "php", "linux"}, + {"unknown", "aspx", "windows"}, + } + for _, c := range cases { + got := resolveWebshellOS(c.osTag, c.shellType) + if got != c.want { + t.Errorf("resolveWebshellOS(%q,%q) = %q, want %q", c.osTag, c.shellType, got, c.want) + } + } +} + +func TestQuoteCmdPath(t *testing.T) { + cases := map[string]string{ + "": `"."`, + `C:\Windows\Temp`: `"C:\Windows\Temp"`, + `C:\Program Files\a`: `"C:\Program Files\a"`, + `C:\weird"name\f.txt`: `"C:\weird""name\f.txt"`, + `.`: `"."`, + } + for in, want := range cases { + if got := quoteCmdPath(in); got != want { + t.Errorf("quoteCmdPath(%q) = %q, want %q", in, got, want) + } + } +} + +func TestQuoteShellSinglePosix(t *testing.T) { + cases := map[string]string{ + "": ".", + "/tmp/a b": "'/tmp/a b'", + "/tmp/it's.txt": `'/tmp/it'\''s.txt'`, + } + for in, want := range cases { + if got := quoteShellSinglePosix(in); got != want { + t.Errorf("quoteShellSinglePosix(%q) = %q, want %q", in, got, want) + } + } +} + +// TestBuildFileCommand_LinuxBranch 覆盖 Linux 目标下每个 action 产出的命令 +func TestBuildFileCommand_LinuxBranch(t *testing.T) { + h := newTestWebShellHandler() + base := fileCommandInput{OS: "linux", ShellType: "php"} + + mustContain := func(t *testing.T, cmd string, substrings ...string) { + t.Helper() + for _, s := range substrings { + if !strings.Contains(cmd, s) { + t.Errorf("expected command to contain %q, got: %s", s, cmd) + } + } + } + mustNotContain := func(t *testing.T, cmd string, substrings ...string) { + t.Helper() + for _, s := range substrings { + if strings.Contains(cmd, s) { + t.Errorf("command should not contain %q, got: %s", s, cmd) + } + } + } + + // list with empty path defaults to '.' + in := base + in.Action = "list" + cmd, err := h.buildFileCommand(in) + if err != nil { + t.Fatalf("list linux: unexpected err: %v", err) + } + mustContain(t, cmd, "ls -la", "'.'") + + // list with path containing spaces + in.Path = "/tmp/my files" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "ls -la ", "'/tmp/my files'") + + // read with path + in = base + in.Action = "read" + in.Path = "/etc/passwd" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "cat ", "'/etc/passwd'") + + // read without path → error + in.Path = "" + if _, err := h.buildFileCommand(in); err != errFileOpPathRequired { + t.Errorf("read empty path: want errFileOpPathRequired, got %v", err) + } + + // delete + in = base + in.Action = "delete" + in.Path = "/tmp/a.txt" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "rm -f ", "'/tmp/a.txt'") + mustNotContain(t, cmd, "del") + + // mkdir + in.Action = "mkdir" + in.Path = "/tmp/new/sub" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "mkdir -p ", "'/tmp/new/sub'") + + // rename + in = base + in.Action = "rename" + in.Path = "/tmp/a" + in.TargetPath = "/tmp/b" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "mv -f ", "'/tmp/a'", "'/tmp/b'") + + // rename missing target → error + in.TargetPath = "" + if _, err := h.buildFileCommand(in); err != errFileOpRenameNeedsBothPaths { + t.Errorf("rename empty target: want errFileOpRenameNeedsBothPaths, got %v", err) + } + + // write + in = base + in.Action = "write" + in.Path = "/tmp/w.txt" + in.Content = "hello 世界" + cmd, _ = h.buildFileCommand(in) + b64 := base64.StdEncoding.EncodeToString([]byte("hello 世界")) + mustContain(t, cmd, "echo '"+b64+"'", "| base64 -d", "> '/tmp/w.txt'") + + // upload + in = base + in.Action = "upload" + in.Path = "/tmp/bin" + in.Content = "YWJjZA==" // base64 of "abcd" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "echo 'YWJjZA=='", "| base64 -d", "> '/tmp/bin'") + + // upload oversized content → error + in.Content = strings.Repeat("A", 513*1024) + if _, err := h.buildFileCommand(in); err != errFileOpUploadTooLarge { + t.Errorf("upload too large: want errFileOpUploadTooLarge, got %v", err) + } + + // upload_chunk with chunk_index=0 uses single redirect + in = base + in.Action = "upload_chunk" + in.Path = "/tmp/bin" + in.Content = "YWJj" + in.ChunkIndex = 0 + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "base64 -d > '/tmp/bin'") + mustNotContain(t, cmd, ">>") + + // upload_chunk with chunk_index>0 uses append redirect + in.ChunkIndex = 1 + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "base64 -d >> '/tmp/bin'") + + // unsupported action + in = base + in.Action = "nope" + if _, err := h.buildFileCommand(in); err == nil || !strings.Contains(err.Error(), "unsupported action") { + t.Errorf("unknown action: want unsupported action error, got %v", err) + } +} + +// TestBuildFileCommand_WindowsBranch 覆盖 Windows 目标下每个 action 产出的命令 +func TestBuildFileCommand_WindowsBranch(t *testing.T) { + h := newTestWebShellHandler() + base := fileCommandInput{OS: "windows", ShellType: "php"} + + mustContain := func(t *testing.T, cmd string, substrings ...string) { + t.Helper() + for _, s := range substrings { + if !strings.Contains(cmd, s) { + t.Errorf("expected command to contain %q, got: %s", s, cmd) + } + } + } + mustNotContain := func(t *testing.T, cmd string, substrings ...string) { + t.Helper() + for _, s := range substrings { + if strings.Contains(cmd, s) { + t.Errorf("command should not contain %q, got: %s", s, cmd) + } + } + } + + // list + in := base + in.Action = "list" + cmd, _ := h.buildFileCommand(in) + mustContain(t, cmd, "dir /a ", `"."`) + mustNotContain(t, cmd, "ls -la") + + in.Path = `C:\Users\Public Docs` + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "dir /a ", `"C:\Users\Public Docs"`) + + // read + in = base + in.Action = "read" + in.Path = `C:\flag.txt` + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "type ", `"C:\flag.txt"`) + + // delete + in.Action = "delete" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "del /q /f ", `"C:\flag.txt"`) + mustNotContain(t, cmd, "rm -f") + + // mkdir + in.Action = "mkdir" + in.Path = `C:\a\b\c` + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "md ", `"C:\a\b\c"`) + + // rename + in = base + in.Action = "rename" + in.Path = `C:\a.txt` + in.TargetPath = `C:\b.txt` + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "move /y ", `"C:\a.txt"`, `"C:\b.txt"`) + + // write → PowerShell base64 one-liner + in = base + in.Action = "write" + in.Path = `C:\out.txt` + in.Content = "hello 世界" + cmd, _ = h.buildFileCommand(in) + wantB64 := base64.StdEncoding.EncodeToString([]byte("hello 世界")) + mustContain(t, cmd, + "powershell -NoProfile -NonInteractive -Command", + "[Convert]::FromBase64String('"+wantB64+"')", + "[IO.File]::WriteAllBytes('C:\\out.txt'", + ) + mustNotContain(t, cmd, "echo ", "base64 -d") + + // upload (chunk_index=0 equivalent) uses WriteAllBytes + in = base + in.Action = "upload" + in.Path = `C:\bin\f` + in.Content = "YWJjZA==" + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "WriteAllBytes('C:\\bin\\f'", "FromBase64String('YWJjZA==')") + + // upload_chunk index=0 → WriteAllBytes + in.Action = "upload_chunk" + in.ChunkIndex = 0 + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "WriteAllBytes(") + mustNotContain(t, cmd, "FileMode]::Append") + + // upload_chunk index>0 → append (Open with Append mode) + in.ChunkIndex = 1 + cmd, _ = h.buildFileCommand(in) + mustContain(t, cmd, "[IO.FileMode]::Append", "FromBase64String('YWJjZA==')") +} + +// TestBuildFileCommand_AutoFallbackMatchesLegacyBehavior 确保 os=auto 时与旧版 shellType 判定行为完全一致 +// asp/aspx 视为 Windows(旧行为),其他视为 Linux。 +func TestBuildFileCommand_AutoFallbackMatchesLegacyBehavior(t *testing.T) { + h := newTestWebShellHandler() + + // asp + auto → windows 命令 + cmd, _ := h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: "asp"}) + if !strings.Contains(cmd, "dir /a") { + t.Errorf("auto + asp should use Windows cmd, got: %s", cmd) + } + + cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: "aspx"}) + if !strings.Contains(cmd, "dir /a") { + t.Errorf("auto + aspx should use Windows cmd, got: %s", cmd) + } + + // php/jsp/custom + auto → linux 命令(与历史行为一致) + for _, st := range []string{"php", "jsp", "custom", ""} { + cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: st}) + if !strings.Contains(cmd, "ls -la") { + t.Errorf("auto + %q should use Linux cmd, got: %s", st, cmd) + } + } + + // 显式 OS 覆盖 shellType + cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "windows", ShellType: "php"}) + if !strings.Contains(cmd, "dir /a") { + t.Errorf("explicit windows should override php shellType, got: %s", cmd) + } + cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "linux", ShellType: "asp"}) + if !strings.Contains(cmd, "ls -la") { + t.Errorf("explicit linux should override asp shellType, got: %s", cmd) + } +} diff --git a/internal/handler/webshell_probe.go b/internal/handler/webshell_probe.go new file mode 100644 index 00000000..75917206 --- /dev/null +++ b/internal/handler/webshell_probe.go @@ -0,0 +1,127 @@ +package handler + +import ( + "bytes" + "io" + "net/http" + "strings" + + "go.uber.org/zap" +) + +// webshellOSProbeCommand 探活命令:利用 Windows cmd 与 POSIX shell 对 `%OS%` 展开差异进行判定。 +// - Windows cmd:`%OS%` 被展开为 `Windows_NT`,回显 `:OSPROBE_Windows_NT:END` +// - POSIX sh/bash:`%OS%` 不是变量语法,作为字面量原样保留,回显 `:OSPROBE_%OS%:END` +// +// 一条命令即可得到明确的、互斥的信号,避免探活成本(相比发两次命令)。 +// 冒号包裹是为了避免部分 shell 输出多余空白/BOM 时字符串匹配失效。 +const webshellOSProbeCommand = "echo :OSPROBE_%OS%:END" + +// probeWebshellOSViaExec 通过一次命令执行的回显推断目标操作系统。 +// +// 返回值: +// - "windows" / "linux":识别成功 +// - "":无法判定(调用方应保留既有 fallback 逻辑) +// +// 入参 execFn 是一个"发命令并拿到回显"的闭包;让 HTTP 入口和 MCP 入口可以共用同一套探活逻辑 +// 而不必关心底层是如何发包的。 +func probeWebshellOSViaExec(execFn func(cmd string) (output string, ok bool)) string { + if execFn == nil { + return "" + } + out, ok := execFn(webshellOSProbeCommand) + if !ok { + return "" + } + return classifyWebshellOSProbeOutput(out) +} + +// classifyWebshellOSProbeOutput 纯函数:根据探活命令的回显判定 OS。 +// 抽出来是为了单测可直接覆盖所有分支,无需真实 HTTP 调用。 +func classifyWebshellOSProbeOutput(out string) string { + if out == "" { + return "" + } + lower := strings.ToLower(out) + + // Windows 强信号:cmd.exe 成功展开了 %OS% 变量 + if strings.Contains(out, "Windows_NT") { + return "windows" + } + // 容错:部分老版本 Windows 可能 `%OS%` 展开为其他字样(极少见),再看 PATH/OS 等次级线索 + if strings.Contains(lower, "microsoft windows") { + return "windows" + } + + // Linux/Unix 强信号:`%OS%` 字面量被原样回显,说明 shell 不是 cmd.exe + if strings.Contains(out, "%OS%") { + return "linux" + } + + // 次级线索:部分 webshell 在 Linux 上可能走了其他外壳(如 zsh/ash), + // 但它们对 `%OS%` 同样不展开;若命中 OSPROBE 头部却没拿到 %OS% 字面量, + // 说明回显被中途截断或过滤,保守返回空让上层 fallback。 + return "" +} + +// newHTTPExecFn 为 HTTP FileOp 路径构造"发命令取回显"的闭包,供探活复用。 +// 参数来自 HTTP 请求,复用 buildExecURL / buildExecBody 两个已有的命令编排器, +// 确保探活包与实际文件操作包走完全一致的 webshell 协议(GET/POST、参数名、编码)。 +func (h *WebShellHandler) newHTTPExecFn(targetURL, password, shellType, method, cmdParam, encoding string) func(string) (string, bool) { + useGET := strings.ToUpper(strings.TrimSpace(method)) == "GET" + if strings.TrimSpace(cmdParam) == "" { + cmdParam = "cmd" + } + return func(cmd string) (string, bool) { + var ( + httpReq *http.Request + err error + ) + if useGET { + u := h.buildExecURL(targetURL, shellType, password, cmdParam, cmd) + httpReq, err = http.NewRequest(http.MethodGet, u, nil) + } else { + body := h.buildExecBody(shellType, password, cmdParam, cmd) + httpReq, err = http.NewRequest(http.MethodPost, targetURL, bytes.NewReader(body)) + if err == nil { + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + } + if err != nil { + return "", false + } + httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)") + resp, err := h.client.Do(httpReq) + if err != nil { + return "", false + } + defer resp.Body.Close() + raw, _ := io.ReadAll(resp.Body) + return decodeWebshellOutput(raw, encoding), resp.StatusCode == http.StatusOK + } +} + +// persistDetectedOS 把探活结果回写到连接表;失败只记日志不阻断主流程。 +// 设计上故意只触发 UPDATE,不会新建记录,因此即便 connectionID 不存在也只是悄悄放弃。 +func (h *WebShellHandler) persistDetectedOS(connectionID, detected string) { + connectionID = strings.TrimSpace(connectionID) + detected = normalizeWebshellOS(detected) + if connectionID == "" || detected == "" || detected == "auto" { + return + } + conn, err := h.db.GetWebshellConnection(connectionID) + if err != nil || conn == nil { + // 不是所有调用方都能提供有效 ID(比如临时测试),这里静默返回 + return + } + if normalizeWebshellOS(conn.OS) != "auto" { + // 用户已经显式选过 OS,尊重用户选择,不自动覆盖 + return + } + conn.OS = detected + if err := h.db.UpdateWebshellConnection(conn); err != nil { + h.logger.Warn("webshell 探活结果持久化失败", zap.String("id", connectionID), zap.String("os", detected), zap.Error(err)) + return + } + h.logger.Info("webshell auto OS 探活成功并持久化", zap.String("id", connectionID), zap.String("os", detected)) +} diff --git a/internal/handler/webshell_probe_test.go b/internal/handler/webshell_probe_test.go new file mode 100644 index 00000000..03917315 --- /dev/null +++ b/internal/handler/webshell_probe_test.go @@ -0,0 +1,68 @@ +package handler + +import "testing" + +func TestClassifyWebshellOSProbeOutput(t *testing.T) { + cases := []struct { + name string + in string + want string + }{ + {"Windows cmd 回显完整", ":OSPROBE_Windows_NT:END\r\n", "windows"}, + {"Windows cmd 回显带额外空行", "\r\n:OSPROBE_Windows_NT:END\r\n", "windows"}, + {"Windows 次级线索 - ver banner", "Microsoft Windows [版本 10.0.19045]\r\n", "windows"}, + {"Linux sh 字面量回显", ":OSPROBE_%OS%:END\n", "linux"}, + {"Linux 紧凑输出(无换行)", ":OSPROBE_%OS%:END", "linux"}, + {"空输出 - 无法判定", "", ""}, + {"被过滤的输出 - 无法判定", "something weird", ""}, + {"仅有 OSPROBE 前缀但被截断 - 保守返回空", ":OSPROBE_:END", ""}, + } + for _, c := range cases { + if got := classifyWebshellOSProbeOutput(c.in); got != c.want { + t.Errorf("case %q: got %q, want %q", c.name, got, c.want) + } + } +} + +func TestProbeWebshellOSViaExec_SendsOneCommandOnly(t *testing.T) { + var calls []string + fn := func(cmd string) (string, bool) { + calls = append(calls, cmd) + return ":OSPROBE_Windows_NT:END", true + } + got := probeWebshellOSViaExec(fn) + if got != "windows" { + t.Fatalf("want windows, got %q", got) + } + if len(calls) != 1 { + t.Fatalf("probe should issue exactly one exec call, got %d: %v", len(calls), calls) + } + if calls[0] != webshellOSProbeCommand { + t.Errorf("probe command mismatch: got %q", calls[0]) + } +} + +func TestProbeWebshellOSViaExec_NotOkReturnsEmpty(t *testing.T) { + // HTTP 非 200 的场景:execFn 返回 ok=false,探活应放弃 + fn := func(cmd string) (string, bool) { return "whatever", false } + if got := probeWebshellOSViaExec(fn); got != "" { + t.Errorf("want empty when exec not ok, got %q", got) + } +} + +func TestProbeWebshellOSViaExec_NilSafeguard(t *testing.T) { + if got := probeWebshellOSViaExec(nil); got != "" { + t.Errorf("nil execFn should return empty, got %q", got) + } +} + +func TestProbeWebshellOSViaExec_LinuxUname(t *testing.T) { + // 某些 webshell 对 `%OS%` 字面量也会过滤(例如安全规则), + // 但主要路径是"%OS% 字面量被原样回显"。这里覆盖标准 Linux 场景。 + fn := func(cmd string) (string, bool) { + return ":OSPROBE_%OS%:END\n", true + } + if got := probeWebshellOSViaExec(fn); got != "linux" { + t.Errorf("Linux case: want linux, got %q", got) + } +}