From e606369e31ba364c533c24817ab392551d488374 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Wed, 3 Jun 2026 17:16:48 +0800 Subject: [PATCH] Add files via upload --- internal/handler/agent.go | 2 +- internal/handler/config.go | 147 ++++++++++++++++ internal/handler/openapi.go | 92 +++++++++- internal/project/vision_image_prompt.go | 22 +++ internal/vision/client.go | 132 +++++++++++++++ internal/vision/client_test.go | 12 ++ internal/vision/path.go | 142 ++++++++++++++++ internal/vision/path_test.go | 43 +++++ internal/vision/preprocess.go | 212 ++++++++++++++++++++++++ internal/vision/preprocess_test.go | 109 ++++++++++++ internal/vision/tool.go | 130 +++++++++++++++ 11 files changed, 1041 insertions(+), 2 deletions(-) create mode 100644 internal/project/vision_image_prompt.go create mode 100644 internal/vision/client.go create mode 100644 internal/vision/client_test.go create mode 100644 internal/vision/path.go create mode 100644 internal/vision/path_test.go create mode 100644 internal/vision/preprocess.go create mode 100644 internal/vision/preprocess_test.go create mode 100644 internal/vision/tool.go diff --git a/internal/handler/agent.go b/internal/handler/agent.go index 1a600706..46685162 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -491,7 +491,7 @@ func appendAttachmentsToMessage(msg string, attachments []ChatAttachment, savedP } var b strings.Builder b.WriteString(msg) - b.WriteString("\n\n[用户上传的文件已保存到以下路径(请按需读取文件内容,而不是依赖内联内容)]\n") + b.WriteString("\n\n[用户上传的文件]\n") for i, a := range attachments { if i < len(savedPaths) && savedPaths[i] != "" { b.WriteString(fmt.Sprintf("- %s: %s\n", a.FileName, savedPaths[i])) diff --git a/internal/handler/config.go b/internal/handler/config.go index 212296ac..7944a61b 100644 --- a/internal/handler/config.go +++ b/internal/handler/config.go @@ -237,6 +237,7 @@ func (h *ConfigHandler) ApplyWechatRobotBinding(wc config.RobotWechatConfig) err // GetConfigResponse 获取配置响应 type GetConfigResponse struct { OpenAI config.OpenAIConfig `json:"openai"` + Vision config.VisionConfig `json:"vision"` FOFA config.FofaConfig `json:"fofa"` MCP config.MCPConfig `json:"mcp"` Tools []ToolConfigInfo `json:"tools"` @@ -333,6 +334,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) { c.JSON(http.StatusOK, GetConfigResponse{ OpenAI: h.config.OpenAI, + Vision: h.config.Vision, FOFA: h.config.FOFA, MCP: h.config.MCP, Tools: tools, @@ -638,6 +640,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) { // UpdateConfigRequest 更新配置请求 type UpdateConfigRequest struct { OpenAI *config.OpenAIConfig `json:"openai,omitempty"` + Vision *config.VisionConfig `json:"vision,omitempty"` FOFA *config.FofaConfig `json:"fofa,omitempty"` MCP *config.MCPConfig `json:"mcp,omitempty"` Tools []ToolEnableStatus `json:"tools,omitempty"` @@ -707,6 +710,14 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) { ) } + if req.Vision != nil { + h.config.Vision = *req.Vision + h.logger.Info("更新 Vision 配置", + zap.Bool("enabled", h.config.Vision.Enabled), + zap.String("model", h.config.Vision.Model), + ) + } + // 更新FOFA配置 if req.FOFA != nil { h.config.FOFA = *req.FOFA @@ -1031,6 +1042,99 @@ func (h *ConfigHandler) TestOpenAI(c *gin.Context) { }) } +// TestVisionRequest 测试 Vision 模型连接;vision.api_key/base_url 留空时可传 openai 段作回退。 +type TestVisionRequest struct { + Vision config.VisionConfig `json:"vision"` + OpenAI config.OpenAIConfig `json:"openai,omitempty"` +} + +// TestVision 测试视觉模型 API 连接(最小 chat completion)。 +func (h *ConfigHandler) TestVision(c *gin.Context) { + var req TestVisionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()}) + return + } + oa := req.Vision.OpenAICfgEffective(req.OpenAI) + if strings.TrimSpace(oa.APIKey) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "API Key 不能为空(可填写 vision.api_key 或 openai.api_key)"}) + return + } + if strings.TrimSpace(oa.Model) == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "视觉模型不能为空"}) + return + } + + baseURL := strings.TrimSuffix(strings.TrimSpace(oa.BaseURL), "/") + if baseURL == "" { + if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") { + baseURL = "https://api.anthropic.com" + } else { + baseURL = "https://api.openai.com/v1" + } + } + + payload := map[string]interface{}{ + "model": oa.Model, + "messages": []map[string]string{ + {"role": "user", "content": "Hi"}, + }, + "max_completion_tokens": 5, + } + + tmpCfg := &config.OpenAIConfig{ + Provider: oa.Provider, + BaseURL: baseURL, + APIKey: strings.TrimSpace(oa.APIKey), + Model: oa.Model, + } + client := openai.NewClient(tmpCfg, nil, h.logger) + + ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second) + defer cancel() + + start := time.Now() + var chatResp struct { + Model string `json:"model"` + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + err := client.ChatCompletion(ctx, payload, &chatResp) + latency := time.Since(start) + + if err != nil { + if apiErr, ok := err.(*openai.APIError); ok { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "error": fmt.Sprintf("API 返回错误 (HTTP %d): %s", apiErr.StatusCode, apiErr.Body), + "status_code": apiErr.StatusCode, + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": false, + "error": "连接失败: " + err.Error(), + }) + return + } + if len(chatResp.Choices) == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "error": "API 响应缺少 choices 字段,请检查 Base URL 与视觉模型名称", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "model": chatResp.Model, + "latency_ms": latency.Milliseconds(), + }) +} + // ApplyConfig 应用配置(重新加载并重启相关服务) func (h *ConfigHandler) ApplyConfig(c *gin.Context) { // 先检查是否需要动态初始化知识库(在锁外执行,避免阻塞其他请求) @@ -1286,6 +1390,7 @@ func (h *ConfigHandler) saveConfig() error { updateAgentConfig(root, h.config.Agent) updateMCPConfig(root, h.config.MCP) updateOpenAIConfig(root, h.config.OpenAI) + updateVisionConfig(root, h.config.Vision) updateFOFAConfig(root, h.config.FOFA) updateKnowledgeConfig(root, h.config.Knowledge) updateC2Config(root, h.config.C2) @@ -1406,6 +1511,48 @@ func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) { setIntInMap(mcpNode, "port", cfg.Port) } +func updateVisionConfig(doc *yaml.Node, cfg config.VisionConfig) { + root := doc.Content[0] + visionNode := ensureMap(root, "vision") + setBoolInMap(visionNode, "enabled", cfg.Enabled) + if strings.TrimSpace(cfg.APIKey) != "" { + setStringInMap(visionNode, "api_key", cfg.APIKey) + } else { + setStringInMap(visionNode, "api_key", "") + } + if strings.TrimSpace(cfg.BaseURL) != "" { + setStringInMap(visionNode, "base_url", cfg.BaseURL) + } else { + setStringInMap(visionNode, "base_url", "") + } + setStringInMap(visionNode, "model", cfg.Model) + if strings.TrimSpace(cfg.Provider) != "" { + setStringInMap(visionNode, "provider", cfg.Provider) + } + if cfg.TimeoutSeconds > 0 { + setIntInMap(visionNode, "timeout_seconds", cfg.TimeoutSeconds) + } + if cfg.MaxImageBytes > 0 { + setIntInMap(visionNode, "max_image_bytes", int(cfg.MaxImageBytes)) + } + if cfg.MaxDimension > 0 { + setIntInMap(visionNode, "max_dimension", cfg.MaxDimension) + } + if cfg.JPEGQuality > 0 { + setIntInMap(visionNode, "jpeg_quality", cfg.JPEGQuality) + } + if cfg.MaxPayloadBytes > 0 { + setIntInMap(visionNode, "max_payload_bytes", int(cfg.MaxPayloadBytes)) + } + setIntInMap(visionNode, "skip_preprocess_below_bytes", int(cfg.SkipPreprocessBelowBytes)) + if strings.TrimSpace(cfg.Detail) != "" { + setStringInMap(visionNode, "detail", cfg.Detail) + } + if len(cfg.AllowedRoots) > 0 { + setStringSliceInMap(visionNode, "allowed_roots", cfg.AllowedRoots) + } +} + func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) { root := doc.Content[0] openaiNode := ensureMap(root, "openai") diff --git a/internal/handler/openapi.go b/internal/handler/openapi.go index e3271a81..428248a5 100644 --- a/internal/handler/openapi.go +++ b/internal/handler/openapi.go @@ -778,11 +778,55 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { }, "ConfigResponse": map[string]interface{}{ "type": "object", - "description": "配置信息", + "description": "配置信息(含 openai、vision、multi_agent 等)", + "properties": map[string]interface{}{ + "vision": map[string]interface{}{ + "$ref": "#/components/schemas/VisionConfig", + }, + }, }, "UpdateConfigRequest": map[string]interface{}{ "type": "object", "description": "更新配置请求", + "properties": map[string]interface{}{ + "vision": map[string]interface{}{ + "$ref": "#/components/schemas/VisionConfig", + }, + }, + }, + "VisionConfig": map[string]interface{}{ + "type": "object", + "description": "视觉分析(analyze_image MCP 工具);enabled 且 model 非空时注册工具", + "properties": map[string]interface{}{ + "enabled": map[string]interface{}{"type": "boolean", "description": "是否启用 analyze_image"}, + "model": map[string]interface{}{"type": "string", "description": "视觉模型名(必填)", "example": "qwen-vl-max"}, + "api_key": map[string]interface{}{"type": "string", "description": "API Key;留空复用 openai.api_key"}, + "base_url": map[string]interface{}{"type": "string", "description": "Base URL;留空复用 openai.base_url"}, + "provider": map[string]interface{}{"type": "string", "description": "提供商;留空复用 openai.provider"}, + "timeout_seconds": map[string]interface{}{"type": "integer", "description": "VL 调用超时(秒)"}, + "max_image_bytes": map[string]interface{}{"type": "integer", "description": "原始文件大小上限(字节)"}, + "max_dimension": map[string]interface{}{"type": "integer", "description": "长边缩放像素"}, + "jpeg_quality": map[string]interface{}{"type": "integer", "description": "JPEG 质量 60-100"}, + "max_payload_bytes": map[string]interface{}{"type": "integer", "description": "送 API 体积上限(字节)"}, + "skip_preprocess_below_bytes": map[string]interface{}{"type": "integer", "description": "低于该字节且尺寸合规时可原图直传;0=始终压缩"}, + "detail": map[string]interface{}{"type": "string", "enum": []string{"low", "high", "auto"}, "description": "OpenAI 兼容 image detail"}, + "allowed_roots": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "额外允许读取的绝对路径根"}, + }, + }, + "AnalyzeImageToolCall": map[string]interface{}{ + "type": "object", + "description": "内置 MCP 工具 analyze_image:分析服务器本地图片,返回纯文本(验证码/UI/报错等)", + "properties": map[string]interface{}{ + "path": map[string]interface{}{ + "type": "string", + "description": "图片路径(cwd、chat_uploads、result_storage_dir 或 allowed_roots 下)", + }, + "question": map[string]interface{}{ + "type": "string", + "description": "可选:重点问题;验证码建议「只输出验证码字符」", + }, + }, + "required": []string{"path"}, }, "ExternalMCPConfig": map[string]interface{}{ "type": "object", @@ -4900,6 +4944,52 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) { }, // ==================== 配置管理 - 缺失端点 ==================== + "/api/config/test-vision": map[string]interface{}{ + "post": map[string]interface{}{ + "tags": []string{"配置管理"}, + "summary": "测试视觉模型连接", + "description": "测试 Vision 模型 API 是否可用。vision.api_key/base_url 留空时可传 openai 段作回退。", + "operationId": "testVision", + "requestBody": map[string]interface{}{ + "required": true, + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "required": []string{"vision"}, + "properties": map[string]interface{}{ + "vision": map[string]interface{}{"$ref": "#/components/schemas/VisionConfig"}, + "openai": map[string]interface{}{ + "type": "object", + "description": "主 LLM 配置(vision 字段留空时用于 API Key/Base URL 回退)", + }, + }, + }, + }, + }, + }, + "responses": map[string]interface{}{ + "200": map[string]interface{}{ + "description": "测试结果", + "content": map[string]interface{}{ + "application/json": map[string]interface{}{ + "schema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "success": map[string]interface{}{"type": "boolean"}, + "error": map[string]interface{}{"type": "string"}, + "model": map[string]interface{}{"type": "string"}, + "latency_ms": map[string]interface{}{"type": "number"}, + }, + }, + }, + }, + }, + "400": map[string]interface{}{"description": "参数错误"}, + "401": map[string]interface{}{"description": "未授权"}, + }, + }, + }, "/api/config/test-openai": map[string]interface{}{ "post": map[string]interface{}{ "tags": []string{"配置管理"}, diff --git a/internal/project/vision_image_prompt.go b/internal/project/vision_image_prompt.go new file mode 100644 index 00000000..9cb960ac --- /dev/null +++ b/internal/project/vision_image_prompt.go @@ -0,0 +1,22 @@ +package project + +import "strings" + +// VisionImageAnalysisSection 单/多代理共用的图片分析提示(analyze_image;上下文仅保留文字摘要)。 +func VisionImageAnalysisSection() string { + var b strings.Builder + b.WriteString("## 图片分析\n\n") + b.WriteString("- 遇到图片文件(截图、验证码、登录页、报告配图)时,若存在工具 analyze_image,请传入服务器上的文件路径进行分析。\n") + b.WriteString("- 不要对二进制图片使用 read_file 指望理解内容;用户消息中「📎 xxx.png: /path」即为可传给 analyze_image 的路径。\n") + b.WriteString("- 验证码类:若已从页面或接口保存为本地图片(如 captcha.png),用 analyze_image,question 写明「只输出验证码字符」;识别失败则刷新验证码后重新保存再识;复杂滑块/行为验证码勿指望单次识图成功。\n") + b.WriteString("- 委派子代理时,若子任务含验证码/截图识读,在 task description 中写明图片路径与期望输出格式。\n") + return b.String() +} + +// AppendVisionImageAnalysisIfReady 仅在 vision.enabled 且 model 已配置时追加图片分析提示。 +func AppendVisionImageAnalysisIfReady(base string, visionReady bool) string { + if !visionReady { + return base + } + return AppendSystemPromptBlock(base, VisionImageAnalysisSection()) +} diff --git a/internal/vision/client.go b/internal/vision/client.go new file mode 100644 index 00000000..dbbe52b7 --- /dev/null +++ b/internal/vision/client.go @@ -0,0 +1,132 @@ +package vision + +import ( + "context" + "encoding/base64" + "fmt" + "net" + "net/http" + "strings" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/openai" + + einoopenai "github.com/cloudwego/eino-ext/components/model/openai" + "github.com/cloudwego/eino/schema" +) + +// Client 调用独立 Vision ChatModel(单次 Generate)。 +type Client struct { + cfg config.VisionConfig + mainOA config.OpenAIConfig +} + +// NewClient 构造视觉客户端。 +func NewClient(visionCfg config.VisionConfig, mainOpenAI config.OpenAIConfig) *Client { + return &Client{cfg: visionCfg, mainOA: mainOpenAI} +} + +// Analyze 将图片字节送入 VL 模型并返回文本描述。 +func (c *Client) Analyze(ctx context.Context, img ImagePayload, question string) (string, error) { + if len(img.Bytes) == 0 { + return "", fmt.Errorf("empty image payload") + } + mime := strings.TrimSpace(img.MIMEType) + if mime == "" { + mime = "image/jpeg" + } + oa := c.cfg.OpenAICfgEffective(c.mainOA) + if strings.TrimSpace(oa.APIKey) == "" { + return "", fmt.Errorf("vision API key is empty (set vision.api_key or openai.api_key)") + } + if strings.TrimSpace(oa.Model) == "" { + return "", fmt.Errorf("vision model is empty") + } + + timeout := time.Duration(c.cfg.TimeoutSecondsEffective()) * time.Second + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + httpClient := &http.Client{ + Timeout: timeout + 15*time.Second, + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 60 * time.Second, + KeepAlive: 60 * time.Second, + }).DialContext, + ResponseHeaderTimeout: timeout + 10*time.Second, + }, + } + httpClient = openai.NewEinoHTTPClient(&oa, httpClient) + + modelCfg := &einoopenai.ChatModelConfig{ + APIKey: oa.APIKey, + BaseURL: strings.TrimSuffix(oa.BaseURL, "/"), + Model: oa.Model, + HTTPClient: httpClient, + } + chatModel, err := einoopenai.NewChatModel(ctx, modelCfg) + if err != nil { + return "", fmt.Errorf("vision chat model: %w", err) + } + + b64 := base64.StdEncoding.EncodeToString(img.Bytes) + detail := schema.ImageURLDetailLow + switch c.cfg.DetailEffective() { + case "high": + detail = schema.ImageURLDetailHigh + case "auto": + detail = schema.ImageURLDetailAuto + } + + prompt := buildVisionPrompt(question) + userMsg := &schema.Message{ + Role: schema.User, + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeText, Text: prompt}, + { + Type: schema.ChatMessagePartTypeImageURL, + Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + Base64Data: &b64, + MIMEType: mime, + }, + Detail: detail, + }, + }, + }, + } + + resp, err := chatModel.Generate(ctx, []*schema.Message{userMsg}) + if err != nil { + return "", fmt.Errorf("vision generate: %w", err) + } + if resp == nil || strings.TrimSpace(resp.Content) == "" { + return "", fmt.Errorf("vision model returned empty content") + } + return strings.TrimSpace(resp.Content), nil +} + +func buildVisionPrompt(question string) string { + q := strings.TrimSpace(question) + if q == "" { + q = "请对图片做通用描述,侧重授权安全测试场景(可见文本、表单、按钮、验证码、错误信息、技术栈线索)。" + } + extra := "" + if looksLikeCaptchaQuestion(q) { + extra = "\n若为验证码:仅输出你辨认出的字符序列,不要空格、标点、解释;看不清则明确说无法识别。" + } + return `你是授权安全测试助手。请根据图片回答用户问题,只描述你能从图中确认的内容,不要编造。 +用户问题:` + q + extra +} + +func looksLikeCaptchaQuestion(q string) bool { + s := strings.ToLower(q) + for _, kw := range []string{"验证码", "captcha", "verification code", "verify code", "vcode", "图形码"} { + if strings.Contains(s, kw) { + return true + } + } + return strings.Contains(s, "只输出") && (strings.Contains(s, "字符") || strings.Contains(s, "character")) +} diff --git a/internal/vision/client_test.go b/internal/vision/client_test.go new file mode 100644 index 00000000..101aa943 --- /dev/null +++ b/internal/vision/client_test.go @@ -0,0 +1,12 @@ +package vision + +import "testing" + +func TestLooksLikeCaptchaQuestion(t *testing.T) { + if !looksLikeCaptchaQuestion("识别验证码,只输出字符") { + t.Fatal("expected captcha hint") + } + if looksLikeCaptchaQuestion("描述登录页布局") { + t.Fatal("expected non-captcha") + } +} diff --git a/internal/vision/path.go b/internal/vision/path.go new file mode 100644 index 00000000..439921ac --- /dev/null +++ b/internal/vision/path.go @@ -0,0 +1,142 @@ +package vision + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +const chatUploadsDirName = "chat_uploads" + +var allowedImageExt = map[string]struct{}{ + ".png": {}, ".jpg": {}, ".jpeg": {}, ".webp": {}, ".gif": {}, + ".bmp": {}, ".tif": {}, ".tiff": {}, +} + +// PathOptions 图片路径白名单根目录。 +type PathOptions struct { + CWD string + ResultStorageDir string // 相对 CWD,如 tmp + ExtraRoots []string // vision.allowed_roots 绝对路径 +} + +// ResolveImagePath 解析并校验可读图片路径(防穿越、symlink 逃逸)。 +func ResolveImagePath(path string, opt PathOptions) (string, error) { + p := strings.TrimSpace(path) + if p == "" { + return "", fmt.Errorf("path is empty") + } + cwd := strings.TrimSpace(opt.CWD) + if cwd == "" { + var err error + cwd, err = os.Getwd() + if err != nil { + return "", fmt.Errorf("getwd: %w", err) + } + } + cwdAbs, err := filepath.Abs(filepath.Clean(cwd)) + if err != nil { + return "", err + } + + var candidate string + if filepath.IsAbs(p) { + candidate = filepath.Clean(p) + } else { + candidate = filepath.Clean(filepath.Join(cwdAbs, p)) + } + candidate = normalizeAbsPath(candidate) + if candidate == "" { + return "", fmt.Errorf("invalid path") + } + + ext := strings.ToLower(filepath.Ext(candidate)) + if _, ok := allowedImageExt[ext]; !ok { + return "", fmt.Errorf("unsupported image extension %q", ext) + } + + roots := buildAllowedRoots(cwdAbs, opt) + resolved, err := evalUnderAllowedRoots(candidate, roots) + if err != nil { + return "", err + } + + st, err := os.Stat(resolved) + if err != nil { + return "", fmt.Errorf("stat: %w", err) + } + if st.IsDir() { + return "", fmt.Errorf("not a regular file") + } + if st.Size() > 0 && st.Size() > 1<<30 { + return "", fmt.Errorf("file too large on disk") + } + return resolved, nil +} + +func normalizeAbsPath(p string) string { + abs, err := filepath.Abs(filepath.Clean(p)) + if err != nil { + return "" + } + if link, err := filepath.EvalSymlinks(abs); err == nil { + return link + } + return abs +} + +func buildAllowedRoots(cwdAbs string, opt PathOptions) []string { + seen := make(map[string]struct{}) + var roots []string + add := func(r string) { + r = strings.TrimSpace(r) + if r == "" { + return + } + abs := normalizeAbsPath(r) + if abs == "" { + return + } + if _, ok := seen[abs]; ok { + return + } + seen[abs] = struct{}{} + roots = append(roots, abs) + } + add(cwdAbs) + add(filepath.Join(cwdAbs, chatUploadsDirName)) + rs := strings.TrimSpace(opt.ResultStorageDir) + if rs == "" { + rs = "tmp" + } + if filepath.IsAbs(rs) { + add(rs) + } else { + add(filepath.Join(cwdAbs, rs)) + } + for _, r := range opt.ExtraRoots { + add(r) + } + return roots +} + +func evalUnderAllowedRoots(candidate string, roots []string) (string, error) { + check := normalizeAbsPath(candidate) + for _, root := range roots { + if isUnderRoot(check, root) { + return candidate, nil + } + } + return "", fmt.Errorf("path %q is outside allowed directories", candidate) +} + +func isUnderRoot(path, root string) bool { + path = filepath.Clean(path) + root = filepath.Clean(root) + if path == root { + return true + } + sep := string(filepath.Separator) + return strings.HasPrefix(path, root+sep) +} diff --git a/internal/vision/path_test.go b/internal/vision/path_test.go new file mode 100644 index 00000000..390112ef --- /dev/null +++ b/internal/vision/path_test.go @@ -0,0 +1,43 @@ +package vision + +import ( + "os" + "path/filepath" + "testing" +) + +func TestResolveImagePath_underCWD(t *testing.T) { + dir := t.TempDir() + img := filepath.Join(dir, "shot.png") + if err := os.WriteFile(img, []byte{0x89, 0x50, 0x4e, 0x47}, 0o644); err != nil { + t.Fatal(err) + } + got, err := ResolveImagePath(img, PathOptions{CWD: dir, ResultStorageDir: "tmp"}) + if err != nil { + t.Fatal(err) + } + want := normalizeAbsPath(img) + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} + +func TestResolveImagePath_rejectsTraversal(t *testing.T) { + dir := t.TempDir() + _, err := ResolveImagePath("../../../etc/passwd", PathOptions{CWD: dir}) + if err == nil { + t.Fatal("expected error for path outside roots") + } +} + +func TestResolveImagePath_rejectsNonImageExt(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "notes.txt") + if err := os.WriteFile(f, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + _, err := ResolveImagePath(f, PathOptions{CWD: dir}) + if err == nil { + t.Fatal("expected error for non-image extension") + } +} diff --git a/internal/vision/preprocess.go b/internal/vision/preprocess.go new file mode 100644 index 00000000..860dab63 --- /dev/null +++ b/internal/vision/preprocess.go @@ -0,0 +1,212 @@ +package vision + +import ( + "bytes" + "fmt" + "image" + "os" + "strings" + + "github.com/disintegration/imaging" +) + +// ImagePayload 送入 VL API 的图片字节与 MIME。 +type ImagePayload struct { + Bytes []byte + MIMEType string +} + +// PreprocessMeta 记录缩放与编码结果,供工具输出与排障。 +type PreprocessMeta struct { + OriginalPath string + OriginalBytes int64 + OriginalWidth int + OriginalHeight int + OutputWidth int + OutputHeight int + OutputBytes int + OutputMIMEType string + JPEGQuality int // 0 表示未 JPEG 重编码(原图直传) + PreprocessMode string // passthrough | jpeg +} + +// PreprocessOptions 图片预处理参数。 +type PreprocessOptions struct { + MaxImageBytes int64 + MaxDimension int + JPEGQuality int + MaxPayloadBytes int64 + SkipPreprocessBelowBytes int64 // 0 = 始终压缩;>0 时小图+尺寸合规可直传 +} + +// PreprocessImageFile 读取图片;大图或超尺寸走 imaging 缩放+JPEG,否则可原图直传。 +func PreprocessImageFile(path string, opt PreprocessOptions) (ImagePayload, PreprocessMeta, error) { + var meta PreprocessMeta + meta.OriginalPath = path + + st, err := os.Stat(path) + if err != nil { + return ImagePayload{}, meta, err + } + meta.OriginalBytes = st.Size() + if opt.MaxImageBytes > 0 && st.Size() > opt.MaxImageBytes { + return ImagePayload{}, meta, fmt.Errorf("file size %d exceeds max_image_bytes %d", st.Size(), opt.MaxImageBytes) + } + + cfgW, cfgH, format, err := imageDimensions(path) + if err != nil { + return ImagePayload{}, meta, err + } + meta.OriginalWidth = cfgW + meta.OriginalHeight = cfgH + + maxDim := opt.MaxDimension + if maxDim <= 0 { + maxDim = 2048 + } + maxPayload := opt.MaxPayloadBytes + if maxPayload <= 0 { + maxPayload = 512 * 1024 + } + + if payload, meta, ok, err := tryPassthrough(path, st.Size(), cfgW, cfgH, format, opt, maxDim, maxPayload); ok { + return payload, meta, err + } + + return compressWithImaging(path, opt, maxDim, maxPayload, meta) +} + +func tryPassthrough(path string, size int64, w, h int, format string, opt PreprocessOptions, maxDim int, maxPayload int64) (ImagePayload, PreprocessMeta, bool, error) { + var meta PreprocessMeta + meta.OriginalPath = path + meta.OriginalBytes = size + meta.OriginalWidth = w + meta.OriginalHeight = h + + threshold := opt.SkipPreprocessBelowBytes + if threshold <= 0 { + return ImagePayload{}, meta, false, nil + } + if size > threshold { + return ImagePayload{}, meta, false, nil + } + longEdge := w + if h > longEdge { + longEdge = h + } + if longEdge > maxDim { + return ImagePayload{}, meta, false, nil + } + if size > maxPayload { + return ImagePayload{}, meta, false, nil + } + + raw, err := os.ReadFile(path) + if err != nil { + return ImagePayload{}, meta, false, err + } + mime := mimeFromImageFormat(format) + if mime == "" { + return ImagePayload{}, meta, false, nil + } + + meta.OutputWidth = w + meta.OutputHeight = h + meta.OutputBytes = len(raw) + meta.OutputMIMEType = mime + meta.PreprocessMode = "passthrough" + return ImagePayload{Bytes: raw, MIMEType: mime}, meta, true, nil +} + +func compressWithImaging(path string, opt PreprocessOptions, maxDim int, maxPayload int64, meta PreprocessMeta) (ImagePayload, PreprocessMeta, error) { + src, err := imaging.Open(path) + if err != nil { + return ImagePayload{}, meta, fmt.Errorf("open image: %w", err) + } + bounds := src.Bounds() + meta.OriginalWidth = bounds.Dx() + meta.OriginalHeight = bounds.Dy() + + dst := imaging.Fit(src, maxDim, maxDim, imaging.Lanczos) + outBounds := dst.Bounds() + meta.OutputWidth = outBounds.Dx() + meta.OutputHeight = outBounds.Dy() + + quality := opt.JPEGQuality + if quality <= 0 || quality > 100 { + quality = 82 + } + + dim := maxDim + for attempt := 0; attempt < 6; attempt++ { + if attempt > 0 { + dim = int(float64(dim) * 0.85) + if dim < 256 { + dim = 256 + } + dst = imaging.Fit(src, dim, dim, imaging.Lanczos) + outBounds = dst.Bounds() + meta.OutputWidth = outBounds.Dx() + meta.OutputHeight = outBounds.Dy() + } + q := quality + for q >= 60 { + var buf bytes.Buffer + if err := imaging.Encode(&buf, dst, imaging.JPEG, imaging.JPEGQuality(q)); err != nil { + return ImagePayload{}, meta, fmt.Errorf("encode jpeg: %w", err) + } + if int64(buf.Len()) <= maxPayload { + meta.JPEGQuality = q + meta.OutputBytes = buf.Len() + meta.OutputMIMEType = "image/jpeg" + meta.PreprocessMode = "jpeg" + return ImagePayload{Bytes: buf.Bytes(), MIMEType: "image/jpeg"}, meta, nil + } + q -= 5 + } + quality = 75 + } + return ImagePayload{}, meta, fmt.Errorf("could not compress image under max_payload_bytes %d", maxPayload) +} + +func imageDimensions(path string) (w, h int, format string, err error) { + f, err := os.Open(path) + if err != nil { + return 0, 0, "", err + } + defer f.Close() + cfg, format, err := image.DecodeConfig(f) + if err != nil { + return 0, 0, "", fmt.Errorf("decode image config: %w", err) + } + return cfg.Width, cfg.Height, format, nil +} + +func mimeFromImageFormat(format string) string { + switch strings.ToLower(strings.TrimSpace(format)) { + case "jpeg", "jpg": + return "image/jpeg" + case "png": + return "image/png" + case "gif": + return "image/gif" + case "webp": + return "image/webp" + case "bmp": + return "image/bmp" + case "tiff": + return "image/tiff" + default: + return "" + } +} + +// DecodeImageConfig 用于测试:确认文件可被解码。 +func DecodeImageConfig(path string) (image.Config, string, error) { + f, err := os.Open(path) + if err != nil { + return image.Config{}, "", err + } + defer f.Close() + return image.DecodeConfig(f) +} diff --git a/internal/vision/preprocess_test.go b/internal/vision/preprocess_test.go new file mode 100644 index 00000000..a9b9e068 --- /dev/null +++ b/internal/vision/preprocess_test.go @@ -0,0 +1,109 @@ +package vision + +import ( + "image" + "image/color" + "image/png" + "os" + "path/filepath" + "testing" + + "github.com/disintegration/imaging" +) + +func TestPreprocessImageFile_scalesAndLimitsPayload(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "big.png") + img := imaging.New(3000, 2000, color.White) + if err := imaging.Save(img, path); err != nil { + t.Fatal(err) + } + + out, meta, err := PreprocessImageFile(path, PreprocessOptions{ + MaxImageBytes: 10 * 1024 * 1024, + MaxDimension: 1024, + JPEGQuality: 85, + MaxPayloadBytes: 600 * 1024, + SkipPreprocessBelowBytes: 0, + }) + if err != nil { + t.Fatal(err) + } + if len(out.Bytes) == 0 { + t.Fatal("empty output") + } + if meta.PreprocessMode != "jpeg" { + t.Fatalf("mode: %s", meta.PreprocessMode) + } + if meta.OutputWidth > 1024 || meta.OutputHeight > 1024 { + t.Fatalf("expected fit within 1024, got %dx%d", meta.OutputWidth, meta.OutputHeight) + } + if int64(len(out.Bytes)) > 600*1024 { + t.Fatalf("payload %d exceeds max", len(out.Bytes)) + } +} + +func TestPreprocessImageFile_passthroughSmallPNG(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "small.png") + if err := imaging.Save(imaging.New(400, 300, color.White), path); err != nil { + t.Fatal(err) + } + + out, meta, err := PreprocessImageFile(path, PreprocessOptions{ + MaxImageBytes: 5 * 1024 * 1024, + MaxDimension: 2048, + MaxPayloadBytes: 512 * 1024, + SkipPreprocessBelowBytes: 2 * 1024 * 1024, + }) + if err != nil { + t.Fatal(err) + } + if meta.PreprocessMode != "passthrough" { + t.Fatalf("expected passthrough, got %s", meta.PreprocessMode) + } + if out.MIMEType != "image/png" { + t.Fatalf("mime: %s", out.MIMEType) + } + if meta.OutputWidth != 400 || meta.OutputHeight != 300 { + t.Fatalf("dims: %dx%d", meta.OutputWidth, meta.OutputHeight) + } +} + +func TestPreprocessImageFile_passthroughDisabled(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "small.png") + if err := imaging.Save(imaging.New(100, 100, color.White), path); err != nil { + t.Fatal(err) + } + + _, meta, err := PreprocessImageFile(path, PreprocessOptions{ + MaxDimension: 2048, + MaxPayloadBytes: 512 * 1024, + SkipPreprocessBelowBytes: 0, + }) + if err != nil { + t.Fatal(err) + } + if meta.PreprocessMode != "jpeg" { + t.Fatalf("expected jpeg compress, got %s", meta.PreprocessMode) + } +} + +func TestPreprocessImageFile_rejectsOversizeFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "tiny.png") + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + if err := png.Encode(f, image.NewRGBA(image.Rect(0, 0, 2, 2))); err != nil { + t.Fatal(err) + } + f.Close() + + _, _, err = PreprocessImageFile(path, PreprocessOptions{MaxImageBytes: 1}) + if err == nil { + t.Fatal("expected error when file exceeds max_image_bytes") + } +} diff --git a/internal/vision/tool.go b/internal/vision/tool.go new file mode 100644 index 00000000..ad5780ea --- /dev/null +++ b/internal/vision/tool.go @@ -0,0 +1,130 @@ +package vision + +import ( + "context" + "fmt" + "os" + "strings" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/mcp/builtin" + + "go.uber.org/zap" +) + +// RegisterAnalyzeImageTool 在 vision.enabled 且 model 已配置时注册 MCP 工具 analyze_image。 +func RegisterAnalyzeImageTool(mcpServer *mcp.Server, cfg *config.Config, logger *zap.Logger) { + if mcpServer == nil || cfg == nil { + return + } + if !cfg.Vision.Ready() { + if cfg.Vision.Enabled && logger != nil { + logger.Warn("vision.enabled 但 vision.model 为空,跳过注册 analyze_image") + } + return + } + + cwd, err := os.Getwd() + if err != nil { + if logger != nil { + logger.Warn("vision: getwd failed, skip analyze_image", zap.Error(err)) + } + return + } + + pathOpt := PathOptions{ + CWD: cwd, + ResultStorageDir: cfg.Agent.ResultStorageDir, + ExtraRoots: cfg.Vision.AllowedRoots, + } + preOpt := PreprocessOptions{ + MaxImageBytes: cfg.Vision.MaxImageBytesEffective(), + MaxDimension: cfg.Vision.MaxDimensionEffective(), + JPEGQuality: cfg.Vision.JPEGQualityEffective(), + MaxPayloadBytes: cfg.Vision.MaxPayloadBytesEffective(), + SkipPreprocessBelowBytes: cfg.Vision.SkipPreprocessBelowBytesEffective(), + } + client := NewClient(cfg.Vision, cfg.OpenAI) + + tool := mcp.Tool{ + Name: builtin.ToolAnalyzeImage, + Description: "分析服务器上的本地图片并返回文字描述(验证码、UI 元素、报错、架构图要点等)。" + + "输入为文件路径(如用户上传的 chat_uploads 路径或工具截图路径)。" + + "输出仅为文本,不含图片数据。不要对二进制图片使用 read_file 指望理解内容。", + ShortDescription: "分析本地图片并返回文字描述(验证码/UI/报错等)", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "path": map[string]interface{}{ + "type": "string", + "description": "图片绝对路径或相对于进程工作目录的路径", + }, + "question": map[string]interface{}{ + "type": "string", + "description": "可选:希望模型重点回答的问题。验证码图建议:只输出验证码字符,不要空格和解释", + }, + }, + "required": []string{"path"}, + }, + } + + handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + path, _ := args["path"].(string) + question, _ := args["question"].(string) + + abs, err := ResolveImagePath(path, pathOpt) + if err != nil { + return textResult(fmt.Sprintf("路径校验失败: %v", err), true), nil + } + + img, meta, err := PreprocessImageFile(abs, preOpt) + if err != nil { + return textResult(fmt.Sprintf("图片预处理失败: %v", err), true), nil + } + + summary, err := client.Analyze(ctx, img, question) + if err != nil { + return textResult(fmt.Sprintf("视觉模型调用失败: %v", err), true), nil + } + + body := formatAnalysisResult(abs, meta, summary) + return textResult(body, false), nil + } + + mcpServer.RegisterTool(tool, handler) + if logger != nil { + logger.Info("vision: analyze_image 工具已注册", zap.String("model", cfg.Vision.Model)) + } +} + +func textResult(text string, isError bool) *mcp.ToolResult { + return &mcp.ToolResult{ + Content: []mcp.Content{{Type: "text", Text: text}}, + IsError: isError, + } +} + +func formatAnalysisResult(path string, meta PreprocessMeta, summary string) string { + var b strings.Builder + b.WriteString("## Image analysis\n") + b.WriteString("- **path**: ") + b.WriteString(path) + b.WriteString("\n") + switch meta.PreprocessMode { + case "passthrough": + b.WriteString(fmt.Sprintf("- **preprocess**: passthrough %dx%d, %s, %dKB (original %dKB)\n\n", + meta.OutputWidth, meta.OutputHeight, meta.OutputMIMEType, + (meta.OutputBytes+1023)/1024, (meta.OriginalBytes+1023)/1024)) + default: + b.WriteString(fmt.Sprintf("- **preprocess**: %dx%d → %dx%d, jpeg q=%d, %dKB (original %dKB)\n\n", + meta.OriginalWidth, meta.OriginalHeight, + meta.OutputWidth, meta.OutputHeight, + meta.JPEGQuality, (meta.OutputBytes+1023)/1024, + (meta.OriginalBytes+1023)/1024)) + } + b.WriteString("### Summary\n") + b.WriteString(strings.TrimSpace(summary)) + b.WriteString("\n") + return b.String() +}