From add33e1cf72209938793f6a793dbffa35825e0bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AC=E6=98=8E?= <83812544+Ed1s0nZ@users.noreply.github.com> Date: Sat, 8 Nov 2025 18:56:23 +0800 Subject: [PATCH] Add files via upload --- README.md | 319 ++++++++++++ cmd/server/main.go | 36 ++ config.yaml | 174 +++++++ data/conversations.db | Bin 0 -> 4096 bytes data/conversations.db-shm | Bin 0 -> 32768 bytes data/conversations.db-wal | Bin 0 -> 41232 bytes go.mod | 38 ++ go.sum | 96 ++++ internal/agent/agent.go | 576 +++++++++++++++++++++ internal/app/app.go | 163 ++++++ internal/config/config.go | 114 +++++ internal/database/conversation.go | 256 ++++++++++ internal/database/database.go | 90 ++++ internal/handler/agent.go | 134 +++++ internal/handler/conversation.go | 102 ++++ internal/handler/monitor.go | 92 ++++ internal/logger/logger.go | 60 +++ internal/mcp/server.go | 798 ++++++++++++++++++++++++++++++ internal/mcp/types.go | 232 +++++++++ internal/security/executor.go | 730 +++++++++++++++++++++++++++ run.sh | 35 ++ web/static/css/style.css | 691 ++++++++++++++++++++++++++ web/static/js/app.js | 400 +++++++++++++++ web/templates/index.html | 92 ++++ 24 files changed, 5228 insertions(+) create mode 100644 README.md create mode 100644 cmd/server/main.go create mode 100644 config.yaml create mode 100644 data/conversations.db create mode 100644 data/conversations.db-shm create mode 100644 data/conversations.db-wal create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/agent/agent.go create mode 100644 internal/app/app.go create mode 100644 internal/config/config.go create mode 100644 internal/database/conversation.go create mode 100644 internal/database/database.go create mode 100644 internal/handler/agent.go create mode 100644 internal/handler/conversation.go create mode 100644 internal/handler/monitor.go create mode 100644 internal/logger/logger.go create mode 100644 internal/mcp/server.go create mode 100644 internal/mcp/types.go create mode 100644 internal/security/executor.go create mode 100644 run.sh create mode 100644 web/static/css/style.css create mode 100644 web/static/js/app.js create mode 100644 web/templates/index.html diff --git a/README.md b/README.md new file mode 100644 index 00000000..479312ce --- /dev/null +++ b/README.md @@ -0,0 +1,319 @@ +# CyberStrikeAI + +基于Golang和Gin框架的AI驱动自主渗透测试平台,使用MCP协议集成安全工具。 + +## 功能特性 + +- 🤖 **AI代理连接** - 支持Claude、GPT等兼容MCP的AI代理通过FastMCP协议连接 +- 🧠 **智能分析** - 决策引擎分析目标并选择最佳测试策略 +- ⚡ **自主执行** - AI代理执行全面的安全评估 +- 🔄 **实时适应** - 系统根据结果和发现的漏洞进行调整 +- 📊 **高级报告** - 可视化方式输出漏洞卡片和风险分析 +- 💬 **对话式交互** - 前端以对话形式调用后端agent-loop +- 📈 **实时监控** - 监控安全工具的执行状态、结果、调用次数等 + +## 项目结构 + +``` +CyberStrikeAI/ +├── cmd/ +│ └── server/ +│ └── main.go # 程序入口 +├── internal/ +│ ├── agent/ # AI代理模块 +│ ├── app/ # 应用初始化 +│ ├── config/ # 配置管理 +│ ├── handler/ # HTTP处理器 +│ ├── logger/ # 日志系统 +│ ├── mcp/ # MCP协议实现 +│ └── security/ # 安全工具执行器 +├── web/ +│ ├── static/ # 静态资源 +│ │ ├── css/ +│ │ └── js/ +│ └── templates/ # HTML模板 +├── config.yaml # 配置文件 +├── go.mod # Go模块文件 +└── README.md # 说明文档 +``` + +## 快速开始 + +### 前置要求 + +- Go 1.21 或更高版本 +- OpenAI API Key(或其他兼容OpenAI协议的API) +- 安全工具(可选):nmap, sqlmap, nikto, dirb + +### 安装步骤 + +1. **克隆项目** +```bash +cd /Users/temp/Desktop/wenjian/tools/CyberStrikeAI +``` + +2. **安装依赖** +```bash +go mod download +``` + +3. **配置** +编辑 `config.yaml` 文件,设置您的OpenAI API Key: +```yaml +openai: + api_key: "sk-your-api-key-here" + base_url: "https://api.openai.com/v1" + model: "gpt-4" +``` + +4. **启动服务器** + +#### 方式一:使用启动脚本 +```bash +./run.sh +``` + +#### 方式二:直接运行 +```bash +go run cmd/server/main.go +``` + +#### 方式三:编译后运行 +```bash +go build -o cyberstrike-ai cmd/server/main.go +./cyberstrike-ai +``` + +5. **访问应用** +打开浏览器访问:http://localhost:8080 + +## 配置说明 + +### 服务器配置 +```yaml +server: + host: "0.0.0.0" + port: 8080 +``` + +### MCP配置 +```yaml +mcp: + enabled: true + host: "0.0.0.0" + port: 8081 +``` + +### 安全工具配置 +```yaml +security: + tools: + - name: "nmap" + command: "nmap" + args: ["-sV", "-sC"] + description: "网络扫描工具" + enabled: true +``` + +## 使用示例 + +### 对话式渗透测试 + +在"对话测试"标签页中,您可以: + +1. **网络扫描** + ``` + 扫描 192.168.1.1 的开放端口 + ``` + +2. **SQL注入检测** + ``` + 检测 https://example.com 的SQL注入漏洞 + ``` + +3. **Web漏洞扫描** + ``` + 扫描 https://example.com 的Web服务器漏洞 + ``` + +4. **目录扫描** + ``` + 扫描 https://example.com 的隐藏目录 + ``` + +### 监控工具执行 + +在"工具监控"标签页中,您可以: + +- 查看所有工具的执行统计 +- 查看详细的执行记录 +- 查看发现的漏洞列表 +- 实时监控工具状态 + +## API接口 + +### Agent Loop API + +**POST** `/api/agent-loop` + +请求体: +```json +{ + "message": "扫描 192.168.1.1" +} +``` + +使用示例: +```bash +curl -X POST http://localhost:8080/api/agent-loop \ + -H "Content-Type: application/json" \ + -d '{"message": "扫描 192.168.1.1"}' +``` + +### 监控API + +- **GET** `/api/monitor` - 获取所有监控信息 +- **GET** `/api/monitor/execution/:id` - 获取特定执行记录 +- **GET** `/api/monitor/stats` - 获取统计信息 +- **GET** `/api/monitor/vulnerabilities` - 获取漏洞列表 + +使用示例: +```bash +# 获取所有监控信息 +curl http://localhost:8080/api/monitor + +# 获取统计信息 +curl http://localhost:8080/api/monitor/stats + +# 获取漏洞列表 +curl http://localhost:8080/api/monitor/vulnerabilities +``` + +### MCP接口 + +**POST** `/api/mcp` - MCP协议端点 + +## MCP协议 + +本项目实现了MCP(Model Context Protocol)协议,支持: + +- `initialize` - 初始化连接 +- `tools/list` - 列出可用工具 +- `tools/call` - 调用工具 + +工具调用是异步执行的,系统会跟踪每个工具的执行状态和结果。 + +### MCP协议使用示例 + +#### 初始化连接 + +```bash +curl -X POST http://localhost:8080/api/mcp \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": "1", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0.0" + } + } + }' +``` + +#### 列出工具 + +```bash +curl -X POST http://localhost:8080/api/mcp \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": "2", + "method": "tools/list" + }' +``` + +#### 调用工具 + +```bash +curl -X POST http://localhost:8080/api/mcp \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc": "2.0", + "id": "3", + "method": "tools/call", + "params": { + "name": "nmap", + "arguments": { + "target": "192.168.1.1", + "ports": "1-1000" + } + } + }' +``` + +## 安全工具支持 + +当前支持的安全工具: +- **nmap** - 网络扫描 +- **sqlmap** - SQL注入检测 +- **nikto** - Web服务器扫描 +- **dirb** - 目录扫描 + +可以通过修改 `config.yaml` 添加更多工具。 + +## 故障排除 + +### 问题:无法连接到OpenAI API + +- 检查API Key是否正确 +- 检查网络连接 +- 检查base_url配置 + +### 问题:工具执行失败 + +- 确保已安装相应的安全工具(nmap, sqlmap等) +- 检查工具是否在PATH中 +- 某些工具可能需要root权限 + +### 问题:前端无法加载 + +- 检查服务器是否正常运行 +- 检查端口8080是否被占用 +- 查看浏览器控制台错误信息 + +## 安全注意事项 + +⚠️ **重要提示**: + +- 仅对您拥有或已获得授权的系统进行测试 +- 遵守相关法律法规 +- 建议在隔离的测试环境中使用 +- 不要在生产环境中使用 +- 某些安全工具可能需要root权限 + +## 开发 + +### 添加新工具 + +1. 在 `config.yaml` 中添加工具配置 +2. 在 `internal/security/executor.go` 的 `buildCommandArgs` 方法中添加参数构建逻辑 +3. 在 `internal/agent/agent.go` 的 `getAvailableTools` 方法中添加工具定义 + +### 构建 + +```bash +go build -o cyberstrike-ai cmd/server/main.go +``` + +## 许可证 + +本项目仅供学习和研究使用。 + +## 贡献 + +欢迎提交Issue和Pull Request! diff --git a/cmd/server/main.go b/cmd/server/main.go new file mode 100644 index 00000000..cb4292bd --- /dev/null +++ b/cmd/server/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "cyberstrike-ai/internal/app" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/logger" + "flag" + "fmt" +) + +func main() { + var configPath = flag.String("config", "config.yaml", "配置文件路径") + flag.Parse() + + // 加载配置 + cfg, err := config.Load(*configPath) + if err != nil { + fmt.Printf("加载配置失败: %v\n", err) + return + } + + // 初始化日志 + log := logger.New(cfg.Log.Level, cfg.Log.Output) + + // 创建应用 + application, err := app.New(cfg, log) + if err != nil { + log.Fatal("应用初始化失败", "error", err) + } + + // 启动服务器 + if err := application.Run(); err != nil { + log.Fatal("服务器启动失败", "error", err) + } +} + diff --git a/config.yaml b/config.yaml new file mode 100644 index 00000000..fa2672db --- /dev/null +++ b/config.yaml @@ -0,0 +1,174 @@ +server: + host: "0.0.0.0" + port: 8080 + +log: + level: "info" + output: "stdout" + +mcp: + enabled: true + host: "0.0.0.0" + port: 8081 + +openai: + api_key: "sk-xxx" # 请设置您的OpenAI API Key + base_url: "https://api.deepseek.com/v1" + model: "deepseek-chat" + +database: + path: "data/conversations.db" + +security: + tools: + # 示例1: 使用参数定义的工具(推荐方式) + - name: "nmap" + command: "nmap" + args: ["-sT", "-sV", "-sC"] # 固定参数 + description: "网络扫描工具,用于发现网络主机和服务" + enabled: true + parameters: + - name: "target" + type: "string" + description: "目标IP地址或域名" + required: true + position: 0 # 位置参数,放在最后 + format: "positional" + - name: "ports" + type: "string" + description: "端口范围,例如: 1-1000, 80,443,8080" + required: false + flag: "-p" + format: "flag" + + # 示例2: 标志参数工具 + - name: "sqlmap" + command: "sqlmap" + description: "SQL注入检测和利用工具" + enabled: true + parameters: + - name: "url" + type: "string" + description: "目标URL,例如: http://example.com/page?id=1" + required: true + flag: "-u" + format: "flag" + - name: "batch" + type: "bool" + description: "非交互模式" + required: false + default: true + flag: "--batch" + format: "flag" + - name: "level" + type: "int" + description: "测试级别 (1-5)" + required: false + default: 3 + flag: "--level" + format: "combined" # --level=3 + + # 示例3: 位置参数工具 + - name: "nikto" + command: "nikto" + description: "Web服务器扫描工具" + enabled: true + parameters: + - name: "target" + type: "string" + description: "目标URL或IP地址" + required: true + flag: "-h" + format: "flag" + + # 示例4: 简单位置参数 + - name: "dirb" + command: "dirb" + description: "Web目录扫描工具" + enabled: true + parameters: + - name: "url" + type: "string" + description: "目标URL" + required: true + position: 0 + format: "positional" + - name: "wordlist" + type: "string" + description: "字典文件路径" + required: false + flag: "-w" + format: "flag" + + # 示例5: 执行系统命令 + - name: "exec" + command: "sh" + args: ["-c"] + description: "执行系统命令(谨慎使用)" + enabled: true + parameters: + - name: "command" + type: "string" + description: "要执行的系统命令" + required: true + position: 0 + format: "positional" + - name: "shell" + type: "string" + description: "使用的shell(可选,默认为sh)" + required: false + default: "sh" + - name: "workdir" + type: "string" + description: "工作目录" + required: false + + # 示例6: 自定义工具 - 使用模板格式 + - name: "custom_scanner" + command: "my-scanner" + description: "自定义扫描工具示例" + enabled: false # 默认禁用,需要时启用 + parameters: + - name: "target" + type: "string" + description: "扫描目标" + required: true + flag: "--target" + format: "flag" + - name: "mode" + type: "string" + description: "扫描模式" + required: false + default: "normal" + options: ["normal", "aggressive", "stealth"] # 枚举值 + flag: "--mode" + format: "combined" # --mode=normal + - name: "threads" + type: "int" + description: "线程数" + required: false + default: 10 + flag: "-t" + format: "flag" + - name: "output" + type: "string" + description: "输出文件路径" + required: false + flag: "-o" + format: "template" + template: "-o {value}" # 自定义模板 + - name: "verbose" + type: "bool" + description: "详细输出" + required: false + default: false + flag: "-v" + format: "flag" # 布尔值:如果为true,只添加-v,不添加值 + + # 示例7: 向后兼容 - 不定义parameters,使用旧的硬编码逻辑 + # - name: "legacy_tool" + # command: "legacy" + # args: ["--option"] + # description: "旧工具(使用硬编码逻辑)" + # enabled: false + diff --git a/data/conversations.db b/data/conversations.db new file mode 100644 index 0000000000000000000000000000000000000000..4ebf78cf96088c7148cb47caf804cc036de75a34 GIT binary patch literal 4096 zcmWFz^vNtqRY=P(%1ta$FlG>7U}9o$P*7lCU|@t|AVoG{WY8;GzzfnYK(-m98b?E5 nGz3ONU^E0qLtr!nMnhmU1V%$(Gz3ONU^E0qLtr!nC=3ArfDQ*E literal 0 HcmV?d00001 diff --git a/data/conversations.db-shm b/data/conversations.db-shm new file mode 100644 index 0000000000000000000000000000000000000000..8dfd5497315c2907118971c9b5253c20f10b838f GIT binary patch literal 32768 zcmeI)D{jL;6a~=9r}-h5WfL*&0tyQP0eqZ`&NWNodVf3Bn0Xd$ZjnmP^UmP$_arw1+t4#2-GQ%4W~kY009C72oNAZ sfB*pk1PBlyK!5-N0t5&UAV7cs0RjXF5FkK+009C72oNAZ;64R@0k_K}uK)l5 literal 0 HcmV?d00001 diff --git a/data/conversations.db-wal b/data/conversations.db-wal new file mode 100644 index 0000000000000000000000000000000000000000..9b8a63c9d735180741f40902ff5f232649f3b0b0 GIT binary patch literal 41232 zcmeI*L1^1n9LI6liS5P2b9?f}2zw1HcGe~(WDs_XYrnK6jsX+Vgm8}yyFVSO&(@;)cX}J{@8AI-Y}g?t=b4| zgmYz?)uqt0XNytA?9Dx;30;)FDHc`p zOx-pe+cMmtciug#UKp3czkl(n^~U}2uD&nB(bRZ4q*bbxz0_Ekb4T)qx#-lI+eYBQ z8H0gMHLm|)%T|r^UtKR)Ug&?Z_Ne+ap&BKZ|Zmm}z%{r->S&DBrh>bY`V;Ap<$JSzkcKmY**5I_I{1Q0*~0R&nCfjWZG)BoPD zcdtF(rH&xd>fWp?*moU4XJWNkN03|`#Rmu|rROlLHe>v>ba-lH)Fy|u{uTu0Cx(|dV=pDr&H*Y9@xM;$?%e?XcLKmY** z5I_I{1Q0*~0R#{@dIIX+f5f{kaO?c%BM%b)e?c9=(R-nJMhGB)00IagfB*srAbDW{&MBB z;oq!hyW|C8zQ3xn0{hMjM7#g=b8AgU#4}+KJo(0J1F$**`f@-M|PUw(Wy9XXK<|Brej-% z>4UO 0 { + toolCallsJSON := make([]map[string]interface{}, len(cm.ToolCalls)) + for i, tc := range cm.ToolCalls { + // 将arguments转换为JSON字符串 + argsJSON := "" + if tc.Function.Arguments != nil { + argsBytes, err := json.Marshal(tc.Function.Arguments) + if err != nil { + return nil, err + } + argsJSON = string(argsBytes) + } + + toolCallsJSON[i] = map[string]interface{}{ + "id": tc.ID, + "type": tc.Type, + "function": map[string]interface{}{ + "name": tc.Function.Name, + "arguments": argsJSON, + }, + } + } + aux["tool_calls"] = toolCallsJSON + } + + return json.Marshal(aux) +} + +// OpenAIRequest OpenAI API请求 +type OpenAIRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Tools []Tool `json:"tools,omitempty"` +} + +// OpenAIResponse OpenAI API响应 +type OpenAIResponse struct { + ID string `json:"id"` + Choices []Choice `json:"choices"` + Error *Error `json:"error,omitempty"` +} + +// Choice 选择 +type Choice struct { + Message MessageWithTools `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// MessageWithTools 带工具调用的消息 +type MessageWithTools struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +// Tool OpenAI工具定义 +type Tool struct { + Type string `json:"type"` + Function FunctionDefinition `json:"function"` +} + +// FunctionDefinition 函数定义 +type FunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} + +// Error OpenAI错误 +type Error struct { + Message string `json:"message"` + Type string `json:"type"` +} + +// ToolCall 工具调用 +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function FunctionCall `json:"function"` +} + +// FunctionCall 函数调用 +type FunctionCall struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// UnmarshalJSON 自定义JSON解析,处理arguments可能是字符串或对象的情况 +func (fc *FunctionCall) UnmarshalJSON(data []byte) error { + type Alias FunctionCall + aux := &struct { + Name string `json:"name"` + Arguments interface{} `json:"arguments"` + *Alias + }{ + Alias: (*Alias)(fc), + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + fc.Name = aux.Name + + // 处理arguments可能是字符串或对象的情况 + switch v := aux.Arguments.(type) { + case map[string]interface{}: + fc.Arguments = v + case string: + // 如果是字符串,尝试解析为JSON + if err := json.Unmarshal([]byte(v), &fc.Arguments); err != nil { + // 如果解析失败,创建一个包含原始字符串的map + fc.Arguments = map[string]interface{}{ + "raw": v, + } + } + case nil: + fc.Arguments = make(map[string]interface{}) + default: + // 其他类型,尝试转换为map + fc.Arguments = map[string]interface{}{ + "value": v, + } + } + + return nil +} + +// AgentLoopResult Agent Loop执行结果 +type AgentLoopResult struct { + Response string + MCPExecutionIDs []string +} + +// AgentLoop 执行Agent循环 +func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) { + messages := []ChatMessage{ + { + Role: "system", + Content: "你是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。当需要执行工具时,使用提供的工具函数。", + }, + } + + // 添加历史消息(数据库只保存user和assistant消息) + a.logger.Info("处理历史消息", + zap.Int("count", len(historyMessages)), + ) + addedCount := 0 + for i, msg := range historyMessages { + // 只添加有内容的消息 + if msg.Content != "" { + messages = append(messages, ChatMessage{ + Role: msg.Role, + Content: msg.Content, + }) + addedCount++ + contentPreview := msg.Content + if len(contentPreview) > 50 { + contentPreview = contentPreview[:50] + "..." + } + a.logger.Info("添加历史消息到上下文", + zap.Int("index", i), + zap.String("role", msg.Role), + zap.String("content", contentPreview), + ) + } + } + + a.logger.Info("构建消息数组", + zap.Int("historyMessages", len(historyMessages)), + zap.Int("addedMessages", addedCount), + zap.Int("totalMessages", len(messages)), + ) + + // 添加当前用户消息 + messages = append(messages, ChatMessage{ + Role: "user", + Content: userInput, + }) + + result := &AgentLoopResult{ + MCPExecutionIDs: make([]string, 0), + } + + maxIterations := 10 + for i := 0; i < maxIterations; i++ { + // 获取可用工具 + tools := a.getAvailableTools() + + // 记录每次调用OpenAI + if i == 0 { + a.logger.Info("调用OpenAI", + zap.Int("iteration", i+1), + zap.Int("messagesCount", len(messages)), + ) + // 记录前几条消息的内容(用于调试) + for j, msg := range messages { + if j >= 5 { // 只记录前5条 + break + } + contentPreview := msg.Content + if len(contentPreview) > 100 { + contentPreview = contentPreview[:100] + "..." + } + a.logger.Debug("消息内容", + zap.Int("index", j), + zap.String("role", msg.Role), + zap.String("content", contentPreview), + ) + } + } else { + a.logger.Info("调用OpenAI", + zap.Int("iteration", i+1), + zap.Int("messagesCount", len(messages)), + ) + } + + // 调用OpenAI + response, err := a.callOpenAI(ctx, messages, tools) + if err != nil { + result.Response = "" + return result, fmt.Errorf("调用OpenAI失败: %w", err) + } + + if response.Error != nil { + result.Response = "" + return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message) + } + + if len(response.Choices) == 0 { + result.Response = "" + return result, fmt.Errorf("没有收到响应") + } + + choice := response.Choices[0] + + // 检查是否有工具调用 + if len(choice.Message.ToolCalls) > 0 { + // 添加assistant消息(包含工具调用) + messages = append(messages, ChatMessage{ + Role: "assistant", + Content: choice.Message.Content, + ToolCalls: choice.Message.ToolCalls, + }) + + // 执行所有工具调用 + for _, toolCall := range choice.Message.ToolCalls { + // 执行工具 + execResult, err := a.executeToolViaMCP(ctx, toolCall.Function.Name, toolCall.Function.Arguments) + if err != nil { + messages = append(messages, ChatMessage{ + Role: "tool", + ToolCallID: toolCall.ID, + Content: fmt.Sprintf("工具执行失败: %v", err), + }) + } else { + messages = append(messages, ChatMessage{ + Role: "tool", + ToolCallID: toolCall.ID, + Content: execResult.Result, + }) + // 收集执行ID + if execResult.ExecutionID != "" { + result.MCPExecutionIDs = append(result.MCPExecutionIDs, execResult.ExecutionID) + } + } + } + continue + } + + // 添加assistant响应 + messages = append(messages, ChatMessage{ + Role: "assistant", + Content: choice.Message.Content, + }) + + // 如果完成,返回结果 + if choice.FinishReason == "stop" { + result.Response = choice.Message.Content + return result, nil + } + } + + result.Response = "达到最大迭代次数" + return result, nil +} + +// getAvailableTools 获取可用工具 +func (a *Agent) getAvailableTools() []Tool { + // 从MCP服务器获取工具列表 + executions := a.mcpServer.GetAllExecutions() + toolNames := make(map[string]bool) + for _, exec := range executions { + toolNames[exec.ToolName] = true + } + + tools := []Tool{ + { + Type: "function", + Function: FunctionDefinition{ + Name: "nmap", + Description: "使用nmap进行网络扫描,发现开放端口和服务。支持IP地址、域名或URL(会自动提取域名)。使用TCP连接扫描,不需要root权限。", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "target": map[string]interface{}{ + "type": "string", + "description": "目标IP地址、域名或URL(如 https://example.com)。如果是URL,会自动提取域名部分。", + }, + "ports": map[string]interface{}{ + "type": "string", + "description": "要扫描的端口范围,例如: 1-1000 或 80,443,8080。如果不指定,将扫描常用端口。", + }, + }, + "required": []string{"target"}, + }, + }, + }, + { + Type: "function", + Function: FunctionDefinition{ + Name: "sqlmap", + Description: "使用sqlmap检测SQL注入漏洞", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "url": map[string]interface{}{ + "type": "string", + "description": "目标URL", + }, + }, + "required": []string{"url"}, + }, + }, + }, + { + Type: "function", + Function: FunctionDefinition{ + Name: "nikto", + Description: "使用nikto扫描Web服务器漏洞", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "target": map[string]interface{}{ + "type": "string", + "description": "目标URL", + }, + }, + "required": []string{"target"}, + }, + }, + }, + { + Type: "function", + Function: FunctionDefinition{ + Name: "dirb", + Description: "使用dirb进行目录扫描", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "url": map[string]interface{}{ + "type": "string", + "description": "目标URL", + }, + }, + "required": []string{"url"}, + }, + }, + }, + { + Type: "function", + Function: FunctionDefinition{ + Name: "exec", + Description: "执行系统命令(谨慎使用,仅用于必要的系统操作)", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "command": map[string]interface{}{ + "type": "string", + "description": "要执行的系统命令", + }, + "shell": map[string]interface{}{ + "type": "string", + "description": "使用的shell(可选,默认为sh)", + }, + "workdir": map[string]interface{}{ + "type": "string", + "description": "工作目录(可选)", + }, + }, + "required": []string{"command"}, + }, + }, + }, + } + + return tools +} + +// callOpenAI 调用OpenAI API +func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) { + reqBody := OpenAIRequest{ + Model: a.config.Model, + Messages: messages, + } + + if len(tools) > 0 { + reqBody.Tools = tools + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "POST", a.config.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+a.config.APIKey) + + resp, err := a.openAIClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // 记录响应内容(用于调试) + if resp.StatusCode != http.StatusOK { + a.logger.Warn("OpenAI API返回非200状态码", + zap.Int("status", resp.StatusCode), + zap.String("body", string(body)), + ) + } + + var response OpenAIResponse + if err := json.Unmarshal(body, &response); err != nil { + a.logger.Error("解析OpenAI响应失败", + zap.Error(err), + zap.String("body", string(body)), + ) + return nil, fmt.Errorf("解析响应失败: %w, 响应内容: %s", err, string(body)) + } + + return &response, nil +} + +// parseToolCall 解析工具调用 +func (a *Agent) parseToolCall(content string) (map[string]interface{}, error) { + // 简单解析,实际应该更复杂 + // 格式: [TOOL_CALL]tool_name:arg1=value1,arg2=value2 + if !strings.HasPrefix(content, "[TOOL_CALL]") { + return nil, fmt.Errorf("不是有效的工具调用格式") + } + + parts := strings.Split(content[len("[TOOL_CALL]"):], ":") + if len(parts) < 2 { + return nil, fmt.Errorf("工具调用格式错误") + } + + toolName := strings.TrimSpace(parts[0]) + argsStr := strings.TrimSpace(parts[1]) + + args := make(map[string]interface{}) + argPairs := strings.Split(argsStr, ",") + for _, pair := range argPairs { + kv := strings.Split(pair, "=") + if len(kv) == 2 { + args[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1]) + } + } + + args["_tool_name"] = toolName + return args, nil +} + +// ToolExecutionResult 工具执行结果 +type ToolExecutionResult struct { + Result string + ExecutionID string +} + +// executeToolViaMCP 通过MCP执行工具 +func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map[string]interface{}) (*ToolExecutionResult, error) { + a.logger.Info("通过MCP执行工具", + zap.String("tool", toolName), + zap.Any("args", args), + ) + + // 通过MCP服务器调用工具 + result, executionID, err := a.mcpServer.CallTool(ctx, toolName, args) + if err != nil { + return nil, fmt.Errorf("工具执行失败: %w", err) + } + + // 格式化结果 + var resultText strings.Builder + for _, content := range result.Content { + resultText.WriteString(content.Text) + resultText.WriteString("\n") + } + + return &ToolExecutionResult{ + Result: resultText.String(), + ExecutionID: executionID, + }, nil +} + diff --git a/internal/app/app.go b/internal/app/app.go new file mode 100644 index 00000000..8f78d2b7 --- /dev/null +++ b/internal/app/app.go @@ -0,0 +1,163 @@ +package app + +import ( + "fmt" + "net/http" + "os" + "path/filepath" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/handler" + "cyberstrike-ai/internal/logger" + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/security" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// App 应用 +type App struct { + config *config.Config + logger *logger.Logger + router *gin.Engine + mcpServer *mcp.Server + agent *agent.Agent + executor *security.Executor + db *database.DB +} + +// New 创建新应用 +func New(cfg *config.Config, log *logger.Logger) (*App, error) { + gin.SetMode(gin.ReleaseMode) + router := gin.Default() + + // CORS中间件 + router.Use(corsMiddleware()) + + // 初始化数据库 + dbPath := cfg.Database.Path + if dbPath == "" { + dbPath = "data/conversations.db" + } + + // 确保目录存在 + if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil { + return nil, fmt.Errorf("创建数据库目录失败: %w", err) + } + + db, err := database.NewDB(dbPath, log.Logger) + if err != nil { + return nil, fmt.Errorf("初始化数据库失败: %w", err) + } + + // 创建MCP服务器 + mcpServer := mcp.NewServer(log.Logger) + + // 创建安全工具执行器 + executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger) + + // 注册工具 + executor.RegisterTools(mcpServer) + + // 创建Agent + agent := agent.NewAgent(&cfg.OpenAI, mcpServer, log.Logger) + + // 创建处理器 + agentHandler := handler.NewAgentHandler(agent, db, log.Logger) + monitorHandler := handler.NewMonitorHandler(mcpServer, executor, log.Logger) + conversationHandler := handler.NewConversationHandler(db, log.Logger) + + // 设置路由 + setupRoutes(router, agentHandler, monitorHandler, conversationHandler, mcpServer) + + return &App{ + config: cfg, + logger: log, + router: router, + mcpServer: mcpServer, + agent: agent, + executor: executor, + db: db, + }, nil +} + +// Run 启动应用 +func (a *App) Run() error { + // 启动MCP服务器(如果启用) + if a.config.MCP.Enabled { + go func() { + mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port) + a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr)) + + mux := http.NewServeMux() + mux.HandleFunc("/mcp", a.mcpServer.HandleHTTP) + + if err := http.ListenAndServe(mcpAddr, mux); err != nil { + a.logger.Error("MCP服务器启动失败", zap.Error(err)) + } + }() + } + + // 启动主服务器 + addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port) + a.logger.Info("启动HTTP服务器", zap.String("address", addr)) + + return a.router.Run(addr) +} + +// setupRoutes 设置路由 +func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitorHandler *handler.MonitorHandler, conversationHandler *handler.ConversationHandler, mcpServer *mcp.Server) { + // API路由 + api := router.Group("/api") + { + // Agent Loop + api.POST("/agent-loop", agentHandler.AgentLoop) + + // 对话历史 + api.POST("/conversations", conversationHandler.CreateConversation) + api.GET("/conversations", conversationHandler.ListConversations) + api.GET("/conversations/:id", conversationHandler.GetConversation) + api.DELETE("/conversations/:id", conversationHandler.DeleteConversation) + + // 监控 + api.GET("/monitor", monitorHandler.Monitor) + api.GET("/monitor/execution/:id", monitorHandler.GetExecution) + api.GET("/monitor/stats", monitorHandler.GetStats) + api.GET("/monitor/vulnerabilities", monitorHandler.GetVulnerabilities) + + // MCP端点 + api.POST("/mcp", func(c *gin.Context) { + mcpServer.HandleHTTP(c.Writer, c.Request) + }) + } + + // 静态文件 + router.Static("/static", "./web/static") + router.LoadHTMLGlob("web/templates/*") + + // 前端页面 + router.GET("/", func(c *gin.Context) { + c.HTML(http.StatusOK, "index.html", nil) + }) +} + +// corsMiddleware CORS中间件 +func corsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") + c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE") + + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(204) + return + } + + c.Next() + } +} + diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 00000000..2215bce7 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,114 @@ +package config + +import ( + "fmt" + "os" + + "gopkg.in/yaml.v3" +) + +type Config struct { + Server ServerConfig `yaml:"server"` + Log LogConfig `yaml:"log"` + MCP MCPConfig `yaml:"mcp"` + OpenAI OpenAIConfig `yaml:"openai"` + Security SecurityConfig `yaml:"security"` + Database DatabaseConfig `yaml:"database"` +} + +type ServerConfig struct { + Host string `yaml:"host"` + Port int `yaml:"port"` +} + +type LogConfig struct { + Level string `yaml:"level"` + Output string `yaml:"output"` +} + +type MCPConfig struct { + Enabled bool `yaml:"enabled"` + Host string `yaml:"host"` + Port int `yaml:"port"` +} + +type OpenAIConfig struct { + APIKey string `yaml:"api_key"` + BaseURL string `yaml:"base_url"` + Model string `yaml:"model"` +} + +type SecurityConfig struct { + Tools []ToolConfig `yaml:"tools"` +} + +type DatabaseConfig struct { + Path string `yaml:"path"` +} + +type ToolConfig struct { + Name string `yaml:"name"` + Command string `yaml:"command"` + Args []string `yaml:"args,omitempty"` // 固定参数(可选) + Description string `yaml:"description"` + Enabled bool `yaml:"enabled"` + Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选) + ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选) +} + +// ParameterConfig 参数配置 +type ParameterConfig struct { + Name string `yaml:"name"` // 参数名称 + Type string `yaml:"type"` // 参数类型: string, int, bool, array + Description string `yaml:"description"` // 参数描述 + Required bool `yaml:"required,omitempty"` // 是否必需 + Default interface{} `yaml:"default,omitempty"` // 默认值 + Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p" + Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始) + Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template" + Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}" + Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举) +} + +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("读取配置文件失败: %w", err) + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("解析配置文件失败: %w", err) + } + + return &cfg, nil +} + +func Default() *Config { + return &Config{ + Server: ServerConfig{ + Host: "0.0.0.0", + Port: 8080, + }, + Log: LogConfig{ + Level: "info", + Output: "stdout", + }, + MCP: MCPConfig{ + Enabled: true, + Host: "0.0.0.0", + Port: 8081, + }, + OpenAI: OpenAIConfig{ + BaseURL: "https://api.openai.com/v1", + Model: "gpt-4", + }, + Security: SecurityConfig{ + Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 加载,不在此硬编码 + }, + Database: DatabaseConfig{ + Path: "data/conversations.db", + }, + } +} + diff --git a/internal/database/conversation.go b/internal/database/conversation.go new file mode 100644 index 00000000..95905772 --- /dev/null +++ b/internal/database/conversation.go @@ -0,0 +1,256 @@ +package database + +import ( + "database/sql" + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Conversation 对话 +type Conversation struct { + ID string `json:"id"` + Title string `json:"title"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` + Messages []Message `json:"messages,omitempty"` +} + +// Message 消息 +type Message struct { + ID string `json:"id"` + ConversationID string `json:"conversationId"` + Role string `json:"role"` + Content string `json:"content"` + MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` + CreatedAt time.Time `json:"createdAt"` +} + +// CreateConversation 创建新对话 +func (db *DB) CreateConversation(title string) (*Conversation, error) { + id := uuid.New().String() + now := time.Now() + + _, err := db.Exec( + "INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)", + id, title, now, now, + ) + if err != nil { + return nil, fmt.Errorf("创建对话失败: %w", err) + } + + return &Conversation{ + ID: id, + Title: title, + CreatedAt: now, + UpdatedAt: now, + }, nil +} + +// GetConversation 获取对话 +func (db *DB) GetConversation(id string) (*Conversation, error) { + var conv Conversation + var createdAt, updatedAt string + + err := db.QueryRow( + "SELECT id, title, created_at, updated_at FROM conversations WHERE id = ?", + id, + ).Scan(&conv.ID, &conv.Title, &createdAt, &updatedAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("对话不存在") + } + return nil, fmt.Errorf("查询对话失败: %w", err) + } + + // 尝试多种时间格式解析 + var err1, err2 error + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err1 != nil { + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err1 != nil { + conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt) + } + + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) + if err2 != nil { + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) + } + if err2 != nil { + conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt) + } + + // 加载消息 + messages, err := db.GetMessages(id) + if err != nil { + return nil, fmt.Errorf("加载消息失败: %w", err) + } + conv.Messages = messages + + return &conv, nil +} + +// ListConversations 列出所有对话 +func (db *DB) ListConversations(limit, offset int) ([]*Conversation, error) { + rows, err := db.Query( + "SELECT id, title, created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?", + limit, offset, + ) + if err != nil { + return nil, fmt.Errorf("查询对话列表失败: %w", err) + } + defer rows.Close() + + var conversations []*Conversation + for rows.Next() { + var conv Conversation + var createdAt, updatedAt string + + if err := rows.Scan(&conv.ID, &conv.Title, &createdAt, &updatedAt); err != nil { + return nil, fmt.Errorf("扫描对话失败: %w", err) + } + + // 尝试多种时间格式解析 + var err1, err2 error + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err1 != nil { + conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err1 != nil { + conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt) + } + + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt) + if err2 != nil { + conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt) + } + if err2 != nil { + conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt) + } + + conversations = append(conversations, &conv) + } + + return conversations, nil +} + +// UpdateConversationTitle 更新对话标题 +func (db *DB) UpdateConversationTitle(id, title string) error { + _, err := db.Exec( + "UPDATE conversations SET title = ?, updated_at = ? WHERE id = ?", + title, time.Now(), id, + ) + if err != nil { + return fmt.Errorf("更新对话标题失败: %w", err) + } + return nil +} + +// UpdateConversationTime 更新对话时间 +func (db *DB) UpdateConversationTime(id string) error { + _, err := db.Exec( + "UPDATE conversations SET updated_at = ? WHERE id = ?", + time.Now(), id, + ) + if err != nil { + return fmt.Errorf("更新对话时间失败: %w", err) + } + return nil +} + +// DeleteConversation 删除对话 +func (db *DB) DeleteConversation(id string) error { + _, err := db.Exec("DELETE FROM conversations WHERE id = ?", id) + if err != nil { + return fmt.Errorf("删除对话失败: %w", err) + } + return nil +} + +// AddMessage 添加消息 +func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) { + id := uuid.New().String() + + var mcpIDsJSON string + if len(mcpExecutionIDs) > 0 { + jsonData, err := json.Marshal(mcpExecutionIDs) + if err != nil { + db.logger.Warn("序列化MCP执行ID失败", zap.Error(err)) + } else { + mcpIDsJSON = string(jsonData) + } + } + + _, err := db.Exec( + "INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)", + id, conversationID, role, content, mcpIDsJSON, time.Now(), + ) + if err != nil { + return nil, fmt.Errorf("添加消息失败: %w", err) + } + + // 更新对话时间 + if err := db.UpdateConversationTime(conversationID); err != nil { + db.logger.Warn("更新对话时间失败", zap.Error(err)) + } + + message := &Message{ + ID: id, + ConversationID: conversationID, + Role: role, + Content: content, + MCPExecutionIDs: mcpExecutionIDs, + CreatedAt: time.Now(), + } + + return message, nil +} + +// GetMessages 获取对话的所有消息 +func (db *DB) GetMessages(conversationID string) ([]Message, error) { + rows, err := db.Query( + "SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC", + conversationID, + ) + if err != nil { + return nil, fmt.Errorf("查询消息失败: %w", err) + } + defer rows.Close() + + var messages []Message + for rows.Next() { + var msg Message + var mcpIDsJSON sql.NullString + var createdAt string + + if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil { + return nil, fmt.Errorf("扫描消息失败: %w", err) + } + + // 尝试多种时间格式解析 + var err error + msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt) + if err != nil { + msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt) + } + if err != nil { + msg.CreatedAt, err = time.Parse(time.RFC3339, createdAt) + } + + // 解析MCP执行ID + if mcpIDsJSON.Valid && mcpIDsJSON.String != "" { + if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil { + db.logger.Warn("解析MCP执行ID失败", zap.Error(err)) + } + } + + messages = append(messages, msg) + } + + return messages, nil +} + diff --git a/internal/database/database.go b/internal/database/database.go new file mode 100644 index 00000000..5091e5ef --- /dev/null +++ b/internal/database/database.go @@ -0,0 +1,90 @@ +package database + +import ( + "database/sql" + "fmt" + + _ "github.com/mattn/go-sqlite3" + "go.uber.org/zap" +) + +// DB 数据库连接 +type DB struct { + *sql.DB + logger *zap.Logger +} + +// NewDB 创建数据库连接 +func NewDB(dbPath string, logger *zap.Logger) (*DB, error) { + db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1") + if err != nil { + return nil, fmt.Errorf("打开数据库失败: %w", err) + } + + if err := db.Ping(); err != nil { + return nil, fmt.Errorf("连接数据库失败: %w", err) + } + + database := &DB{ + DB: db, + logger: logger, + } + + // 初始化表 + if err := database.initTables(); err != nil { + return nil, fmt.Errorf("初始化表失败: %w", err) + } + + return database, nil +} + +// initTables 初始化数据库表 +func (db *DB) initTables() error { + // 创建对话表 + createConversationsTable := ` + CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + created_at DATETIME NOT NULL, + updated_at DATETIME NOT NULL + );` + + // 创建消息表 + createMessagesTable := ` + CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + mcp_execution_ids TEXT, + created_at DATETIME NOT NULL, + FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE + );` + + // 创建索引 + createIndexes := ` + CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id); + CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at); + ` + + if _, err := db.Exec(createConversationsTable); err != nil { + return fmt.Errorf("创建conversations表失败: %w", err) + } + + if _, err := db.Exec(createMessagesTable); err != nil { + return fmt.Errorf("创建messages表失败: %w", err) + } + + if _, err := db.Exec(createIndexes); err != nil { + return fmt.Errorf("创建索引失败: %w", err) + } + + db.logger.Info("数据库表初始化完成") + return nil +} + +// Close 关闭数据库连接 +func (db *DB) Close() error { + return db.DB.Close() +} + diff --git a/internal/handler/agent.go b/internal/handler/agent.go new file mode 100644 index 00000000..b628b08c --- /dev/null +++ b/internal/handler/agent.go @@ -0,0 +1,134 @@ +package handler + +import ( + "net/http" + "time" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/database" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// AgentHandler Agent处理器 +type AgentHandler struct { + agent *agent.Agent + db *database.DB + logger *zap.Logger +} + +// NewAgentHandler 创建新的Agent处理器 +func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler { + return &AgentHandler{ + agent: agent, + db: db, + logger: logger, + } +} + +// ChatRequest 聊天请求 +type ChatRequest struct { + Message string `json:"message" binding:"required"` + ConversationID string `json:"conversationId,omitempty"` +} + +// ChatResponse 聊天响应 +type ChatResponse struct { + Response string `json:"response"` + MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` // 本次对话中执行的MCP调用ID列表 + ConversationID string `json:"conversationId"` // 对话ID + Time time.Time `json:"time"` +} + +// AgentLoop 处理Agent Loop请求 +func (h *AgentHandler) AgentLoop(c *gin.Context) { + var req ChatRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + h.logger.Info("收到Agent Loop请求", + zap.String("message", req.Message), + zap.String("conversationId", req.ConversationID), + ) + + // 如果没有对话ID,创建新对话 + conversationID := req.ConversationID + if conversationID == "" { + title := req.Message + if len(title) > 50 { + title = title[:50] + "..." + } + conv, err := h.db.CreateConversation(title) + if err != nil { + h.logger.Error("创建对话失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + conversationID = conv.ID + } + + // 获取历史消息(排除当前消息,因为还没保存) + historyMessages, err := h.db.GetMessages(conversationID) + if err != nil { + h.logger.Warn("获取历史消息失败", zap.Error(err)) + historyMessages = []database.Message{} + } + + h.logger.Info("获取历史消息", + zap.String("conversationId", conversationID), + zap.Int("count", len(historyMessages)), + ) + + // 将数据库消息转换为Agent消息格式 + agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages)) + for i, msg := range historyMessages { + agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{ + Role: msg.Role, + Content: msg.Content, + }) + contentPreview := msg.Content + if len(contentPreview) > 50 { + contentPreview = contentPreview[:50] + "..." + } + h.logger.Info("添加历史消息", + zap.Int("index", i), + zap.String("role", msg.Role), + zap.String("content", contentPreview), + ) + } + + h.logger.Info("历史消息转换完成", + zap.Int("originalCount", len(historyMessages)), + zap.Int("convertedCount", len(agentHistoryMessages)), + ) + + // 保存用户消息 + _, err = h.db.AddMessage(conversationID, "user", req.Message, nil) + if err != nil { + h.logger.Error("保存用户消息失败", zap.Error(err)) + } + + // 执行Agent Loop,传入历史消息 + result, err := h.agent.AgentLoop(c.Request.Context(), req.Message, agentHistoryMessages) + if err != nil { + h.logger.Error("Agent Loop执行失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 保存助手回复 + _, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs) + if err != nil { + h.logger.Error("保存助手消息失败", zap.Error(err)) + } + + c.JSON(http.StatusOK, ChatResponse{ + Response: result.Response, + MCPExecutionIDs: result.MCPExecutionIDs, + ConversationID: conversationID, + Time: time.Now(), + }) +} + diff --git a/internal/handler/conversation.go b/internal/handler/conversation.go new file mode 100644 index 00000000..7e2a1d22 --- /dev/null +++ b/internal/handler/conversation.go @@ -0,0 +1,102 @@ +package handler + +import ( + "net/http" + "strconv" + + "cyberstrike-ai/internal/database" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// ConversationHandler 对话处理器 +type ConversationHandler struct { + db *database.DB + logger *zap.Logger +} + +// NewConversationHandler 创建新的对话处理器 +func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler { + return &ConversationHandler{ + db: db, + logger: logger, + } +} + +// CreateConversationRequest 创建对话请求 +type CreateConversationRequest struct { + Title string `json:"title"` +} + +// CreateConversation 创建新对话 +func (h *ConversationHandler) CreateConversation(c *gin.Context) { + var req CreateConversationRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + title := req.Title + if title == "" { + title = "新对话" + } + + conv, err := h.db.CreateConversation(title) + if err != nil { + h.logger.Error("创建对话失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, conv) +} + +// ListConversations 列出对话 +func (h *ConversationHandler) ListConversations(c *gin.Context) { + limitStr := c.DefaultQuery("limit", "50") + offsetStr := c.DefaultQuery("offset", "0") + + limit, _ := strconv.Atoi(limitStr) + offset, _ := strconv.Atoi(offsetStr) + + if limit <= 0 || limit > 100 { + limit = 50 + } + + conversations, err := h.db.ListConversations(limit, offset) + if err != nil { + h.logger.Error("获取对话列表失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, conversations) +} + +// GetConversation 获取对话 +func (h *ConversationHandler) GetConversation(c *gin.Context) { + id := c.Param("id") + + conv, err := h.db.GetConversation(id) + if err != nil { + h.logger.Error("获取对话失败", zap.Error(err)) + c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) + return + } + + c.JSON(http.StatusOK, conv) +} + +// DeleteConversation 删除对话 +func (h *ConversationHandler) DeleteConversation(c *gin.Context) { + id := c.Param("id") + + if err := h.db.DeleteConversation(id); err != nil { + h.logger.Error("删除对话失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) +} + diff --git a/internal/handler/monitor.go b/internal/handler/monitor.go new file mode 100644 index 00000000..a5893a1e --- /dev/null +++ b/internal/handler/monitor.go @@ -0,0 +1,92 @@ +package handler + +import ( + "net/http" + "time" + + "cyberstrike-ai/internal/mcp" + "cyberstrike-ai/internal/security" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// MonitorHandler 监控处理器 +type MonitorHandler struct { + mcpServer *mcp.Server + executor *security.Executor + logger *zap.Logger + vulns []security.Vulnerability +} + +// NewMonitorHandler 创建新的监控处理器 +func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, logger *zap.Logger) *MonitorHandler { + return &MonitorHandler{ + mcpServer: mcpServer, + executor: executor, + logger: logger, + vulns: []security.Vulnerability{}, + } +} + +// MonitorResponse 监控响应 +type MonitorResponse struct { + Executions []*mcp.ToolExecution `json:"executions"` + Stats map[string]*mcp.ToolStats `json:"stats"` + Vulnerabilities []security.Vulnerability `json:"vulnerabilities"` + Report map[string]interface{} `json:"report"` + Timestamp time.Time `json:"timestamp"` +} + +// Monitor 获取监控信息 +func (h *MonitorHandler) Monitor(c *gin.Context) { + // 获取所有执行记录 + executions := h.mcpServer.GetAllExecutions() + + // 分析执行结果,提取漏洞 + for _, exec := range executions { + if exec.Status == "completed" && exec.Result != nil { + vulns := h.executor.AnalyzeResults(exec.ToolName, exec.Result) + h.vulns = append(h.vulns, vulns...) + } + } + + // 获取统计信息 + stats := h.mcpServer.GetStats() + + // 生成报告 + report := h.executor.GetVulnerabilityReport(h.vulns) + + c.JSON(http.StatusOK, MonitorResponse{ + Executions: executions, + Stats: stats, + Vulnerabilities: h.vulns, + Report: report, + Timestamp: time.Now(), + }) +} + +// GetExecution 获取特定执行记录 +func (h *MonitorHandler) GetExecution(c *gin.Context) { + id := c.Param("id") + + exec, exists := h.mcpServer.GetExecution(id) + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"}) + return + } + + c.JSON(http.StatusOK, exec) +} + +// GetStats 获取统计信息 +func (h *MonitorHandler) GetStats(c *gin.Context) { + stats := h.mcpServer.GetStats() + c.JSON(http.StatusOK, stats) +} + +// GetVulnerabilities 获取漏洞列表 +func (h *MonitorHandler) GetVulnerabilities(c *gin.Context) { + report := h.executor.GetVulnerabilityReport(h.vulns) + c.JSON(http.StatusOK, report) +} + diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 00000000..549bceb0 --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,60 @@ +package logger + +import ( + "os" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type Logger struct { + *zap.Logger +} + +func New(level, output string) *Logger { + var zapLevel zapcore.Level + switch level { + case "debug": + zapLevel = zapcore.DebugLevel + case "info": + zapLevel = zapcore.InfoLevel + case "warn": + zapLevel = zapcore.WarnLevel + case "error": + zapLevel = zapcore.ErrorLevel + default: + zapLevel = zapcore.InfoLevel + } + + config := zap.NewProductionConfig() + config.Level = zap.NewAtomicLevelAt(zapLevel) + config.EncoderConfig.TimeKey = "timestamp" + config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + + var writeSyncer zapcore.WriteSyncer + if output == "stdout" { + writeSyncer = zapcore.AddSync(os.Stdout) + } else { + file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + writeSyncer = zapcore.AddSync(os.Stdout) + } else { + writeSyncer = zapcore.AddSync(file) + } + } + + core := zapcore.NewCore( + zapcore.NewJSONEncoder(config.EncoderConfig), + writeSyncer, + zapLevel, + ) + + logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel)) + + return &Logger{Logger: logger} +} + +func (l *Logger) Fatal(msg string, fields ...interface{}) { + l.Logger.Fatal(msg, zap.Any("fields", fields)) +} + diff --git a/internal/mcp/server.go b/internal/mcp/server.go new file mode 100644 index 00000000..c9e53ea5 --- /dev/null +++ b/internal/mcp/server.go @@ -0,0 +1,798 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Server MCP服务器 +type Server struct { + tools map[string]ToolHandler + toolDefs map[string]Tool // 工具定义 + executions map[string]*ToolExecution + stats map[string]*ToolStats + prompts map[string]*Prompt // 提示词模板 + resources map[string]*Resource // 资源 + mu sync.RWMutex + logger *zap.Logger +} + +// ToolHandler 工具处理函数 +type ToolHandler func(ctx context.Context, args map[string]interface{}) (*ToolResult, error) + +// NewServer 创建新的MCP服务器 +func NewServer(logger *zap.Logger) *Server { + s := &Server{ + tools: make(map[string]ToolHandler), + toolDefs: make(map[string]Tool), + executions: make(map[string]*ToolExecution), + stats: make(map[string]*ToolStats), + prompts: make(map[string]*Prompt), + resources: make(map[string]*Resource), + logger: logger, + } + + // 初始化默认提示词和资源 + s.initDefaultPrompts() + s.initDefaultResources() + + return s +} + +// RegisterTool 注册工具 +func (s *Server) RegisterTool(tool Tool, handler ToolHandler) { + s.mu.Lock() + defer s.mu.Unlock() + s.tools[tool.Name] = handler + s.toolDefs[tool.Name] = tool + + // 自动为工具创建资源文档 + resourceURI := fmt.Sprintf("tool://%s", tool.Name) + s.resources[resourceURI] = &Resource{ + URI: resourceURI, + Name: fmt.Sprintf("%s工具文档", tool.Name), + Description: tool.Description, + MimeType: "text/plain", + } +} + +// HandleHTTP 处理HTTP请求 +func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + s.sendError(w, nil, -32700, "Parse error", err.Error()) + return + } + + var msg Message + if err := json.Unmarshal(body, &msg); err != nil { + s.sendError(w, nil, -32700, "Parse error", err.Error()) + return + } + + // 处理消息 + response := s.handleMessage(&msg) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + +// handleMessage 处理MCP消息 +func (s *Server) handleMessage(msg *Message) *Message { + if msg.ID == "" { + msg.ID = uuid.New().String() + } + + switch msg.Method { + case "initialize": + return s.handleInitialize(msg) + case "tools/list": + return s.handleListTools(msg) + case "tools/call": + return s.handleCallTool(msg) + case "prompts/list": + return s.handleListPrompts(msg) + case "prompts/get": + return s.handleGetPrompt(msg) + case "resources/list": + return s.handleListResources(msg) + case "resources/read": + return s.handleReadResource(msg) + case "sampling/request": + return s.handleSamplingRequest(msg) + default: + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Error: &Error{Code: -32601, Message: "Method not found"}, + } + } +} + +// handleInitialize 处理初始化请求 +func (s *Server) handleInitialize(msg *Message) *Message { + var req InitializeRequest + if err := json.Unmarshal(msg.Params, &req); err != nil { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Error: &Error{Code: -32602, Message: "Invalid params"}, + } + } + + response := InitializeResponse{ + ProtocolVersion: ProtocolVersion, + Capabilities: ServerCapabilities{ + Tools: map[string]interface{}{ + "listChanged": true, + }, + Prompts: map[string]interface{}{ + "listChanged": true, + }, + Resources: map[string]interface{}{ + "subscribe": true, + "listChanged": true, + }, + Sampling: map[string]interface{}{}, + }, + ServerInfo: ServerInfo{ + Name: "CyberStrikeAI", + Version: "1.0.0", + }, + } + + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// handleListTools 处理列出工具请求 +func (s *Server) handleListTools(msg *Message) *Message { + s.mu.RLock() + tools := make([]Tool, 0, len(s.toolDefs)) + for _, tool := range s.toolDefs { + tools = append(tools, tool) + } + s.mu.RUnlock() + + response := ListToolsResponse{Tools: tools} + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// handleCallTool 处理工具调用请求 +func (s *Server) handleCallTool(msg *Message) *Message { + var req CallToolRequest + if err := json.Unmarshal(msg.Params, &req); err != nil { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Error: &Error{Code: -32602, Message: "Invalid params"}, + } + } + + // 创建执行记录 + executionID := uuid.New().String() + execution := &ToolExecution{ + ID: executionID, + ToolName: req.Name, + Arguments: req.Arguments, + Status: "running", + StartTime: time.Now(), + } + + s.mu.Lock() + s.executions[executionID] = execution + s.mu.Unlock() + + // 更新统计 + s.updateStats(req.Name, false) + + // 执行工具 + s.mu.RLock() + handler, exists := s.tools[req.Name] + s.mu.RUnlock() + + if !exists { + execution.Status = "failed" + execution.Error = "Tool not found" + now := time.Now() + execution.EndTime = &now + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Error: &Error{Code: -32601, Message: "Tool not found"}, + } + } + + // 同步执行所有工具,确保错误能正确返回 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + s.logger.Info("开始执行工具", + zap.String("toolName", req.Name), + zap.Any("arguments", req.Arguments), + ) + + result, err := handler(ctx, req.Arguments) + + s.mu.Lock() + now := time.Now() + execution.EndTime = &now + execution.Duration = now.Sub(execution.StartTime) + + if err != nil { + execution.Status = "failed" + execution.Error = err.Error() + s.updateStats(req.Name, true) + s.mu.Unlock() + + s.logger.Error("工具执行失败", + zap.String("toolName", req.Name), + zap.Error(err), + ) + + // 返回错误结果 + errorResult, _ := json.Marshal(CallToolResponse{ + Content: []Content{ + {Type: "text", Text: fmt.Sprintf("工具执行失败: %v", err)}, + }, + IsError: true, + }) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: errorResult, + } + } + + // 检查result是否为错误 + if result != nil && result.IsError { + execution.Status = "failed" + if len(result.Content) > 0 { + execution.Error = result.Content[0].Text + } + s.updateStats(req.Name, true) + } else { + execution.Status = "completed" + execution.Result = result + s.updateStats(req.Name, false) + } + s.mu.Unlock() + + // 返回执行结果 + if result == nil { + result = &ToolResult{ + Content: []Content{ + {Type: "text", Text: "工具执行完成,但未返回结果"}, + }, + } + } + + resultJSON, _ := json.Marshal(CallToolResponse{ + Content: result.Content, + IsError: result.IsError, + }) + + s.logger.Info("工具执行完成", + zap.String("toolName", req.Name), + zap.Bool("isError", result.IsError), + ) + + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: resultJSON, + } +} + +// updateStats 更新统计信息 +func (s *Server) updateStats(toolName string, failed bool) { + if s.stats[toolName] == nil { + s.stats[toolName] = &ToolStats{ + ToolName: toolName, + } + } + + stats := s.stats[toolName] + stats.TotalCalls++ + now := time.Now() + stats.LastCallTime = &now + + if failed { + stats.FailedCalls++ + } else { + stats.SuccessCalls++ + } +} + +// GetExecution 获取执行记录 +func (s *Server) GetExecution(id string) (*ToolExecution, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + exec, exists := s.executions[id] + return exec, exists +} + +// GetAllExecutions 获取所有执行记录 +func (s *Server) GetAllExecutions() []*ToolExecution { + s.mu.RLock() + defer s.mu.RUnlock() + executions := make([]*ToolExecution, 0, len(s.executions)) + for _, exec := range s.executions { + executions = append(executions, exec) + } + return executions +} + +// GetStats 获取统计信息 +func (s *Server) GetStats() map[string]*ToolStats { + s.mu.RLock() + defer s.mu.RUnlock() + stats := make(map[string]*ToolStats) + for k, v := range s.stats { + stats[k] = v + } + return stats +} + +// CallTool 直接调用工具(用于内部调用) +func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) { + s.mu.RLock() + handler, exists := s.tools[toolName] + s.mu.RUnlock() + + if !exists { + return nil, "", fmt.Errorf("工具 %s 未找到", toolName) + } + + // 创建执行记录 + executionID := uuid.New().String() + execution := &ToolExecution{ + ID: executionID, + ToolName: toolName, + Arguments: args, + Status: "running", + StartTime: time.Now(), + } + + s.mu.Lock() + s.executions[executionID] = execution + s.mu.Unlock() + + // 更新统计 + s.updateStats(toolName, false) + + // 执行工具 + result, err := handler(ctx, args) + + s.mu.Lock() + now := time.Now() + execution.EndTime = &now + execution.Duration = now.Sub(execution.StartTime) + + if err != nil { + execution.Status = "failed" + execution.Error = err.Error() + s.updateStats(toolName, true) + s.mu.Unlock() + return nil, executionID, err + } else { + execution.Status = "completed" + execution.Result = result + s.updateStats(toolName, false) + s.mu.Unlock() + return result, executionID, nil + } +} + +// initDefaultPrompts 初始化默认提示词模板 +func (s *Server) initDefaultPrompts() { + s.mu.Lock() + defer s.mu.Unlock() + + // 网络安全测试提示词 + s.prompts["security_scan"] = &Prompt{ + Name: "security_scan", + Description: "生成网络安全扫描任务的提示词", + Arguments: []PromptArgument{ + {Name: "target", Description: "扫描目标(IP地址或域名)", Required: true}, + {Name: "scan_type", Description: "扫描类型(port, vuln, web等)", Required: false}, + }, + } + + // 渗透测试提示词 + s.prompts["penetration_test"] = &Prompt{ + Name: "penetration_test", + Description: "生成渗透测试任务的提示词", + Arguments: []PromptArgument{ + {Name: "target", Description: "测试目标", Required: true}, + {Name: "scope", Description: "测试范围", Required: false}, + }, + } +} + +// initDefaultResources 初始化默认资源 +// 注意:工具资源现在在 RegisterTool 时自动创建,此函数保留用于其他非工具资源 +func (s *Server) initDefaultResources() { + // 工具资源已改为在 RegisterTool 时自动创建,无需在此硬编码 +} + +// handleListPrompts 处理列出提示词请求 +func (s *Server) handleListPrompts(msg *Message) *Message { + s.mu.RLock() + prompts := make([]Prompt, 0, len(s.prompts)) + for _, prompt := range s.prompts { + prompts = append(prompts, *prompt) + } + s.mu.RUnlock() + + response := ListPromptsResponse{ + Prompts: prompts, + } + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// handleGetPrompt 处理获取提示词请求 +func (s *Server) handleGetPrompt(msg *Message) *Message { + var req GetPromptRequest + if err := json.Unmarshal(msg.Params, &req); err != nil { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Error: &Error{Code: -32602, Message: "Invalid params"}, + } + } + + s.mu.RLock() + prompt, exists := s.prompts[req.Name] + s.mu.RUnlock() + + if !exists { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Error: &Error{Code: -32601, Message: "Prompt not found"}, + } + } + + // 根据提示词名称生成消息 + messages := s.generatePromptMessages(prompt, req.Arguments) + + response := GetPromptResponse{ + Messages: messages, + } + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// generatePromptMessages 生成提示词消息 +func (s *Server) generatePromptMessages(prompt *Prompt, args map[string]interface{}) []PromptMessage { + messages := []PromptMessage{} + + switch prompt.Name { + case "security_scan": + target, _ := args["target"].(string) + scanType, _ := args["scan_type"].(string) + if scanType == "" { + scanType = "comprehensive" + } + + content := fmt.Sprintf(`请对目标 %s 执行%s安全扫描。包括: +1. 端口扫描和服务识别 +2. 漏洞检测 +3. Web应用安全测试 +4. 生成详细的安全报告`, target, scanType) + + messages = append(messages, PromptMessage{ + Role: "user", + Content: content, + }) + + case "penetration_test": + target, _ := args["target"].(string) + scope, _ := args["scope"].(string) + + content := fmt.Sprintf(`请对目标 %s 执行渗透测试。`, target) + if scope != "" { + content += fmt.Sprintf("测试范围:%s", scope) + } + content += "\n请按照OWASP Top 10进行全面的安全测试。" + + messages = append(messages, PromptMessage{ + Role: "user", + Content: content, + }) + + default: + messages = append(messages, PromptMessage{ + Role: "user", + Content: "请执行安全测试任务", + }) + } + + return messages +} + +// handleListResources 处理列出资源请求 +func (s *Server) handleListResources(msg *Message) *Message { + s.mu.RLock() + resources := make([]Resource, 0, len(s.resources)) + for _, resource := range s.resources { + resources = append(resources, *resource) + } + s.mu.RUnlock() + + response := ListResourcesResponse{ + Resources: resources, + } + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// handleReadResource 处理读取资源请求 +func (s *Server) handleReadResource(msg *Message) *Message { + var req ReadResourceRequest + if err := json.Unmarshal(msg.Params, &req); err != nil { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Error: &Error{Code: -32602, Message: "Invalid params"}, + } + } + + s.mu.RLock() + resource, exists := s.resources[req.URI] + s.mu.RUnlock() + + if !exists { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Error: &Error{Code: -32601, Message: "Resource not found"}, + } + } + + // 生成资源内容 + content := s.generateResourceContent(resource) + + response := ReadResourceResponse{ + Contents: []ResourceContent{content}, + } + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// generateResourceContent 生成资源内容 +func (s *Server) generateResourceContent(resource *Resource) ResourceContent { + content := ResourceContent{ + URI: resource.URI, + MimeType: resource.MimeType, + } + + // 如果是工具资源,生成详细文档 + if strings.HasPrefix(resource.URI, "tool://") { + toolName := strings.TrimPrefix(resource.URI, "tool://") + content.Text = s.generateToolDocumentation(toolName, resource) + } else { + // 其他资源使用描述或默认内容 + content.Text = resource.Description + } + + return content +} + +// generateToolDocumentation 生成工具文档 +func (s *Server) generateToolDocumentation(toolName string, resource *Resource) string { + // 获取工具定义以获取更详细的信息 + s.mu.RLock() + tool, hasTool := s.toolDefs[toolName] + s.mu.RUnlock() + + // 为常见工具生成详细文档 + switch toolName { + case "nmap": + return `Nmap (Network Mapper) 是一个强大的网络扫描工具。 + +主要功能: +- 端口扫描:发现目标主机开放的端口 +- 服务识别:识别运行在端口上的服务 +- 版本检测:检测服务版本信息 +- 操作系统检测:识别目标操作系统 + +常用命令: +- nmap -sT target # TCP连接扫描 +- nmap -sV target # 版本检测 +- nmap -sC target # 默认脚本扫描 +- nmap -p 1-1000 target # 扫描指定端口范围 + +参数说明: +- target: 目标IP地址或域名(必需) +- ports: 端口范围,例如: 1-1000(可选)` + + case "sqlmap": + return `SQLMap 是一个自动化的SQL注入检测和利用工具。 + +主要功能: +- 自动检测SQL注入漏洞 +- 数据库指纹识别 +- 数据提取 +- 文件系统访问 + +常用命令: +- sqlmap -u "http://target.com/page?id=1" # 检测URL参数 +- sqlmap -u "http://target.com" --forms # 检测表单 +- sqlmap -u "http://target.com" --dbs # 列出数据库 + +参数说明: +- url: 目标URL(必需)` + + case "nikto": + return `Nikto 是一个Web服务器扫描工具。 + +主要功能: +- Web服务器漏洞扫描 +- 检测过时的服务器软件 +- 检测危险文件和程序 +- 检测服务器配置问题 + +常用命令: +- nikto -h target # 扫描目标主机 +- nikto -h target -p 80,443 # 扫描指定端口 + +参数说明: +- target: 目标URL(必需)` + + case "dirb": + return `Dirb 是一个Web内容扫描器。 + +主要功能: +- 扫描Web目录和文件 +- 发现隐藏的目录和文件 +- 支持自定义字典 + +常用命令: +- dirb url # 扫描目标URL +- dirb url -w wordlist.txt # 使用自定义字典 + +参数说明: +- target: 目标URL(必需)` + + case "exec": + return `Exec 工具用于执行系统命令。 + +⚠️ 警告:此工具可以执行任意系统命令,请谨慎使用! + +参数说明: +- command: 要执行的系统命令(必需) +- shell: 使用的shell,默认为sh(可选) +- workdir: 工作目录(可选)` + + default: + // 对于其他工具,使用工具定义中的描述信息 + if hasTool { + doc := fmt.Sprintf("%s\n\n", resource.Description) + if tool.InputSchema != nil { + if props, ok := tool.InputSchema["properties"].(map[string]interface{}); ok { + doc += "参数说明:\n" + for paramName, paramInfo := range props { + if paramMap, ok := paramInfo.(map[string]interface{}); ok { + if desc, ok := paramMap["description"].(string); ok { + doc += fmt.Sprintf("- %s: %s\n", paramName, desc) + } + } + } + } + } + return doc + } + return resource.Description + } +} + +// handleSamplingRequest 处理采样请求 +func (s *Server) handleSamplingRequest(msg *Message) *Message { + var req SamplingRequest + if err := json.Unmarshal(msg.Params, &req); err != nil { + return &Message{ + ID: msg.ID, + Type: MessageTypeError, + Error: &Error{Code: -32602, Message: "Invalid params"}, + } + } + + // 注意:采样功能通常需要连接到实际的LLM服务 + // 这里返回一个占位符响应,实际实现需要集成LLM API + s.logger.Warn("Sampling request received but not fully implemented", + zap.Any("request", req), + ) + + response := SamplingResponse{ + Content: []SamplingContent{ + { + Type: "text", + Text: "采样功能需要配置LLM服务。请使用Agent Loop API进行AI对话。", + }, + }, + StopReason: "length", + } + result, _ := json.Marshal(response) + return &Message{ + ID: msg.ID, + Type: MessageTypeResponse, + Version: "2.0", + Result: result, + } +} + +// RegisterPrompt 注册提示词模板 +func (s *Server) RegisterPrompt(prompt *Prompt) { + s.mu.Lock() + defer s.mu.Unlock() + s.prompts[prompt.Name] = prompt +} + +// RegisterResource 注册资源 +func (s *Server) RegisterResource(resource *Resource) { + s.mu.Lock() + defer s.mu.Unlock() + s.resources[resource.URI] = resource +} + +// sendError 发送错误响应 +func (s *Server) sendError(w http.ResponseWriter, id interface{}, code int, message, data string) { + response := Message{ + ID: fmt.Sprintf("%v", id), + Type: MessageTypeError, + Error: &Error{Code: code, Message: message, Data: data}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + diff --git a/internal/mcp/types.go b/internal/mcp/types.go new file mode 100644 index 00000000..a757196f --- /dev/null +++ b/internal/mcp/types.go @@ -0,0 +1,232 @@ +package mcp + +import ( + "encoding/json" + "time" +) + +// MCP消息类型 +const ( + MessageTypeRequest = "request" + MessageTypeResponse = "response" + MessageTypeError = "error" + MessageTypeNotify = "notify" +) + +// MCP协议版本 +const ProtocolVersion = "2024-11-05" + +// Message 表示MCP消息 +type Message struct { + ID string `json:"id,omitempty"` + Type string `json:"type"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *Error `json:"error,omitempty"` + Version string `json:"jsonrpc,omitempty"` +} + +// Error 表示MCP错误 +type Error struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// Tool 表示MCP工具定义 +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]interface{} `json:"inputSchema"` +} + +// ToolCall 表示工具调用 +type ToolCall struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// ToolResult 表示工具执行结果 +type ToolResult struct { + Content []Content `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +// Content 表示内容 +type Content struct { + Type string `json:"type"` + Text string `json:"text"` +} + +// InitializeRequest 初始化请求 +type InitializeRequest struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]interface{} `json:"capabilities"` + ClientInfo ClientInfo `json:"clientInfo"` +} + +// ClientInfo 客户端信息 +type ClientInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// InitializeResponse 初始化响应 +type InitializeResponse struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities ServerCapabilities `json:"capabilities"` + ServerInfo ServerInfo `json:"serverInfo"` +} + +// ServerCapabilities 服务器能力 +type ServerCapabilities struct { + Tools map[string]interface{} `json:"tools,omitempty"` + Prompts map[string]interface{} `json:"prompts,omitempty"` + Resources map[string]interface{} `json:"resources,omitempty"` + Sampling map[string]interface{} `json:"sampling,omitempty"` +} + +// ServerInfo 服务器信息 +type ServerInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// ListToolsRequest 列出工具请求 +type ListToolsRequest struct{} + +// ListToolsResponse 列出工具响应 +type ListToolsResponse struct { + Tools []Tool `json:"tools"` +} + +// ListPromptsResponse 列出提示词响应 +type ListPromptsResponse struct { + Prompts []Prompt `json:"prompts"` +} + +// ListResourcesResponse 列出资源响应 +type ListResourcesResponse struct { + Resources []Resource `json:"resources"` +} + +// CallToolRequest 调用工具请求 +type CallToolRequest struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +// CallToolResponse 调用工具响应 +type CallToolResponse struct { + Content []Content `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +// ToolExecution 工具执行记录 +type ToolExecution struct { + ID string `json:"id"` + ToolName string `json:"toolName"` + Arguments map[string]interface{} `json:"arguments"` + Status string `json:"status"` // pending, running, completed, failed + Result *ToolResult `json:"result,omitempty"` + Error string `json:"error,omitempty"` + StartTime time.Time `json:"startTime"` + EndTime *time.Time `json:"endTime,omitempty"` + Duration time.Duration `json:"duration,omitempty"` +} + +// ToolStats 工具统计信息 +type ToolStats struct { + ToolName string `json:"toolName"` + TotalCalls int `json:"totalCalls"` + SuccessCalls int `json:"successCalls"` + FailedCalls int `json:"failedCalls"` + LastCallTime *time.Time `json:"lastCallTime,omitempty"` +} + +// Prompt 提示词模板 +type Prompt struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Arguments []PromptArgument `json:"arguments,omitempty"` +} + +// PromptArgument 提示词参数 +type PromptArgument struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Required bool `json:"required,omitempty"` +} + +// GetPromptRequest 获取提示词请求 +type GetPromptRequest struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +// GetPromptResponse 获取提示词响应 +type GetPromptResponse struct { + Messages []PromptMessage `json:"messages"` +} + +// PromptMessage 提示词消息 +type PromptMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// Resource 资源 +type Resource struct { + URI string `json:"uri"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + MimeType string `json:"mimeType,omitempty"` +} + +// ReadResourceRequest 读取资源请求 +type ReadResourceRequest struct { + URI string `json:"uri"` +} + +// ReadResourceResponse 读取资源响应 +type ReadResourceResponse struct { + Contents []ResourceContent `json:"contents"` +} + +// ResourceContent 资源内容 +type ResourceContent struct { + URI string `json:"uri"` + MimeType string `json:"mimeType,omitempty"` + Text string `json:"text,omitempty"` + Blob string `json:"blob,omitempty"` +} + +// SamplingRequest 采样请求 +type SamplingRequest struct { + Messages []SamplingMessage `json:"messages"` + Model string `json:"model,omitempty"` + MaxTokens int `json:"maxTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` +} + +// SamplingMessage 采样消息 +type SamplingMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// SamplingResponse 采样响应 +type SamplingResponse struct { + Content []SamplingContent `json:"content"` + Model string `json:"model,omitempty"` + StopReason string `json:"stopReason,omitempty"` +} + +// SamplingContent 采样内容 +type SamplingContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + diff --git a/internal/security/executor.go b/internal/security/executor.go new file mode 100644 index 00000000..9131ea44 --- /dev/null +++ b/internal/security/executor.go @@ -0,0 +1,730 @@ +package security + +import ( + "context" + "fmt" + "os/exec" + "strings" + "time" + + "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/mcp" + "go.uber.org/zap" +) + +// Executor 安全工具执行器 +type Executor struct { + config *config.SecurityConfig + mcpServer *mcp.Server + logger *zap.Logger +} + +// NewExecutor 创建新的执行器 +func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor { + return &Executor{ + config: cfg, + mcpServer: mcpServer, + logger: logger, + } +} + +// ExecuteTool 执行安全工具 +func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[string]interface{}) (*mcp.ToolResult, error) { + e.logger.Info("ExecuteTool被调用", + zap.String("toolName", toolName), + zap.Any("args", args), + ) + + // 特殊处理:exec工具直接执行系统命令 + if toolName == "exec" { + e.logger.Info("执行exec工具") + return e.executeSystemCommand(ctx, args) + } + + // 查找工具配置 + var toolConfig *config.ToolConfig + for i := range e.config.Tools { + if e.config.Tools[i].Name == toolName && e.config.Tools[i].Enabled { + toolConfig = &e.config.Tools[i] + break + } + } + + if toolConfig == nil { + e.logger.Error("工具未找到或未启用", + zap.String("toolName", toolName), + zap.Int("totalTools", len(e.config.Tools)), + ) + return nil, fmt.Errorf("工具 %s 未找到或未启用", toolName) + } + + e.logger.Info("找到工具配置", + zap.String("toolName", toolName), + zap.String("command", toolConfig.Command), + zap.Strings("args", toolConfig.Args), + ) + + // 构建命令 - 根据工具类型使用不同的参数格式 + cmdArgs := e.buildCommandArgs(toolName, toolConfig, args) + + e.logger.Info("构建命令参数完成", + zap.String("toolName", toolName), + zap.Strings("cmdArgs", cmdArgs), + zap.Int("argsCount", len(cmdArgs)), + ) + + // 验证命令参数 + if len(cmdArgs) == 0 { + e.logger.Warn("命令参数为空", + zap.String("toolName", toolName), + zap.Any("inputArgs", args), + ) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("错误: 工具 %s 缺少必需的参数。接收到的参数: %v", toolName, args), + }, + }, + IsError: true, + }, nil + } + + // 执行命令 + cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...) + + e.logger.Info("执行安全工具", + zap.String("tool", toolName), + zap.Strings("args", cmdArgs), + ) + + output, err := cmd.CombinedOutput() + if err != nil { + e.logger.Error("工具执行失败", + zap.String("tool", toolName), + zap.Error(err), + zap.String("output", string(output)), + ) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("工具执行失败: %v\n输出: %s", err, string(output)), + }, + }, + IsError: true, + }, nil + } + + e.logger.Info("工具执行成功", + zap.String("tool", toolName), + zap.String("output", string(output)), + ) + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: string(output), + }, + }, + IsError: false, + }, nil +} + +// RegisterTools 注册工具到MCP服务器 +func (e *Executor) RegisterTools(mcpServer *mcp.Server) { + e.logger.Info("开始注册工具", + zap.Int("totalTools", len(e.config.Tools)), + ) + + for i, toolConfig := range e.config.Tools { + if !toolConfig.Enabled { + e.logger.Debug("跳过未启用的工具", + zap.String("tool", toolConfig.Name), + ) + continue + } + + // 创建工具配置的副本,避免闭包问题 + toolName := toolConfig.Name + toolConfigCopy := toolConfig + + tool := mcp.Tool{ + Name: toolConfigCopy.Name, + Description: toolConfigCopy.Description, + InputSchema: e.buildInputSchema(&toolConfigCopy), + } + + handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + e.logger.Info("工具handler被调用", + zap.String("toolName", toolName), + zap.Any("args", args), + ) + return e.ExecuteTool(ctx, toolName, args) + } + + mcpServer.RegisterTool(tool, handler) + e.logger.Info("注册安全工具成功", + zap.String("tool", toolConfigCopy.Name), + zap.String("command", toolConfigCopy.Command), + zap.Int("index", i), + ) + } + + e.logger.Info("工具注册完成", + zap.Int("registeredCount", len(e.config.Tools)), + ) +} + +// buildCommandArgs 构建命令参数 +func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConfig, args map[string]interface{}) []string { + cmdArgs := make([]string, 0) + + // 如果配置中定义了参数映射,使用配置中的映射规则 + if len(toolConfig.Parameters) > 0 { + // 先添加固定参数 + cmdArgs = append(cmdArgs, toolConfig.Args...) + + // 按位置参数排序 + positionalParams := make([]config.ParameterConfig, 0) + flagParams := make([]config.ParameterConfig, 0) + + for _, param := range toolConfig.Parameters { + if param.Position != nil { + positionalParams = append(positionalParams, param) + } else { + flagParams = append(flagParams, param) + } + } + + // 对位置参数按位置排序 + for i := 0; i < len(positionalParams); i++ { + for _, param := range positionalParams { + if param.Position != nil && *param.Position == i { + value := e.getParamValue(args, param) + if value == nil { + if param.Required { + // 必需参数缺失,返回空数组让上层处理错误 + e.logger.Warn("缺少必需的位置参数", + zap.String("tool", toolName), + zap.String("param", param.Name), + zap.Int("position", *param.Position), + ) + return []string{} + } + break + } + cmdArgs = append(cmdArgs, e.formatParamValue(param, value)) + break + } + } + } + + // 处理标志参数 + for _, param := range flagParams { + value := e.getParamValue(args, param) + if value == nil { + if param.Required { + // 必需参数缺失,返回空数组让上层处理错误 + e.logger.Warn("缺少必需的标志参数", + zap.String("tool", toolName), + zap.String("param", param.Name), + ) + return []string{} + } + continue + } + + // 布尔值特殊处理:如果为 false,跳过;如果为 true,只添加标志 + if param.Type == "bool" { + if boolVal, ok := value.(bool); ok { + if !boolVal { + continue // false 时不添加任何参数 + } + // true 时只添加标志,不添加值 + if param.Flag != "" { + cmdArgs = append(cmdArgs, param.Flag) + } + continue + } + } + + format := param.Format + if format == "" { + format = "flag" // 默认格式 + } + + switch format { + case "flag": + // --flag value 或 -f value + if param.Flag != "" { + cmdArgs = append(cmdArgs, param.Flag) + } + formattedValue := e.formatParamValue(param, value) + if formattedValue != "" { + cmdArgs = append(cmdArgs, formattedValue) + } + case "combined": + // --flag=value 或 -f=value + if param.Flag != "" { + cmdArgs = append(cmdArgs, fmt.Sprintf("%s=%s", param.Flag, e.formatParamValue(param, value))) + } else { + cmdArgs = append(cmdArgs, e.formatParamValue(param, value)) + } + case "template": + // 使用模板字符串 + if param.Template != "" { + template := param.Template + template = strings.ReplaceAll(template, "{flag}", param.Flag) + template = strings.ReplaceAll(template, "{value}", e.formatParamValue(param, value)) + template = strings.ReplaceAll(template, "{name}", param.Name) + cmdArgs = append(cmdArgs, strings.Fields(template)...) + } else { + // 如果没有模板,使用默认格式 + if param.Flag != "" { + cmdArgs = append(cmdArgs, param.Flag) + } + cmdArgs = append(cmdArgs, e.formatParamValue(param, value)) + } + case "positional": + // 位置参数(已在上面处理) + cmdArgs = append(cmdArgs, e.formatParamValue(param, value)) + default: + // 默认:直接添加值 + cmdArgs = append(cmdArgs, e.formatParamValue(param, value)) + } + } + + return cmdArgs + } + + // 向后兼容:如果没有定义参数,使用旧的硬编码逻辑 + switch toolName { + case "nmap": + // nmap -sT -sV -sC target [ports] + // 使用 -sT (TCP连接扫描) 而不是 -sS (SYN扫描),因为 -sS 需要root权限 + e.logger.Debug("处理nmap参数", + zap.Any("args", args), + ) + + // 尝试多种方式获取target参数 + var target string + var ok bool + + // 方式1: 直接获取target + if target, ok = args["target"].(string); !ok || target == "" { + // 方式2: 尝试从tool字段获取(兼容某些格式) + if toolVal, exists := args["tool"]; exists { + if toolMap, ok := toolVal.(map[string]interface{}); ok { + if t, ok := toolMap["target"].(string); ok { + target = t + } + } + } + } + + if target == "" { + e.logger.Warn("nmap缺少target参数", + zap.Any("args", args), + ) + return cmdArgs // 返回空数组,让上层处理错误 + } + + e.logger.Debug("提取到target", + zap.String("target", target), + ) + + // 处理URL格式的目标(提取域名) + if strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://") { + // 提取域名部分 + target = strings.TrimPrefix(target, "http://") + target = strings.TrimPrefix(target, "https://") + // 移除路径部分 + if idx := strings.Index(target, "/"); idx != -1 { + target = target[:idx] + } + } + + // 添加扫描选项:-sT (TCP连接扫描,不需要root权限), -sV (版本检测), -sC (默认脚本) + cmdArgs = append(cmdArgs, "-sT", "-sV", "-sC") + + // 添加端口范围(如果指定) + if ports, ok := args["ports"].(string); ok && ports != "" { + cmdArgs = append(cmdArgs, "-p", ports) + } + + // 添加目标 + cmdArgs = append(cmdArgs, target) + + e.logger.Debug("nmap命令参数构建完成", + zap.Strings("cmdArgs", cmdArgs), + ) + case "sqlmap": + // sqlmap -u url + if url, ok := args["url"].(string); ok { + cmdArgs = append(cmdArgs, "-u", url, "--batch", "--level=3", "--risk=2") + } + case "nikto": + // nikto -h target + if target, ok := args["target"].(string); ok { + cmdArgs = append(cmdArgs, "-h", target) + } + case "dirb": + // dirb url + if url, ok := args["url"].(string); ok { + cmdArgs = append(cmdArgs, url) + } + default: + // 通用处理 + cmdArgs = append(cmdArgs, toolConfig.Args...) + for key, value := range args { + if key == "_tool_name" { + continue + } + cmdArgs = append(cmdArgs, fmt.Sprintf("--%s", key)) + if strValue, ok := value.(string); ok { + cmdArgs = append(cmdArgs, strValue) + } else { + cmdArgs = append(cmdArgs, fmt.Sprintf("%v", value)) + } + } + } + + return cmdArgs +} + +// getParamValue 获取参数值,支持默认值 +func (e *Executor) getParamValue(args map[string]interface{}, param config.ParameterConfig) interface{} { + // 从参数中获取值 + if value, ok := args[param.Name]; ok && value != nil { + return value + } + + // 如果参数是必需的但没有提供,返回 nil(让上层处理错误) + if param.Required { + return nil + } + + // 返回默认值 + return param.Default +} + +// formatParamValue 格式化参数值 +func (e *Executor) formatParamValue(param config.ParameterConfig, value interface{}) string { + switch param.Type { + case "bool": + // 布尔值应该在上层处理,这里不应该被调用 + if boolVal, ok := value.(bool); ok { + return fmt.Sprintf("%v", boolVal) + } + return "false" + case "array": + // 数组:转换为逗号分隔的字符串 + if arr, ok := value.([]interface{}); ok { + strs := make([]string, 0, len(arr)) + for _, item := range arr { + strs = append(strs, fmt.Sprintf("%v", item)) + } + return strings.Join(strs, ",") + } + return fmt.Sprintf("%v", value) + default: + return fmt.Sprintf("%v", value) + } +} + +// executeSystemCommand 执行系统命令 +func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { + // 获取命令 + command, ok := args["command"].(string) + if !ok { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "错误: 缺少command参数", + }, + }, + IsError: true, + }, nil + } + + if command == "" { + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: "错误: command参数不能为空", + }, + }, + IsError: true, + }, nil + } + + // 安全检查:记录执行的命令 + e.logger.Warn("执行系统命令", + zap.String("command", command), + ) + + // 获取shell类型(可选,默认为sh) + shell := "sh" + if s, ok := args["shell"].(string); ok && s != "" { + shell = s + } + + // 获取工作目录(可选) + workDir := "" + if wd, ok := args["workdir"].(string); ok && wd != "" { + workDir = wd + } + + // 构建命令 + var cmd *exec.Cmd + if workDir != "" { + cmd = exec.CommandContext(ctx, shell, "-c", command) + cmd.Dir = workDir + } else { + cmd = exec.CommandContext(ctx, shell, "-c", command) + } + + // 执行命令 + e.logger.Info("执行系统命令", + zap.String("command", command), + zap.String("shell", shell), + zap.String("workdir", workDir), + ) + + output, err := cmd.CombinedOutput() + if err != nil { + e.logger.Error("系统命令执行失败", + zap.String("command", command), + zap.Error(err), + zap.String("output", string(output)), + ) + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)), + }, + }, + IsError: true, + }, nil + } + + e.logger.Info("系统命令执行成功", + zap.String("command", command), + zap.String("output_length", fmt.Sprintf("%d", len(output))), + ) + + return &mcp.ToolResult{ + Content: []mcp.Content{ + { + Type: "text", + Text: string(output), + }, + }, + IsError: false, + }, nil +} + +// buildInputSchema 构建输入模式 +func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} { + schema := map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + "required": []string{}, + } + + // 如果配置中定义了参数,优先使用配置中的参数定义 + if len(toolConfig.Parameters) > 0 { + properties := make(map[string]interface{}) + required := []string{} + + for _, param := range toolConfig.Parameters { + prop := map[string]interface{}{ + "type": param.Type, + "description": param.Description, + } + + // 添加默认值 + if param.Default != nil { + prop["default"] = param.Default + } + + // 添加枚举选项 + if len(param.Options) > 0 { + prop["enum"] = param.Options + } + + properties[param.Name] = prop + + // 添加到必需参数列表 + if param.Required { + required = append(required, param.Name) + } + } + + schema["properties"] = properties + schema["required"] = required + return schema + } + + // 向后兼容:如果没有定义参数,使用旧的硬编码逻辑 + switch toolConfig.Name { + case "nmap": + schema["properties"] = map[string]interface{}{ + "target": map[string]interface{}{ + "type": "string", + "description": "目标IP地址或域名", + }, + "ports": map[string]interface{}{ + "type": "string", + "description": "端口范围,例如: 1-1000", + }, + } + schema["required"] = []string{"target"} + case "sqlmap": + schema["properties"] = map[string]interface{}{ + "url": map[string]interface{}{ + "type": "string", + "description": "目标URL", + }, + } + schema["required"] = []string{"url"} + case "nikto", "dirb": + schema["properties"] = map[string]interface{}{ + "target": map[string]interface{}{ + "type": "string", + "description": "目标URL", + }, + } + schema["required"] = []string{"target"} + case "exec": + schema["properties"] = map[string]interface{}{ + "command": map[string]interface{}{ + "type": "string", + "description": "要执行的系统命令", + }, + "shell": map[string]interface{}{ + "type": "string", + "description": "使用的shell(可选,默认为sh)", + }, + "workdir": map[string]interface{}{ + "type": "string", + "description": "工作目录(可选)", + }, + } + schema["required"] = []string{"command"} + } + + return schema +} + +// Vulnerability 漏洞信息 +type Vulnerability struct { + ID string `json:"id"` + Type string `json:"type"` + Severity string `json:"severity"` // low, medium, high, critical + Title string `json:"title"` + Description string `json:"description"` + Target string `json:"target"` + FoundAt time.Time `json:"foundAt"` + Details string `json:"details"` +} + +// AnalyzeResults 分析工具执行结果,提取漏洞信息 +func (e *Executor) AnalyzeResults(toolName string, result *mcp.ToolResult) []Vulnerability { + vulnerabilities := []Vulnerability{} + + if result.IsError { + return vulnerabilities + } + + // 分析输出内容 + for _, content := range result.Content { + if content.Type == "text" { + vulns := e.parseToolOutput(toolName, content.Text) + vulnerabilities = append(vulnerabilities, vulns...) + } + } + + return vulnerabilities +} + +// parseToolOutput 解析工具输出 +func (e *Executor) parseToolOutput(toolName, output string) []Vulnerability { + vulnerabilities := []Vulnerability{} + + // 简单的漏洞检测逻辑 + outputLower := strings.ToLower(output) + + // SQL注入检测 + if strings.Contains(outputLower, "sql injection") || strings.Contains(outputLower, "sqli") { + vulnerabilities = append(vulnerabilities, Vulnerability{ + ID: fmt.Sprintf("sql-%d", time.Now().Unix()), + Type: "SQL Injection", + Severity: "high", + Title: "SQL注入漏洞", + Description: "检测到潜在的SQL注入漏洞", + FoundAt: time.Now(), + Details: output, + }) + } + + // XSS检测 + if strings.Contains(outputLower, "xss") || strings.Contains(outputLower, "cross-site scripting") { + vulnerabilities = append(vulnerabilities, Vulnerability{ + ID: fmt.Sprintf("xss-%d", time.Now().Unix()), + Type: "XSS", + Severity: "medium", + Title: "跨站脚本攻击漏洞", + Description: "检测到潜在的XSS漏洞", + FoundAt: time.Now(), + Details: output, + }) + } + + // 开放端口检测 + if toolName == "nmap" { + lines := strings.Split(output, "\n") + for _, line := range lines { + if strings.Contains(line, "open") && strings.Contains(line, "port") { + vulnerabilities = append(vulnerabilities, Vulnerability{ + ID: fmt.Sprintf("port-%d", time.Now().Unix()), + Type: "Open Port", + Severity: "low", + Title: "开放端口", + Description: fmt.Sprintf("发现开放端口: %s", line), + FoundAt: time.Now(), + Details: line, + }) + } + } + } + + return vulnerabilities +} + +// GetVulnerabilityReport 生成漏洞报告 +func (e *Executor) GetVulnerabilityReport(vulnerabilities []Vulnerability) map[string]interface{} { + severityCount := map[string]int{ + "critical": 0, + "high": 0, + "medium": 0, + "low": 0, + } + + for _, vuln := range vulnerabilities { + severityCount[vuln.Severity]++ + } + + return map[string]interface{}{ + "total": len(vulnerabilities), + "severityCount": severityCount, + "vulnerabilities": vulnerabilities, + "generatedAt": time.Now(), + } +} + diff --git a/run.sh b/run.sh new file mode 100644 index 00000000..a752063c --- /dev/null +++ b/run.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +# CyberStrikeAI 启动脚本 + +echo "🚀 启动 CyberStrikeAI..." + +# 检查配置文件 +if [ ! -f "config.yaml" ]; then + echo "❌ 配置文件 config.yaml 不存在" + exit 1 +fi + +# 检查Go环境 +if ! command -v go &> /dev/null; then + echo "❌ Go 未安装,请先安装 Go 1.21 或更高版本" + exit 1 +fi + +# 下载依赖 +echo "📦 下载依赖..." +go mod download + +# 构建项目 +echo "🔨 构建项目..." +go build -o cyberstrike-ai cmd/server/main.go + +if [ $? -ne 0 ]; then + echo "❌ 构建失败" + exit 1 +fi + +# 运行服务器 +echo "✅ 启动服务器..." +./cyberstrike-ai + diff --git a/web/static/css/style.css b/web/static/css/style.css new file mode 100644 index 00000000..ad78f71d --- /dev/null +++ b/web/static/css/style.css @@ -0,0 +1,691 @@ +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +:root { + --primary-color: #1a1a1a; + --secondary-color: #2d2d2d; + --accent-color: #0066ff; + --accent-hover: #0052cc; + --bg-primary: #ffffff; + --bg-secondary: #f8f9fa; + --bg-tertiary: #f1f3f5; + --text-primary: #1a1a1a; + --text-secondary: #6c757d; + --text-muted: #adb5bd; + --border-color: #e9ecef; + --success-color: #28a745; + --warning-color: #ffc107; + --error-color: #dc3545; + --shadow-sm: 0 1px 3px rgba(0, 0, 0, 0.05); + --shadow-md: 0 4px 6px rgba(0, 0, 0, 0.1); + --shadow-lg: 0 10px 25px rgba(0, 0, 0, 0.15); +} + +body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Helvetica Neue', Arial, sans-serif; + background: var(--bg-secondary); + margin: 0; + padding: 0; + color: var(--text-primary); + line-height: 1.6; + height: 100vh; + overflow: hidden; +} + +.container { + max-width: 100%; + margin: 0; + background: var(--bg-primary); + height: 100vh; + display: flex; + flex-direction: column; + box-shadow: var(--shadow-lg); + overflow: hidden; +} + +.main-layout { + display: flex; + flex: 1; + overflow: hidden; + min-height: 0; +} + +header { + background: var(--primary-color); + color: white; + padding: 24px 32px; + border-bottom: 1px solid rgba(255, 255, 255, 0.1); + flex-shrink: 0; +} + +.header-content { + display: flex; + justify-content: space-between; + align-items: center; +} + +.logo { + display: flex; + align-items: center; + gap: 12px; +} + +.logo svg { + color: var(--accent-color); +} + +.logo h1 { + font-size: 1.75rem; + font-weight: 600; + letter-spacing: -0.5px; + margin: 0; +} + +.header-subtitle { + font-size: 0.875rem; + color: rgba(255, 255, 255, 0.7); + margin: 0; + font-weight: 400; +} + +/* 侧边栏样式 */ +.sidebar { + width: 280px; + background: var(--bg-secondary); + border-right: 1px solid var(--border-color); + display: flex; + flex-direction: column; + flex-shrink: 0; + height: 100%; + overflow: hidden; +} + +.sidebar-header { + padding: 16px; + border-bottom: 1px solid var(--border-color); + flex-shrink: 0; +} + +.new-chat-btn { + width: 100%; + padding: 10px 16px; + background: var(--accent-color); + color: white; + border: none; + border-radius: 8px; + font-size: 0.9375rem; + font-weight: 500; + cursor: pointer; + transition: all 0.2s; + display: flex; + align-items: center; + justify-content: center; + gap: 8px; +} + +.new-chat-btn:hover { + background: var(--accent-hover); + transform: translateY(-1px); + box-shadow: var(--shadow-sm); +} + +.new-chat-btn span { + font-size: 1.2em; + line-height: 1; +} + +.sidebar-content { + flex: 1; + overflow-y: auto; + overflow-x: hidden; + padding: 16px; + min-height: 0; +} + +.sidebar-title { + font-size: 0.8125rem; + font-weight: 600; + color: var(--text-secondary); + text-transform: uppercase; + letter-spacing: 0.5px; + margin-bottom: 12px; + padding: 0 8px; +} + +.conversations-list { + display: flex; + flex-direction: column; + gap: 4px; +} + +.conversation-item { + padding: 12px; + border-radius: 8px; + cursor: pointer; + transition: all 0.2s; + border: 1px solid transparent; +} + +.conversation-item:hover { + background: var(--bg-tertiary); +} + +.conversation-item.active { + background: var(--bg-primary); + border-color: var(--accent-color); +} + +.conversation-title { + font-size: 0.875rem; + font-weight: 500; + color: var(--text-primary); + margin-bottom: 4px; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.conversation-time { + font-size: 0.75rem; + color: var(--text-muted); +} + +/* 对话界面样式 */ +.chat-container { + display: flex; + flex-direction: column; + flex: 1; + min-width: 0; + background: var(--bg-primary); + overflow: hidden; + height: 100%; +} + +.chat-messages { + flex: 1; + overflow-y: auto; + overflow-x: hidden; + padding: 24px; + background: var(--bg-secondary); + display: flex; + flex-direction: column; + min-height: 0; +} + +.message { + margin-bottom: 24px; + display: flex; + align-items: flex-start; + gap: 12px; + animation: fadeIn 0.3s ease-in; + width: 100%; +} + +@keyframes fadeIn { + from { + opacity: 0; + transform: translateY(10px); + } + to { + opacity: 1; + transform: translateY(0); + } +} + +.message.user { + flex-direction: row-reverse; + justify-content: flex-end; +} + +.message.system { + justify-content: center; + margin-bottom: 16px; +} + +.message-avatar { + width: 32px; + height: 32px; + border-radius: 6px; + display: flex; + align-items: center; + justify-content: center; + font-size: 0.75rem; + font-weight: 600; + flex-shrink: 0; + margin-top: 2px; +} + +.message.user .message-avatar { + background: var(--accent-color); + color: white; +} + +.message.assistant .message-avatar { + background: var(--bg-tertiary); + color: var(--text-secondary); + border: 1px solid var(--border-color); +} + +.message.system .message-avatar { + display: none; +} + +.message-content { + flex: 0 1 auto; + max-width: 70%; + min-width: 120px; + display: flex; + flex-direction: column; +} + +.message.user .message-content { + align-items: flex-end; + margin-left: auto; +} + +.message.assistant .message-content { + align-items: flex-start; + margin-right: auto; +} + +.message.system .message-content { + max-width: 90%; + align-items: center; + margin: 0 auto; +} + +.message-bubble { + padding: 12px 16px; + border-radius: 8px; + word-wrap: break-word; + word-break: break-word; + line-height: 1.6; + box-shadow: var(--shadow-sm); + white-space: pre-wrap; +} + +.message.user .message-bubble { + background: var(--accent-color); + color: white; + border-bottom-right-radius: 2px; +} + +.message.assistant .message-bubble { + background: var(--bg-primary); + color: var(--text-primary); + border: 1px solid var(--border-color); + border-bottom-left-radius: 2px; +} + +.message.assistant .message-bubble pre { + margin: 0; + white-space: pre-wrap; + word-wrap: break-word; + font-family: inherit; +} + +.message.system .message-bubble { + background: var(--bg-tertiary); + color: var(--text-secondary); + border: 1px solid var(--border-color); + text-align: center; + font-size: 0.875rem; + padding: 10px 16px; + width: 100%; +} + +.message-time { + font-size: 0.6875rem; + color: var(--text-muted); + margin-top: 4px; + padding: 0 2px; + font-weight: 400; +} + +/* MCP调用区域 */ +.mcp-call-section { + margin-top: 12px; + padding-top: 12px; + border-top: 1px solid var(--border-color); + width: 100%; +} + +.mcp-call-label { + font-size: 0.75rem; + color: var(--text-secondary); + margin-bottom: 8px; + display: flex; + align-items: center; + gap: 6px; + font-weight: 500; +} + +.mcp-call-label::before { + content: ''; + width: 4px; + height: 4px; + background: var(--accent-color); + border-radius: 50%; + display: inline-block; + flex-shrink: 0; +} + +.mcp-call-buttons { + display: flex; + flex-wrap: wrap; + gap: 6px; +} + +.chat-input-container { + display: flex; + gap: 12px; + padding: 20px 24px; + background: var(--bg-primary); + border-top: 1px solid var(--border-color); + flex-shrink: 0; +} + +.chat-input-container input { + flex: 1; + padding: 12px 16px; + border: 1px solid var(--border-color); + border-radius: 8px; + font-size: 0.9375rem; + outline: none; + transition: all 0.2s; + background: var(--bg-primary); + color: var(--text-primary); +} + +.chat-input-container input:focus { + border-color: var(--accent-color); + box-shadow: 0 0 0 3px rgba(0, 102, 255, 0.1); +} + +.chat-input-container input::placeholder { + color: var(--text-muted); +} + +.chat-input-container button { + padding: 12px 24px; + background: var(--accent-color); + color: white; + border: none; + border-radius: 8px; + cursor: pointer; + font-size: 0.9375rem; + font-weight: 500; + transition: all 0.2s; + white-space: nowrap; +} + +.chat-input-container button:hover { + background: var(--accent-hover); + transform: translateY(-1px); + box-shadow: var(--shadow-md); +} + +.chat-input-container button:active { + transform: translateY(0); +} + +/* MCP调用详情按钮 */ +.mcp-detail-btn { + display: inline-flex; + align-items: center; + gap: 6px; + padding: 6px 12px; + background: var(--bg-primary); + color: var(--accent-color); + border: 1px solid var(--border-color); + border-radius: 6px; + font-size: 0.8125rem; + font-weight: 500; + cursor: pointer; + transition: all 0.2s; +} + +.mcp-detail-btn:hover { + background: var(--accent-color); + color: white; + border-color: var(--accent-color); + transform: translateY(-1px); + box-shadow: var(--shadow-sm); +} + +.mcp-detail-btn:active { + transform: translateY(0); +} + +/* 模态框样式 */ +.modal { + display: none; + position: fixed; + z-index: 1000; + left: 0; + top: 0; + width: 100%; + height: 100%; + background-color: rgba(0, 0, 0, 0.6); + backdrop-filter: blur(4px); + overflow: auto; + animation: fadeIn 0.2s ease-in; +} + +.modal-content { + background-color: var(--bg-primary); + margin: 5% auto; + padding: 0; + border-radius: 12px; + width: 90%; + max-width: 900px; + max-height: 85vh; + display: flex; + flex-direction: column; + box-shadow: var(--shadow-lg); + border: 1px solid var(--border-color); + animation: slideDown 0.3s ease-out; +} + +@keyframes slideDown { + from { + opacity: 0; + transform: translateY(-20px); + } + to { + opacity: 1; + transform: translateY(0); + } +} + +.modal-header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 20px 24px; + border-bottom: 1px solid var(--border-color); + background: var(--bg-primary); +} + +.modal-header h2 { + margin: 0; + font-size: 1.25rem; + font-weight: 600; + color: var(--text-primary); +} + +.modal-close { + width: 32px; + height: 32px; + display: flex; + align-items: center; + justify-content: center; + border-radius: 6px; + cursor: pointer; + color: var(--text-secondary); + font-size: 1.5rem; + line-height: 1; + transition: all 0.2s; + border: none; + background: transparent; +} + +.modal-close:hover { + background: var(--bg-tertiary); + color: var(--text-primary); +} + +.modal-body { + padding: 24px; + overflow-y: auto; + flex: 1; +} + +.detail-section { + margin-bottom: 24px; +} + +.detail-section:last-child { + margin-bottom: 0; +} + +.detail-section h3 { + color: var(--text-primary); + margin-bottom: 12px; + padding-bottom: 8px; + border-bottom: 1px solid var(--border-color); + font-size: 0.9375rem; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.5px; +} + +.detail-item { + margin-bottom: 10px; + padding: 8px 0; + display: flex; + align-items: baseline; + gap: 8px; +} + +.detail-item strong { + color: var(--text-secondary); + font-weight: 500; + font-size: 0.875rem; + min-width: 80px; +} + +.detail-item span { + color: var(--text-primary); + font-size: 0.875rem; +} + +.code-block { + background: var(--bg-tertiary); + border: 1px solid var(--border-color); + border-radius: 8px; + padding: 16px; + font-family: 'SF Mono', 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', monospace; + font-size: 0.8125rem; + line-height: 1.6; + overflow-x: auto; + white-space: pre-wrap; + word-wrap: break-word; + max-height: 400px; + overflow-y: auto; + color: var(--text-primary); +} + +.code-block.error { + background: #fff5f5; + border-color: var(--error-color); + color: var(--error-color); +} + +/* 滚动条样式 */ +::-webkit-scrollbar { + width: 8px; + height: 8px; +} + +::-webkit-scrollbar-track { + background: var(--bg-secondary); + border-radius: 4px; +} + +::-webkit-scrollbar-thumb { + background: var(--text-muted); + border-radius: 4px; +} + +::-webkit-scrollbar-thumb:hover { + background: var(--text-secondary); +} + +/* 响应式设计 */ +@media (max-width: 768px) { + body { + height: 100vh; + overflow: hidden; + } + + .container { + border-radius: 0; + height: 100vh; + overflow: hidden; + } + + header { + padding: 16px 20px; + flex-shrink: 0; + } + + .logo h1 { + font-size: 1.5rem; + } + + .header-subtitle { + display: none; + } + + .main-layout { + overflow: hidden; + min-height: 0; + } + + .sidebar { + height: 100%; + overflow: hidden; + } + + .sidebar-content { + min-height: 0; + } + + .chat-container { + height: 100%; + overflow: hidden; + } + + .chat-messages { + padding: 16px; + min-height: 0; + } + + .message-content { + max-width: 85%; + } + + .chat-input-container { + padding: 16px; + flex-shrink: 0; + } + + .modal-content { + width: 95%; + margin: 10% auto; + } +} diff --git a/web/static/js/app.js b/web/static/js/app.js new file mode 100644 index 00000000..fbb16ba2 --- /dev/null +++ b/web/static/js/app.js @@ -0,0 +1,400 @@ + +// 当前对话ID +let currentConversationId = null; + +// 发送消息 +async function sendMessage() { + const input = document.getElementById('chat-input'); + const message = input.value.trim(); + + if (!message) { + return; + } + + // 显示用户消息 + addMessage('user', message); + input.value = ''; + + // 显示加载状态 + const loadingId = addMessage('system', '正在处理中...'); + + try { + const response = await fetch('/api/agent-loop', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + message: message, + conversationId: currentConversationId + }), + }); + + const data = await response.json(); + + // 移除加载消息 + removeMessage(loadingId); + + if (response.ok) { + // 更新当前对话ID + if (data.conversationId) { + currentConversationId = data.conversationId; + updateActiveConversation(); + } + + // 如果有MCP执行ID,显示所有调用 + const mcpIds = data.mcpExecutionIds || []; + addMessage('assistant', data.response, mcpIds); + + // 刷新对话列表 + loadConversations(); + } else { + addMessage('system', '错误: ' + (data.error || '未知错误')); + } + } catch (error) { + removeMessage(loadingId); + addMessage('system', '错误: ' + error.message); + } +} + +// 消息计数器,确保ID唯一 +let messageCounter = 0; + +// 添加消息 +function addMessage(role, content, mcpExecutionIds = null) { + const messagesDiv = document.getElementById('chat-messages'); + const messageDiv = document.createElement('div'); + messageCounter++; + const id = 'msg-' + Date.now() + '-' + messageCounter + '-' + Math.random().toString(36).substr(2, 9); + messageDiv.id = id; + messageDiv.className = 'message ' + role; + + // 创建头像 + const avatar = document.createElement('div'); + avatar.className = 'message-avatar'; + if (role === 'user') { + avatar.textContent = 'U'; + } else if (role === 'assistant') { + avatar.textContent = 'A'; + } else { + avatar.textContent = 'S'; + } + messageDiv.appendChild(avatar); + + // 创建消息内容容器 + const contentWrapper = document.createElement('div'); + contentWrapper.className = 'message-content'; + + // 创建消息气泡 + const bubble = document.createElement('div'); + bubble.className = 'message-bubble'; + // 处理换行和格式化 + const formattedContent = content.replace(/\n/g, '
'); + bubble.innerHTML = formattedContent; + contentWrapper.appendChild(bubble); + + // 添加时间戳 + const timeDiv = document.createElement('div'); + timeDiv.className = 'message-time'; + timeDiv.textContent = new Date().toLocaleTimeString('zh-CN', { hour: '2-digit', minute: '2-digit' }); + contentWrapper.appendChild(timeDiv); + + // 如果有MCP执行ID,添加查看详情区域 + if (mcpExecutionIds && Array.isArray(mcpExecutionIds) && mcpExecutionIds.length > 0 && role === 'assistant') { + const mcpSection = document.createElement('div'); + mcpSection.className = 'mcp-call-section'; + + const mcpLabel = document.createElement('div'); + mcpLabel.className = 'mcp-call-label'; + mcpLabel.textContent = `工具调用 (${mcpExecutionIds.length})`; + mcpSection.appendChild(mcpLabel); + + const buttonsContainer = document.createElement('div'); + buttonsContainer.className = 'mcp-call-buttons'; + + mcpExecutionIds.forEach((execId, index) => { + const detailBtn = document.createElement('button'); + detailBtn.className = 'mcp-detail-btn'; + detailBtn.innerHTML = `调用 #${index + 1}`; + detailBtn.onclick = () => showMCPDetail(execId); + buttonsContainer.appendChild(detailBtn); + }); + + mcpSection.appendChild(buttonsContainer); + contentWrapper.appendChild(mcpSection); + } + + messageDiv.appendChild(contentWrapper); + messagesDiv.appendChild(messageDiv); + messagesDiv.scrollTop = messagesDiv.scrollHeight; + return id; +} + +// 移除消息 +function removeMessage(id) { + const messageDiv = document.getElementById(id); + if (messageDiv) { + messageDiv.remove(); + } +} + +// 回车发送消息 +document.getElementById('chat-input').addEventListener('keypress', function(e) { + if (e.key === 'Enter') { + sendMessage(); + } +}); + +// 显示MCP调用详情 +async function showMCPDetail(executionId) { + try { + const response = await fetch(`/api/monitor/execution/${executionId}`); + const exec = await response.json(); + + if (response.ok) { + // 填充模态框内容 + document.getElementById('detail-tool-name').textContent = exec.toolName || 'Unknown'; + document.getElementById('detail-execution-id').textContent = exec.id || 'N/A'; + document.getElementById('detail-status').textContent = getStatusText(exec.status); + document.getElementById('detail-time').textContent = new Date(exec.startTime).toLocaleString('zh-CN'); + + // 请求参数 + const requestData = { + tool: exec.toolName, + arguments: exec.arguments + }; + document.getElementById('detail-request').textContent = JSON.stringify(requestData, null, 2); + + // 响应结果 + if (exec.result) { + const responseData = { + content: exec.result.content, + isError: exec.result.isError + }; + document.getElementById('detail-response').textContent = JSON.stringify(responseData, null, 2); + document.getElementById('detail-response').className = exec.result.isError ? 'code-block error' : 'code-block'; + } else { + document.getElementById('detail-response').textContent = '暂无响应数据'; + } + + // 错误信息 + if (exec.error) { + document.getElementById('detail-error-section').style.display = 'block'; + document.getElementById('detail-error').textContent = exec.error; + } else { + document.getElementById('detail-error-section').style.display = 'none'; + } + + // 显示模态框 + document.getElementById('mcp-detail-modal').style.display = 'block'; + } else { + alert('获取详情失败: ' + (exec.error || '未知错误')); + } + } catch (error) { + alert('获取详情失败: ' + error.message); + } +} + +// 关闭MCP详情模态框 +function closeMCPDetail() { + document.getElementById('mcp-detail-modal').style.display = 'none'; +} + +// 点击模态框外部关闭 +window.onclick = function(event) { + const modal = document.getElementById('mcp-detail-modal'); + if (event.target == modal) { + closeMCPDetail(); + } +} + +// 工具函数 +function getStatusText(status) { + const statusMap = { + 'pending': '等待中', + 'running': '执行中', + 'completed': '已完成', + 'failed': '失败' + }; + return statusMap[status] || status; +} + +function formatDuration(ms) { + const seconds = Math.floor(ms / 1000); + const minutes = Math.floor(seconds / 60); + const hours = Math.floor(minutes / 60); + + if (hours > 0) { + return `${hours}小时${minutes % 60}分钟`; + } else if (minutes > 0) { + return `${minutes}分钟${seconds % 60}秒`; + } else { + return `${seconds}秒`; + } +} + +function escapeHtml(text) { + const div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; +} + +// 开始新对话 +function startNewConversation() { + currentConversationId = null; + document.getElementById('chat-messages').innerHTML = ''; + addMessage('assistant', '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。'); + updateActiveConversation(); +} + +// 加载对话列表 +async function loadConversations() { + try { + const response = await fetch('/api/conversations?limit=50'); + const conversations = await response.json(); + + const listContainer = document.getElementById('conversations-list'); + listContainer.innerHTML = ''; + + if (conversations.length === 0) { + listContainer.innerHTML = '
暂无历史对话
'; + return; + } + + conversations.forEach(conv => { + const item = document.createElement('div'); + item.className = 'conversation-item'; + item.dataset.conversationId = conv.id; + if (conv.id === currentConversationId) { + item.classList.add('active'); + } + + const title = document.createElement('div'); + title.className = 'conversation-title'; + title.textContent = conv.title || '未命名对话'; + item.appendChild(title); + + const time = document.createElement('div'); + time.className = 'conversation-time'; + // 解析时间,支持多种格式 + let dateObj; + if (conv.updatedAt) { + dateObj = new Date(conv.updatedAt); + // 检查日期是否有效 + if (isNaN(dateObj.getTime())) { + // 如果解析失败,尝试其他格式 + console.warn('时间解析失败:', conv.updatedAt); + dateObj = new Date(); + } + } else { + dateObj = new Date(); + } + + // 格式化时间显示 + const now = new Date(); + const today = new Date(now.getFullYear(), now.getMonth(), now.getDate()); + const yesterday = new Date(today); + yesterday.setDate(yesterday.getDate() - 1); + const messageDate = new Date(dateObj.getFullYear(), dateObj.getMonth(), dateObj.getDate()); + + let timeText; + if (messageDate.getTime() === today.getTime()) { + // 今天:只显示时间 + timeText = dateObj.toLocaleTimeString('zh-CN', { + hour: '2-digit', + minute: '2-digit' + }); + } else if (messageDate.getTime() === yesterday.getTime()) { + // 昨天 + timeText = '昨天 ' + dateObj.toLocaleTimeString('zh-CN', { + hour: '2-digit', + minute: '2-digit' + }); + } else if (now.getFullYear() === dateObj.getFullYear()) { + // 今年:显示月日和时间 + timeText = dateObj.toLocaleString('zh-CN', { + month: 'short', + day: 'numeric', + hour: '2-digit', + minute: '2-digit' + }); + } else { + // 去年或更早:显示完整日期和时间 + timeText = dateObj.toLocaleString('zh-CN', { + year: 'numeric', + month: 'short', + day: 'numeric', + hour: '2-digit', + minute: '2-digit' + }); + } + + time.textContent = timeText; + item.appendChild(time); + + item.onclick = () => loadConversation(conv.id); + listContainer.appendChild(item); + }); + } catch (error) { + console.error('加载对话列表失败:', error); + } +} + +// 加载对话 +async function loadConversation(conversationId) { + try { + const response = await fetch(`/api/conversations/${conversationId}`); + const conversation = await response.json(); + + if (!response.ok) { + alert('加载对话失败: ' + (conversation.error || '未知错误')); + return; + } + + // 更新当前对话ID + currentConversationId = conversationId; + updateActiveConversation(); + + // 清空消息区域 + const messagesDiv = document.getElementById('chat-messages'); + messagesDiv.innerHTML = ''; + + // 加载消息 + if (conversation.messages && conversation.messages.length > 0) { + conversation.messages.forEach(msg => { + addMessage(msg.role, msg.content, msg.mcpExecutionIds || []); + }); + } else { + addMessage('assistant', '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。'); + } + + // 滚动到底部 + messagesDiv.scrollTop = messagesDiv.scrollHeight; + + // 刷新对话列表 + loadConversations(); + } catch (error) { + console.error('加载对话失败:', error); + alert('加载对话失败: ' + error.message); + } +} + +// 更新活动对话样式 +function updateActiveConversation() { + document.querySelectorAll('.conversation-item').forEach(item => { + item.classList.remove('active'); + if (currentConversationId && item.dataset.conversationId === currentConversationId) { + item.classList.add('active'); + } + }); +} + +// 页面加载时初始化 +document.addEventListener('DOMContentLoaded', function() { + // 加载对话列表 + loadConversations(); + + // 添加欢迎消息 + addMessage('assistant', '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。'); +}); + diff --git a/web/templates/index.html b/web/templates/index.html new file mode 100644 index 00000000..bc563f32 --- /dev/null +++ b/web/templates/index.html @@ -0,0 +1,92 @@ + + + + + + CyberStrikeAI - 自主渗透测试平台 + + + +
+
+
+ +

安全测试平台

+
+
+ +
+ + + + +
+
+
+ + +
+
+
+
+ + + + + + + +