mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-03-31 00:09:29 +02:00
Add files via upload
This commit is contained in:
319
README.md
Normal file
319
README.md
Normal file
@@ -0,0 +1,319 @@
|
||||
# CyberStrikeAI
|
||||
|
||||
基于Golang和Gin框架的AI驱动自主渗透测试平台,使用MCP协议集成安全工具。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 🤖 **AI代理连接** - 支持Claude、GPT等兼容MCP的AI代理通过FastMCP协议连接
|
||||
- 🧠 **智能分析** - 决策引擎分析目标并选择最佳测试策略
|
||||
- ⚡ **自主执行** - AI代理执行全面的安全评估
|
||||
- 🔄 **实时适应** - 系统根据结果和发现的漏洞进行调整
|
||||
- 📊 **高级报告** - 可视化方式输出漏洞卡片和风险分析
|
||||
- 💬 **对话式交互** - 前端以对话形式调用后端agent-loop
|
||||
- 📈 **实时监控** - 监控安全工具的执行状态、结果、调用次数等
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
CyberStrikeAI/
|
||||
├── cmd/
|
||||
│ └── server/
|
||||
│ └── main.go # 程序入口
|
||||
├── internal/
|
||||
│ ├── agent/ # AI代理模块
|
||||
│ ├── app/ # 应用初始化
|
||||
│ ├── config/ # 配置管理
|
||||
│ ├── handler/ # HTTP处理器
|
||||
│ ├── logger/ # 日志系统
|
||||
│ ├── mcp/ # MCP协议实现
|
||||
│ └── security/ # 安全工具执行器
|
||||
├── web/
|
||||
│ ├── static/ # 静态资源
|
||||
│ │ ├── css/
|
||||
│ │ └── js/
|
||||
│ └── templates/ # HTML模板
|
||||
├── config.yaml # 配置文件
|
||||
├── go.mod # Go模块文件
|
||||
└── README.md # 说明文档
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 前置要求
|
||||
|
||||
- Go 1.21 或更高版本
|
||||
- OpenAI API Key(或其他兼容OpenAI协议的API)
|
||||
- 安全工具(可选):nmap, sqlmap, nikto, dirb
|
||||
|
||||
### 安装步骤
|
||||
|
||||
1. **克隆项目**
|
||||
```bash
|
||||
cd /Users/temp/Desktop/wenjian/tools/CyberStrikeAI
|
||||
```
|
||||
|
||||
2. **安装依赖**
|
||||
```bash
|
||||
go mod download
|
||||
```
|
||||
|
||||
3. **配置**
|
||||
编辑 `config.yaml` 文件,设置您的OpenAI API Key:
|
||||
```yaml
|
||||
openai:
|
||||
api_key: "sk-your-api-key-here"
|
||||
base_url: "https://api.openai.com/v1"
|
||||
model: "gpt-4"
|
||||
```
|
||||
|
||||
4. **启动服务器**
|
||||
|
||||
#### 方式一:使用启动脚本
|
||||
```bash
|
||||
./run.sh
|
||||
```
|
||||
|
||||
#### 方式二:直接运行
|
||||
```bash
|
||||
go run cmd/server/main.go
|
||||
```
|
||||
|
||||
#### 方式三:编译后运行
|
||||
```bash
|
||||
go build -o cyberstrike-ai cmd/server/main.go
|
||||
./cyberstrike-ai
|
||||
```
|
||||
|
||||
5. **访问应用**
|
||||
打开浏览器访问:http://localhost:8080
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 服务器配置
|
||||
```yaml
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
```
|
||||
|
||||
### MCP配置
|
||||
```yaml
|
||||
mcp:
|
||||
enabled: true
|
||||
host: "0.0.0.0"
|
||||
port: 8081
|
||||
```
|
||||
|
||||
### 安全工具配置
|
||||
```yaml
|
||||
security:
|
||||
tools:
|
||||
- name: "nmap"
|
||||
command: "nmap"
|
||||
args: ["-sV", "-sC"]
|
||||
description: "网络扫描工具"
|
||||
enabled: true
|
||||
```
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 对话式渗透测试
|
||||
|
||||
在"对话测试"标签页中,您可以:
|
||||
|
||||
1. **网络扫描**
|
||||
```
|
||||
扫描 192.168.1.1 的开放端口
|
||||
```
|
||||
|
||||
2. **SQL注入检测**
|
||||
```
|
||||
检测 https://example.com 的SQL注入漏洞
|
||||
```
|
||||
|
||||
3. **Web漏洞扫描**
|
||||
```
|
||||
扫描 https://example.com 的Web服务器漏洞
|
||||
```
|
||||
|
||||
4. **目录扫描**
|
||||
```
|
||||
扫描 https://example.com 的隐藏目录
|
||||
```
|
||||
|
||||
### 监控工具执行
|
||||
|
||||
在"工具监控"标签页中,您可以:
|
||||
|
||||
- 查看所有工具的执行统计
|
||||
- 查看详细的执行记录
|
||||
- 查看发现的漏洞列表
|
||||
- 实时监控工具状态
|
||||
|
||||
## API接口
|
||||
|
||||
### Agent Loop API
|
||||
|
||||
**POST** `/api/agent-loop`
|
||||
|
||||
请求体:
|
||||
```json
|
||||
{
|
||||
"message": "扫描 192.168.1.1"
|
||||
}
|
||||
```
|
||||
|
||||
使用示例:
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/agent-loop \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"message": "扫描 192.168.1.1"}'
|
||||
```
|
||||
|
||||
### 监控API
|
||||
|
||||
- **GET** `/api/monitor` - 获取所有监控信息
|
||||
- **GET** `/api/monitor/execution/:id` - 获取特定执行记录
|
||||
- **GET** `/api/monitor/stats` - 获取统计信息
|
||||
- **GET** `/api/monitor/vulnerabilities` - 获取漏洞列表
|
||||
|
||||
使用示例:
|
||||
```bash
|
||||
# 获取所有监控信息
|
||||
curl http://localhost:8080/api/monitor
|
||||
|
||||
# 获取统计信息
|
||||
curl http://localhost:8080/api/monitor/stats
|
||||
|
||||
# 获取漏洞列表
|
||||
curl http://localhost:8080/api/monitor/vulnerabilities
|
||||
```
|
||||
|
||||
### MCP接口
|
||||
|
||||
**POST** `/api/mcp` - MCP协议端点
|
||||
|
||||
## MCP协议
|
||||
|
||||
本项目实现了MCP(Model Context Protocol)协议,支持:
|
||||
|
||||
- `initialize` - 初始化连接
|
||||
- `tools/list` - 列出可用工具
|
||||
- `tools/call` - 调用工具
|
||||
|
||||
工具调用是异步执行的,系统会跟踪每个工具的执行状态和结果。
|
||||
|
||||
### MCP协议使用示例
|
||||
|
||||
#### 初始化连接
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/mcp \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "1",
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "test-client",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
#### 列出工具
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/mcp \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "2",
|
||||
"method": "tools/list"
|
||||
}'
|
||||
```
|
||||
|
||||
#### 调用工具
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/api/mcp \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "3",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": "nmap",
|
||||
"arguments": {
|
||||
"target": "192.168.1.1",
|
||||
"ports": "1-1000"
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## 安全工具支持
|
||||
|
||||
当前支持的安全工具:
|
||||
- **nmap** - 网络扫描
|
||||
- **sqlmap** - SQL注入检测
|
||||
- **nikto** - Web服务器扫描
|
||||
- **dirb** - 目录扫描
|
||||
|
||||
可以通过修改 `config.yaml` 添加更多工具。
|
||||
|
||||
## 故障排除
|
||||
|
||||
### 问题:无法连接到OpenAI API
|
||||
|
||||
- 检查API Key是否正确
|
||||
- 检查网络连接
|
||||
- 检查base_url配置
|
||||
|
||||
### 问题:工具执行失败
|
||||
|
||||
- 确保已安装相应的安全工具(nmap, sqlmap等)
|
||||
- 检查工具是否在PATH中
|
||||
- 某些工具可能需要root权限
|
||||
|
||||
### 问题:前端无法加载
|
||||
|
||||
- 检查服务器是否正常运行
|
||||
- 检查端口8080是否被占用
|
||||
- 查看浏览器控制台错误信息
|
||||
|
||||
## 安全注意事项
|
||||
|
||||
⚠️ **重要提示**:
|
||||
|
||||
- 仅对您拥有或已获得授权的系统进行测试
|
||||
- 遵守相关法律法规
|
||||
- 建议在隔离的测试环境中使用
|
||||
- 不要在生产环境中使用
|
||||
- 某些安全工具可能需要root权限
|
||||
|
||||
## 开发
|
||||
|
||||
### 添加新工具
|
||||
|
||||
1. 在 `config.yaml` 中添加工具配置
|
||||
2. 在 `internal/security/executor.go` 的 `buildCommandArgs` 方法中添加参数构建逻辑
|
||||
3. 在 `internal/agent/agent.go` 的 `getAvailableTools` 方法中添加工具定义
|
||||
|
||||
### 构建
|
||||
|
||||
```bash
|
||||
go build -o cyberstrike-ai cmd/server/main.go
|
||||
```
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目仅供学习和研究使用。
|
||||
|
||||
## 贡献
|
||||
|
||||
欢迎提交Issue和Pull Request!
|
||||
36
cmd/server/main.go
Normal file
36
cmd/server/main.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"cyberstrike-ai/internal/app"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/logger"
|
||||
"flag"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var configPath = flag.String("config", "config.yaml", "配置文件路径")
|
||||
flag.Parse()
|
||||
|
||||
// 加载配置
|
||||
cfg, err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("加载配置失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化日志
|
||||
log := logger.New(cfg.Log.Level, cfg.Log.Output)
|
||||
|
||||
// 创建应用
|
||||
application, err := app.New(cfg, log)
|
||||
if err != nil {
|
||||
log.Fatal("应用初始化失败", "error", err)
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
if err := application.Run(); err != nil {
|
||||
log.Fatal("服务器启动失败", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
174
config.yaml
Normal file
174
config.yaml
Normal file
@@ -0,0 +1,174 @@
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
|
||||
log:
|
||||
level: "info"
|
||||
output: "stdout"
|
||||
|
||||
mcp:
|
||||
enabled: true
|
||||
host: "0.0.0.0"
|
||||
port: 8081
|
||||
|
||||
openai:
|
||||
api_key: "sk-xxx" # 请设置您的OpenAI API Key
|
||||
base_url: "https://api.deepseek.com/v1"
|
||||
model: "deepseek-chat"
|
||||
|
||||
database:
|
||||
path: "data/conversations.db"
|
||||
|
||||
security:
|
||||
tools:
|
||||
# 示例1: 使用参数定义的工具(推荐方式)
|
||||
- name: "nmap"
|
||||
command: "nmap"
|
||||
args: ["-sT", "-sV", "-sC"] # 固定参数
|
||||
description: "网络扫描工具,用于发现网络主机和服务"
|
||||
enabled: true
|
||||
parameters:
|
||||
- name: "target"
|
||||
type: "string"
|
||||
description: "目标IP地址或域名"
|
||||
required: true
|
||||
position: 0 # 位置参数,放在最后
|
||||
format: "positional"
|
||||
- name: "ports"
|
||||
type: "string"
|
||||
description: "端口范围,例如: 1-1000, 80,443,8080"
|
||||
required: false
|
||||
flag: "-p"
|
||||
format: "flag"
|
||||
|
||||
# 示例2: 标志参数工具
|
||||
- name: "sqlmap"
|
||||
command: "sqlmap"
|
||||
description: "SQL注入检测和利用工具"
|
||||
enabled: true
|
||||
parameters:
|
||||
- name: "url"
|
||||
type: "string"
|
||||
description: "目标URL,例如: http://example.com/page?id=1"
|
||||
required: true
|
||||
flag: "-u"
|
||||
format: "flag"
|
||||
- name: "batch"
|
||||
type: "bool"
|
||||
description: "非交互模式"
|
||||
required: false
|
||||
default: true
|
||||
flag: "--batch"
|
||||
format: "flag"
|
||||
- name: "level"
|
||||
type: "int"
|
||||
description: "测试级别 (1-5)"
|
||||
required: false
|
||||
default: 3
|
||||
flag: "--level"
|
||||
format: "combined" # --level=3
|
||||
|
||||
# 示例3: 位置参数工具
|
||||
- name: "nikto"
|
||||
command: "nikto"
|
||||
description: "Web服务器扫描工具"
|
||||
enabled: true
|
||||
parameters:
|
||||
- name: "target"
|
||||
type: "string"
|
||||
description: "目标URL或IP地址"
|
||||
required: true
|
||||
flag: "-h"
|
||||
format: "flag"
|
||||
|
||||
# 示例4: 简单位置参数
|
||||
- name: "dirb"
|
||||
command: "dirb"
|
||||
description: "Web目录扫描工具"
|
||||
enabled: true
|
||||
parameters:
|
||||
- name: "url"
|
||||
type: "string"
|
||||
description: "目标URL"
|
||||
required: true
|
||||
position: 0
|
||||
format: "positional"
|
||||
- name: "wordlist"
|
||||
type: "string"
|
||||
description: "字典文件路径"
|
||||
required: false
|
||||
flag: "-w"
|
||||
format: "flag"
|
||||
|
||||
# 示例5: 执行系统命令
|
||||
- name: "exec"
|
||||
command: "sh"
|
||||
args: ["-c"]
|
||||
description: "执行系统命令(谨慎使用)"
|
||||
enabled: true
|
||||
parameters:
|
||||
- name: "command"
|
||||
type: "string"
|
||||
description: "要执行的系统命令"
|
||||
required: true
|
||||
position: 0
|
||||
format: "positional"
|
||||
- name: "shell"
|
||||
type: "string"
|
||||
description: "使用的shell(可选,默认为sh)"
|
||||
required: false
|
||||
default: "sh"
|
||||
- name: "workdir"
|
||||
type: "string"
|
||||
description: "工作目录"
|
||||
required: false
|
||||
|
||||
# 示例6: 自定义工具 - 使用模板格式
|
||||
- name: "custom_scanner"
|
||||
command: "my-scanner"
|
||||
description: "自定义扫描工具示例"
|
||||
enabled: false # 默认禁用,需要时启用
|
||||
parameters:
|
||||
- name: "target"
|
||||
type: "string"
|
||||
description: "扫描目标"
|
||||
required: true
|
||||
flag: "--target"
|
||||
format: "flag"
|
||||
- name: "mode"
|
||||
type: "string"
|
||||
description: "扫描模式"
|
||||
required: false
|
||||
default: "normal"
|
||||
options: ["normal", "aggressive", "stealth"] # 枚举值
|
||||
flag: "--mode"
|
||||
format: "combined" # --mode=normal
|
||||
- name: "threads"
|
||||
type: "int"
|
||||
description: "线程数"
|
||||
required: false
|
||||
default: 10
|
||||
flag: "-t"
|
||||
format: "flag"
|
||||
- name: "output"
|
||||
type: "string"
|
||||
description: "输出文件路径"
|
||||
required: false
|
||||
flag: "-o"
|
||||
format: "template"
|
||||
template: "-o {value}" # 自定义模板
|
||||
- name: "verbose"
|
||||
type: "bool"
|
||||
description: "详细输出"
|
||||
required: false
|
||||
default: false
|
||||
flag: "-v"
|
||||
format: "flag" # 布尔值:如果为true,只添加-v,不添加值
|
||||
|
||||
# 示例7: 向后兼容 - 不定义parameters,使用旧的硬编码逻辑
|
||||
# - name: "legacy_tool"
|
||||
# command: "legacy"
|
||||
# args: ["--option"]
|
||||
# description: "旧工具(使用硬编码逻辑)"
|
||||
# enabled: false
|
||||
|
||||
BIN
data/conversations.db
Normal file
BIN
data/conversations.db
Normal file
Binary file not shown.
BIN
data/conversations.db-shm
Normal file
BIN
data/conversations.db-shm
Normal file
Binary file not shown.
BIN
data/conversations.db-wal
Normal file
BIN
data/conversations.db-wal
Normal file
Binary file not shown.
38
go.mod
Normal file
38
go.mod
Normal file
@@ -0,0 +1,38 @@
|
||||
module cyberstrike-ai
|
||||
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/google/uuid v1.5.0
|
||||
github.com/mattn/go-sqlite3 v1.14.18
|
||||
go.uber.org/zap v1.26.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.14.0 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
)
|
||||
96
go.sum
Normal file
96
go.sum
Normal file
@@ -0,0 +1,96 @@
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
|
||||
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
|
||||
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
|
||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
|
||||
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
|
||||
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
576
internal/agent/agent.go
Normal file
576
internal/agent/agent.go
Normal file
@@ -0,0 +1,576 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Agent AI代理
|
||||
type Agent struct {
|
||||
openAIClient *http.Client
|
||||
config *config.OpenAIConfig
|
||||
mcpServer *mcp.Server
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAgent 创建新的Agent
|
||||
func NewAgent(cfg *config.OpenAIConfig, mcpServer *mcp.Server, logger *zap.Logger) *Agent {
|
||||
return &Agent{
|
||||
openAIClient: &http.Client{Timeout: 5 * time.Minute},
|
||||
config: cfg,
|
||||
mcpServer: mcpServer,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatMessage 聊天消息
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串
|
||||
func (cm ChatMessage) MarshalJSON() ([]byte, error) {
|
||||
// 构建序列化结构
|
||||
aux := map[string]interface{}{
|
||||
"role": cm.Role,
|
||||
}
|
||||
|
||||
// 添加content(如果存在)
|
||||
if cm.Content != "" {
|
||||
aux["content"] = cm.Content
|
||||
}
|
||||
|
||||
// 添加tool_call_id(如果存在)
|
||||
if cm.ToolCallID != "" {
|
||||
aux["tool_call_id"] = cm.ToolCallID
|
||||
}
|
||||
|
||||
// 转换tool_calls,将arguments转换为JSON字符串
|
||||
if len(cm.ToolCalls) > 0 {
|
||||
toolCallsJSON := make([]map[string]interface{}, len(cm.ToolCalls))
|
||||
for i, tc := range cm.ToolCalls {
|
||||
// 将arguments转换为JSON字符串
|
||||
argsJSON := ""
|
||||
if tc.Function.Arguments != nil {
|
||||
argsBytes, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
argsJSON = string(argsBytes)
|
||||
}
|
||||
|
||||
toolCallsJSON[i] = map[string]interface{}{
|
||||
"id": tc.ID,
|
||||
"type": tc.Type,
|
||||
"function": map[string]interface{}{
|
||||
"name": tc.Function.Name,
|
||||
"arguments": argsJSON,
|
||||
},
|
||||
}
|
||||
}
|
||||
aux["tool_calls"] = toolCallsJSON
|
||||
}
|
||||
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
// OpenAIRequest OpenAI API请求
|
||||
type OpenAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIResponse OpenAI API响应
|
||||
type OpenAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Choice 选择
|
||||
type Choice struct {
|
||||
Message MessageWithTools `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// MessageWithTools 带工具调用的消息
|
||||
type MessageWithTools struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// Tool OpenAI工具定义
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function FunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
// FunctionDefinition 函数定义
|
||||
type FunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
|
||||
// Error OpenAI错误
|
||||
type Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// ToolCall 工具调用
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
// FunctionCall 函数调用
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON 自定义JSON解析,处理arguments可能是字符串或对象的情况
|
||||
func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
|
||||
type Alias FunctionCall
|
||||
aux := &struct {
|
||||
Name string `json:"name"`
|
||||
Arguments interface{} `json:"arguments"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(fc),
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fc.Name = aux.Name
|
||||
|
||||
// 处理arguments可能是字符串或对象的情况
|
||||
switch v := aux.Arguments.(type) {
|
||||
case map[string]interface{}:
|
||||
fc.Arguments = v
|
||||
case string:
|
||||
// 如果是字符串,尝试解析为JSON
|
||||
if err := json.Unmarshal([]byte(v), &fc.Arguments); err != nil {
|
||||
// 如果解析失败,创建一个包含原始字符串的map
|
||||
fc.Arguments = map[string]interface{}{
|
||||
"raw": v,
|
||||
}
|
||||
}
|
||||
case nil:
|
||||
fc.Arguments = make(map[string]interface{})
|
||||
default:
|
||||
// 其他类型,尝试转换为map
|
||||
fc.Arguments = map[string]interface{}{
|
||||
"value": v,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AgentLoopResult Agent Loop执行结果
|
||||
type AgentLoopResult struct {
|
||||
Response string
|
||||
MCPExecutionIDs []string
|
||||
}
|
||||
|
||||
// AgentLoop 执行Agent循环
|
||||
func (a *Agent) AgentLoop(ctx context.Context, userInput string, historyMessages []ChatMessage) (*AgentLoopResult, error) {
|
||||
messages := []ChatMessage{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "你是一个专业的网络安全渗透测试专家。你可以使用各种安全工具进行自主渗透测试。分析目标并选择最佳测试策略。当需要执行工具时,使用提供的工具函数。",
|
||||
},
|
||||
}
|
||||
|
||||
// 添加历史消息(数据库只保存user和assistant消息)
|
||||
a.logger.Info("处理历史消息",
|
||||
zap.Int("count", len(historyMessages)),
|
||||
)
|
||||
addedCount := 0
|
||||
for i, msg := range historyMessages {
|
||||
// 只添加有内容的消息
|
||||
if msg.Content != "" {
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
addedCount++
|
||||
contentPreview := msg.Content
|
||||
if len(contentPreview) > 50 {
|
||||
contentPreview = contentPreview[:50] + "..."
|
||||
}
|
||||
a.logger.Info("添加历史消息到上下文",
|
||||
zap.Int("index", i),
|
||||
zap.String("role", msg.Role),
|
||||
zap.String("content", contentPreview),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
a.logger.Info("构建消息数组",
|
||||
zap.Int("historyMessages", len(historyMessages)),
|
||||
zap.Int("addedMessages", addedCount),
|
||||
zap.Int("totalMessages", len(messages)),
|
||||
)
|
||||
|
||||
// 添加当前用户消息
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "user",
|
||||
Content: userInput,
|
||||
})
|
||||
|
||||
result := &AgentLoopResult{
|
||||
MCPExecutionIDs: make([]string, 0),
|
||||
}
|
||||
|
||||
maxIterations := 10
|
||||
for i := 0; i < maxIterations; i++ {
|
||||
// 获取可用工具
|
||||
tools := a.getAvailableTools()
|
||||
|
||||
// 记录每次调用OpenAI
|
||||
if i == 0 {
|
||||
a.logger.Info("调用OpenAI",
|
||||
zap.Int("iteration", i+1),
|
||||
zap.Int("messagesCount", len(messages)),
|
||||
)
|
||||
// 记录前几条消息的内容(用于调试)
|
||||
for j, msg := range messages {
|
||||
if j >= 5 { // 只记录前5条
|
||||
break
|
||||
}
|
||||
contentPreview := msg.Content
|
||||
if len(contentPreview) > 100 {
|
||||
contentPreview = contentPreview[:100] + "..."
|
||||
}
|
||||
a.logger.Debug("消息内容",
|
||||
zap.Int("index", j),
|
||||
zap.String("role", msg.Role),
|
||||
zap.String("content", contentPreview),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
a.logger.Info("调用OpenAI",
|
||||
zap.Int("iteration", i+1),
|
||||
zap.Int("messagesCount", len(messages)),
|
||||
)
|
||||
}
|
||||
|
||||
// 调用OpenAI
|
||||
response, err := a.callOpenAI(ctx, messages, tools)
|
||||
if err != nil {
|
||||
result.Response = ""
|
||||
return result, fmt.Errorf("调用OpenAI失败: %w", err)
|
||||
}
|
||||
|
||||
if response.Error != nil {
|
||||
result.Response = ""
|
||||
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
|
||||
}
|
||||
|
||||
if len(response.Choices) == 0 {
|
||||
result.Response = ""
|
||||
return result, fmt.Errorf("没有收到响应")
|
||||
}
|
||||
|
||||
choice := response.Choices[0]
|
||||
|
||||
// 检查是否有工具调用
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// 添加assistant消息(包含工具调用)
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: choice.Message.Content,
|
||||
ToolCalls: choice.Message.ToolCalls,
|
||||
})
|
||||
|
||||
// 执行所有工具调用
|
||||
for _, toolCall := range choice.Message.ToolCalls {
|
||||
// 执行工具
|
||||
execResult, err := a.executeToolViaMCP(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
if err != nil {
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "tool",
|
||||
ToolCallID: toolCall.ID,
|
||||
Content: fmt.Sprintf("工具执行失败: %v", err),
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "tool",
|
||||
ToolCallID: toolCall.ID,
|
||||
Content: execResult.Result,
|
||||
})
|
||||
// 收集执行ID
|
||||
if execResult.ExecutionID != "" {
|
||||
result.MCPExecutionIDs = append(result.MCPExecutionIDs, execResult.ExecutionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 添加assistant响应
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: choice.Message.Content,
|
||||
})
|
||||
|
||||
// 如果完成,返回结果
|
||||
if choice.FinishReason == "stop" {
|
||||
result.Response = choice.Message.Content
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
result.Response = "达到最大迭代次数"
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// getAvailableTools 获取可用工具
|
||||
func (a *Agent) getAvailableTools() []Tool {
|
||||
// 从MCP服务器获取工具列表
|
||||
executions := a.mcpServer.GetAllExecutions()
|
||||
toolNames := make(map[string]bool)
|
||||
for _, exec := range executions {
|
||||
toolNames[exec.ToolName] = true
|
||||
}
|
||||
|
||||
tools := []Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: FunctionDefinition{
|
||||
Name: "nmap",
|
||||
Description: "使用nmap进行网络扫描,发现开放端口和服务。支持IP地址、域名或URL(会自动提取域名)。使用TCP连接扫描,不需要root权限。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"target": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "目标IP地址、域名或URL(如 https://example.com)。如果是URL,会自动提取域名部分。",
|
||||
},
|
||||
"ports": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "要扫描的端口范围,例如: 1-1000 或 80,443,8080。如果不指定,将扫描常用端口。",
|
||||
},
|
||||
},
|
||||
"required": []string{"target"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: FunctionDefinition{
|
||||
Name: "sqlmap",
|
||||
Description: "使用sqlmap检测SQL注入漏洞",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"url": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "目标URL",
|
||||
},
|
||||
},
|
||||
"required": []string{"url"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: FunctionDefinition{
|
||||
Name: "nikto",
|
||||
Description: "使用nikto扫描Web服务器漏洞",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"target": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "目标URL",
|
||||
},
|
||||
},
|
||||
"required": []string{"target"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: FunctionDefinition{
|
||||
Name: "dirb",
|
||||
Description: "使用dirb进行目录扫描",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"url": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "目标URL",
|
||||
},
|
||||
},
|
||||
"required": []string{"url"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: FunctionDefinition{
|
||||
Name: "exec",
|
||||
Description: "执行系统命令(谨慎使用,仅用于必要的系统操作)",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"command": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "要执行的系统命令",
|
||||
},
|
||||
"shell": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "使用的shell(可选,默认为sh)",
|
||||
},
|
||||
"workdir": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "工作目录(可选)",
|
||||
},
|
||||
},
|
||||
"required": []string{"command"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
// callOpenAI 调用OpenAI API
|
||||
func (a *Agent) callOpenAI(ctx context.Context, messages []ChatMessage, tools []Tool) (*OpenAIResponse, error) {
|
||||
reqBody := OpenAIRequest{
|
||||
Model: a.config.Model,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
reqBody.Tools = tools
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", a.config.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+a.config.APIKey)
|
||||
|
||||
resp, err := a.openAIClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 记录响应内容(用于调试)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
a.logger.Warn("OpenAI API返回非200状态码",
|
||||
zap.Int("status", resp.StatusCode),
|
||||
zap.String("body", string(body)),
|
||||
)
|
||||
}
|
||||
|
||||
var response OpenAIResponse
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
a.logger.Error("解析OpenAI响应失败",
|
||||
zap.Error(err),
|
||||
zap.String("body", string(body)),
|
||||
)
|
||||
return nil, fmt.Errorf("解析响应失败: %w, 响应内容: %s", err, string(body))
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// parseToolCall 解析工具调用
|
||||
func (a *Agent) parseToolCall(content string) (map[string]interface{}, error) {
|
||||
// 简单解析,实际应该更复杂
|
||||
// 格式: [TOOL_CALL]tool_name:arg1=value1,arg2=value2
|
||||
if !strings.HasPrefix(content, "[TOOL_CALL]") {
|
||||
return nil, fmt.Errorf("不是有效的工具调用格式")
|
||||
}
|
||||
|
||||
parts := strings.Split(content[len("[TOOL_CALL]"):], ":")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("工具调用格式错误")
|
||||
}
|
||||
|
||||
toolName := strings.TrimSpace(parts[0])
|
||||
argsStr := strings.TrimSpace(parts[1])
|
||||
|
||||
args := make(map[string]interface{})
|
||||
argPairs := strings.Split(argsStr, ",")
|
||||
for _, pair := range argPairs {
|
||||
kv := strings.Split(pair, "=")
|
||||
if len(kv) == 2 {
|
||||
args[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1])
|
||||
}
|
||||
}
|
||||
|
||||
args["_tool_name"] = toolName
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// ToolExecutionResult 工具执行结果
|
||||
type ToolExecutionResult struct {
|
||||
Result string
|
||||
ExecutionID string
|
||||
}
|
||||
|
||||
// executeToolViaMCP 通过MCP执行工具
|
||||
func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map[string]interface{}) (*ToolExecutionResult, error) {
|
||||
a.logger.Info("通过MCP执行工具",
|
||||
zap.String("tool", toolName),
|
||||
zap.Any("args", args),
|
||||
)
|
||||
|
||||
// 通过MCP服务器调用工具
|
||||
result, executionID, err := a.mcpServer.CallTool(ctx, toolName, args)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("工具执行失败: %w", err)
|
||||
}
|
||||
|
||||
// 格式化结果
|
||||
var resultText strings.Builder
|
||||
for _, content := range result.Content {
|
||||
resultText.WriteString(content.Text)
|
||||
resultText.WriteString("\n")
|
||||
}
|
||||
|
||||
return &ToolExecutionResult{
|
||||
Result: resultText.String(),
|
||||
ExecutionID: executionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
163
internal/app/app.go
Normal file
163
internal/app/app.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/handler"
|
||||
"cyberstrike-ai/internal/logger"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// App 应用
|
||||
type App struct {
|
||||
config *config.Config
|
||||
logger *logger.Logger
|
||||
router *gin.Engine
|
||||
mcpServer *mcp.Server
|
||||
agent *agent.Agent
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
}
|
||||
|
||||
// New 创建新应用
|
||||
func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
router := gin.Default()
|
||||
|
||||
// CORS中间件
|
||||
router.Use(corsMiddleware())
|
||||
|
||||
// 初始化数据库
|
||||
dbPath := cfg.Database.Path
|
||||
if dbPath == "" {
|
||||
dbPath = "data/conversations.db"
|
||||
}
|
||||
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建数据库目录失败: %w", err)
|
||||
}
|
||||
|
||||
db, err := database.NewDB(dbPath, log.Logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建MCP服务器
|
||||
mcpServer := mcp.NewServer(log.Logger)
|
||||
|
||||
// 创建安全工具执行器
|
||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||
|
||||
// 注册工具
|
||||
executor.RegisterTools(mcpServer)
|
||||
|
||||
// 创建Agent
|
||||
agent := agent.NewAgent(&cfg.OpenAI, mcpServer, log.Logger)
|
||||
|
||||
// 创建处理器
|
||||
agentHandler := handler.NewAgentHandler(agent, db, log.Logger)
|
||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, log.Logger)
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
|
||||
// 设置路由
|
||||
setupRoutes(router, agentHandler, monitorHandler, conversationHandler, mcpServer)
|
||||
|
||||
return &App{
|
||||
config: cfg,
|
||||
logger: log,
|
||||
router: router,
|
||||
mcpServer: mcpServer,
|
||||
agent: agent,
|
||||
executor: executor,
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run 启动应用
|
||||
func (a *App) Run() error {
|
||||
// 启动MCP服务器(如果启用)
|
||||
if a.config.MCP.Enabled {
|
||||
go func() {
|
||||
mcpAddr := fmt.Sprintf("%s:%d", a.config.MCP.Host, a.config.MCP.Port)
|
||||
a.logger.Info("启动MCP服务器", zap.String("address", mcpAddr))
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/mcp", a.mcpServer.HandleHTTP)
|
||||
|
||||
if err := http.ListenAndServe(mcpAddr, mux); err != nil {
|
||||
a.logger.Error("MCP服务器启动失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 启动主服务器
|
||||
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
|
||||
a.logger.Info("启动HTTP服务器", zap.String("address", addr))
|
||||
|
||||
return a.router.Run(addr)
|
||||
}
|
||||
|
||||
// setupRoutes 设置路由
|
||||
func setupRoutes(router *gin.Engine, agentHandler *handler.AgentHandler, monitorHandler *handler.MonitorHandler, conversationHandler *handler.ConversationHandler, mcpServer *mcp.Server) {
|
||||
// API路由
|
||||
api := router.Group("/api")
|
||||
{
|
||||
// Agent Loop
|
||||
api.POST("/agent-loop", agentHandler.AgentLoop)
|
||||
|
||||
// 对话历史
|
||||
api.POST("/conversations", conversationHandler.CreateConversation)
|
||||
api.GET("/conversations", conversationHandler.ListConversations)
|
||||
api.GET("/conversations/:id", conversationHandler.GetConversation)
|
||||
api.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
|
||||
|
||||
// 监控
|
||||
api.GET("/monitor", monitorHandler.Monitor)
|
||||
api.GET("/monitor/execution/:id", monitorHandler.GetExecution)
|
||||
api.GET("/monitor/stats", monitorHandler.GetStats)
|
||||
api.GET("/monitor/vulnerabilities", monitorHandler.GetVulnerabilities)
|
||||
|
||||
// MCP端点
|
||||
api.POST("/mcp", func(c *gin.Context) {
|
||||
mcpServer.HandleHTTP(c.Writer, c.Request)
|
||||
})
|
||||
}
|
||||
|
||||
// 静态文件
|
||||
router.Static("/static", "./web/static")
|
||||
router.LoadHTMLGlob("web/templates/*")
|
||||
|
||||
// 前端页面
|
||||
router.GET("/", func(c *gin.Context) {
|
||||
c.HTML(http.StatusOK, "index.html", nil)
|
||||
})
|
||||
}
|
||||
|
||||
// corsMiddleware CORS中间件
|
||||
func corsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
114
internal/config/config.go
Normal file
114
internal/config/config.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Log LogConfig `yaml:"log"`
|
||||
MCP MCPConfig `yaml:"mcp"`
|
||||
OpenAI OpenAIConfig `yaml:"openai"`
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
Level string `yaml:"level"`
|
||||
Output string `yaml:"output"`
|
||||
}
|
||||
|
||||
type MCPConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
}
|
||||
|
||||
type OpenAIConfig struct {
|
||||
APIKey string `yaml:"api_key"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
Model string `yaml:"model"`
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
Tools []ToolConfig `yaml:"tools"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Path string `yaml:"path"`
|
||||
}
|
||||
|
||||
type ToolConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Command string `yaml:"command"`
|
||||
Args []string `yaml:"args,omitempty"` // 固定参数(可选)
|
||||
Description string `yaml:"description"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Parameters []ParameterConfig `yaml:"parameters,omitempty"` // 参数定义(可选)
|
||||
ArgMapping string `yaml:"arg_mapping,omitempty"` // 参数映射方式: "auto", "manual", "template"(可选)
|
||||
}
|
||||
|
||||
// ParameterConfig 参数配置
|
||||
type ParameterConfig struct {
|
||||
Name string `yaml:"name"` // 参数名称
|
||||
Type string `yaml:"type"` // 参数类型: string, int, bool, array
|
||||
Description string `yaml:"description"` // 参数描述
|
||||
Required bool `yaml:"required,omitempty"` // 是否必需
|
||||
Default interface{} `yaml:"default,omitempty"` // 默认值
|
||||
Flag string `yaml:"flag,omitempty"` // 命令行标志,如 "-u", "--url", "-p"
|
||||
Position *int `yaml:"position,omitempty"` // 位置参数的位置(从0开始)
|
||||
Format string `yaml:"format,omitempty"` // 参数格式: "flag", "positional", "combined" (flag=value), "template"
|
||||
Template string `yaml:"template,omitempty"` // 模板字符串,如 "{flag} {value}" 或 "{value}"
|
||||
Options []string `yaml:"options,omitempty"` // 可选值列表(用于枚举)
|
||||
}
|
||||
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func Default() *Config {
|
||||
return &Config{
|
||||
Server: ServerConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
},
|
||||
Log: LogConfig{
|
||||
Level: "info",
|
||||
Output: "stdout",
|
||||
},
|
||||
MCP: MCPConfig{
|
||||
Enabled: true,
|
||||
Host: "0.0.0.0",
|
||||
Port: 8081,
|
||||
},
|
||||
OpenAI: OpenAIConfig{
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
Model: "gpt-4",
|
||||
},
|
||||
Security: SecurityConfig{
|
||||
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 加载,不在此硬编码
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
Path: "data/conversations.db",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
256
internal/database/conversation.go
Normal file
256
internal/database/conversation.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Conversation 对话
|
||||
type Conversation struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
// Message 消息
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
func (db *DB) CreateConversation(title string) (*Conversation, error) {
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
|
||||
id, title, now, now,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||||
}
|
||||
|
||||
return &Conversation{
|
||||
ID: id,
|
||||
Title: title,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetConversation 获取对话
|
||||
func (db *DB) GetConversation(id string) (*Conversation, error) {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, created_at, updated_at FROM conversations WHERE id = ?",
|
||||
id,
|
||||
).Scan(&conv.ID, &conv.Title, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("对话不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询对话失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
}
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
|
||||
// 加载消息
|
||||
messages, err := db.GetMessages(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载消息失败: %w", err)
|
||||
}
|
||||
conv.Messages = messages
|
||||
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
// ListConversations 列出所有对话
|
||||
func (db *DB) ListConversations(limit, offset int) ([]*Conversation, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, title, created_at, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
||||
limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询对话列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var conversations []*Conversation
|
||||
for rows.Next() {
|
||||
var conv Conversation
|
||||
var createdAt, updatedAt string
|
||||
|
||||
if err := rows.Scan(&conv.ID, &conv.Title, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描对话失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err1, err2 error
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err1 != nil {
|
||||
conv.CreatedAt, err1 = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt)
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
}
|
||||
if err2 != nil {
|
||||
conv.UpdatedAt, err2 = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
|
||||
conversations = append(conversations, &conv)
|
||||
}
|
||||
|
||||
return conversations, nil
|
||||
}
|
||||
|
||||
// UpdateConversationTitle 更新对话标题
|
||||
func (db *DB) UpdateConversationTitle(id, title string) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE conversations SET title = ?, updated_at = ? WHERE id = ?",
|
||||
title, time.Now(), id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新对话标题失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateConversationTime 更新对话时间
|
||||
func (db *DB) UpdateConversationTime(id string) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE conversations SET updated_at = ? WHERE id = ?",
|
||||
time.Now(), id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新对话时间失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteConversation 删除对话
|
||||
func (db *DB) DeleteConversation(id string) error {
|
||||
_, err := db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除对话失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMessage 添加消息
|
||||
func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) {
|
||||
id := uuid.New().String()
|
||||
|
||||
var mcpIDsJSON string
|
||||
if len(mcpExecutionIDs) > 0 {
|
||||
jsonData, err := json.Marshal(mcpExecutionIDs)
|
||||
if err != nil {
|
||||
db.logger.Warn("序列化MCP执行ID失败", zap.Error(err))
|
||||
} else {
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, role, content, mcpIDsJSON, time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("添加消息失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新对话时间
|
||||
if err := db.UpdateConversationTime(conversationID); err != nil {
|
||||
db.logger.Warn("更新对话时间失败", zap.Error(err))
|
||||
}
|
||||
|
||||
message := &Message{
|
||||
ID: id,
|
||||
ConversationID: conversationID,
|
||||
Role: role,
|
||||
Content: content,
|
||||
MCPExecutionIDs: mcpExecutionIDs,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return message, nil
|
||||
}
|
||||
|
||||
// GetMessages 获取对话的所有消息
|
||||
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||
conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询消息失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var msg Message
|
||||
var mcpIDsJSON sql.NullString
|
||||
var createdAt string
|
||||
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err error
|
||||
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err != nil {
|
||||
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err != nil {
|
||||
msg.CreatedAt, err = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
// 解析MCP执行ID
|
||||
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
|
||||
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
|
||||
db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
90
internal/database/database.go
Normal file
90
internal/database/database.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DB 数据库连接
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewDB 创建数据库连接
|
||||
func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据库失败: %w", err)
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
database := &DB{
|
||||
DB: db,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// 初始化表
|
||||
if err := database.initTables(); err != nil {
|
||||
return nil, fmt.Errorf("初始化表失败: %w", err)
|
||||
}
|
||||
|
||||
return database, nil
|
||||
}
|
||||
|
||||
// initTables 初始化数据库表
|
||||
func (db *DB) initTables() error {
|
||||
// 创建对话表
|
||||
createConversationsTable := `
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL
|
||||
);`
|
||||
|
||||
// 创建消息表
|
||||
createMessagesTable := `
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
mcp_execution_ids TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建索引
|
||||
createIndexes := `
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(createConversationsTable); err != nil {
|
||||
return fmt.Errorf("创建conversations表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createMessagesTable); err != nil {
|
||||
return fmt.Errorf("创建messages表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createIndexes); err != nil {
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
}
|
||||
|
||||
db.logger.Info("数据库表初始化完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (db *DB) Close() error {
|
||||
return db.DB.Close()
|
||||
}
|
||||
|
||||
134
internal/handler/agent.go
Normal file
134
internal/handler/agent.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AgentHandler Agent处理器
|
||||
type AgentHandler struct {
|
||||
agent *agent.Agent
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAgentHandler 创建新的Agent处理器
|
||||
func NewAgentHandler(agent *agent.Agent, db *database.DB, logger *zap.Logger) *AgentHandler {
|
||||
return &AgentHandler{
|
||||
agent: agent,
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatRequest 聊天请求
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
}
|
||||
|
||||
// ChatResponse 聊天响应
|
||||
type ChatResponse struct {
|
||||
Response string `json:"response"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` // 本次对话中执行的MCP调用ID列表
|
||||
ConversationID string `json:"conversationId"` // 对话ID
|
||||
Time time.Time `json:"time"`
|
||||
}
|
||||
|
||||
// AgentLoop 处理Agent Loop请求
|
||||
func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
var req ChatRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("收到Agent Loop请求",
|
||||
zap.String("message", req.Message),
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
)
|
||||
|
||||
// 如果没有对话ID,创建新对话
|
||||
conversationID := req.ConversationID
|
||||
if conversationID == "" {
|
||||
title := req.Message
|
||||
if len(title) > 50 {
|
||||
title = title[:50] + "..."
|
||||
}
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
conversationID = conv.ID
|
||||
}
|
||||
|
||||
// 获取历史消息(排除当前消息,因为还没保存)
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
historyMessages = []database.Message{}
|
||||
}
|
||||
|
||||
h.logger.Info("获取历史消息",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("count", len(historyMessages)),
|
||||
)
|
||||
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages := make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for i, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
contentPreview := msg.Content
|
||||
if len(contentPreview) > 50 {
|
||||
contentPreview = contentPreview[:50] + "..."
|
||||
}
|
||||
h.logger.Info("添加历史消息",
|
||||
zap.Int("index", i),
|
||||
zap.String("role", msg.Role),
|
||||
zap.String("content", contentPreview),
|
||||
)
|
||||
}
|
||||
|
||||
h.logger.Info("历史消息转换完成",
|
||||
zap.Int("originalCount", len(historyMessages)),
|
||||
zap.Int("convertedCount", len(agentHistoryMessages)),
|
||||
)
|
||||
|
||||
// 保存用户消息
|
||||
_, err = h.db.AddMessage(conversationID, "user", req.Message, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 执行Agent Loop,传入历史消息
|
||||
result, err := h.agent.AgentLoop(c.Request.Context(), req.Message, agentHistoryMessages)
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 保存助手回复
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs)
|
||||
if err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, ChatResponse{
|
||||
Response: result.Response,
|
||||
MCPExecutionIDs: result.MCPExecutionIDs,
|
||||
ConversationID: conversationID,
|
||||
Time: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
102
internal/handler/conversation.go
Normal file
102
internal/handler/conversation.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ConversationHandler 对话处理器
|
||||
type ConversationHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewConversationHandler 创建新的对话处理器
|
||||
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
|
||||
return &ConversationHandler{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateConversationRequest 创建对话请求
|
||||
type CreateConversationRequest struct {
|
||||
Title string `json:"title"`
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
||||
var req CreateConversationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
title := req.Title
|
||||
if title == "" {
|
||||
title = "新对话"
|
||||
}
|
||||
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, conv)
|
||||
}
|
||||
|
||||
// ListConversations 列出对话
|
||||
func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "50")
|
||||
offsetStr := c.DefaultQuery("offset", "0")
|
||||
|
||||
limit, _ := strconv.Atoi(limitStr)
|
||||
offset, _ := strconv.Atoi(offsetStr)
|
||||
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
conversations, err := h.db.ListConversations(limit, offset)
|
||||
if err != nil {
|
||||
h.logger.Error("获取对话列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, conversations)
|
||||
}
|
||||
|
||||
// GetConversation 获取对话
|
||||
func (h *ConversationHandler) GetConversation(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
conv, err := h.db.GetConversation(id)
|
||||
if err != nil {
|
||||
h.logger.Error("获取对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, conv)
|
||||
}
|
||||
|
||||
// DeleteConversation 删除对话
|
||||
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if err := h.db.DeleteConversation(id); err != nil {
|
||||
h.logger.Error("删除对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
92
internal/handler/monitor.go
Normal file
92
internal/handler/monitor.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// MonitorHandler 监控处理器
|
||||
type MonitorHandler struct {
|
||||
mcpServer *mcp.Server
|
||||
executor *security.Executor
|
||||
logger *zap.Logger
|
||||
vulns []security.Vulnerability
|
||||
}
|
||||
|
||||
// NewMonitorHandler 创建新的监控处理器
|
||||
func NewMonitorHandler(mcpServer *mcp.Server, executor *security.Executor, logger *zap.Logger) *MonitorHandler {
|
||||
return &MonitorHandler{
|
||||
mcpServer: mcpServer,
|
||||
executor: executor,
|
||||
logger: logger,
|
||||
vulns: []security.Vulnerability{},
|
||||
}
|
||||
}
|
||||
|
||||
// MonitorResponse 监控响应
|
||||
type MonitorResponse struct {
|
||||
Executions []*mcp.ToolExecution `json:"executions"`
|
||||
Stats map[string]*mcp.ToolStats `json:"stats"`
|
||||
Vulnerabilities []security.Vulnerability `json:"vulnerabilities"`
|
||||
Report map[string]interface{} `json:"report"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// Monitor 获取监控信息
|
||||
func (h *MonitorHandler) Monitor(c *gin.Context) {
|
||||
// 获取所有执行记录
|
||||
executions := h.mcpServer.GetAllExecutions()
|
||||
|
||||
// 分析执行结果,提取漏洞
|
||||
for _, exec := range executions {
|
||||
if exec.Status == "completed" && exec.Result != nil {
|
||||
vulns := h.executor.AnalyzeResults(exec.ToolName, exec.Result)
|
||||
h.vulns = append(h.vulns, vulns...)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取统计信息
|
||||
stats := h.mcpServer.GetStats()
|
||||
|
||||
// 生成报告
|
||||
report := h.executor.GetVulnerabilityReport(h.vulns)
|
||||
|
||||
c.JSON(http.StatusOK, MonitorResponse{
|
||||
Executions: executions,
|
||||
Stats: stats,
|
||||
Vulnerabilities: h.vulns,
|
||||
Report: report,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// GetExecution 获取特定执行记录
|
||||
func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
exec, exists := h.mcpServer.GetExecution(id)
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, exec)
|
||||
}
|
||||
|
||||
// GetStats 获取统计信息
|
||||
func (h *MonitorHandler) GetStats(c *gin.Context) {
|
||||
stats := h.mcpServer.GetStats()
|
||||
c.JSON(http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// GetVulnerabilities 获取漏洞列表
|
||||
func (h *MonitorHandler) GetVulnerabilities(c *gin.Context) {
|
||||
report := h.executor.GetVulnerabilityReport(h.vulns)
|
||||
c.JSON(http.StatusOK, report)
|
||||
}
|
||||
|
||||
60
internal/logger/logger.go
Normal file
60
internal/logger/logger.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
*zap.Logger
|
||||
}
|
||||
|
||||
func New(level, output string) *Logger {
|
||||
var zapLevel zapcore.Level
|
||||
switch level {
|
||||
case "debug":
|
||||
zapLevel = zapcore.DebugLevel
|
||||
case "info":
|
||||
zapLevel = zapcore.InfoLevel
|
||||
case "warn":
|
||||
zapLevel = zapcore.WarnLevel
|
||||
case "error":
|
||||
zapLevel = zapcore.ErrorLevel
|
||||
default:
|
||||
zapLevel = zapcore.InfoLevel
|
||||
}
|
||||
|
||||
config := zap.NewProductionConfig()
|
||||
config.Level = zap.NewAtomicLevelAt(zapLevel)
|
||||
config.EncoderConfig.TimeKey = "timestamp"
|
||||
config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
|
||||
var writeSyncer zapcore.WriteSyncer
|
||||
if output == "stdout" {
|
||||
writeSyncer = zapcore.AddSync(os.Stdout)
|
||||
} else {
|
||||
file, err := os.OpenFile(output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
writeSyncer = zapcore.AddSync(os.Stdout)
|
||||
} else {
|
||||
writeSyncer = zapcore.AddSync(file)
|
||||
}
|
||||
}
|
||||
|
||||
core := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(config.EncoderConfig),
|
||||
writeSyncer,
|
||||
zapLevel,
|
||||
)
|
||||
|
||||
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
|
||||
|
||||
return &Logger{Logger: logger}
|
||||
}
|
||||
|
||||
func (l *Logger) Fatal(msg string, fields ...interface{}) {
|
||||
l.Logger.Fatal(msg, zap.Any("fields", fields))
|
||||
}
|
||||
|
||||
798
internal/mcp/server.go
Normal file
798
internal/mcp/server.go
Normal file
@@ -0,0 +1,798 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Server MCP服务器
|
||||
type Server struct {
|
||||
tools map[string]ToolHandler
|
||||
toolDefs map[string]Tool // 工具定义
|
||||
executions map[string]*ToolExecution
|
||||
stats map[string]*ToolStats
|
||||
prompts map[string]*Prompt // 提示词模板
|
||||
resources map[string]*Resource // 资源
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// ToolHandler 工具处理函数
|
||||
type ToolHandler func(ctx context.Context, args map[string]interface{}) (*ToolResult, error)
|
||||
|
||||
// NewServer 创建新的MCP服务器
|
||||
func NewServer(logger *zap.Logger) *Server {
|
||||
s := &Server{
|
||||
tools: make(map[string]ToolHandler),
|
||||
toolDefs: make(map[string]Tool),
|
||||
executions: make(map[string]*ToolExecution),
|
||||
stats: make(map[string]*ToolStats),
|
||||
prompts: make(map[string]*Prompt),
|
||||
resources: make(map[string]*Resource),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// 初始化默认提示词和资源
|
||||
s.initDefaultPrompts()
|
||||
s.initDefaultResources()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// RegisterTool 注册工具
|
||||
func (s *Server) RegisterTool(tool Tool, handler ToolHandler) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.tools[tool.Name] = handler
|
||||
s.toolDefs[tool.Name] = tool
|
||||
|
||||
// 自动为工具创建资源文档
|
||||
resourceURI := fmt.Sprintf("tool://%s", tool.Name)
|
||||
s.resources[resourceURI] = &Resource{
|
||||
URI: resourceURI,
|
||||
Name: fmt.Sprintf("%s工具文档", tool.Name),
|
||||
Description: tool.Description,
|
||||
MimeType: "text/plain",
|
||||
}
|
||||
}
|
||||
|
||||
// HandleHTTP 处理HTTP请求
|
||||
func (s *Server) HandleHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
s.sendError(w, nil, -32700, "Parse error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var msg Message
|
||||
if err := json.Unmarshal(body, &msg); err != nil {
|
||||
s.sendError(w, nil, -32700, "Parse error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 处理消息
|
||||
response := s.handleMessage(&msg)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// handleMessage 处理MCP消息
|
||||
func (s *Server) handleMessage(msg *Message) *Message {
|
||||
if msg.ID == "" {
|
||||
msg.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
switch msg.Method {
|
||||
case "initialize":
|
||||
return s.handleInitialize(msg)
|
||||
case "tools/list":
|
||||
return s.handleListTools(msg)
|
||||
case "tools/call":
|
||||
return s.handleCallTool(msg)
|
||||
case "prompts/list":
|
||||
return s.handleListPrompts(msg)
|
||||
case "prompts/get":
|
||||
return s.handleGetPrompt(msg)
|
||||
case "resources/list":
|
||||
return s.handleListResources(msg)
|
||||
case "resources/read":
|
||||
return s.handleReadResource(msg)
|
||||
case "sampling/request":
|
||||
return s.handleSamplingRequest(msg)
|
||||
default:
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeError,
|
||||
Error: &Error{Code: -32601, Message: "Method not found"},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleInitialize 处理初始化请求
|
||||
func (s *Server) handleInitialize(msg *Message) *Message {
|
||||
var req InitializeRequest
|
||||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeError,
|
||||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||||
}
|
||||
}
|
||||
|
||||
response := InitializeResponse{
|
||||
ProtocolVersion: ProtocolVersion,
|
||||
Capabilities: ServerCapabilities{
|
||||
Tools: map[string]interface{}{
|
||||
"listChanged": true,
|
||||
},
|
||||
Prompts: map[string]interface{}{
|
||||
"listChanged": true,
|
||||
},
|
||||
Resources: map[string]interface{}{
|
||||
"subscribe": true,
|
||||
"listChanged": true,
|
||||
},
|
||||
Sampling: map[string]interface{}{},
|
||||
},
|
||||
ServerInfo: ServerInfo{
|
||||
Name: "CyberStrikeAI",
|
||||
Version: "1.0.0",
|
||||
},
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(response)
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeResponse,
|
||||
Version: "2.0",
|
||||
Result: result,
|
||||
}
|
||||
}
|
||||
|
||||
// handleListTools 处理列出工具请求
|
||||
func (s *Server) handleListTools(msg *Message) *Message {
|
||||
s.mu.RLock()
|
||||
tools := make([]Tool, 0, len(s.toolDefs))
|
||||
for _, tool := range s.toolDefs {
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
response := ListToolsResponse{Tools: tools}
|
||||
result, _ := json.Marshal(response)
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeResponse,
|
||||
Version: "2.0",
|
||||
Result: result,
|
||||
}
|
||||
}
|
||||
|
||||
// handleCallTool 处理工具调用请求
|
||||
func (s *Server) handleCallTool(msg *Message) *Message {
|
||||
var req CallToolRequest
|
||||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeError,
|
||||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||||
}
|
||||
}
|
||||
|
||||
// 创建执行记录
|
||||
executionID := uuid.New().String()
|
||||
execution := &ToolExecution{
|
||||
ID: executionID,
|
||||
ToolName: req.Name,
|
||||
Arguments: req.Arguments,
|
||||
Status: "running",
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.executions[executionID] = execution
|
||||
s.mu.Unlock()
|
||||
|
||||
// 更新统计
|
||||
s.updateStats(req.Name, false)
|
||||
|
||||
// 执行工具
|
||||
s.mu.RLock()
|
||||
handler, exists := s.tools[req.Name]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
execution.Status = "failed"
|
||||
execution.Error = "Tool not found"
|
||||
now := time.Now()
|
||||
execution.EndTime = &now
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeError,
|
||||
Error: &Error{Code: -32601, Message: "Tool not found"},
|
||||
}
|
||||
}
|
||||
|
||||
// 同步执行所有工具,确保错误能正确返回
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
s.logger.Info("开始执行工具",
|
||||
zap.String("toolName", req.Name),
|
||||
zap.Any("arguments", req.Arguments),
|
||||
)
|
||||
|
||||
result, err := handler(ctx, req.Arguments)
|
||||
|
||||
s.mu.Lock()
|
||||
now := time.Now()
|
||||
execution.EndTime = &now
|
||||
execution.Duration = now.Sub(execution.StartTime)
|
||||
|
||||
if err != nil {
|
||||
execution.Status = "failed"
|
||||
execution.Error = err.Error()
|
||||
s.updateStats(req.Name, true)
|
||||
s.mu.Unlock()
|
||||
|
||||
s.logger.Error("工具执行失败",
|
||||
zap.String("toolName", req.Name),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
// 返回错误结果
|
||||
errorResult, _ := json.Marshal(CallToolResponse{
|
||||
Content: []Content{
|
||||
{Type: "text", Text: fmt.Sprintf("工具执行失败: %v", err)},
|
||||
},
|
||||
IsError: true,
|
||||
})
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeResponse,
|
||||
Version: "2.0",
|
||||
Result: errorResult,
|
||||
}
|
||||
}
|
||||
|
||||
// 检查result是否为错误
|
||||
if result != nil && result.IsError {
|
||||
execution.Status = "failed"
|
||||
if len(result.Content) > 0 {
|
||||
execution.Error = result.Content[0].Text
|
||||
}
|
||||
s.updateStats(req.Name, true)
|
||||
} else {
|
||||
execution.Status = "completed"
|
||||
execution.Result = result
|
||||
s.updateStats(req.Name, false)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// 返回执行结果
|
||||
if result == nil {
|
||||
result = &ToolResult{
|
||||
Content: []Content{
|
||||
{Type: "text", Text: "工具执行完成,但未返回结果"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
resultJSON, _ := json.Marshal(CallToolResponse{
|
||||
Content: result.Content,
|
||||
IsError: result.IsError,
|
||||
})
|
||||
|
||||
s.logger.Info("工具执行完成",
|
||||
zap.String("toolName", req.Name),
|
||||
zap.Bool("isError", result.IsError),
|
||||
)
|
||||
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeResponse,
|
||||
Version: "2.0",
|
||||
Result: resultJSON,
|
||||
}
|
||||
}
|
||||
|
||||
// updateStats 更新统计信息
|
||||
func (s *Server) updateStats(toolName string, failed bool) {
|
||||
if s.stats[toolName] == nil {
|
||||
s.stats[toolName] = &ToolStats{
|
||||
ToolName: toolName,
|
||||
}
|
||||
}
|
||||
|
||||
stats := s.stats[toolName]
|
||||
stats.TotalCalls++
|
||||
now := time.Now()
|
||||
stats.LastCallTime = &now
|
||||
|
||||
if failed {
|
||||
stats.FailedCalls++
|
||||
} else {
|
||||
stats.SuccessCalls++
|
||||
}
|
||||
}
|
||||
|
||||
// GetExecution 获取执行记录
|
||||
func (s *Server) GetExecution(id string) (*ToolExecution, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
exec, exists := s.executions[id]
|
||||
return exec, exists
|
||||
}
|
||||
|
||||
// GetAllExecutions 获取所有执行记录
|
||||
func (s *Server) GetAllExecutions() []*ToolExecution {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
executions := make([]*ToolExecution, 0, len(s.executions))
|
||||
for _, exec := range s.executions {
|
||||
executions = append(executions, exec)
|
||||
}
|
||||
return executions
|
||||
}
|
||||
|
||||
// GetStats 获取统计信息
|
||||
func (s *Server) GetStats() map[string]*ToolStats {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
stats := make(map[string]*ToolStats)
|
||||
for k, v := range s.stats {
|
||||
stats[k] = v
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
// CallTool 直接调用工具(用于内部调用)
|
||||
func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]interface{}) (*ToolResult, string, error) {
|
||||
s.mu.RLock()
|
||||
handler, exists := s.tools[toolName]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, "", fmt.Errorf("工具 %s 未找到", toolName)
|
||||
}
|
||||
|
||||
// 创建执行记录
|
||||
executionID := uuid.New().String()
|
||||
execution := &ToolExecution{
|
||||
ID: executionID,
|
||||
ToolName: toolName,
|
||||
Arguments: args,
|
||||
Status: "running",
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.executions[executionID] = execution
|
||||
s.mu.Unlock()
|
||||
|
||||
// 更新统计
|
||||
s.updateStats(toolName, false)
|
||||
|
||||
// 执行工具
|
||||
result, err := handler(ctx, args)
|
||||
|
||||
s.mu.Lock()
|
||||
now := time.Now()
|
||||
execution.EndTime = &now
|
||||
execution.Duration = now.Sub(execution.StartTime)
|
||||
|
||||
if err != nil {
|
||||
execution.Status = "failed"
|
||||
execution.Error = err.Error()
|
||||
s.updateStats(toolName, true)
|
||||
s.mu.Unlock()
|
||||
return nil, executionID, err
|
||||
} else {
|
||||
execution.Status = "completed"
|
||||
execution.Result = result
|
||||
s.updateStats(toolName, false)
|
||||
s.mu.Unlock()
|
||||
return result, executionID, nil
|
||||
}
|
||||
}
|
||||
|
||||
// initDefaultPrompts 初始化默认提示词模板
|
||||
func (s *Server) initDefaultPrompts() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 网络安全测试提示词
|
||||
s.prompts["security_scan"] = &Prompt{
|
||||
Name: "security_scan",
|
||||
Description: "生成网络安全扫描任务的提示词",
|
||||
Arguments: []PromptArgument{
|
||||
{Name: "target", Description: "扫描目标(IP地址或域名)", Required: true},
|
||||
{Name: "scan_type", Description: "扫描类型(port, vuln, web等)", Required: false},
|
||||
},
|
||||
}
|
||||
|
||||
// 渗透测试提示词
|
||||
s.prompts["penetration_test"] = &Prompt{
|
||||
Name: "penetration_test",
|
||||
Description: "生成渗透测试任务的提示词",
|
||||
Arguments: []PromptArgument{
|
||||
{Name: "target", Description: "测试目标", Required: true},
|
||||
{Name: "scope", Description: "测试范围", Required: false},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// initDefaultResources 初始化默认资源
|
||||
// 注意:工具资源现在在 RegisterTool 时自动创建,此函数保留用于其他非工具资源
|
||||
func (s *Server) initDefaultResources() {
|
||||
// 工具资源已改为在 RegisterTool 时自动创建,无需在此硬编码
|
||||
}
|
||||
|
||||
// handleListPrompts 处理列出提示词请求
|
||||
func (s *Server) handleListPrompts(msg *Message) *Message {
|
||||
s.mu.RLock()
|
||||
prompts := make([]Prompt, 0, len(s.prompts))
|
||||
for _, prompt := range s.prompts {
|
||||
prompts = append(prompts, *prompt)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
response := ListPromptsResponse{
|
||||
Prompts: prompts,
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeResponse,
|
||||
Version: "2.0",
|
||||
Result: result,
|
||||
}
|
||||
}
|
||||
|
||||
// handleGetPrompt 处理获取提示词请求
|
||||
func (s *Server) handleGetPrompt(msg *Message) *Message {
|
||||
var req GetPromptRequest
|
||||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeError,
|
||||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
prompt, exists := s.prompts[req.Name]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeError,
|
||||
Error: &Error{Code: -32601, Message: "Prompt not found"},
|
||||
}
|
||||
}
|
||||
|
||||
// 根据提示词名称生成消息
|
||||
messages := s.generatePromptMessages(prompt, req.Arguments)
|
||||
|
||||
response := GetPromptResponse{
|
||||
Messages: messages,
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeResponse,
|
||||
Version: "2.0",
|
||||
Result: result,
|
||||
}
|
||||
}
|
||||
|
||||
// generatePromptMessages 生成提示词消息
|
||||
func (s *Server) generatePromptMessages(prompt *Prompt, args map[string]interface{}) []PromptMessage {
|
||||
messages := []PromptMessage{}
|
||||
|
||||
switch prompt.Name {
|
||||
case "security_scan":
|
||||
target, _ := args["target"].(string)
|
||||
scanType, _ := args["scan_type"].(string)
|
||||
if scanType == "" {
|
||||
scanType = "comprehensive"
|
||||
}
|
||||
|
||||
content := fmt.Sprintf(`请对目标 %s 执行%s安全扫描。包括:
|
||||
1. 端口扫描和服务识别
|
||||
2. 漏洞检测
|
||||
3. Web应用安全测试
|
||||
4. 生成详细的安全报告`, target, scanType)
|
||||
|
||||
messages = append(messages, PromptMessage{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
})
|
||||
|
||||
case "penetration_test":
|
||||
target, _ := args["target"].(string)
|
||||
scope, _ := args["scope"].(string)
|
||||
|
||||
content := fmt.Sprintf(`请对目标 %s 执行渗透测试。`, target)
|
||||
if scope != "" {
|
||||
content += fmt.Sprintf("测试范围:%s", scope)
|
||||
}
|
||||
content += "\n请按照OWASP Top 10进行全面的安全测试。"
|
||||
|
||||
messages = append(messages, PromptMessage{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
})
|
||||
|
||||
default:
|
||||
messages = append(messages, PromptMessage{
|
||||
Role: "user",
|
||||
Content: "请执行安全测试任务",
|
||||
})
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// handleListResources 处理列出资源请求
|
||||
func (s *Server) handleListResources(msg *Message) *Message {
|
||||
s.mu.RLock()
|
||||
resources := make([]Resource, 0, len(s.resources))
|
||||
for _, resource := range s.resources {
|
||||
resources = append(resources, *resource)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
response := ListResourcesResponse{
|
||||
Resources: resources,
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeResponse,
|
||||
Version: "2.0",
|
||||
Result: result,
|
||||
}
|
||||
}
|
||||
|
||||
// handleReadResource 处理读取资源请求
|
||||
func (s *Server) handleReadResource(msg *Message) *Message {
|
||||
var req ReadResourceRequest
|
||||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeError,
|
||||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
resource, exists := s.resources[req.URI]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeError,
|
||||
Error: &Error{Code: -32601, Message: "Resource not found"},
|
||||
}
|
||||
}
|
||||
|
||||
// 生成资源内容
|
||||
content := s.generateResourceContent(resource)
|
||||
|
||||
response := ReadResourceResponse{
|
||||
Contents: []ResourceContent{content},
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeResponse,
|
||||
Version: "2.0",
|
||||
Result: result,
|
||||
}
|
||||
}
|
||||
|
||||
// generateResourceContent 生成资源内容
|
||||
func (s *Server) generateResourceContent(resource *Resource) ResourceContent {
|
||||
content := ResourceContent{
|
||||
URI: resource.URI,
|
||||
MimeType: resource.MimeType,
|
||||
}
|
||||
|
||||
// 如果是工具资源,生成详细文档
|
||||
if strings.HasPrefix(resource.URI, "tool://") {
|
||||
toolName := strings.TrimPrefix(resource.URI, "tool://")
|
||||
content.Text = s.generateToolDocumentation(toolName, resource)
|
||||
} else {
|
||||
// 其他资源使用描述或默认内容
|
||||
content.Text = resource.Description
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// generateToolDocumentation 生成工具文档
|
||||
func (s *Server) generateToolDocumentation(toolName string, resource *Resource) string {
|
||||
// 获取工具定义以获取更详细的信息
|
||||
s.mu.RLock()
|
||||
tool, hasTool := s.toolDefs[toolName]
|
||||
s.mu.RUnlock()
|
||||
|
||||
// 为常见工具生成详细文档
|
||||
switch toolName {
|
||||
case "nmap":
|
||||
return `Nmap (Network Mapper) 是一个强大的网络扫描工具。
|
||||
|
||||
主要功能:
|
||||
- 端口扫描:发现目标主机开放的端口
|
||||
- 服务识别:识别运行在端口上的服务
|
||||
- 版本检测:检测服务版本信息
|
||||
- 操作系统检测:识别目标操作系统
|
||||
|
||||
常用命令:
|
||||
- nmap -sT target # TCP连接扫描
|
||||
- nmap -sV target # 版本检测
|
||||
- nmap -sC target # 默认脚本扫描
|
||||
- nmap -p 1-1000 target # 扫描指定端口范围
|
||||
|
||||
参数说明:
|
||||
- target: 目标IP地址或域名(必需)
|
||||
- ports: 端口范围,例如: 1-1000(可选)`
|
||||
|
||||
case "sqlmap":
|
||||
return `SQLMap 是一个自动化的SQL注入检测和利用工具。
|
||||
|
||||
主要功能:
|
||||
- 自动检测SQL注入漏洞
|
||||
- 数据库指纹识别
|
||||
- 数据提取
|
||||
- 文件系统访问
|
||||
|
||||
常用命令:
|
||||
- sqlmap -u "http://target.com/page?id=1" # 检测URL参数
|
||||
- sqlmap -u "http://target.com" --forms # 检测表单
|
||||
- sqlmap -u "http://target.com" --dbs # 列出数据库
|
||||
|
||||
参数说明:
|
||||
- url: 目标URL(必需)`
|
||||
|
||||
case "nikto":
|
||||
return `Nikto 是一个Web服务器扫描工具。
|
||||
|
||||
主要功能:
|
||||
- Web服务器漏洞扫描
|
||||
- 检测过时的服务器软件
|
||||
- 检测危险文件和程序
|
||||
- 检测服务器配置问题
|
||||
|
||||
常用命令:
|
||||
- nikto -h target # 扫描目标主机
|
||||
- nikto -h target -p 80,443 # 扫描指定端口
|
||||
|
||||
参数说明:
|
||||
- target: 目标URL(必需)`
|
||||
|
||||
case "dirb":
|
||||
return `Dirb 是一个Web内容扫描器。
|
||||
|
||||
主要功能:
|
||||
- 扫描Web目录和文件
|
||||
- 发现隐藏的目录和文件
|
||||
- 支持自定义字典
|
||||
|
||||
常用命令:
|
||||
- dirb url # 扫描目标URL
|
||||
- dirb url -w wordlist.txt # 使用自定义字典
|
||||
|
||||
参数说明:
|
||||
- target: 目标URL(必需)`
|
||||
|
||||
case "exec":
|
||||
return `Exec 工具用于执行系统命令。
|
||||
|
||||
⚠️ 警告:此工具可以执行任意系统命令,请谨慎使用!
|
||||
|
||||
参数说明:
|
||||
- command: 要执行的系统命令(必需)
|
||||
- shell: 使用的shell,默认为sh(可选)
|
||||
- workdir: 工作目录(可选)`
|
||||
|
||||
default:
|
||||
// 对于其他工具,使用工具定义中的描述信息
|
||||
if hasTool {
|
||||
doc := fmt.Sprintf("%s\n\n", resource.Description)
|
||||
if tool.InputSchema != nil {
|
||||
if props, ok := tool.InputSchema["properties"].(map[string]interface{}); ok {
|
||||
doc += "参数说明:\n"
|
||||
for paramName, paramInfo := range props {
|
||||
if paramMap, ok := paramInfo.(map[string]interface{}); ok {
|
||||
if desc, ok := paramMap["description"].(string); ok {
|
||||
doc += fmt.Sprintf("- %s: %s\n", paramName, desc)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return doc
|
||||
}
|
||||
return resource.Description
|
||||
}
|
||||
}
|
||||
|
||||
// handleSamplingRequest 处理采样请求
|
||||
func (s *Server) handleSamplingRequest(msg *Message) *Message {
|
||||
var req SamplingRequest
|
||||
if err := json.Unmarshal(msg.Params, &req); err != nil {
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeError,
|
||||
Error: &Error{Code: -32602, Message: "Invalid params"},
|
||||
}
|
||||
}
|
||||
|
||||
// 注意:采样功能通常需要连接到实际的LLM服务
|
||||
// 这里返回一个占位符响应,实际实现需要集成LLM API
|
||||
s.logger.Warn("Sampling request received but not fully implemented",
|
||||
zap.Any("request", req),
|
||||
)
|
||||
|
||||
response := SamplingResponse{
|
||||
Content: []SamplingContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "采样功能需要配置LLM服务。请使用Agent Loop API进行AI对话。",
|
||||
},
|
||||
},
|
||||
StopReason: "length",
|
||||
}
|
||||
result, _ := json.Marshal(response)
|
||||
return &Message{
|
||||
ID: msg.ID,
|
||||
Type: MessageTypeResponse,
|
||||
Version: "2.0",
|
||||
Result: result,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterPrompt 注册提示词模板
|
||||
func (s *Server) RegisterPrompt(prompt *Prompt) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.prompts[prompt.Name] = prompt
|
||||
}
|
||||
|
||||
// RegisterResource 注册资源
|
||||
func (s *Server) RegisterResource(resource *Resource) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.resources[resource.URI] = resource
|
||||
}
|
||||
|
||||
// sendError 发送错误响应
|
||||
func (s *Server) sendError(w http.ResponseWriter, id interface{}, code int, message, data string) {
|
||||
response := Message{
|
||||
ID: fmt.Sprintf("%v", id),
|
||||
Type: MessageTypeError,
|
||||
Error: &Error{Code: code, Message: message, Data: data},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
232
internal/mcp/types.go
Normal file
232
internal/mcp/types.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MCP消息类型
|
||||
const (
|
||||
MessageTypeRequest = "request"
|
||||
MessageTypeResponse = "response"
|
||||
MessageTypeError = "error"
|
||||
MessageTypeNotify = "notify"
|
||||
)
|
||||
|
||||
// MCP协议版本
|
||||
const ProtocolVersion = "2024-11-05"
|
||||
|
||||
// Message 表示MCP消息
|
||||
type Message struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
Version string `json:"jsonrpc,omitempty"`
|
||||
}
|
||||
|
||||
// Error 表示MCP错误
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Tool 表示MCP工具定义
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema map[string]interface{} `json:"inputSchema"`
|
||||
}
|
||||
|
||||
// ToolCall 表示工具调用
|
||||
type ToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
// ToolResult 表示工具执行结果
|
||||
type ToolResult struct {
|
||||
Content []Content `json:"content"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
// Content 表示内容
|
||||
type Content struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// InitializeRequest 初始化请求
|
||||
type InitializeRequest struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities map[string]interface{} `json:"capabilities"`
|
||||
ClientInfo ClientInfo `json:"clientInfo"`
|
||||
}
|
||||
|
||||
// ClientInfo 客户端信息
|
||||
type ClientInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// InitializeResponse 初始化响应
|
||||
type InitializeResponse struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities ServerCapabilities `json:"capabilities"`
|
||||
ServerInfo ServerInfo `json:"serverInfo"`
|
||||
}
|
||||
|
||||
// ServerCapabilities 服务器能力
|
||||
type ServerCapabilities struct {
|
||||
Tools map[string]interface{} `json:"tools,omitempty"`
|
||||
Prompts map[string]interface{} `json:"prompts,omitempty"`
|
||||
Resources map[string]interface{} `json:"resources,omitempty"`
|
||||
Sampling map[string]interface{} `json:"sampling,omitempty"`
|
||||
}
|
||||
|
||||
// ServerInfo 服务器信息
|
||||
type ServerInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// ListToolsRequest 列出工具请求
|
||||
type ListToolsRequest struct{}
|
||||
|
||||
// ListToolsResponse 列出工具响应
|
||||
type ListToolsResponse struct {
|
||||
Tools []Tool `json:"tools"`
|
||||
}
|
||||
|
||||
// ListPromptsResponse 列出提示词响应
|
||||
type ListPromptsResponse struct {
|
||||
Prompts []Prompt `json:"prompts"`
|
||||
}
|
||||
|
||||
// ListResourcesResponse 列出资源响应
|
||||
type ListResourcesResponse struct {
|
||||
Resources []Resource `json:"resources"`
|
||||
}
|
||||
|
||||
// CallToolRequest 调用工具请求
|
||||
type CallToolRequest struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
// CallToolResponse 调用工具响应
|
||||
type CallToolResponse struct {
|
||||
Content []Content `json:"content"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
// ToolExecution 工具执行记录
|
||||
type ToolExecution struct {
|
||||
ID string `json:"id"`
|
||||
ToolName string `json:"toolName"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
Status string `json:"status"` // pending, running, completed, failed
|
||||
Result *ToolResult `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StartTime time.Time `json:"startTime"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Duration time.Duration `json:"duration,omitempty"`
|
||||
}
|
||||
|
||||
// ToolStats 工具统计信息
|
||||
type ToolStats struct {
|
||||
ToolName string `json:"toolName"`
|
||||
TotalCalls int `json:"totalCalls"`
|
||||
SuccessCalls int `json:"successCalls"`
|
||||
FailedCalls int `json:"failedCalls"`
|
||||
LastCallTime *time.Time `json:"lastCallTime,omitempty"`
|
||||
}
|
||||
|
||||
// Prompt 提示词模板
|
||||
type Prompt struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Arguments []PromptArgument `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// PromptArgument 提示词参数
|
||||
type PromptArgument struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Required bool `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
// GetPromptRequest 获取提示词请求
|
||||
type GetPromptRequest struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// GetPromptResponse 获取提示词响应
|
||||
type GetPromptResponse struct {
|
||||
Messages []PromptMessage `json:"messages"`
|
||||
}
|
||||
|
||||
// PromptMessage 提示词消息
|
||||
type PromptMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// Resource 资源
|
||||
type Resource struct {
|
||||
URI string `json:"uri"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
}
|
||||
|
||||
// ReadResourceRequest 读取资源请求
|
||||
type ReadResourceRequest struct {
|
||||
URI string `json:"uri"`
|
||||
}
|
||||
|
||||
// ReadResourceResponse 读取资源响应
|
||||
type ReadResourceResponse struct {
|
||||
Contents []ResourceContent `json:"contents"`
|
||||
}
|
||||
|
||||
// ResourceContent 资源内容
|
||||
type ResourceContent struct {
|
||||
URI string `json:"uri"`
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Blob string `json:"blob,omitempty"`
|
||||
}
|
||||
|
||||
// SamplingRequest 采样请求
|
||||
type SamplingRequest struct {
|
||||
Messages []SamplingMessage `json:"messages"`
|
||||
Model string `json:"model,omitempty"`
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
}
|
||||
|
||||
// SamplingMessage 采样消息
|
||||
type SamplingMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// SamplingResponse 采样响应
|
||||
type SamplingResponse struct {
|
||||
Content []SamplingContent `json:"content"`
|
||||
Model string `json:"model,omitempty"`
|
||||
StopReason string `json:"stopReason,omitempty"`
|
||||
}
|
||||
|
||||
// SamplingContent 采样内容
|
||||
type SamplingContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
730
internal/security/executor.go
Normal file
730
internal/security/executor.go
Normal file
@@ -0,0 +1,730 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Executor 安全工具执行器
|
||||
type Executor struct {
|
||||
config *config.SecurityConfig
|
||||
mcpServer *mcp.Server
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewExecutor 创建新的执行器
|
||||
func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.Logger) *Executor {
|
||||
return &Executor{
|
||||
config: cfg,
|
||||
mcpServer: mcpServer,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteTool 执行安全工具
|
||||
func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
e.logger.Info("ExecuteTool被调用",
|
||||
zap.String("toolName", toolName),
|
||||
zap.Any("args", args),
|
||||
)
|
||||
|
||||
// 特殊处理:exec工具直接执行系统命令
|
||||
if toolName == "exec" {
|
||||
e.logger.Info("执行exec工具")
|
||||
return e.executeSystemCommand(ctx, args)
|
||||
}
|
||||
|
||||
// 查找工具配置
|
||||
var toolConfig *config.ToolConfig
|
||||
for i := range e.config.Tools {
|
||||
if e.config.Tools[i].Name == toolName && e.config.Tools[i].Enabled {
|
||||
toolConfig = &e.config.Tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if toolConfig == nil {
|
||||
e.logger.Error("工具未找到或未启用",
|
||||
zap.String("toolName", toolName),
|
||||
zap.Int("totalTools", len(e.config.Tools)),
|
||||
)
|
||||
return nil, fmt.Errorf("工具 %s 未找到或未启用", toolName)
|
||||
}
|
||||
|
||||
e.logger.Info("找到工具配置",
|
||||
zap.String("toolName", toolName),
|
||||
zap.String("command", toolConfig.Command),
|
||||
zap.Strings("args", toolConfig.Args),
|
||||
)
|
||||
|
||||
// 构建命令 - 根据工具类型使用不同的参数格式
|
||||
cmdArgs := e.buildCommandArgs(toolName, toolConfig, args)
|
||||
|
||||
e.logger.Info("构建命令参数完成",
|
||||
zap.String("toolName", toolName),
|
||||
zap.Strings("cmdArgs", cmdArgs),
|
||||
zap.Int("argsCount", len(cmdArgs)),
|
||||
)
|
||||
|
||||
// 验证命令参数
|
||||
if len(cmdArgs) == 0 {
|
||||
e.logger.Warn("命令参数为空",
|
||||
zap.String("toolName", toolName),
|
||||
zap.Any("inputArgs", args),
|
||||
)
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("错误: 工具 %s 缺少必需的参数。接收到的参数: %v", toolName, args),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 执行命令
|
||||
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||
|
||||
e.logger.Info("执行安全工具",
|
||||
zap.String("tool", toolName),
|
||||
zap.Strings("args", cmdArgs),
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
e.logger.Error("工具执行失败",
|
||||
zap.String("tool", toolName),
|
||||
zap.Error(err),
|
||||
zap.String("output", string(output)),
|
||||
)
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("工具执行失败: %v\n输出: %s", err, string(output)),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
e.logger.Info("工具执行成功",
|
||||
zap.String("tool", toolName),
|
||||
zap.String("output", string(output)),
|
||||
)
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: string(output),
|
||||
},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RegisterTools 注册工具到MCP服务器
|
||||
func (e *Executor) RegisterTools(mcpServer *mcp.Server) {
|
||||
e.logger.Info("开始注册工具",
|
||||
zap.Int("totalTools", len(e.config.Tools)),
|
||||
)
|
||||
|
||||
for i, toolConfig := range e.config.Tools {
|
||||
if !toolConfig.Enabled {
|
||||
e.logger.Debug("跳过未启用的工具",
|
||||
zap.String("tool", toolConfig.Name),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// 创建工具配置的副本,避免闭包问题
|
||||
toolName := toolConfig.Name
|
||||
toolConfigCopy := toolConfig
|
||||
|
||||
tool := mcp.Tool{
|
||||
Name: toolConfigCopy.Name,
|
||||
Description: toolConfigCopy.Description,
|
||||
InputSchema: e.buildInputSchema(&toolConfigCopy),
|
||||
}
|
||||
|
||||
handler := func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
e.logger.Info("工具handler被调用",
|
||||
zap.String("toolName", toolName),
|
||||
zap.Any("args", args),
|
||||
)
|
||||
return e.ExecuteTool(ctx, toolName, args)
|
||||
}
|
||||
|
||||
mcpServer.RegisterTool(tool, handler)
|
||||
e.logger.Info("注册安全工具成功",
|
||||
zap.String("tool", toolConfigCopy.Name),
|
||||
zap.String("command", toolConfigCopy.Command),
|
||||
zap.Int("index", i),
|
||||
)
|
||||
}
|
||||
|
||||
e.logger.Info("工具注册完成",
|
||||
zap.Int("registeredCount", len(e.config.Tools)),
|
||||
)
|
||||
}
|
||||
|
||||
// buildCommandArgs 构建命令参数
|
||||
func (e *Executor) buildCommandArgs(toolName string, toolConfig *config.ToolConfig, args map[string]interface{}) []string {
|
||||
cmdArgs := make([]string, 0)
|
||||
|
||||
// 如果配置中定义了参数映射,使用配置中的映射规则
|
||||
if len(toolConfig.Parameters) > 0 {
|
||||
// 先添加固定参数
|
||||
cmdArgs = append(cmdArgs, toolConfig.Args...)
|
||||
|
||||
// 按位置参数排序
|
||||
positionalParams := make([]config.ParameterConfig, 0)
|
||||
flagParams := make([]config.ParameterConfig, 0)
|
||||
|
||||
for _, param := range toolConfig.Parameters {
|
||||
if param.Position != nil {
|
||||
positionalParams = append(positionalParams, param)
|
||||
} else {
|
||||
flagParams = append(flagParams, param)
|
||||
}
|
||||
}
|
||||
|
||||
// 对位置参数按位置排序
|
||||
for i := 0; i < len(positionalParams); i++ {
|
||||
for _, param := range positionalParams {
|
||||
if param.Position != nil && *param.Position == i {
|
||||
value := e.getParamValue(args, param)
|
||||
if value == nil {
|
||||
if param.Required {
|
||||
// 必需参数缺失,返回空数组让上层处理错误
|
||||
e.logger.Warn("缺少必需的位置参数",
|
||||
zap.String("tool", toolName),
|
||||
zap.String("param", param.Name),
|
||||
zap.Int("position", *param.Position),
|
||||
)
|
||||
return []string{}
|
||||
}
|
||||
break
|
||||
}
|
||||
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理标志参数
|
||||
for _, param := range flagParams {
|
||||
value := e.getParamValue(args, param)
|
||||
if value == nil {
|
||||
if param.Required {
|
||||
// 必需参数缺失,返回空数组让上层处理错误
|
||||
e.logger.Warn("缺少必需的标志参数",
|
||||
zap.String("tool", toolName),
|
||||
zap.String("param", param.Name),
|
||||
)
|
||||
return []string{}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 布尔值特殊处理:如果为 false,跳过;如果为 true,只添加标志
|
||||
if param.Type == "bool" {
|
||||
if boolVal, ok := value.(bool); ok {
|
||||
if !boolVal {
|
||||
continue // false 时不添加任何参数
|
||||
}
|
||||
// true 时只添加标志,不添加值
|
||||
if param.Flag != "" {
|
||||
cmdArgs = append(cmdArgs, param.Flag)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
format := param.Format
|
||||
if format == "" {
|
||||
format = "flag" // 默认格式
|
||||
}
|
||||
|
||||
switch format {
|
||||
case "flag":
|
||||
// --flag value 或 -f value
|
||||
if param.Flag != "" {
|
||||
cmdArgs = append(cmdArgs, param.Flag)
|
||||
}
|
||||
formattedValue := e.formatParamValue(param, value)
|
||||
if formattedValue != "" {
|
||||
cmdArgs = append(cmdArgs, formattedValue)
|
||||
}
|
||||
case "combined":
|
||||
// --flag=value 或 -f=value
|
||||
if param.Flag != "" {
|
||||
cmdArgs = append(cmdArgs, fmt.Sprintf("%s=%s", param.Flag, e.formatParamValue(param, value)))
|
||||
} else {
|
||||
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
|
||||
}
|
||||
case "template":
|
||||
// 使用模板字符串
|
||||
if param.Template != "" {
|
||||
template := param.Template
|
||||
template = strings.ReplaceAll(template, "{flag}", param.Flag)
|
||||
template = strings.ReplaceAll(template, "{value}", e.formatParamValue(param, value))
|
||||
template = strings.ReplaceAll(template, "{name}", param.Name)
|
||||
cmdArgs = append(cmdArgs, strings.Fields(template)...)
|
||||
} else {
|
||||
// 如果没有模板,使用默认格式
|
||||
if param.Flag != "" {
|
||||
cmdArgs = append(cmdArgs, param.Flag)
|
||||
}
|
||||
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
|
||||
}
|
||||
case "positional":
|
||||
// 位置参数(已在上面处理)
|
||||
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
|
||||
default:
|
||||
// 默认:直接添加值
|
||||
cmdArgs = append(cmdArgs, e.formatParamValue(param, value))
|
||||
}
|
||||
}
|
||||
|
||||
return cmdArgs
|
||||
}
|
||||
|
||||
// 向后兼容:如果没有定义参数,使用旧的硬编码逻辑
|
||||
switch toolName {
|
||||
case "nmap":
|
||||
// nmap -sT -sV -sC target [ports]
|
||||
// 使用 -sT (TCP连接扫描) 而不是 -sS (SYN扫描),因为 -sS 需要root权限
|
||||
e.logger.Debug("处理nmap参数",
|
||||
zap.Any("args", args),
|
||||
)
|
||||
|
||||
// 尝试多种方式获取target参数
|
||||
var target string
|
||||
var ok bool
|
||||
|
||||
// 方式1: 直接获取target
|
||||
if target, ok = args["target"].(string); !ok || target == "" {
|
||||
// 方式2: 尝试从tool字段获取(兼容某些格式)
|
||||
if toolVal, exists := args["tool"]; exists {
|
||||
if toolMap, ok := toolVal.(map[string]interface{}); ok {
|
||||
if t, ok := toolMap["target"].(string); ok {
|
||||
target = t
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if target == "" {
|
||||
e.logger.Warn("nmap缺少target参数",
|
||||
zap.Any("args", args),
|
||||
)
|
||||
return cmdArgs // 返回空数组,让上层处理错误
|
||||
}
|
||||
|
||||
e.logger.Debug("提取到target",
|
||||
zap.String("target", target),
|
||||
)
|
||||
|
||||
// 处理URL格式的目标(提取域名)
|
||||
if strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://") {
|
||||
// 提取域名部分
|
||||
target = strings.TrimPrefix(target, "http://")
|
||||
target = strings.TrimPrefix(target, "https://")
|
||||
// 移除路径部分
|
||||
if idx := strings.Index(target, "/"); idx != -1 {
|
||||
target = target[:idx]
|
||||
}
|
||||
}
|
||||
|
||||
// 添加扫描选项:-sT (TCP连接扫描,不需要root权限), -sV (版本检测), -sC (默认脚本)
|
||||
cmdArgs = append(cmdArgs, "-sT", "-sV", "-sC")
|
||||
|
||||
// 添加端口范围(如果指定)
|
||||
if ports, ok := args["ports"].(string); ok && ports != "" {
|
||||
cmdArgs = append(cmdArgs, "-p", ports)
|
||||
}
|
||||
|
||||
// 添加目标
|
||||
cmdArgs = append(cmdArgs, target)
|
||||
|
||||
e.logger.Debug("nmap命令参数构建完成",
|
||||
zap.Strings("cmdArgs", cmdArgs),
|
||||
)
|
||||
case "sqlmap":
|
||||
// sqlmap -u url
|
||||
if url, ok := args["url"].(string); ok {
|
||||
cmdArgs = append(cmdArgs, "-u", url, "--batch", "--level=3", "--risk=2")
|
||||
}
|
||||
case "nikto":
|
||||
// nikto -h target
|
||||
if target, ok := args["target"].(string); ok {
|
||||
cmdArgs = append(cmdArgs, "-h", target)
|
||||
}
|
||||
case "dirb":
|
||||
// dirb url
|
||||
if url, ok := args["url"].(string); ok {
|
||||
cmdArgs = append(cmdArgs, url)
|
||||
}
|
||||
default:
|
||||
// 通用处理
|
||||
cmdArgs = append(cmdArgs, toolConfig.Args...)
|
||||
for key, value := range args {
|
||||
if key == "_tool_name" {
|
||||
continue
|
||||
}
|
||||
cmdArgs = append(cmdArgs, fmt.Sprintf("--%s", key))
|
||||
if strValue, ok := value.(string); ok {
|
||||
cmdArgs = append(cmdArgs, strValue)
|
||||
} else {
|
||||
cmdArgs = append(cmdArgs, fmt.Sprintf("%v", value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cmdArgs
|
||||
}
|
||||
|
||||
// getParamValue 获取参数值,支持默认值
|
||||
func (e *Executor) getParamValue(args map[string]interface{}, param config.ParameterConfig) interface{} {
|
||||
// 从参数中获取值
|
||||
if value, ok := args[param.Name]; ok && value != nil {
|
||||
return value
|
||||
}
|
||||
|
||||
// 如果参数是必需的但没有提供,返回 nil(让上层处理错误)
|
||||
if param.Required {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 返回默认值
|
||||
return param.Default
|
||||
}
|
||||
|
||||
// formatParamValue 格式化参数值
|
||||
func (e *Executor) formatParamValue(param config.ParameterConfig, value interface{}) string {
|
||||
switch param.Type {
|
||||
case "bool":
|
||||
// 布尔值应该在上层处理,这里不应该被调用
|
||||
if boolVal, ok := value.(bool); ok {
|
||||
return fmt.Sprintf("%v", boolVal)
|
||||
}
|
||||
return "false"
|
||||
case "array":
|
||||
// 数组:转换为逗号分隔的字符串
|
||||
if arr, ok := value.([]interface{}); ok {
|
||||
strs := make([]string, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
strs = append(strs, fmt.Sprintf("%v", item))
|
||||
}
|
||||
return strings.Join(strs, ",")
|
||||
}
|
||||
return fmt.Sprintf("%v", value)
|
||||
default:
|
||||
return fmt.Sprintf("%v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// executeSystemCommand 执行系统命令
|
||||
func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
// 获取命令
|
||||
command, ok := args["command"].(string)
|
||||
if !ok {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: 缺少command参数",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if command == "" {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "错误: command参数不能为空",
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 安全检查:记录执行的命令
|
||||
e.logger.Warn("执行系统命令",
|
||||
zap.String("command", command),
|
||||
)
|
||||
|
||||
// 获取shell类型(可选,默认为sh)
|
||||
shell := "sh"
|
||||
if s, ok := args["shell"].(string); ok && s != "" {
|
||||
shell = s
|
||||
}
|
||||
|
||||
// 获取工作目录(可选)
|
||||
workDir := ""
|
||||
if wd, ok := args["workdir"].(string); ok && wd != "" {
|
||||
workDir = wd
|
||||
}
|
||||
|
||||
// 构建命令
|
||||
var cmd *exec.Cmd
|
||||
if workDir != "" {
|
||||
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
||||
cmd.Dir = workDir
|
||||
} else {
|
||||
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
||||
}
|
||||
|
||||
// 执行命令
|
||||
e.logger.Info("执行系统命令",
|
||||
zap.String("command", command),
|
||||
zap.String("shell", shell),
|
||||
zap.String("workdir", workDir),
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
e.logger.Error("系统命令执行失败",
|
||||
zap.String("command", command),
|
||||
zap.Error(err),
|
||||
zap.String("output", string(output)),
|
||||
)
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)),
|
||||
},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
e.logger.Info("系统命令执行成功",
|
||||
zap.String("command", command),
|
||||
zap.String("output_length", fmt.Sprintf("%d", len(output))),
|
||||
)
|
||||
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
{
|
||||
Type: "text",
|
||||
Text: string(output),
|
||||
},
|
||||
},
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildInputSchema 构建输入模式
|
||||
func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} {
|
||||
schema := map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
"required": []string{},
|
||||
}
|
||||
|
||||
// 如果配置中定义了参数,优先使用配置中的参数定义
|
||||
if len(toolConfig.Parameters) > 0 {
|
||||
properties := make(map[string]interface{})
|
||||
required := []string{}
|
||||
|
||||
for _, param := range toolConfig.Parameters {
|
||||
prop := map[string]interface{}{
|
||||
"type": param.Type,
|
||||
"description": param.Description,
|
||||
}
|
||||
|
||||
// 添加默认值
|
||||
if param.Default != nil {
|
||||
prop["default"] = param.Default
|
||||
}
|
||||
|
||||
// 添加枚举选项
|
||||
if len(param.Options) > 0 {
|
||||
prop["enum"] = param.Options
|
||||
}
|
||||
|
||||
properties[param.Name] = prop
|
||||
|
||||
// 添加到必需参数列表
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
schema["properties"] = properties
|
||||
schema["required"] = required
|
||||
return schema
|
||||
}
|
||||
|
||||
// 向后兼容:如果没有定义参数,使用旧的硬编码逻辑
|
||||
switch toolConfig.Name {
|
||||
case "nmap":
|
||||
schema["properties"] = map[string]interface{}{
|
||||
"target": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "目标IP地址或域名",
|
||||
},
|
||||
"ports": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "端口范围,例如: 1-1000",
|
||||
},
|
||||
}
|
||||
schema["required"] = []string{"target"}
|
||||
case "sqlmap":
|
||||
schema["properties"] = map[string]interface{}{
|
||||
"url": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "目标URL",
|
||||
},
|
||||
}
|
||||
schema["required"] = []string{"url"}
|
||||
case "nikto", "dirb":
|
||||
schema["properties"] = map[string]interface{}{
|
||||
"target": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "目标URL",
|
||||
},
|
||||
}
|
||||
schema["required"] = []string{"target"}
|
||||
case "exec":
|
||||
schema["properties"] = map[string]interface{}{
|
||||
"command": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "要执行的系统命令",
|
||||
},
|
||||
"shell": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "使用的shell(可选,默认为sh)",
|
||||
},
|
||||
"workdir": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "工作目录(可选)",
|
||||
},
|
||||
}
|
||||
schema["required"] = []string{"command"}
|
||||
}
|
||||
|
||||
return schema
|
||||
}
|
||||
|
||||
// Vulnerability 漏洞信息
|
||||
type Vulnerability struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Severity string `json:"severity"` // low, medium, high, critical
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Target string `json:"target"`
|
||||
FoundAt time.Time `json:"foundAt"`
|
||||
Details string `json:"details"`
|
||||
}
|
||||
|
||||
// AnalyzeResults 分析工具执行结果,提取漏洞信息
|
||||
func (e *Executor) AnalyzeResults(toolName string, result *mcp.ToolResult) []Vulnerability {
|
||||
vulnerabilities := []Vulnerability{}
|
||||
|
||||
if result.IsError {
|
||||
return vulnerabilities
|
||||
}
|
||||
|
||||
// 分析输出内容
|
||||
for _, content := range result.Content {
|
||||
if content.Type == "text" {
|
||||
vulns := e.parseToolOutput(toolName, content.Text)
|
||||
vulnerabilities = append(vulnerabilities, vulns...)
|
||||
}
|
||||
}
|
||||
|
||||
return vulnerabilities
|
||||
}
|
||||
|
||||
// parseToolOutput 解析工具输出
|
||||
func (e *Executor) parseToolOutput(toolName, output string) []Vulnerability {
|
||||
vulnerabilities := []Vulnerability{}
|
||||
|
||||
// 简单的漏洞检测逻辑
|
||||
outputLower := strings.ToLower(output)
|
||||
|
||||
// SQL注入检测
|
||||
if strings.Contains(outputLower, "sql injection") || strings.Contains(outputLower, "sqli") {
|
||||
vulnerabilities = append(vulnerabilities, Vulnerability{
|
||||
ID: fmt.Sprintf("sql-%d", time.Now().Unix()),
|
||||
Type: "SQL Injection",
|
||||
Severity: "high",
|
||||
Title: "SQL注入漏洞",
|
||||
Description: "检测到潜在的SQL注入漏洞",
|
||||
FoundAt: time.Now(),
|
||||
Details: output,
|
||||
})
|
||||
}
|
||||
|
||||
// XSS检测
|
||||
if strings.Contains(outputLower, "xss") || strings.Contains(outputLower, "cross-site scripting") {
|
||||
vulnerabilities = append(vulnerabilities, Vulnerability{
|
||||
ID: fmt.Sprintf("xss-%d", time.Now().Unix()),
|
||||
Type: "XSS",
|
||||
Severity: "medium",
|
||||
Title: "跨站脚本攻击漏洞",
|
||||
Description: "检测到潜在的XSS漏洞",
|
||||
FoundAt: time.Now(),
|
||||
Details: output,
|
||||
})
|
||||
}
|
||||
|
||||
// 开放端口检测
|
||||
if toolName == "nmap" {
|
||||
lines := strings.Split(output, "\n")
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, "open") && strings.Contains(line, "port") {
|
||||
vulnerabilities = append(vulnerabilities, Vulnerability{
|
||||
ID: fmt.Sprintf("port-%d", time.Now().Unix()),
|
||||
Type: "Open Port",
|
||||
Severity: "low",
|
||||
Title: "开放端口",
|
||||
Description: fmt.Sprintf("发现开放端口: %s", line),
|
||||
FoundAt: time.Now(),
|
||||
Details: line,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return vulnerabilities
|
||||
}
|
||||
|
||||
// GetVulnerabilityReport 生成漏洞报告
|
||||
func (e *Executor) GetVulnerabilityReport(vulnerabilities []Vulnerability) map[string]interface{} {
|
||||
severityCount := map[string]int{
|
||||
"critical": 0,
|
||||
"high": 0,
|
||||
"medium": 0,
|
||||
"low": 0,
|
||||
}
|
||||
|
||||
for _, vuln := range vulnerabilities {
|
||||
severityCount[vuln.Severity]++
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total": len(vulnerabilities),
|
||||
"severityCount": severityCount,
|
||||
"vulnerabilities": vulnerabilities,
|
||||
"generatedAt": time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
35
run.sh
Normal file
35
run.sh
Normal file
@@ -0,0 +1,35 @@
|
||||
#!/bin/bash
|
||||
|
||||
# CyberStrikeAI 启动脚本
|
||||
|
||||
echo "🚀 启动 CyberStrikeAI..."
|
||||
|
||||
# 检查配置文件
|
||||
if [ ! -f "config.yaml" ]; then
|
||||
echo "❌ 配置文件 config.yaml 不存在"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 检查Go环境
|
||||
if ! command -v go &> /dev/null; then
|
||||
echo "❌ Go 未安装,请先安装 Go 1.21 或更高版本"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 下载依赖
|
||||
echo "📦 下载依赖..."
|
||||
go mod download
|
||||
|
||||
# 构建项目
|
||||
echo "🔨 构建项目..."
|
||||
go build -o cyberstrike-ai cmd/server/main.go
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "❌ 构建失败"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 运行服务器
|
||||
echo "✅ 启动服务器..."
|
||||
./cyberstrike-ai
|
||||
|
||||
691
web/static/css/style.css
Normal file
691
web/static/css/style.css
Normal file
@@ -0,0 +1,691 @@
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
:root {
|
||||
--primary-color: #1a1a1a;
|
||||
--secondary-color: #2d2d2d;
|
||||
--accent-color: #0066ff;
|
||||
--accent-hover: #0052cc;
|
||||
--bg-primary: #ffffff;
|
||||
--bg-secondary: #f8f9fa;
|
||||
--bg-tertiary: #f1f3f5;
|
||||
--text-primary: #1a1a1a;
|
||||
--text-secondary: #6c757d;
|
||||
--text-muted: #adb5bd;
|
||||
--border-color: #e9ecef;
|
||||
--success-color: #28a745;
|
||||
--warning-color: #ffc107;
|
||||
--error-color: #dc3545;
|
||||
--shadow-sm: 0 1px 3px rgba(0, 0, 0, 0.05);
|
||||
--shadow-md: 0 4px 6px rgba(0, 0, 0, 0.1);
|
||||
--shadow-lg: 0 10px 25px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Helvetica Neue', Arial, sans-serif;
|
||||
background: var(--bg-secondary);
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
color: var(--text-primary);
|
||||
line-height: 1.6;
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 100%;
|
||||
margin: 0;
|
||||
background: var(--bg-primary);
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
box-shadow: var(--shadow-lg);
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.main-layout {
|
||||
display: flex;
|
||||
flex: 1;
|
||||
overflow: hidden;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
header {
|
||||
background: var(--primary-color);
|
||||
color: white;
|
||||
padding: 24px 32px;
|
||||
border-bottom: 1px solid rgba(255, 255, 255, 0.1);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.header-content {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.logo {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.logo svg {
|
||||
color: var(--accent-color);
|
||||
}
|
||||
|
||||
.logo h1 {
|
||||
font-size: 1.75rem;
|
||||
font-weight: 600;
|
||||
letter-spacing: -0.5px;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.header-subtitle {
|
||||
font-size: 0.875rem;
|
||||
color: rgba(255, 255, 255, 0.7);
|
||||
margin: 0;
|
||||
font-weight: 400;
|
||||
}
|
||||
|
||||
/* 侧边栏样式 */
|
||||
.sidebar {
|
||||
width: 280px;
|
||||
background: var(--bg-secondary);
|
||||
border-right: 1px solid var(--border-color);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
flex-shrink: 0;
|
||||
height: 100%;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sidebar-header {
|
||||
padding: 16px;
|
||||
border-bottom: 1px solid var(--border-color);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.new-chat-btn {
|
||||
width: 100%;
|
||||
padding: 10px 16px;
|
||||
background: var(--accent-color);
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
font-size: 0.9375rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.new-chat-btn:hover {
|
||||
background: var(--accent-hover);
|
||||
transform: translateY(-1px);
|
||||
box-shadow: var(--shadow-sm);
|
||||
}
|
||||
|
||||
.new-chat-btn span {
|
||||
font-size: 1.2em;
|
||||
line-height: 1;
|
||||
}
|
||||
|
||||
.sidebar-content {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
padding: 16px;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.sidebar-title {
|
||||
font-size: 0.8125rem;
|
||||
font-weight: 600;
|
||||
color: var(--text-secondary);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 12px;
|
||||
padding: 0 8px;
|
||||
}
|
||||
|
||||
.conversations-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.conversation-item {
|
||||
padding: 12px;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
border: 1px solid transparent;
|
||||
}
|
||||
|
||||
.conversation-item:hover {
|
||||
background: var(--bg-tertiary);
|
||||
}
|
||||
|
||||
.conversation-item.active {
|
||||
background: var(--bg-primary);
|
||||
border-color: var(--accent-color);
|
||||
}
|
||||
|
||||
.conversation-title {
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
color: var(--text-primary);
|
||||
margin-bottom: 4px;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.conversation-time {
|
||||
font-size: 0.75rem;
|
||||
color: var(--text-muted);
|
||||
}
|
||||
|
||||
/* 对话界面样式 */
|
||||
.chat-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
background: var(--bg-primary);
|
||||
overflow: hidden;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.chat-messages {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
padding: 24px;
|
||||
background: var(--bg-secondary);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.message {
|
||||
margin-bottom: 24px;
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 12px;
|
||||
animation: fadeIn 0.3s ease-in;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(10px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.message.user {
|
||||
flex-direction: row-reverse;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
|
||||
.message.system {
|
||||
justify-content: center;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.message-avatar {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
border-radius: 6px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
flex-shrink: 0;
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.message.user .message-avatar {
|
||||
background: var(--accent-color);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.message.assistant .message-avatar {
|
||||
background: var(--bg-tertiary);
|
||||
color: var(--text-secondary);
|
||||
border: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.message.system .message-avatar {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.message-content {
|
||||
flex: 0 1 auto;
|
||||
max-width: 70%;
|
||||
min-width: 120px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.message.user .message-content {
|
||||
align-items: flex-end;
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.message.assistant .message-content {
|
||||
align-items: flex-start;
|
||||
margin-right: auto;
|
||||
}
|
||||
|
||||
.message.system .message-content {
|
||||
max-width: 90%;
|
||||
align-items: center;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
.message-bubble {
|
||||
padding: 12px 16px;
|
||||
border-radius: 8px;
|
||||
word-wrap: break-word;
|
||||
word-break: break-word;
|
||||
line-height: 1.6;
|
||||
box-shadow: var(--shadow-sm);
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.message.user .message-bubble {
|
||||
background: var(--accent-color);
|
||||
color: white;
|
||||
border-bottom-right-radius: 2px;
|
||||
}
|
||||
|
||||
.message.assistant .message-bubble {
|
||||
background: var(--bg-primary);
|
||||
color: var(--text-primary);
|
||||
border: 1px solid var(--border-color);
|
||||
border-bottom-left-radius: 2px;
|
||||
}
|
||||
|
||||
.message.assistant .message-bubble pre {
|
||||
margin: 0;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
.message.system .message-bubble {
|
||||
background: var(--bg-tertiary);
|
||||
color: var(--text-secondary);
|
||||
border: 1px solid var(--border-color);
|
||||
text-align: center;
|
||||
font-size: 0.875rem;
|
||||
padding: 10px 16px;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.message-time {
|
||||
font-size: 0.6875rem;
|
||||
color: var(--text-muted);
|
||||
margin-top: 4px;
|
||||
padding: 0 2px;
|
||||
font-weight: 400;
|
||||
}
|
||||
|
||||
/* MCP调用区域 */
|
||||
.mcp-call-section {
|
||||
margin-top: 12px;
|
||||
padding-top: 12px;
|
||||
border-top: 1px solid var(--border-color);
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.mcp-call-label {
|
||||
font-size: 0.75rem;
|
||||
color: var(--text-secondary);
|
||||
margin-bottom: 8px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.mcp-call-label::before {
|
||||
content: '';
|
||||
width: 4px;
|
||||
height: 4px;
|
||||
background: var(--accent-color);
|
||||
border-radius: 50%;
|
||||
display: inline-block;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.mcp-call-buttons {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.chat-input-container {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
padding: 20px 24px;
|
||||
background: var(--bg-primary);
|
||||
border-top: 1px solid var(--border-color);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.chat-input-container input {
|
||||
flex: 1;
|
||||
padding: 12px 16px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 8px;
|
||||
font-size: 0.9375rem;
|
||||
outline: none;
|
||||
transition: all 0.2s;
|
||||
background: var(--bg-primary);
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.chat-input-container input:focus {
|
||||
border-color: var(--accent-color);
|
||||
box-shadow: 0 0 0 3px rgba(0, 102, 255, 0.1);
|
||||
}
|
||||
|
||||
.chat-input-container input::placeholder {
|
||||
color: var(--text-muted);
|
||||
}
|
||||
|
||||
.chat-input-container button {
|
||||
padding: 12px 24px;
|
||||
background: var(--accent-color);
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
font-size: 0.9375rem;
|
||||
font-weight: 500;
|
||||
transition: all 0.2s;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.chat-input-container button:hover {
|
||||
background: var(--accent-hover);
|
||||
transform: translateY(-1px);
|
||||
box-shadow: var(--shadow-md);
|
||||
}
|
||||
|
||||
.chat-input-container button:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
/* MCP调用详情按钮 */
|
||||
.mcp-detail-btn {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 6px 12px;
|
||||
background: var(--bg-primary);
|
||||
color: var(--accent-color);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 6px;
|
||||
font-size: 0.8125rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.mcp-detail-btn:hover {
|
||||
background: var(--accent-color);
|
||||
color: white;
|
||||
border-color: var(--accent-color);
|
||||
transform: translateY(-1px);
|
||||
box-shadow: var(--shadow-sm);
|
||||
}
|
||||
|
||||
.mcp-detail-btn:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
/* 模态框样式 */
|
||||
.modal {
|
||||
display: none;
|
||||
position: fixed;
|
||||
z-index: 1000;
|
||||
left: 0;
|
||||
top: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-color: rgba(0, 0, 0, 0.6);
|
||||
backdrop-filter: blur(4px);
|
||||
overflow: auto;
|
||||
animation: fadeIn 0.2s ease-in;
|
||||
}
|
||||
|
||||
.modal-content {
|
||||
background-color: var(--bg-primary);
|
||||
margin: 5% auto;
|
||||
padding: 0;
|
||||
border-radius: 12px;
|
||||
width: 90%;
|
||||
max-width: 900px;
|
||||
max-height: 85vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
box-shadow: var(--shadow-lg);
|
||||
border: 1px solid var(--border-color);
|
||||
animation: slideDown 0.3s ease-out;
|
||||
}
|
||||
|
||||
@keyframes slideDown {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.modal-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 20px 24px;
|
||||
border-bottom: 1px solid var(--border-color);
|
||||
background: var(--bg-primary);
|
||||
}
|
||||
|
||||
.modal-header h2 {
|
||||
margin: 0;
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.modal-close {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
color: var(--text-secondary);
|
||||
font-size: 1.5rem;
|
||||
line-height: 1;
|
||||
transition: all 0.2s;
|
||||
border: none;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.modal-close:hover {
|
||||
background: var(--bg-tertiary);
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.modal-body {
|
||||
padding: 24px;
|
||||
overflow-y: auto;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.detail-section {
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
|
||||
.detail-section:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.detail-section h3 {
|
||||
color: var(--text-primary);
|
||||
margin-bottom: 12px;
|
||||
padding-bottom: 8px;
|
||||
border-bottom: 1px solid var(--border-color);
|
||||
font-size: 0.9375rem;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.detail-item {
|
||||
margin-bottom: 10px;
|
||||
padding: 8px 0;
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.detail-item strong {
|
||||
color: var(--text-secondary);
|
||||
font-weight: 500;
|
||||
font-size: 0.875rem;
|
||||
min-width: 80px;
|
||||
}
|
||||
|
||||
.detail-item span {
|
||||
color: var(--text-primary);
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.code-block {
|
||||
background: var(--bg-tertiary);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 8px;
|
||||
padding: 16px;
|
||||
font-family: 'SF Mono', 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', monospace;
|
||||
font-size: 0.8125rem;
|
||||
line-height: 1.6;
|
||||
overflow-x: auto;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
max-height: 400px;
|
||||
overflow-y: auto;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.code-block.error {
|
||||
background: #fff5f5;
|
||||
border-color: var(--error-color);
|
||||
color: var(--error-color);
|
||||
}
|
||||
|
||||
/* 滚动条样式 */
|
||||
::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-track {
|
||||
background: var(--bg-secondary);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb {
|
||||
background: var(--text-muted);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: var(--text-secondary);
|
||||
}
|
||||
|
||||
/* 响应式设计 */
|
||||
@media (max-width: 768px) {
|
||||
body {
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.container {
|
||||
border-radius: 0;
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
header {
|
||||
padding: 16px 20px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.logo h1 {
|
||||
font-size: 1.5rem;
|
||||
}
|
||||
|
||||
.header-subtitle {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.main-layout {
|
||||
overflow: hidden;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.sidebar {
|
||||
height: 100%;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.sidebar-content {
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.chat-container {
|
||||
height: 100%;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.chat-messages {
|
||||
padding: 16px;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.message-content {
|
||||
max-width: 85%;
|
||||
}
|
||||
|
||||
.chat-input-container {
|
||||
padding: 16px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.modal-content {
|
||||
width: 95%;
|
||||
margin: 10% auto;
|
||||
}
|
||||
}
|
||||
400
web/static/js/app.js
Normal file
400
web/static/js/app.js
Normal file
@@ -0,0 +1,400 @@
|
||||
|
||||
// 当前对话ID
|
||||
let currentConversationId = null;
|
||||
|
||||
// 发送消息
|
||||
async function sendMessage() {
|
||||
const input = document.getElementById('chat-input');
|
||||
const message = input.value.trim();
|
||||
|
||||
if (!message) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 显示用户消息
|
||||
addMessage('user', message);
|
||||
input.value = '';
|
||||
|
||||
// 显示加载状态
|
||||
const loadingId = addMessage('system', '正在处理中...');
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/agent-loop', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
message: message,
|
||||
conversationId: currentConversationId
|
||||
}),
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
// 移除加载消息
|
||||
removeMessage(loadingId);
|
||||
|
||||
if (response.ok) {
|
||||
// 更新当前对话ID
|
||||
if (data.conversationId) {
|
||||
currentConversationId = data.conversationId;
|
||||
updateActiveConversation();
|
||||
}
|
||||
|
||||
// 如果有MCP执行ID,显示所有调用
|
||||
const mcpIds = data.mcpExecutionIds || [];
|
||||
addMessage('assistant', data.response, mcpIds);
|
||||
|
||||
// 刷新对话列表
|
||||
loadConversations();
|
||||
} else {
|
||||
addMessage('system', '错误: ' + (data.error || '未知错误'));
|
||||
}
|
||||
} catch (error) {
|
||||
removeMessage(loadingId);
|
||||
addMessage('system', '错误: ' + error.message);
|
||||
}
|
||||
}
|
||||
|
||||
// 消息计数器,确保ID唯一
|
||||
let messageCounter = 0;
|
||||
|
||||
// 添加消息
|
||||
function addMessage(role, content, mcpExecutionIds = null) {
|
||||
const messagesDiv = document.getElementById('chat-messages');
|
||||
const messageDiv = document.createElement('div');
|
||||
messageCounter++;
|
||||
const id = 'msg-' + Date.now() + '-' + messageCounter + '-' + Math.random().toString(36).substr(2, 9);
|
||||
messageDiv.id = id;
|
||||
messageDiv.className = 'message ' + role;
|
||||
|
||||
// 创建头像
|
||||
const avatar = document.createElement('div');
|
||||
avatar.className = 'message-avatar';
|
||||
if (role === 'user') {
|
||||
avatar.textContent = 'U';
|
||||
} else if (role === 'assistant') {
|
||||
avatar.textContent = 'A';
|
||||
} else {
|
||||
avatar.textContent = 'S';
|
||||
}
|
||||
messageDiv.appendChild(avatar);
|
||||
|
||||
// 创建消息内容容器
|
||||
const contentWrapper = document.createElement('div');
|
||||
contentWrapper.className = 'message-content';
|
||||
|
||||
// 创建消息气泡
|
||||
const bubble = document.createElement('div');
|
||||
bubble.className = 'message-bubble';
|
||||
// 处理换行和格式化
|
||||
const formattedContent = content.replace(/\n/g, '<br>');
|
||||
bubble.innerHTML = formattedContent;
|
||||
contentWrapper.appendChild(bubble);
|
||||
|
||||
// 添加时间戳
|
||||
const timeDiv = document.createElement('div');
|
||||
timeDiv.className = 'message-time';
|
||||
timeDiv.textContent = new Date().toLocaleTimeString('zh-CN', { hour: '2-digit', minute: '2-digit' });
|
||||
contentWrapper.appendChild(timeDiv);
|
||||
|
||||
// 如果有MCP执行ID,添加查看详情区域
|
||||
if (mcpExecutionIds && Array.isArray(mcpExecutionIds) && mcpExecutionIds.length > 0 && role === 'assistant') {
|
||||
const mcpSection = document.createElement('div');
|
||||
mcpSection.className = 'mcp-call-section';
|
||||
|
||||
const mcpLabel = document.createElement('div');
|
||||
mcpLabel.className = 'mcp-call-label';
|
||||
mcpLabel.textContent = `工具调用 (${mcpExecutionIds.length})`;
|
||||
mcpSection.appendChild(mcpLabel);
|
||||
|
||||
const buttonsContainer = document.createElement('div');
|
||||
buttonsContainer.className = 'mcp-call-buttons';
|
||||
|
||||
mcpExecutionIds.forEach((execId, index) => {
|
||||
const detailBtn = document.createElement('button');
|
||||
detailBtn.className = 'mcp-detail-btn';
|
||||
detailBtn.innerHTML = `<span>调用 #${index + 1}</span>`;
|
||||
detailBtn.onclick = () => showMCPDetail(execId);
|
||||
buttonsContainer.appendChild(detailBtn);
|
||||
});
|
||||
|
||||
mcpSection.appendChild(buttonsContainer);
|
||||
contentWrapper.appendChild(mcpSection);
|
||||
}
|
||||
|
||||
messageDiv.appendChild(contentWrapper);
|
||||
messagesDiv.appendChild(messageDiv);
|
||||
messagesDiv.scrollTop = messagesDiv.scrollHeight;
|
||||
return id;
|
||||
}
|
||||
|
||||
// 移除消息
|
||||
function removeMessage(id) {
|
||||
const messageDiv = document.getElementById(id);
|
||||
if (messageDiv) {
|
||||
messageDiv.remove();
|
||||
}
|
||||
}
|
||||
|
||||
// 回车发送消息
|
||||
document.getElementById('chat-input').addEventListener('keypress', function(e) {
|
||||
if (e.key === 'Enter') {
|
||||
sendMessage();
|
||||
}
|
||||
});
|
||||
|
||||
// 显示MCP调用详情
|
||||
async function showMCPDetail(executionId) {
|
||||
try {
|
||||
const response = await fetch(`/api/monitor/execution/${executionId}`);
|
||||
const exec = await response.json();
|
||||
|
||||
if (response.ok) {
|
||||
// 填充模态框内容
|
||||
document.getElementById('detail-tool-name').textContent = exec.toolName || 'Unknown';
|
||||
document.getElementById('detail-execution-id').textContent = exec.id || 'N/A';
|
||||
document.getElementById('detail-status').textContent = getStatusText(exec.status);
|
||||
document.getElementById('detail-time').textContent = new Date(exec.startTime).toLocaleString('zh-CN');
|
||||
|
||||
// 请求参数
|
||||
const requestData = {
|
||||
tool: exec.toolName,
|
||||
arguments: exec.arguments
|
||||
};
|
||||
document.getElementById('detail-request').textContent = JSON.stringify(requestData, null, 2);
|
||||
|
||||
// 响应结果
|
||||
if (exec.result) {
|
||||
const responseData = {
|
||||
content: exec.result.content,
|
||||
isError: exec.result.isError
|
||||
};
|
||||
document.getElementById('detail-response').textContent = JSON.stringify(responseData, null, 2);
|
||||
document.getElementById('detail-response').className = exec.result.isError ? 'code-block error' : 'code-block';
|
||||
} else {
|
||||
document.getElementById('detail-response').textContent = '暂无响应数据';
|
||||
}
|
||||
|
||||
// 错误信息
|
||||
if (exec.error) {
|
||||
document.getElementById('detail-error-section').style.display = 'block';
|
||||
document.getElementById('detail-error').textContent = exec.error;
|
||||
} else {
|
||||
document.getElementById('detail-error-section').style.display = 'none';
|
||||
}
|
||||
|
||||
// 显示模态框
|
||||
document.getElementById('mcp-detail-modal').style.display = 'block';
|
||||
} else {
|
||||
alert('获取详情失败: ' + (exec.error || '未知错误'));
|
||||
}
|
||||
} catch (error) {
|
||||
alert('获取详情失败: ' + error.message);
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭MCP详情模态框
|
||||
function closeMCPDetail() {
|
||||
document.getElementById('mcp-detail-modal').style.display = 'none';
|
||||
}
|
||||
|
||||
// 点击模态框外部关闭
|
||||
window.onclick = function(event) {
|
||||
const modal = document.getElementById('mcp-detail-modal');
|
||||
if (event.target == modal) {
|
||||
closeMCPDetail();
|
||||
}
|
||||
}
|
||||
|
||||
// 工具函数
|
||||
function getStatusText(status) {
|
||||
const statusMap = {
|
||||
'pending': '等待中',
|
||||
'running': '执行中',
|
||||
'completed': '已完成',
|
||||
'failed': '失败'
|
||||
};
|
||||
return statusMap[status] || status;
|
||||
}
|
||||
|
||||
function formatDuration(ms) {
|
||||
const seconds = Math.floor(ms / 1000);
|
||||
const minutes = Math.floor(seconds / 60);
|
||||
const hours = Math.floor(minutes / 60);
|
||||
|
||||
if (hours > 0) {
|
||||
return `${hours}小时${minutes % 60}分钟`;
|
||||
} else if (minutes > 0) {
|
||||
return `${minutes}分钟${seconds % 60}秒`;
|
||||
} else {
|
||||
return `${seconds}秒`;
|
||||
}
|
||||
}
|
||||
|
||||
function escapeHtml(text) {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = text;
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
// 开始新对话
|
||||
function startNewConversation() {
|
||||
currentConversationId = null;
|
||||
document.getElementById('chat-messages').innerHTML = '';
|
||||
addMessage('assistant', '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。');
|
||||
updateActiveConversation();
|
||||
}
|
||||
|
||||
// 加载对话列表
|
||||
async function loadConversations() {
|
||||
try {
|
||||
const response = await fetch('/api/conversations?limit=50');
|
||||
const conversations = await response.json();
|
||||
|
||||
const listContainer = document.getElementById('conversations-list');
|
||||
listContainer.innerHTML = '';
|
||||
|
||||
if (conversations.length === 0) {
|
||||
listContainer.innerHTML = '<div style="padding: 20px; text-align: center; color: var(--text-muted); font-size: 0.875rem;">暂无历史对话</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
conversations.forEach(conv => {
|
||||
const item = document.createElement('div');
|
||||
item.className = 'conversation-item';
|
||||
item.dataset.conversationId = conv.id;
|
||||
if (conv.id === currentConversationId) {
|
||||
item.classList.add('active');
|
||||
}
|
||||
|
||||
const title = document.createElement('div');
|
||||
title.className = 'conversation-title';
|
||||
title.textContent = conv.title || '未命名对话';
|
||||
item.appendChild(title);
|
||||
|
||||
const time = document.createElement('div');
|
||||
time.className = 'conversation-time';
|
||||
// 解析时间,支持多种格式
|
||||
let dateObj;
|
||||
if (conv.updatedAt) {
|
||||
dateObj = new Date(conv.updatedAt);
|
||||
// 检查日期是否有效
|
||||
if (isNaN(dateObj.getTime())) {
|
||||
// 如果解析失败,尝试其他格式
|
||||
console.warn('时间解析失败:', conv.updatedAt);
|
||||
dateObj = new Date();
|
||||
}
|
||||
} else {
|
||||
dateObj = new Date();
|
||||
}
|
||||
|
||||
// 格式化时间显示
|
||||
const now = new Date();
|
||||
const today = new Date(now.getFullYear(), now.getMonth(), now.getDate());
|
||||
const yesterday = new Date(today);
|
||||
yesterday.setDate(yesterday.getDate() - 1);
|
||||
const messageDate = new Date(dateObj.getFullYear(), dateObj.getMonth(), dateObj.getDate());
|
||||
|
||||
let timeText;
|
||||
if (messageDate.getTime() === today.getTime()) {
|
||||
// 今天:只显示时间
|
||||
timeText = dateObj.toLocaleTimeString('zh-CN', {
|
||||
hour: '2-digit',
|
||||
minute: '2-digit'
|
||||
});
|
||||
} else if (messageDate.getTime() === yesterday.getTime()) {
|
||||
// 昨天
|
||||
timeText = '昨天 ' + dateObj.toLocaleTimeString('zh-CN', {
|
||||
hour: '2-digit',
|
||||
minute: '2-digit'
|
||||
});
|
||||
} else if (now.getFullYear() === dateObj.getFullYear()) {
|
||||
// 今年:显示月日和时间
|
||||
timeText = dateObj.toLocaleString('zh-CN', {
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit'
|
||||
});
|
||||
} else {
|
||||
// 去年或更早:显示完整日期和时间
|
||||
timeText = dateObj.toLocaleString('zh-CN', {
|
||||
year: 'numeric',
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit'
|
||||
});
|
||||
}
|
||||
|
||||
time.textContent = timeText;
|
||||
item.appendChild(time);
|
||||
|
||||
item.onclick = () => loadConversation(conv.id);
|
||||
listContainer.appendChild(item);
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('加载对话列表失败:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// 加载对话
|
||||
async function loadConversation(conversationId) {
|
||||
try {
|
||||
const response = await fetch(`/api/conversations/${conversationId}`);
|
||||
const conversation = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
alert('加载对话失败: ' + (conversation.error || '未知错误'));
|
||||
return;
|
||||
}
|
||||
|
||||
// 更新当前对话ID
|
||||
currentConversationId = conversationId;
|
||||
updateActiveConversation();
|
||||
|
||||
// 清空消息区域
|
||||
const messagesDiv = document.getElementById('chat-messages');
|
||||
messagesDiv.innerHTML = '';
|
||||
|
||||
// 加载消息
|
||||
if (conversation.messages && conversation.messages.length > 0) {
|
||||
conversation.messages.forEach(msg => {
|
||||
addMessage(msg.role, msg.content, msg.mcpExecutionIds || []);
|
||||
});
|
||||
} else {
|
||||
addMessage('assistant', '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。');
|
||||
}
|
||||
|
||||
// 滚动到底部
|
||||
messagesDiv.scrollTop = messagesDiv.scrollHeight;
|
||||
|
||||
// 刷新对话列表
|
||||
loadConversations();
|
||||
} catch (error) {
|
||||
console.error('加载对话失败:', error);
|
||||
alert('加载对话失败: ' + error.message);
|
||||
}
|
||||
}
|
||||
|
||||
// 更新活动对话样式
|
||||
function updateActiveConversation() {
|
||||
document.querySelectorAll('.conversation-item').forEach(item => {
|
||||
item.classList.remove('active');
|
||||
if (currentConversationId && item.dataset.conversationId === currentConversationId) {
|
||||
item.classList.add('active');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 页面加载时初始化
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
// 加载对话列表
|
||||
loadConversations();
|
||||
|
||||
// 添加欢迎消息
|
||||
addMessage('assistant', '系统已就绪。请输入您的测试需求,系统将自动执行相应的安全测试。');
|
||||
});
|
||||
|
||||
92
web/templates/index.html
Normal file
92
web/templates/index.html
Normal file
@@ -0,0 +1,92 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>CyberStrikeAI - 自主渗透测试平台</title>
|
||||
<link rel="stylesheet" href="/static/css/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<div class="header-content">
|
||||
<div class="logo">
|
||||
<svg width="32" height="32" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M12 2L2 7L12 12L22 7L12 2Z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M2 17L12 22L22 17" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M2 12L12 17L22 12" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
<h1>CyberStrike</h1>
|
||||
</div>
|
||||
<p class="header-subtitle">安全测试平台</p>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<div class="main-layout">
|
||||
<!-- 历史对话侧边栏 -->
|
||||
<aside class="sidebar">
|
||||
<div class="sidebar-header">
|
||||
<button class="new-chat-btn" onclick="startNewConversation()">
|
||||
<span>+</span> 新对话
|
||||
</button>
|
||||
</div>
|
||||
<div class="sidebar-content">
|
||||
<div class="sidebar-title">历史对话</div>
|
||||
<div id="conversations-list" class="conversations-list"></div>
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
<!-- 对话界面 -->
|
||||
<div class="chat-container">
|
||||
<div id="chat-messages" class="chat-messages"></div>
|
||||
<div class="chat-input-container">
|
||||
<input type="text" id="chat-input" placeholder="输入测试目标或命令..." />
|
||||
<button onclick="sendMessage()">发送</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- MCP调用详情模态框 -->
|
||||
<div id="mcp-detail-modal" class="modal">
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h2>工具调用详情</h2>
|
||||
<span class="modal-close" onclick="closeMCPDetail()">×</span>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
<div class="detail-section">
|
||||
<h3>执行信息</h3>
|
||||
<div class="detail-item">
|
||||
<strong>工具:</strong> <span id="detail-tool-name"></span>
|
||||
</div>
|
||||
<div class="detail-item">
|
||||
<strong>状态:</strong> <span id="detail-status"></span>
|
||||
</div>
|
||||
<div class="detail-item">
|
||||
<strong>时间:</strong> <span id="detail-time"></span>
|
||||
</div>
|
||||
<div class="detail-item">
|
||||
<strong>ID:</strong> <span id="detail-execution-id" style="font-family: monospace; font-size: 0.8125rem; color: var(--text-secondary);"></span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="detail-section">
|
||||
<h3>请求参数</h3>
|
||||
<pre id="detail-request" class="code-block"></pre>
|
||||
</div>
|
||||
<div class="detail-section">
|
||||
<h3>响应结果</h3>
|
||||
<pre id="detail-response" class="code-block"></pre>
|
||||
</div>
|
||||
<div class="detail-section" id="detail-error-section" style="display: none;">
|
||||
<h3>错误信息</h3>
|
||||
<pre id="detail-error" class="code-block error"></pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="/static/js/app.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
Reference in New Issue
Block a user