mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-04-07 12:32:04 +02:00
Compare commits
483 Commits
add-missin
...
ip_blocks
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e9a1225fc | ||
|
|
afe7804a9b | ||
|
|
d7904580ed | ||
|
|
593805bf6f | ||
|
|
ae37c56467 | ||
|
|
41597609c8 | ||
|
|
1f619a669a | ||
|
|
37c3331559 | ||
|
|
f334993f79 | ||
|
|
3ca559e5a4 | ||
|
|
0e3f764299 | ||
|
|
e52402eb0c | ||
|
|
2133f31854 | ||
|
|
a198a5cd65 | ||
|
|
eb2b231bd2 | ||
|
|
7af29cfbc0 | ||
|
|
ce1a165348 | ||
|
|
fd48e6d795 | ||
|
|
d71d1341b6 | ||
|
|
21855df4af | ||
|
|
66e2d3a40a | ||
|
|
26257cf24a | ||
|
|
36a7423634 | ||
|
|
e616091249 | ||
|
|
0948161529 | ||
|
|
ce29b5d217 | ||
|
|
de24fa293e | ||
|
|
6663925c4d | ||
|
|
b9ece6d7b9 | ||
|
|
c4efa1ab97 | ||
|
|
7cea5305e1 | ||
|
|
a20fbf95de | ||
|
|
628c4302aa | ||
|
|
8dc34f8bf5 | ||
|
|
b4faf82f76 | ||
|
|
a983dfaee2 | ||
|
|
62f73bcaa2 | ||
|
|
00e9d2bdd3 | ||
|
|
ace3b1e66e | ||
|
|
d1ea1ba08c | ||
|
|
c06c8aa859 | ||
|
|
0c2cc00c4f | ||
|
|
8d6ea91f35 | ||
|
|
7dfb77228f | ||
|
|
24910f1fa6 | ||
|
|
433a61d2ee | ||
|
|
3937e885f0 | ||
|
|
c651003cc4 | ||
|
|
b7ccfcb8b4 | ||
|
|
a9ed70200b | ||
|
|
c6365e6b74 | ||
|
|
dacc67e50f | ||
|
|
c60cf33af3 | ||
|
|
f27cbe3525 | ||
|
|
2de1b9929a | ||
|
|
8bf654aece | ||
|
|
84376ed719 | ||
|
|
7a136b8874 | ||
|
|
58c0e4f15a | ||
|
|
e0d35d8ba2 | ||
|
|
3b2e48761e | ||
|
|
b27064008e | ||
|
|
1ad63827e1 | ||
|
|
20e61550c2 | ||
|
|
020b814402 | ||
|
|
e578867118 | ||
|
|
46a1039f21 | ||
|
|
cc9e27de5f | ||
|
|
6ab3ab9faf | ||
|
|
e68bfa795a | ||
|
|
e60a92e93e | ||
|
|
62fe14f76b | ||
|
|
a0c5062e3a | ||
|
|
49eb152d02 | ||
|
|
b05056423a | ||
|
|
c7168739c7 | ||
|
|
5b1faf1ce3 | ||
|
|
513a6f9ec7 | ||
|
|
8db6fa4232 | ||
|
|
5036de2602 | ||
|
|
332f8ccc37 | ||
|
|
a582195cec | ||
|
|
9fe36ae984 | ||
|
|
54cb455522 | ||
|
|
8bd3b9e474 | ||
|
|
eff5ff580b | ||
|
|
c45f863ed8 | ||
|
|
414d4e356d | ||
|
|
ef697eb781 | ||
|
|
0631ffe831 | ||
|
|
7444d8517a | ||
|
|
3480043e40 | ||
|
|
619b6e7516 | ||
|
|
0123ca44fb | ||
|
|
7929aafe2a | ||
|
|
dc433f8dc9 | ||
|
|
8ccaeeab60 | ||
|
|
043a28eb33 | ||
|
|
c329402f5d | ||
|
|
23e6ad6e1f | ||
|
|
e6de78c1fa | ||
|
|
a670708f93 | ||
|
|
4ebe2fb5f4 | ||
|
|
3403b2039d | ||
|
|
e30ad31e0f | ||
|
|
81e0bad739 | ||
|
|
7d07d738dc | ||
|
|
0fae584e65 | ||
|
|
9e83085f2a | ||
|
|
41a00c68ac | ||
|
|
e3b99bf339 | ||
|
|
5007a87d3a | ||
|
|
60e65a37a6 | ||
|
|
d37d0e942c | ||
|
|
98042d8dbd | ||
|
|
af4b826b68 | ||
|
|
253a57ca01 | ||
|
|
caf98b4dfe | ||
|
|
398f71fd00 | ||
|
|
e1301ade96 | ||
|
|
7a23f82192 | ||
|
|
715bcc4aa1 | ||
|
|
0c74838740 | ||
|
|
4b05b6da7b | ||
|
|
375844ff1a | ||
|
|
1d207379cb | ||
|
|
fb49cb71e3 | ||
|
|
9618efbcde | ||
|
|
bb2210b06a | ||
|
|
917052723d | ||
|
|
fef85cadeb | ||
|
|
4a05fb6b28 | ||
|
|
6644ce53f2 | ||
|
|
72f0b89fdc | ||
|
|
41a97a6609 | ||
|
|
38064d6ad5 | ||
|
|
ae6945cedf | ||
|
|
3132d1b032 | ||
|
|
2716ae29bd | ||
|
|
1c50c2b6af | ||
|
|
cf6d16b439 | ||
|
|
60686f55ff | ||
|
|
47d7ace3a7 | ||
|
|
2d3779ec27 | ||
|
|
595071b608 | ||
|
|
57ef717080 | ||
|
|
eb27d1482b | ||
|
|
f57972ead7 | ||
|
|
168eaf538b | ||
|
|
1560455ca3 | ||
|
|
028475a193 | ||
|
|
f7a6dbe39b | ||
|
|
e573a490c9 | ||
|
|
ce3281e70d | ||
|
|
0fbfd160c9 | ||
|
|
20759017e6 | ||
|
|
69e0aab73e | ||
|
|
7ed6733fb7 | ||
|
|
9718ab8579 | ||
|
|
2687a4a018 | ||
|
|
2d9c60dea1 | ||
|
|
841be069b7 | ||
|
|
7833132917 | ||
|
|
e9e63b0983 | ||
|
|
4df470b869 | ||
|
|
89600f6091 | ||
|
|
f986a575e8 | ||
|
|
9c2fe8d21f | ||
|
|
8bcbb9249e | ||
|
|
a95d50c0af | ||
|
|
5db7d3577b | ||
|
|
c53a0ca1c4 | ||
|
|
6fd3d1788a | ||
|
|
087c1975e5 | ||
|
|
3713cbecc3 | ||
|
|
6046789fa4 | ||
|
|
3ea69b180c | ||
|
|
db6e977e3a | ||
|
|
a5c776c846 | ||
|
|
5a566c028a | ||
|
|
ff43c74d8d | ||
|
|
3c7255569c | ||
|
|
4a92ec4d2d | ||
|
|
9bbccb4082 | ||
|
|
4f62314646 | ||
|
|
cb49d0d947 | ||
|
|
89f7874fc6 | ||
|
|
221917e80b | ||
|
|
37d41bd215 | ||
|
|
8a96b8bec4 | ||
|
|
02ee113b95 | ||
|
|
f71dd78915 | ||
|
|
cd5619a05b | ||
|
|
a63a30c76b | ||
|
|
f5ba8be182 | ||
|
|
a9f76322bd | ||
|
|
ed39269c80 | ||
|
|
09426dcd36 | ||
|
|
17941882a9 | ||
|
|
70ab8032a0 | ||
|
|
8360bdc50a | ||
|
|
6837176ec7 | ||
|
|
5e9b4244e7 | ||
|
|
9b6a308958 | ||
|
|
71e327653a | ||
|
|
a56711796f | ||
|
|
09495f2a7c | ||
|
|
484643e114 | ||
|
|
da91aabc35 | ||
|
|
c654398981 | ||
|
|
47a90ec2a1 | ||
|
|
2875e22d0b | ||
|
|
c5d14e0075 | ||
|
|
84e06c363c | ||
|
|
5b9ccc5065 | ||
|
|
6ca1a7ccc7 | ||
|
|
9d666be5d4 | ||
|
|
65de7edcde | ||
|
|
0cdff0d368 | ||
|
|
f87220a908 | ||
|
|
30ea0c6499 | ||
|
|
9501e35c60 | ||
|
|
5ac9d17bdf | ||
|
|
cb14992ddc | ||
|
|
e88372fc8c | ||
|
|
b320662d67 | ||
|
|
ce353cd4d9 | ||
|
|
4befd33866 | ||
|
|
4b36e3ac44 | ||
|
|
f507bc8f9e | ||
|
|
14c88f4a6d | ||
|
|
3e388c2857 | ||
|
|
cfe1209d61 | ||
|
|
5a88a7c22c | ||
|
|
8c661c4401 | ||
|
|
e6f256d640 | ||
|
|
ede354166b | ||
|
|
282a8ce78e | ||
|
|
08fe04f1ee | ||
|
|
082d14a9ba | ||
|
|
617674ce43 | ||
|
|
7088df58dd | ||
|
|
9cbd9b3e44 | ||
|
|
e6586fd360 | ||
|
|
33a6db2599 | ||
|
|
70b0c4f7b9 | ||
|
|
5af3ec4f7b | ||
|
|
79476add12 | ||
|
|
1634a06330 | ||
|
|
a007394f60 | ||
|
|
62a0ba8731 | ||
|
|
e8d3ed1acd | ||
|
|
8b98faa441 | ||
|
|
30320ec9c7 | ||
|
|
5f4a399850 | ||
|
|
82e0d4b0c4 | ||
|
|
95a9df826d | ||
|
|
3b71d26cf3 | ||
|
|
c233ad9b1b | ||
|
|
12d6484b1c | ||
|
|
bc7b1cc6d8 | ||
|
|
ec684348ed | ||
|
|
18a19a3aa2 | ||
|
|
905f2d08c5 | ||
|
|
04947b4d87 | ||
|
|
72bf80533e | ||
|
|
9ddedf926e | ||
|
|
139dd62ff3 | ||
|
|
50ef00526e | ||
|
|
80cf79b9cb | ||
|
|
e6ad39b070 | ||
|
|
56f9c72569 | ||
|
|
dc48c908b8 | ||
|
|
9b0f0e792a | ||
|
|
b3eebb19b6 | ||
|
|
c24589a5be | ||
|
|
1e1c5a4dc8 | ||
|
|
339023421a | ||
|
|
a00d2a431a | ||
|
|
5aca118dbb | ||
|
|
411f7434f4 | ||
|
|
34801382f5 | ||
|
|
b9f2259ae4 | ||
|
|
19020a96bf | ||
|
|
96085147ff | ||
|
|
f3dd344026 | ||
|
|
486096416f | ||
|
|
5710f2e984 | ||
|
|
09936f1f07 | ||
|
|
0d6ca57536 | ||
|
|
3ddcb84db8 | ||
|
|
1012bf063f | ||
|
|
b8155e6182 | ||
|
|
9a34df61bb | ||
|
|
fbb879edf9 | ||
|
|
ac97c88876 | ||
|
|
a1fda2c0de | ||
|
|
f499770d45 | ||
|
|
4769da4ef4 | ||
|
|
c2556a8e39 | ||
|
|
29bf329f6a | ||
|
|
1dee4305bc | ||
|
|
429a98b690 | ||
|
|
da01a146d2 | ||
|
|
dd9f2465be | ||
|
|
b5cf0e2b31 | ||
|
|
1db159ad34 | ||
|
|
6604f973ac | ||
|
|
69ee6582e2 | ||
|
|
6f12667e8c | ||
|
|
b002dff624 | ||
|
|
affef963c1 | ||
|
|
56b2056190 | ||
|
|
c1e6f5126a | ||
|
|
1a8c1ec73d | ||
|
|
52954b8ceb | ||
|
|
a5025e35ea | ||
|
|
07f80c9ebf | ||
|
|
13db23553d | ||
|
|
3963fce43b | ||
|
|
ea4e5147bd | ||
|
|
7a491a4cc5 | ||
|
|
5ba90748f6 | ||
|
|
20f8f22bae | ||
|
|
b50cccac85 | ||
|
|
34ebe9b054 | ||
|
|
43d82cf1a7 | ||
|
|
ab88174091 | ||
|
|
ebcbf85373 | ||
|
|
87513cba6d | ||
|
|
64bcd2f00d | ||
|
|
cc6ae290f8 | ||
|
|
3e62bd3dbd | ||
|
|
8491f9c455 | ||
|
|
3ca754b438 | ||
|
|
8c7c3901e8 | ||
|
|
a9672dfff5 | ||
|
|
203a2ec8b8 | ||
|
|
810cbd1f4f | ||
|
|
49eebcdcbc | ||
|
|
e89021ec3a | ||
|
|
73a697b2fa | ||
|
|
9319d08046 | ||
|
|
7dc5138e91 | ||
|
|
8f189c919a | ||
|
|
906479a15c | ||
|
|
dabbf2037b | ||
|
|
b496147ce7 | ||
|
|
583718f234 | ||
|
|
fdb82f6ec3 | ||
|
|
5145729ab1 | ||
|
|
4d810261a4 | ||
|
|
18e8616834 | ||
|
|
d55563cac5 | ||
|
|
bb481d9bcc | ||
|
|
a163be3584 | ||
|
|
891b7cb2c6 | ||
|
|
176c22f229 | ||
|
|
faa0ed06b6 | ||
|
|
9515db7faf | ||
|
|
d822bf4257 | ||
|
|
0826671809 | ||
|
|
67d74774a9 | ||
|
|
5d65416227 | ||
|
|
49441f62f3 | ||
|
|
99651f6e5b | ||
|
|
edca1f4f89 | ||
|
|
3d834f00f6 | ||
|
|
6bb9e7a766 | ||
|
|
61fb71b1fa | ||
|
|
f8967c376f | ||
|
|
6d3c86c0be | ||
|
|
e42554f892 | ||
|
|
28984090e5 | ||
|
|
251255c746 | ||
|
|
32709dc64c | ||
|
|
71f26a6d81 | ||
|
|
44352f8006 | ||
|
|
af38623590 | ||
|
|
9c1665a759 | ||
|
|
eaad24e5e5 | ||
|
|
cfaf32f71a | ||
|
|
51b235b61a | ||
|
|
0a6d9d4454 | ||
|
|
dc700bbd52 | ||
|
|
cb445825f4 | ||
|
|
4d996e317b | ||
|
|
30c9012004 | ||
|
|
2a23feaf4b | ||
|
|
b82ad3720c | ||
|
|
8d2cb6091e | ||
|
|
3023f33dff | ||
|
|
22e97e981a | ||
|
|
44484e1231 | ||
|
|
eac60b87c7 | ||
|
|
8db28cb76e | ||
|
|
8dbe828b99 | ||
|
|
5c24acd952 | ||
|
|
998b9a5c5d | ||
|
|
0084e9ef26 | ||
|
|
122600bff2 | ||
|
|
41846b6d4c | ||
|
|
dfbcb1489d | ||
|
|
684019c2e3 | ||
|
|
e92619620d | ||
|
|
cebfd12d5c | ||
|
|
874ff01ab8 | ||
|
|
0bb8703f78 | ||
|
|
0bb51aa71d | ||
|
|
af2c1c87e0 | ||
|
|
8939debbc0 | ||
|
|
7591a0ccc6 | ||
|
|
c3ff8182af | ||
|
|
5897c174d3 | ||
|
|
f9a3f4c045 | ||
|
|
a2cb895cdc | ||
|
|
2bebe93e47 | ||
|
|
28ec1869fc | ||
|
|
17f6d7a77b | ||
|
|
9e6e647ff8 | ||
|
|
a2116e5eb5 | ||
|
|
564c9ef712 | ||
|
|
856abb71b7 | ||
|
|
0a30fdea69 | ||
|
|
4f125cf107 | ||
|
|
494d8be777 | ||
|
|
cd9c750884 | ||
|
|
91d319804b | ||
|
|
180eae60f2 | ||
|
|
d01f5c2777 | ||
|
|
294a90a807 | ||
|
|
c3b4ae9c79 | ||
|
|
09188bedf7 | ||
|
|
4614b98e94 | ||
|
|
990bc620f7 | ||
|
|
efb5a92571 | ||
|
|
8e0a96a44c | ||
|
|
43ff2f648c | ||
|
|
4816a09e3a | ||
|
|
3fea92c8b1 | ||
|
|
63f959c951 | ||
|
|
44ba6aadd9 | ||
|
|
d88cf52b4e | ||
|
|
58a00ea24a | ||
|
|
712b23a4bb | ||
|
|
baf836557c | ||
|
|
904b23eeac | ||
|
|
6aafe445f5 | ||
|
|
ebd516855b | ||
|
|
df4e04719e | ||
|
|
2440d922c6 | ||
|
|
f1b8d1c4ad | ||
|
|
79076bda35 | ||
|
|
9d2ea15346 | ||
|
|
77c1113ff7 | ||
|
|
e03ad4cd77 | ||
|
|
6e28517454 | ||
|
|
8ddbf881b3 | ||
|
|
c58516cfb0 | ||
|
|
34758f6205 | ||
|
|
a9959a6f3d | ||
|
|
511c4e696f | ||
|
|
bed7435b0c | ||
|
|
507c1afd59 | ||
|
|
2765487f10 | ||
|
|
80a88811cd | ||
|
|
823195c504 | ||
|
|
0f3e8c7ada | ||
|
|
ee5eb4fc4e | ||
|
|
d58d8074f4 | ||
|
|
94a0530991 | ||
|
|
073af0f89c | ||
|
|
6028b8f186 | ||
|
|
126477ef88 | ||
|
|
13391fd469 | ||
|
|
82e44b01af | ||
|
|
e355fd70ab | ||
|
|
d5c171735e | ||
|
|
b175368794 | ||
|
|
bcf4c25ba8 | ||
|
|
11b09af76d | ||
|
|
af0380a96a | ||
|
|
44c0a06996 |
6
.github/workflows/ci.yml
vendored
6
.github/workflows/ci.yml
vendored
@@ -9,7 +9,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ["windows-latest", "ubuntu-latest", "macOS-latest"]
|
||||
go: ["1.20.x"]
|
||||
go: ["1.24.x"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
@@ -19,8 +19,8 @@ jobs:
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
- run: "go test -race ./..."
|
||||
- uses: dominikh/staticcheck-action@v1.2.0
|
||||
- uses: dominikh/staticcheck-action@v1.3.1
|
||||
with:
|
||||
version: "2023.1.2"
|
||||
version: "2025.1"
|
||||
install-go: false
|
||||
cache-key: ${{ matrix.go }}
|
||||
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -3,3 +3,12 @@ gon.hcl
|
||||
|
||||
/Build
|
||||
.DS_Store
|
||||
|
||||
# Release folder
|
||||
dist/
|
||||
|
||||
# Binaries
|
||||
ctrld-*
|
||||
|
||||
# generated file
|
||||
cmd/cli/rsrc_*.syso
|
||||
|
||||
206
README.md
206
README.md
@@ -4,13 +4,16 @@
|
||||
[](https://pkg.go.dev/github.com/Control-D-Inc/ctrld)
|
||||
[](https://goreportcard.com/report/github.com/Control-D-Inc/ctrld)
|
||||
|
||||

|
||||
|
||||
A highly configurable DNS forwarding proxy with support for:
|
||||
- Multiple listeners for incoming queries
|
||||
- Multiple upstreams with fallbacks
|
||||
- Multiple network policy driven DNS query steering
|
||||
- Multiple network policy driven DNS query steering (via network cidr, MAC address or FQDN)
|
||||
- Policy driven domain based "split horizon" DNS with wildcard support
|
||||
- Integrations with common router vendors and firmware
|
||||
- LAN client discovery via DHCP, mDNS, and ARP
|
||||
- LAN client discovery via DHCP, mDNS, ARP, NDP, hosts file parsing
|
||||
- Prometheus metrics exporter
|
||||
|
||||
## TLDR
|
||||
Proxy legacy DNS traffic to secure DNS upstreams in highly configurable ways.
|
||||
@@ -32,13 +35,29 @@ All DNS protocols are supported, including:
|
||||
|
||||
## OS Support
|
||||
- Windows (386, amd64, arm)
|
||||
- Mac (amd64, arm64)
|
||||
- Windows Server (386, amd64)
|
||||
- MacOS (amd64, arm64)
|
||||
- Linux (386, amd64, arm, mips)
|
||||
- FreeBSD
|
||||
- Common routers (See Router Mode below)
|
||||
- FreeBSD (386, amd64, arm)
|
||||
- Common routers (See below)
|
||||
|
||||
|
||||
### Supported Routers
|
||||
You can run `ctrld` on any supported router. The list of supported routers and firmware includes:
|
||||
- Asus Merlin
|
||||
- DD-WRT
|
||||
- Firewalla
|
||||
- FreshTomato
|
||||
- GL.iNet
|
||||
- OpenWRT
|
||||
- pfSense / OPNsense
|
||||
- Synology
|
||||
- Ubiquiti (UniFi, EdgeOS)
|
||||
|
||||
`ctrld` will attempt to interface with dnsmasq (or Windows Server) whenever possible and set itself as the upstream, while running on port 5354. On FreeBSD based OSes, `ctrld` will terminate dnsmasq and unbound in order to be able to listen on port 53 directly.
|
||||
|
||||
# Install
|
||||
There are several ways to download and install `ctrld.
|
||||
There are several ways to download and install `ctrld`.
|
||||
|
||||
## Quick Install
|
||||
The simplest way to download and install `ctrld` is to use the following installer command on any UNIX-like platform:
|
||||
@@ -47,42 +66,41 @@ The simplest way to download and install `ctrld` is to use the following install
|
||||
sh -c 'sh -c "$(curl -sL https://api.controld.com/dl)"'
|
||||
```
|
||||
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative cmd:
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative PowerShell:
|
||||
```shell
|
||||
powershell -Command "(Invoke-WebRequest -Uri 'https://api.controld.com/dl' -UseBasicParsing).Content | Set-Content 'ctrld_install.bat'" && ctrld_install.bat
|
||||
(Invoke-WebRequest -Uri 'https://api.controld.com/dl/ps1' -UseBasicParsing).Content | Set-Content "$env:TEMPctrld_install.ps1"; Invoke-Expression "& '$env:TEMPctrld_install.ps1'"
|
||||
```
|
||||
|
||||
Or you can pull and run a Docker container from [Docker Hub](https://hub.docker.com/r/controldns/ctrld)
|
||||
```
|
||||
$ docker pull controldns/ctrld
|
||||
```shell
|
||||
docker run -d --name=ctrld -p 127.0.0.1:53:53/tcp -p 127.0.0.1:53:53/udp controldns/ctrld:latest
|
||||
```
|
||||
|
||||
## Download Manually
|
||||
Alternatively, if you know what you're doing you can download pre-compiled binaries from the [Releases](https://github.com/Control-D-Inc/ctrld/releases) section for the appropriate platform.
|
||||
|
||||
## Build
|
||||
Lastly, you can build `ctrld` from source which requires `go1.19+`:
|
||||
Lastly, you can build `ctrld` from source which requires `go1.21+`:
|
||||
|
||||
```shell
|
||||
$ go build ./cmd/ctrld
|
||||
go build ./cmd/ctrld
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
$ go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
$ docker build -t controldns/ctrld .
|
||||
$ docker run -d --name=ctrld -p 53:53/tcp -p 53:53/udp controldns/ctrld --cd=RESOLVER_ID_GOES_HERE -vv
|
||||
```shell
|
||||
docker build -t controldns/ctrld . -f docker/Dockerfile
|
||||
```
|
||||
|
||||
|
||||
# Usage
|
||||
The cli is self documenting, so free free to run `--help` on any sub-command to get specific usages.
|
||||
The cli is self documenting, so feel free to run `--help` on any sub-command to get specific usages.
|
||||
|
||||
## Arguments
|
||||
```
|
||||
@@ -98,13 +116,16 @@ Usage:
|
||||
|
||||
Available Commands:
|
||||
run Run the DNS proxy server
|
||||
service Manage ctrld service
|
||||
start Quick start service and configure DNS on interface
|
||||
stop Quick stop service and remove DNS from interface
|
||||
restart Restart the ctrld service
|
||||
reload Reload the ctrld service
|
||||
status Show status of the ctrld service
|
||||
uninstall Stop and uninstall the ctrld service
|
||||
service Manage ctrld service
|
||||
clients Manage clients
|
||||
upgrade Upgrading ctrld to latest version
|
||||
log Manage runtime debug logs
|
||||
|
||||
Flags:
|
||||
-h, --help help for ctrld
|
||||
@@ -116,81 +137,99 @@ Use "ctrld [command] --help" for more information about a command.
|
||||
```
|
||||
|
||||
## Basic Run Mode
|
||||
To start the server with default configuration, simply run: `./ctrld run`. This will create a generic `ctrld.toml` file in the **working directory** and start the application in foreground.
|
||||
1. Start the server
|
||||
```
|
||||
$ sudo ./ctrld run
|
||||
This is the most basic way to run `ctrld`, in foreground mode. Unless you already have a config file, a default one will be generated.
|
||||
|
||||
### Command
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe run
|
||||
```
|
||||
|
||||
2. Run a test query using a DNS client, for example, `dig`:
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld run
|
||||
```
|
||||
|
||||
You can then run a test query using a DNS client, for example, `dig`:
|
||||
```
|
||||
$ dig verify.controld.com @127.0.0.1 +short
|
||||
api.controld.com.
|
||||
147.185.34.1
|
||||
```
|
||||
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file and go nuts with it. To enforce a new config, restart the server.
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file that was generated. To enforce a new config, restart the server.
|
||||
|
||||
## Service Mode
|
||||
To run the application in service mode on any Windows, MacOS, Linux distibution or supported router, simply run: `./ctrld start` as system/root user. This will create a generic `ctrld.toml` file in the **user home** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and configure the listener on the default network interface. Service will start on OS boot.
|
||||
This mode will run the application as a background system service on any Windows, MacOS, Linux, FreeBSD distribution or supported router. This will create a generic `ctrld.toml` file in the **C:\ControlD** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and **configure the listener on all physical network interface**. Service will start on OS boot.
|
||||
|
||||
When Control D upstreams are used, `ctrld` willl [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them.
|
||||
When Control D upstreams are used on a router type device, `ctrld` will [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them.
|
||||
|
||||
In order to stop the service, and restore your DNS to original state, simply run `./ctrld stop`. If you wish to stop and uninstall the service permanently, run `./ctrld uninstall`.
|
||||
### Command
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start
|
||||
```
|
||||
|
||||
### Supported Routers
|
||||
You can run `ctrld` on any supported router, which will function similarly to the Service Mode mentioned above. The list of supported routers and firmware includes:
|
||||
- Asus Merlin
|
||||
- DD-WRT
|
||||
- Firewalla
|
||||
- FreshTomato
|
||||
- GL.iNet
|
||||
- OpenWRT
|
||||
- pfSense / OPNsense
|
||||
- Synology
|
||||
- Ubiquiti (UniFi, EdgeOS)
|
||||
Linux or Macos
|
||||
```
|
||||
sudo ctrld start
|
||||
```
|
||||
|
||||
`ctrld` will attempt to interface with dnsmasq whenever possible and set itself as the upstream, while running on port 5354. On FreeBSD based OSes, `ctrld` will terminate dnsmasq and unbound in order to be able to listen on port 53 directly.
|
||||
If `ctrld` is not in your system path (you installed it manually), you will need to run the above commands from the directory where you installed `ctrld`.
|
||||
|
||||
In order to stop the service, and restore your DNS to original state, simply run `ctrld stop`. If you wish to stop and uninstall the service permanently, run `ctrld uninstall`.
|
||||
|
||||
### Control D Auto Configuration
|
||||
Application can be started with a specific resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `run` (foreground) or `start` (service) modes.
|
||||
## Unmanaged Service Mode
|
||||
This mode functions similarly to the "Service Mode" above except it will simply start a system service and the config defined listeners, but **will not make any changes to any network interfaces**. You can then set the `ctrld` listener(s) IP on the desired network interfaces manually.
|
||||
|
||||
The following command will start the application in foreground mode, using the free "p2" resolver, which blocks Ads & Trackers.
|
||||
### Command
|
||||
|
||||
```shell
|
||||
./ctrld run --cd p2
|
||||
```
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe service start
|
||||
```
|
||||
|
||||
Alternatively, you can use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Device.
|
||||
|
||||
```shell
|
||||
./ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above commands (in service mode only), the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `uninstall` sub-commands
|
||||
- Your default network interface will be updated to use the listener started by the service
|
||||
- All OS DNS queries will be sent to the listener
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld service start
|
||||
```
|
||||
|
||||
# Configuration
|
||||
See [Configuration Docs](docs/config.md).
|
||||
`ctrld` can be configured in variety of different ways, which include: API, local config file or via cli launch args.
|
||||
|
||||
## Example
|
||||
- Start `listener.0` on 127.0.0.1:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `upstream.0` via DoH protocol
|
||||
## API Based Auto Configuration
|
||||
Application can be started with a specific Control D resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `start` (service) mode. In this mode, the application will automatically choose a non-conflicting IP and/or port and configure itself as the upstream to whatever process is running on port 53 (like dnsmasq or Windows DNS Server). This mode is used when the 1 liner installer command from the Control D onboarding guide is executed.
|
||||
|
||||
### Default Config
|
||||
The following command will use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Endpoint.
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start --cd abcd1234
|
||||
```
|
||||
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above command, the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `uninstall` sub-commands
|
||||
- All physical network interface will be updated to use the listener started by the service or dnsmasq upstream will be switched to `ctrld`
|
||||
- All DNS queries will be sent to the listener
|
||||
|
||||
## Manual Configuration
|
||||
`ctrld` is entirely config driven and can be configured in many different ways, please see [Configuration Docs](docs/config.md).
|
||||
|
||||
### Example
|
||||
```toml
|
||||
[listener]
|
||||
|
||||
[listener.0]
|
||||
ip = "127.0.0.1"
|
||||
ip = '0.0.0.0'
|
||||
port = 53
|
||||
restricted = false
|
||||
|
||||
[network]
|
||||
|
||||
@@ -198,10 +237,6 @@ See [Configuration Docs](docs/config.md).
|
||||
cidrs = ["0.0.0.0/0"]
|
||||
name = "Network 0"
|
||||
|
||||
[service]
|
||||
log_level = "info"
|
||||
log_path = ""
|
||||
|
||||
[upstream]
|
||||
|
||||
[upstream.0]
|
||||
@@ -210,27 +245,26 @@ See [Configuration Docs](docs/config.md).
|
||||
name = "Control D - Anti-Malware"
|
||||
timeout = 5000
|
||||
type = "doh"
|
||||
|
||||
[upstream.1]
|
||||
bootstrap_ip = "76.76.2.11"
|
||||
endpoint = "p2.freedns.controld.com"
|
||||
name = "Control D - No Ads"
|
||||
timeout = 3000
|
||||
type = "doq"
|
||||
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
The above is the most basic example, which will work out of the box. If you're looking to do advanced configurations using policies, see [Configuration Docs](docs/config.md) for complete documentation of the config file.
|
||||
The above basic config will:
|
||||
- Start listener on 0.0.0.0:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `https://freedns.controld.com/p1` using DoH protocol
|
||||
|
||||
You can also supply configuration via launch argeuments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
## CLI Args
|
||||
If you're unable to use a config file, `ctrld` can be be supplied with basic configuration via launch arguments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
|
||||
### Example
|
||||
```
|
||||
ctrld run --listen=127.0.0.1:53 --primary_upstream=https://freedns.controld.com/p2 --secondary_upstream=10.0.10.1:53 --domains=*.company.int,very-secure.local --log /path/to/log.log
|
||||
```
|
||||
|
||||
The above will start a foreground process and:
|
||||
- Listen on `127.0.0.1:53` for DNS queries
|
||||
- Forward all queries to `https://freedns.controld.com/p2` using DoH protocol, while...
|
||||
- Excluding `*.company.int` and `very-secure.local` matching queries, that are forwarded to `10.0.10.1:53`
|
||||
- Write a debug log to `/path/to/log.log`
|
||||
|
||||
## Contributing
|
||||
See [Contribution Guideline](./docs/contributing.md)
|
||||
|
||||
## Roadmap
|
||||
The following functionality is on the roadmap and will be available in future releases.
|
||||
- Prometheus metrics exporter
|
||||
- DNS intercept mode
|
||||
- Direct listener mode
|
||||
- Support for more routers (let us know which ones)
|
||||
|
||||
@@ -5,9 +5,11 @@ type ClientInfoCtxKey struct{}
|
||||
|
||||
// ClientInfo represents ctrld's clients information.
|
||||
type ClientInfo struct {
|
||||
Mac string
|
||||
IP string
|
||||
Hostname string
|
||||
Mac string
|
||||
IP string
|
||||
Hostname string
|
||||
Self bool
|
||||
ClientIDPref string
|
||||
}
|
||||
|
||||
// LeaseFileFormat specifies the format of DHCP lease file.
|
||||
@@ -16,4 +18,5 @@ type LeaseFileFormat string
|
||||
const (
|
||||
Dnsmasq LeaseFileFormat = "dnsmasq"
|
||||
IscDhcpd LeaseFileFormat = "isc-dhcpd"
|
||||
KeaDHCP4 LeaseFileFormat = "kea-dhcp4"
|
||||
)
|
||||
|
||||
4
client_info_darwin.go
Normal file
4
client_info_darwin.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package ctrld
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool { return true }
|
||||
6
client_info_others.go
Normal file
6
client_info_others.go
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package ctrld
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool { return false }
|
||||
18
client_info_windows.go
Normal file
18
client_info_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// isWindowsWorkStation reports whether ctrld was run on a Windows workstation machine.
|
||||
func isWindowsWorkStation() bool {
|
||||
// From https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-osversioninfoexa
|
||||
const VER_NT_WORKSTATION = 0x0000001
|
||||
osvi := windows.RtlGetVersion()
|
||||
return osvi.ProductType == VER_NT_WORKSTATION
|
||||
}
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool {
|
||||
return isWindowsWorkStation()
|
||||
}
|
||||
15
cmd/cli/ad_others.go
Normal file
15
cmd/cli/ad_others.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// addExtraSplitDnsRule adds split DNS rule if present.
|
||||
func addExtraSplitDnsRule(_ *ctrld.Config) bool { return false }
|
||||
|
||||
// getActiveDirectoryDomain returns AD domain name of this computer.
|
||||
func getActiveDirectoryDomain() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
73
cmd/cli/ad_windows.go
Normal file
73
cmd/cli/ad_windows.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
hh "github.com/microsoft/wmi/pkg/hardware/host"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// addExtraSplitDnsRule adds split DNS rule for domain if it's part of active directory.
|
||||
func addExtraSplitDnsRule(cfg *ctrld.Config) bool {
|
||||
domain, err := getActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Msgf("unable to get active directory domain: %v", err)
|
||||
return false
|
||||
}
|
||||
if domain == "" {
|
||||
mainLog.Load().Debug().Msg("no active directory domain found")
|
||||
return false
|
||||
}
|
||||
// Network rules are lowercase during toml config marshaling,
|
||||
// lowercase the domain here too for consistency.
|
||||
domain = strings.ToLower(domain)
|
||||
domainRuleAdded := addSplitDnsRule(cfg, domain)
|
||||
wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, "."))
|
||||
return domainRuleAdded || wildcardDomainRuleRuleAdded
|
||||
}
|
||||
|
||||
// addSplitDnsRule adds split-rule for given domain if there's no existed rule.
|
||||
// The return value indicates whether the split-rule was added or not.
|
||||
func addSplitDnsRule(cfg *ctrld.Config, domain string) bool {
|
||||
for n, lc := range cfg.Listener {
|
||||
if lc.Policy == nil {
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{}
|
||||
}
|
||||
for _, rule := range lc.Policy.Rules {
|
||||
if _, ok := rule[domain]; ok {
|
||||
mainLog.Load().Debug().Msgf("split-rule %q already existed for listener.%s", domain, n)
|
||||
return false
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("adding split-rule %q for listener.%s", domain, n)
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// getActiveDirectoryDomain returns AD domain name of this computer.
|
||||
func getActiveDirectoryDomain() (string, error) {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
whost := host.NewWmiLocalHost()
|
||||
cs, err := hh.GetComputerSystem(whost)
|
||||
if cs != nil {
|
||||
defer cs.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pod, err := cs.GetPropertyPartOfDomain()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if pod {
|
||||
return cs.GetPropertyDomain()
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
71
cmd/cli/ad_windows_test.go
Normal file
71
cmd/cli/ad_windows_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_getActiveDirectoryDomain(t *testing.T) {
|
||||
start := time.Now()
|
||||
domain, err := getActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
domainPowershell, err := getActiveDirectoryDomainPowershell()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
if domain != domainPowershell {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", domainPowershell, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func getActiveDirectoryDomainPowershell() (string, error) {
|
||||
cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }"
|
||||
output, err := powershell(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output))
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
func Test_addSplitDnsRule(t *testing.T) {
|
||||
newCfg := func(domains ...string) *ctrld.Config {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
lc := cfg.Listener["0"]
|
||||
for _, domain := range domains {
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *ctrld.Config
|
||||
domain string
|
||||
added bool
|
||||
}{
|
||||
{"added", newCfg(), "example.com", true},
|
||||
{"TLD existed", newCfg("example.com"), "*.example.com", true},
|
||||
{"wildcard existed", newCfg("*.example.com"), "example.com", true},
|
||||
{"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false},
|
||||
{"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
added := addSplitDnsRule(tc.cfg, tc.domain)
|
||||
assert.Equal(t, tc.added, added)
|
||||
})
|
||||
}
|
||||
}
|
||||
5
cmd/cli/cgo.go
Normal file
5
cmd/cli/cgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = true
|
||||
2078
cmd/cli/cli.go
2078
cmd/cli/cli.go
File diff suppressed because it is too large
Load Diff
@@ -16,8 +16,31 @@ func Test_writeConfigFile(t *testing.T) {
|
||||
_, err := os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
assert.NoError(t, writeConfigFile())
|
||||
assert.NoError(t, writeConfigFile(&cfg))
|
||||
|
||||
_, err = os.Stat(configPath)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_isStableVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ver string
|
||||
isStable bool
|
||||
}{
|
||||
{"stable", "v1.3.5", true},
|
||||
{"pre", "v1.3.5-next", false},
|
||||
{"pre with commit hash", "v1.3.5-next-asdf", false},
|
||||
{"dev", "dev", false},
|
||||
{"empty", "dev", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isStableVersion(tc.ver); got != tc.isStable {
|
||||
t.Errorf("unexpected result for %s, want: %v, got: %v", tc.ver, tc.isStable, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
1397
cmd/cli/commands.go
Normal file
1397
cmd/cli/commands.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -25,5 +25,14 @@ func newControlClient(addr string) *controlClient {
|
||||
}
|
||||
|
||||
func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) {
|
||||
// for log/send, set the timeout to 5 minutes
|
||||
if path == sendLogsPath {
|
||||
c.c.Timeout = time.Minute * 5
|
||||
}
|
||||
return c.c.Post("http://unix"+path, contentTypeJson, data)
|
||||
}
|
||||
|
||||
// deactivationRequest represents request for validating deactivation pin.
|
||||
type deactivationRequest struct {
|
||||
Pin int64 `json:"pin"`
|
||||
}
|
||||
|
||||
@@ -3,19 +3,40 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
)
|
||||
|
||||
const (
|
||||
contentTypeJson = "application/json"
|
||||
listClientsPath = "/clients"
|
||||
startedPath = "/started"
|
||||
contentTypeJson = "application/json"
|
||||
listClientsPath = "/clients"
|
||||
startedPath = "/started"
|
||||
reloadPath = "/reload"
|
||||
deactivationPath = "/deactivation"
|
||||
cdPath = "/cd"
|
||||
ifacePath = "/iface"
|
||||
viewLogsPath = "/log/view"
|
||||
sendLogsPath = "/log/send"
|
||||
)
|
||||
|
||||
type ifaceResponse struct {
|
||||
Name string `json:"name"`
|
||||
All bool `json:"all"`
|
||||
OK bool `json:"ok"`
|
||||
}
|
||||
|
||||
type controlServer struct {
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
@@ -58,14 +79,81 @@ func (s *controlServer) register(pattern string, handler http.Handler) {
|
||||
|
||||
func (p *prog) registerControlServerHandler() {
|
||||
p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
mainLog.Load().Debug().Msg("handling list clients request")
|
||||
|
||||
clients := p.ciTable.ListClients()
|
||||
mainLog.Load().Debug().Int("client_count", len(clients)).Msg("retrieved clients list")
|
||||
|
||||
sort.Slice(clients, func(i, j int) bool {
|
||||
return clients[i].IP.Less(clients[j].IP)
|
||||
})
|
||||
mainLog.Load().Debug().Msg("sorted clients by IP address")
|
||||
|
||||
if p.metricsQueryStats.Load() {
|
||||
mainLog.Load().Debug().Msg("metrics query stats enabled, collecting query counts")
|
||||
|
||||
for idx, client := range clients {
|
||||
mainLog.Load().Debug().
|
||||
Int("index", idx).
|
||||
Str("ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("processing client metrics")
|
||||
|
||||
client.IncludeQueryCount = true
|
||||
dm := &dto.Metric{}
|
||||
|
||||
if statsClientQueriesCount.MetricVec == nil {
|
||||
mainLog.Load().Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("skipping metrics collection: MetricVec is nil")
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := statsClientQueriesCount.MetricVec.GetMetricWithLabelValues(
|
||||
client.IP.String(),
|
||||
client.Mac,
|
||||
client.Hostname,
|
||||
)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("failed to get metrics for client")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.Write(dm); err == nil && dm.Counter != nil {
|
||||
client.QueryCount = int64(dm.Counter.GetValue())
|
||||
mainLog.Load().Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Int64("query_count", client.QueryCount).
|
||||
Msg("successfully collected query count")
|
||||
} else if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("failed to write metric")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("metrics query stats disabled, skipping query counts")
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&clients); err != nil {
|
||||
mainLog.Load().Error().
|
||||
Err(err).
|
||||
Int("client_count", len(clients)).
|
||||
Msg("failed to encode clients response")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().
|
||||
Int("client_count", len(clients)).
|
||||
Msg("successfully sent clients list response")
|
||||
}))
|
||||
p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
select {
|
||||
@@ -75,6 +163,177 @@ func (p *prog) registerControlServerHandler() {
|
||||
w.WriteHeader(http.StatusRequestTimeout)
|
||||
}
|
||||
}))
|
||||
p.cs.register(reloadPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
listeners := make(map[string]*ctrld.ListenerConfig)
|
||||
p.mu.Lock()
|
||||
for k, v := range p.cfg.Listener {
|
||||
listeners[k] = &ctrld.ListenerConfig{
|
||||
IP: v.IP,
|
||||
Port: v.Port,
|
||||
}
|
||||
}
|
||||
oldSvc := p.cfg.Service
|
||||
p.mu.Unlock()
|
||||
if err := p.sendReloadSignal(); err != nil {
|
||||
mainLog.Load().Err(err).Msg("could not send reload signal")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-p.reloadDoneCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Checking for cases that we could not do a reload.
|
||||
|
||||
// 1. Listener config ip or port changes.
|
||||
for k, v := range p.cfg.Listener {
|
||||
l := listeners[k]
|
||||
if l == nil || l.IP != v.IP || l.Port != v.Port {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Service config changes.
|
||||
if !reflect.DeepEqual(oldSvc, p.cfg.Service) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, reload is done.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
p.cs.register(deactivationPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
// Non-cd mode always allowing deactivation.
|
||||
if cdUID == "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Re-fetch pin code from API.
|
||||
if rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev); rc != nil {
|
||||
if rc.DeactivationPin != nil {
|
||||
cdDeactivationPin.Store(*rc.DeactivationPin)
|
||||
} else {
|
||||
cdDeactivationPin.Store(defaultDeactivationPin)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not re-fetch deactivation pin code")
|
||||
}
|
||||
|
||||
// If pin code not set, allowing deactivation.
|
||||
if !deactivationPinSet() {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
var req deactivationRequest
|
||||
if err := json.NewDecoder(request.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusPreconditionFailed)
|
||||
mainLog.Load().Err(err).Msg("invalid deactivation request")
|
||||
return
|
||||
}
|
||||
|
||||
code := http.StatusForbidden
|
||||
switch req.Pin {
|
||||
case cdDeactivationPin.Load():
|
||||
code = http.StatusOK
|
||||
select {
|
||||
case p.pinCodeValidCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
case defaultDeactivationPin:
|
||||
// If the pin code was set, but users do not provide --pin, return proper code to client.
|
||||
code = http.StatusBadRequest
|
||||
}
|
||||
w.WriteHeader(code)
|
||||
}))
|
||||
p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if cdUID != "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(cdUID))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
res := &ifaceResponse{Name: iface}
|
||||
// p.setDNS is only called when running as a service
|
||||
if !service.Interactive() {
|
||||
<-p.csSetDnsDone
|
||||
if p.csSetDnsOk {
|
||||
res.Name = p.runningIface
|
||||
res.All = p.requiredMultiNICsConfig
|
||||
res.OK = true
|
||||
}
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(res); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal iface data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
lr, err := p.logReader()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer lr.r.Close()
|
||||
if lr.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(lr.r)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("could not read log: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal log data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if time.Since(p.internalLogSent) < logWriterSentInterval {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
r, err := p.logReader()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if r.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
req := &controld.LogsRequest{
|
||||
UID: cdUID,
|
||||
Data: r.r,
|
||||
}
|
||||
mainLog.Load().Debug().Msg("sending log file to ControlD server")
|
||||
resp := logSentResponse{Size: r.size}
|
||||
if err := controld.SendLogs(req, cdDev); err != nil {
|
||||
mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err)
|
||||
resp.Error = err.Error()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("sending log file successfully")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
p.internalLogSent = time.Now()
|
||||
}))
|
||||
}
|
||||
|
||||
func jsonResponse(next http.Handler) http.Handler {
|
||||
|
||||
1368
cmd/cli/dns_proxy.go
1368
cmd/cli/dns_proxy.go
File diff suppressed because it is too large
Load Diff
@@ -22,14 +22,22 @@ func Test_wildcardMatches(t *testing.T) {
|
||||
domain string
|
||||
match bool
|
||||
}{
|
||||
{"prefix parent should not match", "*.windscribe.com", "windscribe.com", false},
|
||||
{"prefix", "*.windscribe.com", "anything.windscribe.com", true},
|
||||
{"prefix not match other domain", "*.windscribe.com", "example.com", false},
|
||||
{"prefix not match domain in name", "*.windscribe.com", "wwindscribe.com", false},
|
||||
{"suffix", "suffix.*", "suffix.windscribe.com", true},
|
||||
{"suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
|
||||
{"both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
|
||||
{"both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
|
||||
{"domain - prefix parent should not match", "*.windscribe.com", "windscribe.com", false},
|
||||
{"domain - prefix", "*.windscribe.com", "anything.windscribe.com", true},
|
||||
{"domain - prefix not match other s", "*.windscribe.com", "example.com", false},
|
||||
{"domain - prefix not match s in name", "*.windscribe.com", "wwindscribe.com", false},
|
||||
{"domain - suffix", "suffix.*", "suffix.windscribe.com", true},
|
||||
{"domain - suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
|
||||
{"domain - both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
|
||||
{"domain - both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
|
||||
{"domain - case-insensitive", "*.WINDSCRIBE.com", "anything.windscribe.com", true},
|
||||
{"mac - prefix", "*:98:05:b4:2b", "d4:67:98:05:b4:2b", true},
|
||||
{"mac - prefix not match other s", "*:98:05:b4:2b", "0d:ba:54:09:94:2c", false},
|
||||
{"mac - prefix not match s in name", "*:98:05:b4:2b", "e4:67:97:05:b4:2b", false},
|
||||
{"mac - suffix", "d4:67:98:*", "d4:67:98:05:b4:2b", true},
|
||||
{"mac - suffix not match other", "d4:67:98:*", "d4:67:97:15:b4:2b", false},
|
||||
{"mac - both", "d4:67:98:*:b4:2b", "d4:67:98:05:b4:2b", true},
|
||||
{"mac - both not match", "d4:67:98:*:b4:2b", "d4:67:97:05:c4:2b", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -67,8 +75,12 @@ func Test_canonicalName(t *testing.T) {
|
||||
|
||||
func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
for _, nc := range prog.cfg.Network {
|
||||
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
|
||||
p := &prog{cfg: cfg}
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
p.ptrLoopGuard = newLoopGuard()
|
||||
for _, nc := range p.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
@@ -81,6 +93,7 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
mac string
|
||||
defaultUpstreamNum string
|
||||
lc *ctrld.ListenerConfig
|
||||
domain string
|
||||
@@ -88,11 +101,14 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
matched bool
|
||||
testLogMsg string
|
||||
}{
|
||||
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
||||
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
||||
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
||||
{"unenforced loging", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
||||
{"Policy map matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
||||
{"Policy split matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
||||
{"No policy map for listener", "192.168.1.2:0", "", "1", p.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
||||
{"unenforced loging", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
||||
{"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"},
|
||||
{"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
||||
{"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -111,9 +127,13 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, addr)
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
||||
upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain)
|
||||
assert.Equal(t, tc.matched, matched)
|
||||
assert.Equal(t, tc.upstreams, upstreams)
|
||||
ufr := p.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain)
|
||||
p.proxy(ctx, &proxyRequest{
|
||||
msg: newDnsMsgWithHostname("foo", dns.TypeA),
|
||||
ufr: ufr,
|
||||
})
|
||||
assert.Equal(t, tc.matched, ufr.matched)
|
||||
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
||||
if tc.testLogMsg != "" {
|
||||
assert.Contains(t, logOutput.String(), tc.testLogMsg)
|
||||
}
|
||||
@@ -149,26 +169,58 @@ func TestCache(t *testing.T) {
|
||||
answer2.SetRcode(msg, dns.RcodeRefused)
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
|
||||
|
||||
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg, nil)
|
||||
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg, nil)
|
||||
req1 := &proxyRequest{
|
||||
msg: msg,
|
||||
ci: nil,
|
||||
failoverRcodes: nil,
|
||||
ufr: &upstreamForResult{
|
||||
upstreams: []string{"upstream.1"},
|
||||
matchedPolicy: "",
|
||||
matchedNetwork: "",
|
||||
matchedRule: "",
|
||||
matched: false,
|
||||
},
|
||||
}
|
||||
req2 := &proxyRequest{
|
||||
msg: msg,
|
||||
ci: nil,
|
||||
failoverRcodes: nil,
|
||||
ufr: &upstreamForResult{
|
||||
upstreams: []string{"upstream.0"},
|
||||
matchedPolicy: "",
|
||||
matchedNetwork: "",
|
||||
matchedRule: "",
|
||||
matched: false,
|
||||
},
|
||||
}
|
||||
got1 := prog.proxy(context.Background(), req1)
|
||||
got2 := prog.proxy(context.Background(), req2)
|
||||
assert.NotSame(t, got1, got2)
|
||||
assert.Equal(t, answer1.Rcode, got1.Rcode)
|
||||
assert.Equal(t, answer2.Rcode, got2.Rcode)
|
||||
assert.Equal(t, answer1.Rcode, got1.answer.Rcode)
|
||||
assert.Equal(t, answer2.Rcode, got2.answer.Rcode)
|
||||
}
|
||||
|
||||
func Test_macFromMsg(t *testing.T) {
|
||||
func Test_ipAndMacFromMsg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
wantIp bool
|
||||
mac string
|
||||
wantMac bool
|
||||
}{
|
||||
{"has mac", "4c:20:b8:ab:87:1b", true},
|
||||
{"no mac", "4c:20:b8:ab:87:1b", false},
|
||||
{"has ip v4 and mac", "1.2.3.4", true, "4c:20:b8:ab:87:1b", true},
|
||||
{"has ip v6 and mac", "2606:1a40:3::1", true, "4c:20:b8:ab:87:1b", true},
|
||||
{"no ip", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
||||
{"no mac", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if ip == nil {
|
||||
t.Fatal("missing IP")
|
||||
}
|
||||
hw, err := net.ParseMAC(tc.mac)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -180,13 +232,23 @@ func Test_macFromMsg(t *testing.T) {
|
||||
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
|
||||
o.Option = append(o.Option, ec1)
|
||||
}
|
||||
m.Extra = append(m.Extra, o)
|
||||
got := macFromMsg(m)
|
||||
if tc.wantMac && got != tc.mac {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.mac, got)
|
||||
if tc.wantIp {
|
||||
ec2 := &dns.EDNS0_SUBNET{Address: ip}
|
||||
o.Option = append(o.Option, ec2)
|
||||
}
|
||||
if !tc.wantMac && got != "" {
|
||||
t.Errorf("unexpected mac: %q", got)
|
||||
m.Extra = append(m.Extra, o)
|
||||
gotIP, gotMac := ipAndMacFromMsg(m)
|
||||
if tc.wantMac && gotMac != tc.mac {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.mac, gotMac)
|
||||
}
|
||||
if !tc.wantMac && gotMac != "" {
|
||||
t.Errorf("unexpected mac: %q", gotMac)
|
||||
}
|
||||
if tc.wantIp && gotIP != tc.ip {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.ip, gotIP)
|
||||
}
|
||||
if !tc.wantIp && gotIP != "" {
|
||||
t.Errorf("unexpected ip: %q", gotIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -216,3 +278,189 @@ func Test_remoteAddrFromMsg(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ipFromARPA(t *testing.T) {
|
||||
tests := []struct {
|
||||
IP string
|
||||
ARPA string
|
||||
}{
|
||||
{"1.2.3.4", "4.3.2.1.in-addr.arpa."},
|
||||
{"245.110.36.114", "114.36.110.245.in-addr.arpa."},
|
||||
{"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa."},
|
||||
{"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa."},
|
||||
{"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa."},
|
||||
{"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa."},
|
||||
{"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa."},
|
||||
{"", "asd.in-addr.arpa."},
|
||||
{"", "asd.ip6.arpa."},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.IP, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := ipFromARPA(tc.ARPA); !got.Equal(net.ParseIP(tc.IP)) {
|
||||
t.Errorf("unexpected ip, want: %s, got: %s", tc.IP, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgWithClientIP(ip string) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
o.Option = append(o.Option, &dns.EDNS0_SUBNET{Address: net.ParseIP(ip)})
|
||||
m.Extra = append(m.Extra, o)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_stripClientSubnet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
wantSubnet bool
|
||||
}{
|
||||
{"no edns0", new(dns.Msg), false},
|
||||
{"loopback IP v4", newDnsMsgWithClientIP("127.0.0.1"), false},
|
||||
{"loopback IP v6", newDnsMsgWithClientIP("::1"), false},
|
||||
{"private IP v4", newDnsMsgWithClientIP("192.168.1.123"), false},
|
||||
{"private IP v6", newDnsMsgWithClientIP("fd12:3456:789a:1::1"), false},
|
||||
{"public IP", newDnsMsgWithClientIP("1.1.1.1"), true},
|
||||
{"invalid IP", newDnsMsgWithClientIP(""), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
stripClientSubnet(tc.msg)
|
||||
hasSubnet := false
|
||||
if opt := tc.msg.IsEdns0(); opt != nil {
|
||||
for _, s := range opt.Option {
|
||||
if _, ok := s.(*dns.EDNS0_SUBNET); ok {
|
||||
hasSubnet = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if tc.wantSubnet != hasSubnet {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.wantSubnet, hasSubnet)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgWithHostname(hostname string, typ uint16) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(hostname, typ)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_isLanHostnameQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isLanHostnameQuery bool
|
||||
}{
|
||||
{"A", newDnsMsgWithHostname("foo", dns.TypeA), true},
|
||||
{"AAAA", newDnsMsgWithHostname("foo", dns.TypeAAAA), true},
|
||||
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
|
||||
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
|
||||
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
|
||||
{".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true},
|
||||
{".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true},
|
||||
{".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isLanHostnameQuery(tc.msg); tc.isLanHostnameQuery != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isLanHostnameQuery, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgPtr(ip string, t *testing.T) *dns.Msg {
|
||||
t.Helper()
|
||||
m := new(dns.Msg)
|
||||
ptr, err := dns.ReverseAddr(ip)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.SetQuestion(ptr, dns.TypePTR)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_isPrivatePtrLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isPrivatePtrLookup bool
|
||||
}{
|
||||
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
||||
{"10.0.0.0/8", newDnsMsgPtr("10.0.0.123", t), true},
|
||||
{"172.16.0.0/12", newDnsMsgPtr("172.16.0.123", t), true},
|
||||
{"192.168.0.0/16", newDnsMsgPtr("192.168.1.123", t), true},
|
||||
{"CGNAT", newDnsMsgPtr("100.66.27.28", t), true},
|
||||
{"Loopback", newDnsMsgPtr("127.0.0.1", t), true},
|
||||
{"Link Local Unicast", newDnsMsgPtr("fe80::69f6:e16e:8bdb:433f", t), true},
|
||||
{"Public IP", newDnsMsgPtr("8.8.8.8", t), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isPrivatePtrLookup(tc.msg); tc.isPrivatePtrLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isPrivatePtrLookup, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isSrvLanLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isSrvLookup bool
|
||||
}{
|
||||
{"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
|
||||
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
|
||||
{"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isWanClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
isWanClient bool
|
||||
}{
|
||||
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
||||
{"10.0.0.0/8", &net.UDPAddr{IP: net.ParseIP("10.0.0.123")}, false},
|
||||
{"172.16.0.0/12", &net.UDPAddr{IP: net.ParseIP("172.16.0.123")}, false},
|
||||
{"192.168.0.0/16", &net.UDPAddr{IP: net.ParseIP("192.168.1.123")}, false},
|
||||
{"CGNAT", &net.UDPAddr{IP: net.ParseIP("100.66.27.28")}, false},
|
||||
{"Loopback", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}, false},
|
||||
{"Link Local Unicast", &net.UDPAddr{IP: net.ParseIP("fe80::69f6:e16e:8bdb:433f")}, false},
|
||||
{"Public", &net.UDPAddr{IP: net.ParseIP("8.8.8.8")}, true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isWanClient(tc.addr); tc.isWanClient != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isWanClient, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
14
cmd/cli/hostname.go
Normal file
14
cmd/cli/hostname.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package cli
|
||||
|
||||
import "regexp"
|
||||
|
||||
// validHostname reports whether hostname is a valid hostname.
|
||||
// A valid hostname contains 3 -> 64 characters and conform to RFC1123.
|
||||
func validHostname(hostname string) bool {
|
||||
hostnameLen := len(hostname)
|
||||
if hostnameLen < 3 || hostnameLen > 64 {
|
||||
return false
|
||||
}
|
||||
validHostnameRfc1123 := regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
|
||||
return validHostnameRfc1123.MatchString(hostname)
|
||||
}
|
||||
35
cmd/cli/hostname_test.go
Normal file
35
cmd/cli/hostname_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_validHostname(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostname string
|
||||
valid bool
|
||||
}{
|
||||
{"localhost", "localhost", true},
|
||||
{"localdomain", "localhost.localdomain", true},
|
||||
{"localhost6", "localhost6.localdomain6", true},
|
||||
{"ip6", "ip6-localhost", true},
|
||||
{"non-domain", "controld", true},
|
||||
{"domain", "controld.com", true},
|
||||
{"empty", "", false},
|
||||
{"min length", "fo", false},
|
||||
{"max length", strings.Repeat("a", 65), false},
|
||||
{"special char", "foo!", false},
|
||||
{"non-ascii", "fooΩ", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.hostname, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, validHostname(tc.hostname) == tc.valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
95
cmd/cli/library.go
Normal file
95
cmd/cli/library.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AppCallback provides hooks for injecting certain functionalities
|
||||
// from mobile platforms to main ctrld cli.
|
||||
type AppCallback struct {
|
||||
HostName func() string
|
||||
LanIp func() string
|
||||
MacAddress func() string
|
||||
Exit func(error string)
|
||||
}
|
||||
|
||||
// AppConfig allows overwriting ctrld cli flags from mobile platforms.
|
||||
type AppConfig struct {
|
||||
CdUID string
|
||||
ProvisionID string
|
||||
CustomHostname string
|
||||
HomeDir string
|
||||
UpstreamProto string
|
||||
Verbose int
|
||||
LogPath string
|
||||
}
|
||||
|
||||
const (
|
||||
defaultHTTPTimeout = 30 * time.Second
|
||||
defaultMaxRetries = 3
|
||||
downloadServerIp = "23.171.240.151"
|
||||
)
|
||||
|
||||
// httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback
|
||||
func httpClientWithFallback(timeout time.Duration) *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
// Prefer IPv4 over IPv6
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
FallbackDelay: 1 * time.Millisecond, // Very small delay to prefer IPv4
|
||||
}).DialContext,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// doWithRetry performs an HTTP request with retries
|
||||
func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, error) {
|
||||
var lastErr error
|
||||
client := httpClientWithFallback(defaultHTTPTimeout)
|
||||
var ipReq *http.Request
|
||||
if ip != "" {
|
||||
ipReq = req.Clone(req.Context())
|
||||
ipReq.Host = ip
|
||||
ipReq.URL.Host = ip
|
||||
}
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
time.Sleep(time.Second * time.Duration(attempt+1)) // Exponential backoff
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
if ipReq != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("dial to %q failed", req.Host)
|
||||
mainLog.Load().Warn().Msgf("fallback to direct IP to download prod version: %q", ip)
|
||||
resp, err = client.Do(ipReq)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
mainLog.Load().Debug().Err(err).
|
||||
Str("method", req.Method).
|
||||
Str("url", req.URL.String()).
|
||||
Msgf("HTTP request attempt %d/%d failed", attempt+1, maxRetries)
|
||||
}
|
||||
return nil, fmt.Errorf("failed after %d attempts to %s %s: %v", maxRetries, req.Method, req.URL, lastErr)
|
||||
}
|
||||
|
||||
// Helper for making GET requests with retries
|
||||
func getWithRetry(url string, ip string) (*http.Response, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return doWithRetry(req, defaultMaxRetries, ip)
|
||||
}
|
||||
204
cmd/cli/log_writer.go
Normal file
204
cmd/cli/log_writer.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
logWriterSize = 1024 * 1024 * 5 // 5 MB
|
||||
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
|
||||
logWriterInitialSize = 32 * 1024 // 32 KB
|
||||
logWriterSentInterval = time.Minute
|
||||
logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n"
|
||||
logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n"
|
||||
)
|
||||
|
||||
type logViewResponse struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type logSentResponse struct {
|
||||
Size int64 `json:"size"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type logReader struct {
|
||||
r io.ReadCloser
|
||||
size int64
|
||||
}
|
||||
|
||||
// logWriter is an internal buffer to keep track of runtime log when no logging is enabled.
|
||||
type logWriter struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
size int
|
||||
}
|
||||
|
||||
// newLogWriter creates an internal log writer.
|
||||
func newLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSize)
|
||||
}
|
||||
|
||||
// newSmallLogWriter creates an internal log writer with small buffer size.
|
||||
func newSmallLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSmallSize)
|
||||
}
|
||||
|
||||
// newLogWriterWithSize creates an internal log writer with a given buffer size.
|
||||
func newLogWriterWithSize(size int) *logWriter {
|
||||
lw := &logWriter{size: size}
|
||||
return lw
|
||||
}
|
||||
|
||||
func (lw *logWriter) Write(p []byte) (int, error) {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
|
||||
// If writing p causes overflows, discard old data.
|
||||
if lw.buf.Len()+len(p) > lw.size {
|
||||
buf := lw.buf.Bytes()
|
||||
haveEndMarker := false
|
||||
// If there's init end marker already, preserve the data til the marker.
|
||||
if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 {
|
||||
buf = buf[:idx+len(logWriterInitEndMarker)]
|
||||
haveEndMarker = true
|
||||
} else {
|
||||
// Otherwise, preserve the initial size data.
|
||||
buf = buf[:logWriterInitialSize]
|
||||
if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 {
|
||||
buf = buf[:idx]
|
||||
}
|
||||
}
|
||||
lw.buf.Reset()
|
||||
lw.buf.Write(buf)
|
||||
if !haveEndMarker {
|
||||
lw.buf.WriteString(logWriterInitEndMarker) // indicate that the log was truncated.
|
||||
}
|
||||
}
|
||||
// If p is bigger than buffer size, truncate p by half until its size is smaller.
|
||||
for len(p)+lw.buf.Len() > lw.size {
|
||||
p = p[len(p)/2:]
|
||||
}
|
||||
return lw.buf.Write(p)
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func (p *prog) initLogging(backup bool) {
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
logWriters := initLoggingWithBackup(backup)
|
||||
|
||||
// Initializing internal logging after global logging.
|
||||
p.initInternalLogging(logWriters)
|
||||
}
|
||||
|
||||
// initInternalLogging performs internal logging if there's no log enabled.
|
||||
func (p *prog) initInternalLogging(writers []io.Writer) {
|
||||
if !p.needInternalLogging() {
|
||||
return
|
||||
}
|
||||
p.initInternalLogWriterOnce.Do(func() {
|
||||
mainLog.Load().Notice().Msg("internal logging enabled")
|
||||
p.internalLogWriter = newLogWriter()
|
||||
p.internalLogSent = time.Now().Add(-logWriterSentInterval)
|
||||
p.internalWarnLogWriter = newSmallLogWriter()
|
||||
})
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
// If ctrld was run without explicit verbose level,
|
||||
// run the internal logging at debug level, so we could
|
||||
// have enough information for troubleshooting.
|
||||
if verbose == 0 {
|
||||
for i := range writers {
|
||||
w := &zerolog.FilteredLevelWriter{
|
||||
Writer: zerolog.LevelWriterAdapter{Writer: writers[i]},
|
||||
Level: zerolog.NoticeLevel,
|
||||
}
|
||||
writers[i] = w
|
||||
}
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
}
|
||||
writers = append(writers, lw)
|
||||
writers = append(writers, &zerolog.FilteredLevelWriter{
|
||||
Writer: zerolog.LevelWriterAdapter{Writer: wlw},
|
||||
Level: zerolog.WarnLevel,
|
||||
})
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
l := mainLog.Load().Output(multi).With().Logger()
|
||||
mainLog.Store(&l)
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
}
|
||||
|
||||
// needInternalLogging reports whether prog needs to run internal logging.
|
||||
func (p *prog) needInternalLogging() bool {
|
||||
// Do not run in non-cd mode.
|
||||
if cdUID == "" {
|
||||
return false
|
||||
}
|
||||
// Do not run if there's already log file.
|
||||
if p.cfg.Service.LogPath != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *prog) logReader() (*logReader, error) {
|
||||
if p.needInternalLogging() {
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
if lw == nil {
|
||||
return nil, errors.New("nil internal log writer")
|
||||
}
|
||||
if wlw == nil {
|
||||
return nil, errors.New("nil internal warn log writer")
|
||||
}
|
||||
// Normal log content.
|
||||
lw.mu.Lock()
|
||||
lwReader := bytes.NewReader(lw.buf.Bytes())
|
||||
lwSize := lw.buf.Len()
|
||||
lw.mu.Unlock()
|
||||
// Warn log content.
|
||||
wlw.mu.Lock()
|
||||
wlwReader := bytes.NewReader(wlw.buf.Bytes())
|
||||
wlwSize := wlw.buf.Len()
|
||||
wlw.mu.Unlock()
|
||||
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logWriterLogEndMarker)), wlwReader)
|
||||
lr := &logReader{r: io.NopCloser(reader)}
|
||||
lr.size = int64(lwSize + wlwSize)
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("internal log is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
if p.cfg.Service.LogPath == "" {
|
||||
return &logReader{r: io.NopCloser(strings.NewReader(""))}, nil
|
||||
}
|
||||
f, err := os.Open(normalizeLogFilePath(p.cfg.Service.LogPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lr := &logReader{r: f}
|
||||
if st, err := f.Stat(); err == nil {
|
||||
lr.size = st.Size()
|
||||
} else {
|
||||
return nil, fmt.Errorf("f.Stat: %w", err)
|
||||
}
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("log file is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
85
cmd/cli/log_writer_test.go
Normal file
85
cmd/cli/log_writer_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_logWriter_Write(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
data := strings.Repeat("A", size)
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
newData := "B"
|
||||
halfData := strings.Repeat("A", len(data)/2) + logWriterInitEndMarker
|
||||
lw.Write([]byte(newData))
|
||||
if lw.buf.String() != halfData+newData {
|
||||
t.Fatalf("unexpected new buf content: %v", lw.buf.String())
|
||||
}
|
||||
|
||||
bigData := strings.Repeat("B", 256*1024)
|
||||
expected := halfData + strings.Repeat("B", 16*1024)
|
||||
lw.Write([]byte(bigData))
|
||||
if lw.buf.String() != expected {
|
||||
t.Fatalf("unexpected big buf content: %v", lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_ConcurrentWrite(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
n := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
lw.Write([]byte(strings.Repeat("A", i)))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if lw.buf.Len() > lw.size {
|
||||
t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_MarkerInitEnd(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
|
||||
paddingSize := 10
|
||||
// Writing half of the size, minus len(end marker) and padding size.
|
||||
dataSize := size/2 - len(logWriterInitEndMarker) - paddingSize
|
||||
data := strings.Repeat("A", dataSize)
|
||||
// Inserting newline for making partial init data
|
||||
data += "\n"
|
||||
// Filling left over buffer to make the log full.
|
||||
// The data length: len(end marker) + padding size - 1 (for newline above) + size/2
|
||||
data += strings.Repeat("A", len(logWriterInitEndMarker)+paddingSize-1+(size/2))
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
lw.Write([]byte("B"))
|
||||
lw.Write([]byte(strings.Repeat("B", 256*1024)))
|
||||
firstIdx := strings.Index(lw.buf.String(), logWriterInitEndMarker)
|
||||
lastIdx := strings.LastIndex(lw.buf.String(), logWriterInitEndMarker)
|
||||
// Check if init end marker present.
|
||||
if firstIdx == -1 || lastIdx == -1 {
|
||||
t.Fatalf("missing init end marker: %s", lw.buf.String())
|
||||
}
|
||||
// Check if init end marker appears only once.
|
||||
if firstIdx != lastIdx {
|
||||
t.Fatalf("log init end marker appears more than once: %s", lw.buf.String())
|
||||
}
|
||||
// Ensure that we have the correct init log data.
|
||||
if !strings.Contains(lw.buf.String(), strings.Repeat("A", dataSize)+logWriterInitEndMarker) {
|
||||
t.Fatalf("unexpected log content: %s", lw.buf.String())
|
||||
}
|
||||
}
|
||||
145
cmd/cli/loop.go
Normal file
145
cmd/cli/loop.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
loopTestDomain = ".test"
|
||||
loopTestQtype = dns.TypeTXT
|
||||
)
|
||||
|
||||
// newLoopGuard returns new loopGuard.
|
||||
func newLoopGuard() *loopGuard {
|
||||
return &loopGuard{inflight: make(map[string]struct{})}
|
||||
}
|
||||
|
||||
// loopGuard guards against DNS loop, ensuring only one query
|
||||
// for a given domain is processed at a time.
|
||||
type loopGuard struct {
|
||||
mu sync.Mutex
|
||||
inflight map[string]struct{}
|
||||
}
|
||||
|
||||
// TryLock marks the domain as being processed.
|
||||
func (lg *loopGuard) TryLock(domain string) bool {
|
||||
lg.mu.Lock()
|
||||
defer lg.mu.Unlock()
|
||||
if _, inflight := lg.inflight[domain]; !inflight {
|
||||
lg.inflight[domain] = struct{}{}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Unlock marks the domain as being done.
|
||||
func (lg *loopGuard) Unlock(domain string) {
|
||||
lg.mu.Lock()
|
||||
defer lg.mu.Unlock()
|
||||
delete(lg.inflight, domain)
|
||||
}
|
||||
|
||||
// isLoop reports whether the given upstream config is detected as having DNS loop.
|
||||
func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool {
|
||||
p.loopMu.Lock()
|
||||
defer p.loopMu.Unlock()
|
||||
return p.loop[uc.UID()]
|
||||
}
|
||||
|
||||
// detectLoop checks if the given DNS message is initialized sent by ctrld.
|
||||
// If yes, marking the corresponding upstream as loop, prevent infinite DNS
|
||||
// forwarding loop.
|
||||
//
|
||||
// See p.checkDnsLoop for more details how it works.
|
||||
func (p *prog) detectLoop(msg *dns.Msg) {
|
||||
if len(msg.Question) != 1 {
|
||||
return
|
||||
}
|
||||
q := msg.Question[0]
|
||||
if q.Qtype != loopTestQtype {
|
||||
return
|
||||
}
|
||||
unFQDNname := strings.TrimSuffix(q.Name, ".")
|
||||
uid := strings.TrimSuffix(unFQDNname, loopTestDomain)
|
||||
p.loopMu.Lock()
|
||||
if _, loop := p.loop[uid]; loop {
|
||||
p.loop[uid] = loop
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
}
|
||||
|
||||
// checkDnsLoop sends a message to check if there's any DNS forwarding loop
|
||||
// with all the upstreams. The way it works based on dnsmasq --dns-loop-detect.
|
||||
//
|
||||
// - Generating a TXT test query and sending it to all upstream.
|
||||
// - The test query is formed by upstream UID and test domain: <uid>.test
|
||||
// - If the test query returns to ctrld, mark the corresponding upstream as loop (see p.detectLoop).
|
||||
//
|
||||
// See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html
|
||||
func (p *prog) checkDnsLoop() {
|
||||
mainLog.Load().Debug().Msg("start checking DNS loop")
|
||||
upstream := make(map[string]*ctrld.UpstreamConfig)
|
||||
p.loopMu.Lock()
|
||||
for n, uc := range p.cfg.Upstream {
|
||||
if p.um.isDown("upstream." + n) {
|
||||
continue
|
||||
}
|
||||
// Do not send test query to external upstream.
|
||||
if !canBeLocalUpstream(uc.Domain) {
|
||||
mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n)
|
||||
continue
|
||||
}
|
||||
uid := uc.UID()
|
||||
p.loop[uid] = false
|
||||
upstream[uid] = uc
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
|
||||
for uid := range p.loop {
|
||||
msg := loopTestMsg(uid)
|
||||
uc := upstream[uid]
|
||||
// Skipping upstream which is being marked as down.
|
||||
if uc == nil {
|
||||
continue
|
||||
}
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
continue
|
||||
}
|
||||
if _, err := resolver.Resolve(context.Background(), msg); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msg("end checking DNS loop")
|
||||
}
|
||||
|
||||
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
|
||||
func (p *prog) checkDnsLoopTicker(ctx context.Context) {
|
||||
timer := time.NewTicker(time.Minute)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
p.checkDnsLoop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loopTestMsg generates DNS message for checking loop.
|
||||
func loopTestMsg(uid string) *dns.Msg {
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype)
|
||||
return msg
|
||||
}
|
||||
42
cmd/cli/loop_test.go
Normal file
42
cmd/cli/loop_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_loopGuard(t *testing.T) {
|
||||
lg := newLoopGuard()
|
||||
key := "foo"
|
||||
|
||||
var i atomic.Int64
|
||||
var started atomic.Int64
|
||||
n := 1000
|
||||
do := func() {
|
||||
locked := lg.TryLock(key)
|
||||
defer lg.Unlock(key)
|
||||
started.Add(1)
|
||||
for started.Load() < 2 {
|
||||
// Wait until at least 2 goroutines started, otherwise, on system with heavy load,
|
||||
// or having only 1 CPU, all goroutines can be scheduled to run consequently.
|
||||
}
|
||||
if locked {
|
||||
i.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
do()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if i.Load() == int64(n) {
|
||||
t.Fatalf("i must not be increased %d times", n)
|
||||
}
|
||||
}
|
||||
@@ -29,12 +29,28 @@ var (
|
||||
silent bool
|
||||
cdUID string
|
||||
cdOrg string
|
||||
customHostname string
|
||||
cdDev bool
|
||||
iface string
|
||||
ifaceStartStop string
|
||||
nextdns string
|
||||
cdUpstreamProto string
|
||||
deactivationPin int64
|
||||
skipSelfChecks bool
|
||||
cleanup bool
|
||||
startOnly bool
|
||||
rfc1918 bool
|
||||
|
||||
mainLog atomic.Pointer[zerolog.Logger]
|
||||
consoleWriter zerolog.ConsoleWriter
|
||||
noConfigStart bool
|
||||
)
|
||||
|
||||
const (
|
||||
cdUidFlagName = "cd"
|
||||
cdOrgFlagName = "cd-org"
|
||||
customHostnameFlagName = "custom-hostname"
|
||||
nextdnsFlagName = "nextdns"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -65,6 +81,7 @@ func normalizeLogFilePath(logFilePath string) string {
|
||||
return filepath.Join(dir, logFilePath)
|
||||
}
|
||||
|
||||
// initConsoleLogging initializes console logging, then storing to mainLog.
|
||||
func initConsoleLogging() {
|
||||
consoleWriter = zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
|
||||
w.TimeFormat = time.StampMilli
|
||||
@@ -72,21 +89,33 @@ func initConsoleLogging() {
|
||||
multi := zerolog.MultiLevelWriter(consoleWriter)
|
||||
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
|
||||
mainLog.Store(&l)
|
||||
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
case verbose == 1:
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
case verbose > 1:
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
default:
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
}
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func initLogging() {
|
||||
initLoggingWithBackup(true)
|
||||
// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded
|
||||
// to be used for all interactive commands.
|
||||
//
|
||||
// Current log file config will also be ignored.
|
||||
func initInteractiveLogging() {
|
||||
old := cfg.Service.LogPath
|
||||
cfg.Service.LogPath = ""
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
initLoggingWithBackup(false)
|
||||
cfg.Service.LogPath = old
|
||||
l := zerolog.New(io.Discard)
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
}
|
||||
|
||||
// initLoggingWithBackup initializes log setup base on current config.
|
||||
@@ -95,8 +124,8 @@ func initLogging() {
|
||||
// This is only used in runCmd for special handling in case of logging config
|
||||
// change in cd mode. Without special reason, the caller should use initLogging
|
||||
// wrapper instead of calling this function directly.
|
||||
func initLoggingWithBackup(doBackup bool) {
|
||||
writers := []io.Writer{io.Discard}
|
||||
func initLoggingWithBackup(doBackup bool) []io.Writer {
|
||||
var writers []io.Writer
|
||||
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
|
||||
// Create parent directory if necessary.
|
||||
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
|
||||
@@ -108,14 +137,14 @@ func initLoggingWithBackup(doBackup bool) {
|
||||
flags := os.O_CREATE | os.O_RDWR | os.O_APPEND
|
||||
if doBackup {
|
||||
// Backup old log file with .1 suffix.
|
||||
if err := os.Rename(logFilePath, logFilePath+".1"); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) {
|
||||
mainLog.Load().Error().Msgf("could not backup old log file: %v", err)
|
||||
} else {
|
||||
// Backup was created, set flags for truncating old log file.
|
||||
flags = os.O_CREATE | os.O_RDWR
|
||||
}
|
||||
}
|
||||
logFile, err := os.OpenFile(logFilePath, flags, os.FileMode(0o600))
|
||||
logFile, err := openLogFile(logFilePath, flags)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msgf("failed to create log file: %v", err)
|
||||
os.Exit(1)
|
||||
@@ -124,7 +153,7 @@ func initLoggingWithBackup(doBackup bool) {
|
||||
}
|
||||
writers = append(writers, consoleWriter)
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
|
||||
l := mainLog.Load().Output(multi).With().Logger()
|
||||
mainLog.Store(&l)
|
||||
// TODO: find a better way.
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
@@ -134,21 +163,22 @@ func initLoggingWithBackup(doBackup bool) {
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
return
|
||||
return writers
|
||||
case verbose == 1:
|
||||
logLevel = "info"
|
||||
case verbose > 1:
|
||||
logLevel = "debug"
|
||||
}
|
||||
if logLevel == "" {
|
||||
return
|
||||
return writers
|
||||
}
|
||||
level, err := zerolog.ParseLevel(logLevel)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set log level")
|
||||
return
|
||||
return writers
|
||||
}
|
||||
zerolog.SetGlobalLevel(level)
|
||||
return writers
|
||||
}
|
||||
|
||||
func initCache() {
|
||||
|
||||
150
cmd/cli/metrics.go
Normal file
150
cmd/cli/metrics.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/prometheus/prom2json"
|
||||
)
|
||||
|
||||
// metricsServer represents a server to expose Prometheus metrics via HTTP.
|
||||
type metricsServer struct {
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
reg *prometheus.Registry
|
||||
addr string
|
||||
started bool
|
||||
}
|
||||
|
||||
// newMetricsServer returns new metrics server.
|
||||
func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, error) {
|
||||
mux := http.NewServeMux()
|
||||
ms := &metricsServer{
|
||||
server: &http.Server{Handler: mux},
|
||||
mux: mux,
|
||||
reg: reg,
|
||||
}
|
||||
ms.addr = addr
|
||||
ms.registerMetricsServerHandler()
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
// register adds handlers for given pattern.
|
||||
func (ms *metricsServer) register(pattern string, handler http.Handler) {
|
||||
ms.mux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
// registerMetricsServerHandler adds handlers for metrics server.
|
||||
func (ms *metricsServer) registerMetricsServerHandler() {
|
||||
ms.register("/metrics", promhttp.HandlerFor(
|
||||
ms.reg,
|
||||
promhttp.HandlerOpts{
|
||||
EnableOpenMetrics: true,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
))
|
||||
ms.register("/metrics/json", jsonResponse(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
g := prometheus.ToTransactionalGatherer(ms.reg)
|
||||
mfs, done, err := g.Gather()
|
||||
defer done()
|
||||
if err != nil {
|
||||
msg := "could not gather metrics"
|
||||
mainLog.Load().Warn().Err(err).Msg(msg)
|
||||
http.Error(w, msg, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
result := make([]*prom2json.Family, 0, len(mfs))
|
||||
for _, mf := range mfs {
|
||||
result = append(result, prom2json.NewFamily(mf))
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(result); err != nil {
|
||||
msg := "could not marshal metrics result"
|
||||
mainLog.Load().Warn().Err(err).Msg(msg)
|
||||
http.Error(w, msg, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})))
|
||||
}
|
||||
|
||||
// start runs the metricsServer.
|
||||
func (ms *metricsServer) start() error {
|
||||
listener, err := net.Listen("tcp", ms.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go ms.server.Serve(listener)
|
||||
ms.started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop shutdowns the metricsServer within 2 seconds timeout.
|
||||
func (ms *metricsServer) stop() error {
|
||||
if !ms.started {
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1)
|
||||
defer cancel()
|
||||
return ms.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// runMetricsServer initializes metrics stats and runs the metrics server if enabled.
|
||||
func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
if !p.metricsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
// Reset all stats.
|
||||
statsVersion.Reset()
|
||||
statsQueriesCount.Reset()
|
||||
statsClientQueriesCount.Reset()
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
// Register queries count stats if enabled.
|
||||
if p.metricsQueryStats.Load() {
|
||||
reg.MustRegister(statsQueriesCount)
|
||||
reg.MustRegister(statsClientQueriesCount)
|
||||
}
|
||||
|
||||
addr := p.cfg.Service.MetricsListener
|
||||
ms, err := newMetricsServer(addr, reg)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create new metrics server")
|
||||
return
|
||||
}
|
||||
// Only start listener address if defined.
|
||||
if addr != "" {
|
||||
// Go runtime stats.
|
||||
reg.MustRegister(collectors.NewBuildInfoCollector())
|
||||
reg.MustRegister(collectors.NewGoCollector(
|
||||
collectors.WithGoCollectorRuntimeMetrics(collectors.MetricsAll),
|
||||
))
|
||||
// ctrld stats.
|
||||
reg.MustRegister(statsVersion)
|
||||
statsVersion.WithLabelValues(commit, runtime.Version(), curVersion()).Inc()
|
||||
reg.MustRegister(statsTimeStart)
|
||||
statsTimeStart.Set(float64(time.Now().Unix()))
|
||||
mainLog.Load().Debug().Msgf("starting metrics server on: %s", addr)
|
||||
if err := ms.start(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not start metrics server")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
case <-ctx.Done():
|
||||
case <-reloadCh:
|
||||
}
|
||||
|
||||
if err := ms.stop(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not stop metrics server")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -9,17 +9,18 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error {
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output()
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
patched := false
|
||||
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
|
||||
patched = true
|
||||
iface.Name = name
|
||||
mainLog.Load().Debug().Str("network_service", name).Msg("found network service name for interface")
|
||||
}
|
||||
return nil
|
||||
return patched, nil
|
||||
}
|
||||
|
||||
func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
@@ -42,3 +43,34 @@ func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// validInterfacesMap returns a set of all valid hardware ports.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
b, err := exec.Command("networksetup", "-listallhardwareports").Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return parseListAllHardwarePorts(bytes.NewReader(b))
|
||||
}
|
||||
|
||||
// parseListAllHardwarePorts parses output of "networksetup -listallhardwareports"
|
||||
// and returns map presents all hardware ports.
|
||||
func parseListAllHardwarePorts(r io.Reader) map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
after, ok := strings.CutPrefix(line, "Device: ")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
m[after] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
52
cmd/cli/net_linux.go
Normal file
52
cmd/cli/net_linux.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// Only non-virtual interfaces are considered valid.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// validInterfacesMap returns a set containing non virtual interfaces.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
vis := virtualInterfaces()
|
||||
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
|
||||
if _, existed := vis[i.Name]; existed {
|
||||
return
|
||||
}
|
||||
m[i.Name] = struct{}{}
|
||||
})
|
||||
// Fallback to default route interface if found nothing.
|
||||
if len(m) == 0 {
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
return m
|
||||
}
|
||||
m[defaultRoute.InterfaceName] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// virtualInterfaces returns a map of virtual interfaces on current machine.
|
||||
func virtualInterfaces() map[string]struct{} {
|
||||
s := make(map[string]struct{})
|
||||
entries, _ := os.ReadDir("/sys/devices/virtual/net")
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
s[strings.TrimSpace(entry.Name())] = struct{}{}
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -1,7 +1,22 @@
|
||||
//go:build !darwin
|
||||
//go:build !darwin && !windows && !linux
|
||||
|
||||
package cli
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"net"
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error { return nil }
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
|
||||
|
||||
// validInterfacesMap returns a set containing only default route interfaces.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]struct{}{defaultRoute.InterfaceName: {}}
|
||||
}
|
||||
|
||||
93
cmd/cli/net_windows.go
Normal file
93
cmd/cli/net_windows.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
"github.com/microsoft/wmi/pkg/base/instance"
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// On Windows, only physical interfaces are considered valid.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// validInterfacesMap returns a set of all physical interfaces.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
for _, ifaceName := range validInterfaces() {
|
||||
m[ifaceName] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// validInterfaces returns a list of all physical interfaces.
|
||||
func validInterfaces() []string {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
||||
if instances != nil {
|
||||
defer instances.Close()
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter")
|
||||
return nil
|
||||
}
|
||||
var adapters []string
|
||||
for _, i := range instances {
|
||||
adapter, err := netadapter.NewNetworkAdapter(i)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get network adapter")
|
||||
continue
|
||||
}
|
||||
|
||||
name, err := adapter.GetPropertyName()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get interface name")
|
||||
continue
|
||||
}
|
||||
|
||||
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
|
||||
//
|
||||
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
|
||||
// if this is a physical adapter or FALSE if this is not a physical adapter."
|
||||
physical, err := adapter.GetPropertyConnectorPresent()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property")
|
||||
continue
|
||||
}
|
||||
if !physical {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it's a hardware interface. Checking only for connector present is not enough
|
||||
// because some interfaces are not physical but have a connector.
|
||||
hardware, err := adapter.GetPropertyHardwareInterface()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property")
|
||||
continue
|
||||
}
|
||||
if !hardware {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface")
|
||||
continue
|
||||
}
|
||||
|
||||
adapters = append(adapters, name)
|
||||
}
|
||||
return adapters
|
||||
}
|
||||
42
cmd/cli/net_windows_test.go
Normal file
42
cmd/cli/net_windows_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_validInterfaces(t *testing.T) {
|
||||
verbose = 3
|
||||
initConsoleLogging()
|
||||
start := time.Now()
|
||||
ifaces := validInterfaces()
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
ifacesPowershell := validInterfacesPowershell()
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(ifaces)
|
||||
slices.Sort(ifacesPowershell)
|
||||
if !slices.Equal(ifaces, ifacesPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", ifacesPowershell, ifaces)
|
||||
}
|
||||
}
|
||||
|
||||
func validInterfacesPowershell() []string {
|
||||
out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var res []string
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
ifaceName := strings.TrimSpace(scanner.Text())
|
||||
res = append(res, ifaceName)
|
||||
}
|
||||
return res
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (p *prog) watchLinkState() {
|
||||
func (p *prog) watchLinkState(ctx context.Context) {
|
||||
ch := make(chan netlink.LinkUpdate)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
@@ -13,14 +15,19 @@ func (p *prog) watchLinkState() {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not subscribe link")
|
||||
return
|
||||
}
|
||||
for lu := range ch {
|
||||
if lu.Change == 0xFFFFFFFF {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case lu := <-ch:
|
||||
if lu.Change == 0xFFFFFFFF {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,4 +2,6 @@
|
||||
|
||||
package cli
|
||||
|
||||
func (p *prog) watchLinkState() {}
|
||||
import "context"
|
||||
|
||||
func (p *prog) watchLinkState(ctx context.Context) {}
|
||||
|
||||
@@ -3,6 +3,7 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
@@ -16,13 +17,21 @@ const (
|
||||
dns=none
|
||||
systemd-resolved=false
|
||||
`
|
||||
nmSystemdUnitName = "NetworkManager.service"
|
||||
systemdEnabledState = "enabled"
|
||||
nmSystemdUnitName = "NetworkManager.service"
|
||||
)
|
||||
|
||||
var networkManagerCtrldConfFile = filepath.Join(nmConfDir, nmCtrldConfFilename)
|
||||
|
||||
// hasNetworkManager reports whether NetworkManager executable found.
|
||||
func hasNetworkManager() bool {
|
||||
exe, _ := exec.LookPath("NetworkManager")
|
||||
return exe != ""
|
||||
}
|
||||
|
||||
func setupNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent {
|
||||
mainLog.Load().Debug().Msg("NetworkManager already setup, nothing to do")
|
||||
return nil
|
||||
@@ -43,6 +52,9 @@ func setupNetworkManager() error {
|
||||
}
|
||||
|
||||
func restoreNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
err := os.Remove(networkManagerCtrldConfFile)
|
||||
if os.IsNotExist(err) {
|
||||
mainLog.Load().Debug().Msg("NetworkManager is not available")
|
||||
@@ -71,6 +83,7 @@ func reloadNetworkManager() {
|
||||
waitCh := make(chan string)
|
||||
if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not reload NetworkManager")
|
||||
return
|
||||
}
|
||||
<-waitCh
|
||||
}
|
||||
|
||||
31
cmd/cli/nextdns.go
Normal file
31
cmd/cli/nextdns.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const nextdnsURL = "https://dns.nextdns.io"
|
||||
|
||||
func generateNextDNSConfig(uid string) {
|
||||
if uid == "" {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver")
|
||||
cfg = ctrld.Config{
|
||||
Listener: map[string]*ctrld.ListenerConfig{
|
||||
"0": {
|
||||
IP: "0.0.0.0",
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Upstream: map[string]*ctrld.UpstreamConfig{
|
||||
"0": {
|
||||
Type: ctrld.ResolverTypeDOH3,
|
||||
Endpoint: fmt.Sprintf("%s/%s", nextdnsURL, uid),
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
5
cmd/cli/nocgo.go
Normal file
5
cmd/cli/nocgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = false
|
||||
@@ -1,8 +1,12 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
@@ -27,16 +31,41 @@ func deAllocateIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
||||
if err := setDNS(iface, nameservers); err != nil {
|
||||
// TODO: investiate whether we can detect this without relying on error message.
|
||||
if strings.Contains(err.Error(), " is not a recognized network service") {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// set the dns server for the provided network interface
|
||||
// networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1
|
||||
// TODO(cuonglm): use system API
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
// Note that networksetup won't modify search domains settings,
|
||||
// This assignment is just a placeholder to silent linter.
|
||||
_ = searchDomains
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name}
|
||||
args = append(args, nameservers...)
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%v: %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := exec.Command(cmd, args...).Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("setDNS failed, ips = %q", nameservers)
|
||||
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
||||
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
if err := resetDNS(iface); err != nil {
|
||||
// TODO: investiate whether we can detect this without relying on error message.
|
||||
if strings.Contains(err.Error(), " is not a recognized network service") {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -46,14 +75,40 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name, "empty"}
|
||||
|
||||
if err := exec.Command(cmd, args...).Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("resetDNS failed")
|
||||
return err
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%v: %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if ns := savedStaticNameservers(iface); len(ns) > 0 {
|
||||
err = setDNS(iface, ns)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
cmd := "networksetup"
|
||||
args := []string{"-getdnsservers", iface.Name}
|
||||
out, err := exec.Command(cmd, args...).Output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
var ns []string
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if ip := net.ParseIP(line); ip != nil {
|
||||
ns = append(ns, ip.String())
|
||||
}
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,10 @@ import (
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
@@ -29,9 +33,14 @@ func deAllocateIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
||||
return setDNS(iface, nameservers)
|
||||
}
|
||||
|
||||
// set the dns server for the provided network interface
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
r, err := dns.NewOSConfigurator(logf, iface.Name)
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
return err
|
||||
@@ -42,15 +51,30 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
ns = append(ns, netip.MustParseAddr(nameserver))
|
||||
}
|
||||
|
||||
if err := r.SetDNS(dns.OSConfig{Nameservers: ns}); err != nil {
|
||||
osConfig := dns.OSConfig{
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list")
|
||||
}
|
||||
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to set DNS")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
||||
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
r, err := dns.NewOSConfigurator(logf, iface.Name)
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
return err
|
||||
@@ -63,6 +87,17 @@ func resetDNS(iface *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
return currentDNS(iface), nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/nclient4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6/client6"
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
@@ -48,9 +50,13 @@ func deAllocateIP(ip string) error {
|
||||
|
||||
const maxSetDNSAttempts = 5
|
||||
|
||||
// set the dns server for the provided network interface
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
||||
return setDNS(iface, nameservers)
|
||||
}
|
||||
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
r, err := dns.NewOSConfigurator(logf, iface.Name)
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
return err
|
||||
@@ -65,31 +71,31 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list")
|
||||
}
|
||||
trySystemdResolve := false
|
||||
for i := 0; i < maxSetDNSAttempts; i++ {
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
|
||||
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
|
||||
trySystemdResolve = true
|
||||
break
|
||||
}
|
||||
// This error happens on read-only file system, which causes ctrld failed to create backup
|
||||
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
|
||||
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
|
||||
// The error format is controlled by us, so checking for error string is fine.
|
||||
// See: ../../internal/dns/direct.go:L278
|
||||
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
|
||||
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
|
||||
trySystemdResolve = true
|
||||
goto systemdResolve
|
||||
}
|
||||
currentNS := currentDNS(iface)
|
||||
if reflect.DeepEqual(currentNS, nameservers) {
|
||||
// This error happens on read-only file system, which causes ctrld failed to create backup
|
||||
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
|
||||
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
|
||||
// The error format is controlled by us, so checking for error string is fine.
|
||||
// See: ../../internal/dns/direct.go:L278
|
||||
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
systemdResolve:
|
||||
if trySystemdResolve {
|
||||
// Stop systemd-networkd and retry setting DNS.
|
||||
if out, err := exec.Command("systemctl", "stop", "systemd-networkd").CombinedOutput(); err != nil {
|
||||
@@ -104,16 +110,21 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
return fmt.Errorf("%s: %w", string(out), err)
|
||||
}
|
||||
currentNS := currentDNS(iface)
|
||||
if reflect.DeepEqual(currentNS, nameservers) {
|
||||
if isSubSet(nameservers, currentNS) {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
mainLog.Load().Debug().Msg("DNS was not set for some reason")
|
||||
}
|
||||
mainLog.Load().Debug().Msg("DNS was not set for some reason")
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
||||
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
func resetDNS(iface *net.Interface) (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
@@ -123,7 +134,7 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
if exe, _ := exec.LookPath("/lib/systemd/systemd-networkd"); exe != "" {
|
||||
_ = exec.Command("systemctl", "start", "systemd-networkd").Run()
|
||||
}
|
||||
if r, oerr := dns.NewOSConfigurator(logf, iface.Name); oerr == nil {
|
||||
if r, oerr := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name); oerr == nil {
|
||||
_ = r.SetDNS(dns.OSConfig{})
|
||||
if err := r.Close(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to rollback DNS setting")
|
||||
@@ -154,6 +165,7 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
|
||||
// TODO(cuonglm): handle DHCPv6 properly.
|
||||
mainLog.Load().Debug().Msg("checking for IPv6 availability")
|
||||
if ctrldnet.IPv6Available(ctx) {
|
||||
c := client6.NewClient()
|
||||
conversation, err := c.Exchange(iface.Name)
|
||||
@@ -173,6 +185,8 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("IPv6 is not available")
|
||||
}
|
||||
|
||||
return ignoringEINTR(func() error {
|
||||
@@ -180,8 +194,15 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
})
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} {
|
||||
resolvconfFunc := func(_ string) []string { return resolvconffile.NameServers() }
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconfFunc} {
|
||||
if ns := fn(iface.Name); len(ns) > 0 {
|
||||
return ns
|
||||
}
|
||||
@@ -189,6 +210,11 @@ func currentDNS(iface *net.Interface) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
return currentDNS(iface), nil
|
||||
}
|
||||
|
||||
func getDNSByResolvectl(iface string) []string {
|
||||
b, err := exec.Command("resolvectl", "dns", "-i", iface).Output()
|
||||
if err != nil {
|
||||
@@ -265,3 +291,16 @@ func ignoringEINTR(fn func() error) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isSubSet reports whether s2 contains all elements of s1.
|
||||
func isSubSet(s1, s2 []string) bool {
|
||||
ok := true
|
||||
for _, ns := range s1 {
|
||||
if slices.Contains(s2, ns) {
|
||||
continue
|
||||
}
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
@@ -1,79 +1,207 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
const (
|
||||
v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
)
|
||||
|
||||
var (
|
||||
setDNSOnce sync.Once
|
||||
resetDNSOnce sync.Once
|
||||
)
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
||||
return setDNS(iface, nameservers)
|
||||
}
|
||||
|
||||
// setDNS sets the dns server for the provided network interface
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
if len(nameservers) == 0 {
|
||||
return errors.New("empty DNS nameservers")
|
||||
}
|
||||
primaryDNS := nameservers[0]
|
||||
if err := setPrimaryDNS(iface, primaryDNS); err != nil {
|
||||
return err
|
||||
setDNSOnce.Do(func() {
|
||||
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
||||
// Configuring the Dns server to forward queries to ctrld instead.
|
||||
if hasLocalDnsServerRunning() {
|
||||
mainLog.Load().Debug().Msg("Local DNS server detected, configuring forwarders")
|
||||
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
mainLog.Load().Debug().Msgf("Using forwarders file: %s", file)
|
||||
|
||||
oldForwardersContent, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("Could not read existing forwarders file")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Existing forwarders content: %s", string(oldForwardersContent))
|
||||
}
|
||||
|
||||
hasLocalIPv6Listener := needLocalIPv6Listener()
|
||||
mainLog.Load().Debug().Bool("has_ipv6_listener", hasLocalIPv6Listener).Msg("IPv6 listener status")
|
||||
|
||||
forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool {
|
||||
if !hasLocalIPv6Listener {
|
||||
return false
|
||||
}
|
||||
return s == "::1"
|
||||
})
|
||||
mainLog.Load().Debug().Strs("forwarders", forwarders).Msg("Filtered forwarders list")
|
||||
|
||||
if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Successfully wrote new forwarders file")
|
||||
}
|
||||
|
||||
oldForwarders := strings.Split(string(oldForwardersContent), ",")
|
||||
mainLog.Load().Debug().Strs("old_forwarders", oldForwarders).Msg("Previous forwarders")
|
||||
|
||||
if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Successfully configured DNS server forwarders")
|
||||
}
|
||||
}
|
||||
})
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("setDNS: %w", err)
|
||||
}
|
||||
if len(nameservers) > 1 {
|
||||
secondaryDNS := nameservers[1]
|
||||
_ = addSecondaryDNS(iface, secondaryDNS)
|
||||
var (
|
||||
serversV4 []netip.Addr
|
||||
serversV6 []netip.Addr
|
||||
)
|
||||
for _, ns := range nameservers {
|
||||
if addr, err := netip.ParseAddr(ns); err == nil {
|
||||
if addr.Is4() {
|
||||
serversV4 = append(serversV4, addr)
|
||||
} else {
|
||||
serversV6 = append(serversV6, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note that Windows won't modify the current search domains if passing nil to luid.SetDNS function.
|
||||
// searchDomains is still implemented for Windows just in case Windows API changes in future versions.
|
||||
_ = searchDomains
|
||||
|
||||
if len(serversV4) == 0 && len(serversV6) == 0 {
|
||||
return errors.New("invalid DNS nameservers")
|
||||
}
|
||||
if len(serversV4) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv4: %w", err)
|
||||
}
|
||||
}
|
||||
if len(serversV6) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv6: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
||||
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
// TODO(cuonglm): should we use system API?
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
if ctrldnet.SupportsIPv6ListenLocal() {
|
||||
if output, err := netsh("interface", "ipv6", "set", "dnsserver", strconv.Itoa(iface.Index), "dhcp"); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("failed to reset ipv6 DNS: %s", string(output))
|
||||
resetDNSOnce.Do(func() {
|
||||
// See corresponding comment in setDNS.
|
||||
if hasLocalDnsServerRunning() {
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not read forwarders settings")
|
||||
return
|
||||
}
|
||||
nameservers := strings.Split(string(content), ",")
|
||||
if err := removeDnsServerForwarders(nameservers); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not remove forwarders settings")
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("resetDNS: %w", err)
|
||||
}
|
||||
// Restoring DHCP settings.
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv4: %w", err)
|
||||
}
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv6: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if nss := savedStaticNameservers(iface); len(nss) > 0 {
|
||||
v4ns := make([]string, 0, 2)
|
||||
v6ns := make([]string, 0, 2)
|
||||
for _, ns := range nss {
|
||||
if ctrldnet.IsIPv6(ns) {
|
||||
v6ns = append(v6ns, ns)
|
||||
} else {
|
||||
v4ns = append(v4ns, ns)
|
||||
}
|
||||
}
|
||||
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("restoreDNS: %w", err)
|
||||
}
|
||||
|
||||
if len(v4ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns)
|
||||
if err := setDNS(iface, v4ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4): %w", err)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv4 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4 clear): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(v6ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns)
|
||||
if err := setDNS(iface, v6ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6): %w", err)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv6 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6 clear): %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
output, err := netsh("interface", "ipv4", "set", "dnsserver", strconv.Itoa(iface.Index), "dhcp")
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("failed to reset ipv4 DNS: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setPrimaryDNS(iface *net.Interface, dns string) error {
|
||||
ipVer := "ipv4"
|
||||
if ctrldnet.IsIPv6(dns) {
|
||||
ipVer = "ipv6"
|
||||
}
|
||||
idx := strconv.Itoa(iface.Index)
|
||||
output, err := netsh("interface", ipVer, "set", "dnsserver", idx, "static", dns)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("failed to set primary DNS: %s", string(output))
|
||||
return err
|
||||
}
|
||||
if ipVer == "ipv4" && ctrldnet.SupportsIPv6ListenLocal() {
|
||||
// Disable IPv6 DNS, so the query will be fallback to IPv4.
|
||||
_, _ = netsh("interface", "ipv6", "set", "dnsserver", idx, "static", "::1", "primary")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func addSecondaryDNS(iface *net.Interface, dns string) error {
|
||||
ipVer := "ipv4"
|
||||
if ctrldnet.IsIPv6(dns) {
|
||||
ipVer = "ipv6"
|
||||
}
|
||||
output, err := netsh("interface", ipVer, "add", "dns", strconv.Itoa(iface.Index), dns, "index=2")
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("failed to add secondary DNS: %s", string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func netsh(args ...string) ([]byte, error) {
|
||||
return exec.Command("netsh", args...).Output()
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
@@ -93,3 +221,112 @@ func currentDNS(iface *net.Interface) []string {
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
// currentStaticDNS checks both the IPv4 and IPv6 paths for static DNS values using keys
|
||||
// like "NameServer" and "ProfileNameServer".
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fallback winipcfg.LUIDFromIndex: %w", err)
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fallback luid.GUID: %w", err)
|
||||
}
|
||||
|
||||
var ns []string
|
||||
keyPaths := []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat}
|
||||
for _, path := range keyPaths {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name)
|
||||
continue
|
||||
}
|
||||
func() {
|
||||
defer k.Close()
|
||||
for _, keyName := range []string{"NameServer", "ProfileNameServer"} {
|
||||
value, _, err := k.GetStringValue(keyName)
|
||||
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
||||
mainLog.Load().Debug().Err(err).Msgf("error reading %s registry key", keyName)
|
||||
continue
|
||||
}
|
||||
if len(value) > 0 {
|
||||
mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value)
|
||||
parsed := parseDNSServers(value)
|
||||
for _, pns := range parsed {
|
||||
if !slices.Contains(ns, pns) {
|
||||
ns = append(ns, pns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Debug().Msgf("no static DNS values found for interface %q", iface.Name)
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// parseDNSServers splits a DNS server string that may be comma- or space-separated,
|
||||
// and trims any extraneous whitespace or null characters.
|
||||
func parseDNSServers(val string) []string {
|
||||
fields := strings.FieldsFunc(val, func(r rune) bool {
|
||||
return r == ' ' || r == ','
|
||||
})
|
||||
var servers []string
|
||||
for _, f := range fields {
|
||||
trimmed := strings.TrimSpace(f)
|
||||
if len(trimmed) > 0 {
|
||||
servers = append(servers, trimmed)
|
||||
}
|
||||
}
|
||||
return servers
|
||||
}
|
||||
|
||||
// addDnsServerForwarders adds given nameservers to DNS server forwarders list,
|
||||
// and also removing old forwarders if provided.
|
||||
func addDnsServerForwarders(nameservers, old []string) error {
|
||||
newForwardersMap := make(map[string]struct{})
|
||||
newForwarders := make([]string, len(nameservers))
|
||||
for i := range nameservers {
|
||||
newForwardersMap[nameservers[i]] = struct{}{}
|
||||
newForwarders[i] = fmt.Sprintf("%q", nameservers[i])
|
||||
}
|
||||
oldForwarders := old[:0]
|
||||
for _, fwd := range old {
|
||||
if _, ok := newForwardersMap[fwd]; !ok {
|
||||
oldForwarders = append(oldForwarders, fwd)
|
||||
}
|
||||
}
|
||||
// NOTE: It is important to add new forwarder before removing old one.
|
||||
// Testing on Windows Server 2022 shows that removing forwarder1
|
||||
// then adding forwarder2 sometimes ends up adding both of them
|
||||
// to the forwarders list.
|
||||
cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", strings.Join(newForwarders, ","))
|
||||
if len(oldForwarders) > 0 {
|
||||
cmd = fmt.Sprintf("%s ; Remove-DnsServerForwarder -IPAddress %s -Force", cmd, strings.Join(oldForwarders, ","))
|
||||
}
|
||||
if out, err := powershell(cmd); err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeDnsServerForwarders removes given nameservers from DNS server forwarders list.
|
||||
func removeDnsServerForwarders(nameservers []string) error {
|
||||
for _, ns := range nameservers {
|
||||
cmd := fmt.Sprintf("Remove-DnsServerForwarder -IPAddress %s -Force", ns)
|
||||
if out, err := powershell(cmd); err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// powershell runs the given powershell command.
|
||||
func powershell(cmd string) ([]byte, error) {
|
||||
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
|
||||
return bytes.TrimSpace(out), err
|
||||
}
|
||||
|
||||
68
cmd/cli/os_windows_test.go
Normal file
68
cmd/cli/os_windows_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func Test_currentStaticDNS(t *testing.T) {
|
||||
iface, err := net.InterfaceByName(defaultIfaceName())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
start := time.Now()
|
||||
staticDns, err := currentStaticDNS(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
staticDnsPowershell, err := currentStaticDnsPowershell(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(staticDns)
|
||||
slices.Sort(staticDnsPowershell)
|
||||
if !slices.Equal(staticDns, staticDnsPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns)
|
||||
}
|
||||
}
|
||||
|
||||
func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ns []string
|
||||
for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(string(out), ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
1504
cmd/cli/prog.go
1504
cmd/cli/prog.go
File diff suppressed because it is too large
Load Diff
@@ -1,16 +1,26 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo"); err == nil {
|
||||
if r, err := newLoopbackOSConfigurator(); err == nil {
|
||||
useSystemdResolved = r.Mode() == "systemd-resolved"
|
||||
}
|
||||
// Disable quic-go's ECN support by default, see https://github.com/quic-go/quic-go/issues/3911
|
||||
if os.Getenv("QUIC_GO_DISABLE_ECN") == "" {
|
||||
os.Setenv("QUIC_GO_DISABLE_ECN", "true")
|
||||
}
|
||||
}
|
||||
|
||||
func setDependencies(svc *service.Config) {
|
||||
@@ -19,8 +29,13 @@ func setDependencies(svc *service.Config) {
|
||||
"After=network-online.target",
|
||||
"Wants=NetworkManager-wait-online.service",
|
||||
"After=NetworkManager-wait-online.service",
|
||||
"Wants=systemd-networkd-wait-online.service",
|
||||
"After=systemd-networkd-wait-online.service",
|
||||
"Wants=nss-lookup.target",
|
||||
"After=nss-lookup.target",
|
||||
}
|
||||
if out, _ := exec.Command("networkctl", "--no-pager").CombinedOutput(); len(out) > 0 {
|
||||
if wantsSystemDNetworkdWaitOnline(bytes.NewReader(out)) {
|
||||
svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service")
|
||||
}
|
||||
}
|
||||
if routerDeps := router.ServiceDependencies(); len(routerDeps) > 0 {
|
||||
svc.Dependencies = append(svc.Dependencies, routerDeps...)
|
||||
@@ -30,3 +45,21 @@ func setDependencies(svc *service.Config) {
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
// wantsSystemDNetworkdWaitOnline reports whether "systemd-networkd-wait-online" service
|
||||
// is required to be added to ctrld dependencies services.
|
||||
// The input reader r is the output of "networkctl --no-pager" command.
|
||||
func wantsSystemDNetworkdWaitOnline(r io.Reader) bool {
|
||||
scanner := bufio.NewScanner(r)
|
||||
// Skip header
|
||||
scanner.Scan()
|
||||
configured := false
|
||||
for scanner.Scan() {
|
||||
fields := strings.Fields(scanner.Text())
|
||||
if len(fields) > 0 && fields[len(fields)-1] == "configured" {
|
||||
configured = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return configured
|
||||
}
|
||||
|
||||
48
cmd/cli/prog_linux_test.go
Normal file
48
cmd/cli/prog_linux_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
networkctlUnmanagedOutput = `IDX LINK TYPE OPERATIONAL SETUP
|
||||
1 lo loopback carrier unmanaged
|
||||
2 wlp0s20f3 wlan routable unmanaged
|
||||
3 tailscale0 none routable unmanaged
|
||||
4 br-9ac33145e060 bridge no-carrier unmanaged
|
||||
5 docker0 bridge no-carrier unmanaged
|
||||
|
||||
5 links listed.
|
||||
`
|
||||
networkctlManagedOutput = `IDX LINK TYPE OPERATIONAL SETUP
|
||||
1 lo loopback carrier unmanaged
|
||||
2 wlp0s20f3 wlan routable configured
|
||||
3 tailscale0 none routable unmanaged
|
||||
4 br-9ac33145e060 bridge no-carrier unmanaged
|
||||
5 docker0 bridge no-carrier unmanaged
|
||||
|
||||
5 links listed.
|
||||
`
|
||||
)
|
||||
|
||||
func Test_wantsSystemDNetworkdWaitOnline(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
r io.Reader
|
||||
required bool
|
||||
}{
|
||||
{"unmanaged", strings.NewReader(networkctlUnmanagedOutput), false},
|
||||
{"managed", strings.NewReader(networkctlManagedOutput), true},
|
||||
{"empty", strings.NewReader(""), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if required := wantsSystemDNetworkdWaitOnline(tc.r); required != tc.required {
|
||||
t.Errorf("wants %v got %v", tc.required, required)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !freebsd && !darwin
|
||||
//go:build !linux && !freebsd && !darwin && !windows
|
||||
|
||||
package cli
|
||||
|
||||
|
||||
273
cmd/cli/prog_test.go
Normal file
273
cmd/cli/prog_test.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_prog_dnsWatchdogEnabled(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{}}
|
||||
|
||||
// Default value is true.
|
||||
assert.True(t, p.dnsWatchdogEnabled())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
enabled bool
|
||||
}{
|
||||
{"enabled", true},
|
||||
{"disabled", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p.cfg.Service.DnsWatchdogEnabled = &tc.enabled
|
||||
assert.Equal(t, tc.enabled, p.dnsWatchdogEnabled())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_dnsWatchdogInterval(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{}}
|
||||
|
||||
// Default value is 20s.
|
||||
assert.Equal(t, dnsWatchdogDefaultInterval, p.dnsWatchdogDuration())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
duration time.Duration
|
||||
expected time.Duration
|
||||
}{
|
||||
{"valid", time.Minute, time.Minute},
|
||||
{"zero", 0, dnsWatchdogDefaultInterval},
|
||||
{"nagative", time.Duration(-1 * time.Minute), dnsWatchdogDefaultInterval},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p.cfg.Service.DnsWatchdogInvterval = &tc.duration
|
||||
assert.Equal(t, tc.expected, p.dnsWatchdogDuration())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_shouldUpgrade(t *testing.T) {
|
||||
// Helper function to create a version
|
||||
makeVersion := func(v string) *semver.Version {
|
||||
ver, err := semver.NewVersion(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create version %s: %v", v, err)
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
currentVersion *semver.Version
|
||||
shouldUpgrade bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "empty version target",
|
||||
versionTarget: "",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when version target is empty",
|
||||
},
|
||||
{
|
||||
name: "invalid version target",
|
||||
versionTarget: "invalid-version",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when version target is invalid",
|
||||
},
|
||||
{
|
||||
name: "same version",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when target version equals current version",
|
||||
},
|
||||
{
|
||||
name: "older version",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v1.1.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when target version is older than current version",
|
||||
},
|
||||
{
|
||||
name: "patch upgrade allowed",
|
||||
versionTarget: "v1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow patch version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "minor upgrade allowed",
|
||||
versionTarget: "v1.1.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow minor version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "major upgrade blocked",
|
||||
versionTarget: "v2.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block major version upgrade",
|
||||
},
|
||||
{
|
||||
name: "major downgrade blocked",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v2.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block major version downgrade",
|
||||
},
|
||||
{
|
||||
name: "version without v prefix",
|
||||
versionTarget: "1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should handle version target without v prefix",
|
||||
},
|
||||
{
|
||||
name: "complex version upgrade allowed",
|
||||
versionTarget: "v1.5.3",
|
||||
currentVersion: makeVersion("v1.4.2"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow complex version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "complex major upgrade blocked",
|
||||
versionTarget: "v3.1.0",
|
||||
currentVersion: makeVersion("v2.5.3"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block complex major version upgrade",
|
||||
},
|
||||
{
|
||||
name: "pre-release version upgrade allowed",
|
||||
versionTarget: "v1.0.1-beta.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow pre-release version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "pre-release major upgrade blocked",
|
||||
versionTarget: "v2.0.0-alpha.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block pre-release major version upgrade",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger()
|
||||
|
||||
// Call the function and capture the result
|
||||
result := shouldUpgrade(tc.versionTarget, tc.currentVersion, &testLogger)
|
||||
|
||||
// Assert the expected result
|
||||
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_selfUpgradeCheck(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipped due to Windows file locking issue on Github Action runners")
|
||||
}
|
||||
|
||||
// Helper function to create a version
|
||||
makeVersion := func(v string) *semver.Version {
|
||||
ver, err := semver.NewVersion(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create version %s: %v", v, err)
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
currentVersion *semver.Version
|
||||
shouldUpgrade bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "upgrade allowed",
|
||||
versionTarget: "v1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow upgrade and attempt to perform it",
|
||||
},
|
||||
{
|
||||
name: "upgrade blocked",
|
||||
versionTarget: "v2.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block upgrade and not attempt to perform it",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger()
|
||||
|
||||
// Call the function and capture the result
|
||||
result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, &testLogger)
|
||||
|
||||
// Assert the expected result
|
||||
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_performUpgrade(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipped due to Windows file locking issue on Github Action runners")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
expectedResult bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid version target",
|
||||
versionTarget: "v1.0.1",
|
||||
expectedResult: true,
|
||||
description: "should attempt to perform upgrade with valid version target",
|
||||
},
|
||||
{
|
||||
name: "empty version target",
|
||||
versionTarget: "",
|
||||
expectedResult: true,
|
||||
description: "should attempt to perform upgrade even with empty version target",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Call the function and capture the result
|
||||
result := performUpgrade(tc.versionTarget)
|
||||
assert.Equal(t, tc.expectedResult, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
14
cmd/cli/prog_windows.go
Normal file
14
cmd/cli/prog_windows.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package cli
|
||||
|
||||
import "github.com/kardianos/service"
|
||||
|
||||
func setDependencies(svc *service.Config) {
|
||||
if hasLocalDnsServerRunning() {
|
||||
svc.Dependencies = []string{"DNS"}
|
||||
}
|
||||
}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
// WorkingDirectory is not supported on Windows.
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
57
cmd/cli/prometheus.go
Normal file
57
cmd/cli/prometheus.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package cli
|
||||
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
const (
|
||||
metricsLabelListener = "listener"
|
||||
metricsLabelClientSourceIP = "client_source_ip"
|
||||
metricsLabelClientMac = "client_mac"
|
||||
metricsLabelClientHostname = "client_hostname"
|
||||
metricsLabelUpstream = "upstream"
|
||||
metricsLabelRRType = "rr_type"
|
||||
metricsLabelRCode = "rcode"
|
||||
)
|
||||
|
||||
// statsVersion represent ctrld version.
|
||||
var statsVersion = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_build_info",
|
||||
Help: "Version of ctrld process.",
|
||||
}, []string{"gitref", "goversion", "version"})
|
||||
|
||||
// statsTimeStart represents start time of ctrld service.
|
||||
var statsTimeStart = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "ctrld_time_seconds",
|
||||
Help: "Start time of the ctrld process since unix epoch in seconds.",
|
||||
})
|
||||
|
||||
var statsQueriesCountLabels = []string{
|
||||
metricsLabelListener,
|
||||
metricsLabelClientSourceIP,
|
||||
metricsLabelClientMac,
|
||||
metricsLabelClientHostname,
|
||||
metricsLabelUpstream,
|
||||
metricsLabelRRType,
|
||||
metricsLabelRCode,
|
||||
}
|
||||
|
||||
// statsQueriesCount counts total number of queries.
|
||||
var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_queries_count",
|
||||
Help: "Total number of queries.",
|
||||
}, statsQueriesCountLabels)
|
||||
|
||||
// statsClientQueriesCount counts total number of queries of a client.
|
||||
//
|
||||
// The labels "client_source_ip", "client_mac", "client_hostname" are unbounded,
|
||||
// thus this stat is highly inefficient if there are many devices.
|
||||
var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_client_queries_count",
|
||||
Help: "Total number queries of a client.",
|
||||
}, []string{metricsLabelClientSourceIP, metricsLabelClientMac, metricsLabelClientHostname})
|
||||
|
||||
// WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled.
|
||||
func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) {
|
||||
if p.metricsQueryStats.Load() {
|
||||
c.WithLabelValues(lvs...).Inc()
|
||||
}
|
||||
}
|
||||
17
cmd/cli/reload_others.go
Normal file
17
cmd/cli/reload_others.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func notifyReloadSigCh(ch chan os.Signal) {
|
||||
signal.Notify(ch, syscall.SIGUSR1)
|
||||
}
|
||||
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
|
||||
}
|
||||
18
cmd/cli/reload_windows.go
Normal file
18
cmd/cli/reload_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func notifyReloadSigCh(ch chan os.Signal) {}
|
||||
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
select {
|
||||
case p.reloadCh <- struct{}{}:
|
||||
return nil
|
||||
case <-time.After(5 * time.Second):
|
||||
}
|
||||
return errors.New("timeout while sending reload signal")
|
||||
}
|
||||
164
cmd/cli/resolvconf.go
Normal file
164
cmd/cli/resolvconf.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
|
||||
// Returns nil if no nameservers are found.
|
||||
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the file for "nameserver" lines
|
||||
var currentNS []string
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "nameserver") {
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) >= 2 {
|
||||
currentNS = append(currentNS, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentNS, nil
|
||||
}
|
||||
|
||||
// watchResolvConf watches any changes to /etc/resolv.conf file,
|
||||
// and reverting to the original config set by ctrld.
|
||||
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||
resolvConfPath := "/etc/resolv.conf"
|
||||
// Evaluating symbolics link to watch the target file that /etc/resolv.conf point to.
|
||||
if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" {
|
||||
resolvConfPath = rp
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("start watching %s file", resolvConfPath)
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
// We watch /etc instead of /etc/resolv.conf directly,
|
||||
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
|
||||
watchDir := filepath.Dir(resolvConfPath)
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not add %s to watcher list", watchDir)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-p.stopCh:
|
||||
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if p.recoveryRunning.Load() {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Name != resolvConfPath { // skip if not /etc/resolv.conf changes.
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
|
||||
|
||||
// Convert expected nameservers to strings for comparison
|
||||
expectedNS := make([]string, len(ns))
|
||||
for i, addr := range ns {
|
||||
expectedNS[i] = addr.String()
|
||||
}
|
||||
|
||||
var foundNS []string
|
||||
var err error
|
||||
|
||||
maxRetries := 1
|
||||
for retry := 0; retry < maxRetries; retry++ {
|
||||
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content")
|
||||
break
|
||||
}
|
||||
|
||||
// If we found nameservers, break out of retry loop
|
||||
if len(foundNS) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Only retry if we found no nameservers
|
||||
if retry < maxRetries-1 {
|
||||
mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries")
|
||||
}
|
||||
}
|
||||
|
||||
// If we found nameservers, check if they match what we expect
|
||||
if len(foundNS) > 0 {
|
||||
// Check if the nameservers match exactly what we expect
|
||||
matches := len(foundNS) == len(expectedNS)
|
||||
if matches {
|
||||
for i := range foundNS {
|
||||
if foundNS[i] != expectedNS[i] {
|
||||
matches = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().
|
||||
Strs("found", foundNS).
|
||||
Strs("expected", expectedNS).
|
||||
Bool("matches", matches).
|
||||
Msg("checking nameservers")
|
||||
|
||||
// Only revert if the nameservers don't match
|
||||
if !matches {
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
}
|
||||
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf")
|
||||
}
|
||||
}
|
||||
}
|
||||
49
cmd/cli/resolvconf_darwin.go
Normal file
49
cmd/cli/resolvconf_darwin.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
|
||||
)
|
||||
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
servers := make([]string, len(ns))
|
||||
for i := range ns {
|
||||
servers[i] = ns[i].String()
|
||||
}
|
||||
if err := setDNS(iface, servers); err != nil {
|
||||
return err
|
||||
}
|
||||
slices.Sort(servers)
|
||||
curNs := currentDNS(iface)
|
||||
slices.Sort(curNs)
|
||||
if !slices.Equal(curNs, servers) {
|
||||
c, err := resolvconffile.ParseFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Nameservers = ns
|
||||
f, err := os.Create(resolvConfPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := c.Write(f); err != nil {
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
return true
|
||||
}
|
||||
52
cmd/cli/resolvconf_not_darwin_unix.go
Normal file
52
cmd/cli/resolvconf_not_darwin_unix.go
Normal file
@@ -0,0 +1,52 @@
|
||||
//go:build unix && !darwin
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of the resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oc := dns.OSConfig{
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
oc.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file")
|
||||
}
|
||||
return r.SetDNS(oc)
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
switch r.Mode() {
|
||||
case "direct", "resolvconf":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// newLoopbackOSConfigurator creates an OSConfigurator for DNS management using the "lo" interface.
|
||||
func newLoopbackOSConfigurator() (dns.OSConfigurator, error) {
|
||||
return dns.NewOSConfigurator(noopLogf, &health.Tracker{}, &controlknobs.Knobs{}, "lo")
|
||||
}
|
||||
16
cmd/cli/resolvconf_windows.go
Normal file
16
cmd/cli/resolvconf_windows.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(_ *net.Interface, _ []netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
return false
|
||||
}
|
||||
14
cmd/cli/search_domains_unix.go
Normal file
14
cmd/cli/search_domains_unix.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
return resolvconffile.SearchDomains()
|
||||
}
|
||||
43
cmd/cli/search_domains_windows.go
Normal file
43
cmd/cli/search_domains_windows.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
flags := winipcfg.GAAFlagIncludeGateways |
|
||||
winipcfg.GAAFlagIncludePrefix
|
||||
|
||||
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("winipcfg.GetAdaptersAddresses: %w", err)
|
||||
}
|
||||
|
||||
var sds []dnsname.FQDN
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
continue
|
||||
}
|
||||
|
||||
for a := aa.FirstDNSSuffix; a != nil; a = a.Next {
|
||||
d, err := dnsname.ToFQDN(a.String())
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to parse domain: %s", a.String())
|
||||
continue
|
||||
}
|
||||
sds = append(sds, d)
|
||||
}
|
||||
}
|
||||
return sds, nil
|
||||
}
|
||||
7
cmd/cli/self_delete_others.go
Normal file
7
cmd/cli/self_delete_others.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
var supportedSelfDelete = true
|
||||
|
||||
func selfDeleteExe() error { return nil }
|
||||
134
cmd/cli/self_delete_windows.go
Normal file
134
cmd/cli/self_delete_windows.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copied from https://github.com/secur30nly/go-self-delete
|
||||
// with modification to suitable for ctrld usage.
|
||||
|
||||
/*
|
||||
License: MIT Licence
|
||||
|
||||
References:
|
||||
- https://github.com/LloydLabs/delete-self-poc
|
||||
- https://twitter.com/jonasLyk/status/1350401461985955840
|
||||
*/
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var supportedSelfDelete = false
|
||||
|
||||
type FILE_RENAME_INFO struct {
|
||||
Union struct {
|
||||
ReplaceIfExists bool
|
||||
Flags uint32
|
||||
}
|
||||
RootDirectory windows.Handle
|
||||
FileNameLength uint32
|
||||
FileName [1]uint16
|
||||
}
|
||||
|
||||
type FILE_DISPOSITION_INFO struct {
|
||||
DeleteFile bool
|
||||
}
|
||||
|
||||
func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
|
||||
handle, err := windows.CreateFile(
|
||||
pwPath,
|
||||
windows.DELETE,
|
||||
0,
|
||||
nil,
|
||||
windows.OPEN_EXISTING,
|
||||
windows.FILE_ATTRIBUTE_NORMAL,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func dsRenameHandle(hHandle windows.Handle) error {
|
||||
var fRename FILE_RENAME_INFO
|
||||
DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lpwStream := &DS_STREAM_RENAME[0]
|
||||
fRename.FileNameLength = uint32(unsafe.Sizeof(lpwStream))
|
||||
|
||||
windows.NewLazyDLL("kernel32.dll").NewProc("RtlCopyMemory").Call(
|
||||
uintptr(unsafe.Pointer(&fRename.FileName[0])),
|
||||
uintptr(unsafe.Pointer(lpwStream)),
|
||||
unsafe.Sizeof(lpwStream),
|
||||
)
|
||||
|
||||
err = windows.SetFileInformationByHandle(
|
||||
hHandle,
|
||||
windows.FileRenameInfo,
|
||||
(*byte)(unsafe.Pointer(&fRename)),
|
||||
uint32(unsafe.Sizeof(fRename)+unsafe.Sizeof(lpwStream)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func dsDepositeHandle(hHandle windows.Handle) error {
|
||||
var fDelete FILE_DISPOSITION_INFO
|
||||
fDelete.DeleteFile = true
|
||||
|
||||
err := windows.SetFileInformationByHandle(
|
||||
hHandle,
|
||||
windows.FileDispositionInfo,
|
||||
(*byte)(unsafe.Pointer(&fDelete)),
|
||||
uint32(unsafe.Sizeof(fDelete)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func selfDeleteExe() error {
|
||||
var wcPath [windows.MAX_PATH + 1]uint16
|
||||
var hCurrent windows.Handle
|
||||
|
||||
_, err := windows.GetModuleFileName(0, &wcPath[0], windows.MAX_PATH)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hCurrent, err = dsOpenHandle(&wcPath[0])
|
||||
if err != nil || hCurrent == windows.InvalidHandle {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := dsRenameHandle(hCurrent); err != nil {
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
return err
|
||||
}
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
|
||||
hCurrent, err = dsOpenHandle(&wcPath[0])
|
||||
if err != nil || hCurrent == windows.InvalidHandle {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := dsDepositeHandle(hCurrent); err != nil {
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
return err
|
||||
}
|
||||
|
||||
return windows.CloseHandle(hCurrent)
|
||||
}
|
||||
16
cmd/cli/self_kill_others.go
Normal file
16
cmd/cli/self_kill_others.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func selfUninstall(p *prog, logger zerolog.Logger) {
|
||||
if uninstallInvalidCdUID(p, logger, false) {
|
||||
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
45
cmd/cli/self_kill_unix.go
Normal file
45
cmd/cli/self_kill_unix.go
Normal file
@@ -0,0 +1,45 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func selfUninstall(p *prog, logger zerolog.Logger) {
|
||||
if runtime.GOOS == "linux" {
|
||||
selfUninstallLinux(p, logger)
|
||||
}
|
||||
|
||||
bin, err := os.Executable()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("could not determine executable")
|
||||
}
|
||||
args := []string{"uninstall"}
|
||||
if deactivationPinSet() {
|
||||
args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin.Load()))
|
||||
}
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
if err := cmd.Start(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("could not start self uninstall command")
|
||||
}
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||
_ = cmd.Wait()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func selfUninstallLinux(p *prog, logger zerolog.Logger) {
|
||||
if uninstallInvalidCdUID(p, logger, true) {
|
||||
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
12
cmd/cli/self_upgrade_others.go
Normal file
12
cmd/cli/self_upgrade_others.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running a detached child command.
|
||||
func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{Setsid: true}
|
||||
}
|
||||
18
cmd/cli/self_upgrade_windows.go
Normal file
18
cmd/cli/self_upgrade_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// From: https://learn.microsoft.com/en-us/windows/win32/procthread/process-creation-flags?redirectedfrom=MSDN
|
||||
|
||||
// SYSCALL_CREATE_NO_WINDOW set flag to run process without a console window.
|
||||
const SYSCALL_CREATE_NO_WINDOW = 0x08000000
|
||||
|
||||
// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running self-upgrade command.
|
||||
func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | SYSCALL_CREATE_NO_WINDOW,
|
||||
HideWindow: true,
|
||||
}
|
||||
}
|
||||
@@ -4,12 +4,16 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/coreos/go-systemd/v22/unit"
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/openwrt"
|
||||
)
|
||||
|
||||
// newService wraps service.New call to return service.Service
|
||||
@@ -20,14 +24,17 @@ func newService(i service.Interface, c *service.Config) (service.Service, error)
|
||||
return nil, err
|
||||
}
|
||||
switch {
|
||||
case router.IsOldOpenwrt():
|
||||
return &procd{&sysV{s}}, nil
|
||||
case router.IsOldOpenwrt(), router.IsNetGearOrbi():
|
||||
return &procd{sysV: &sysV{s}, svcConfig: c}, nil
|
||||
case router.IsGLiNet():
|
||||
return &sysV{s}, nil
|
||||
case s.Platform() == "unix-systemv":
|
||||
return &sysV{s}, nil
|
||||
case s.Platform() == "linux-systemd":
|
||||
return &systemd{s}, nil
|
||||
case s.Platform() == "darwin-launchd":
|
||||
return newLaunchd(s), nil
|
||||
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
@@ -89,25 +96,31 @@ func (s *sysV) Status() (service.Status, error) {
|
||||
// like old GL.iNET Opal router.
|
||||
type procd struct {
|
||||
*sysV
|
||||
svcConfig *service.Config
|
||||
}
|
||||
|
||||
func (s *procd) Status() (service.Status, error) {
|
||||
if !s.installed() {
|
||||
return service.StatusUnknown, service.ErrNotInstalled
|
||||
}
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, nil
|
||||
bin := s.svcConfig.Executable
|
||||
if bin == "" {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, nil
|
||||
}
|
||||
bin = exe
|
||||
}
|
||||
|
||||
// Looking for something like "/sbin/ctrld run ".
|
||||
shellCmd := fmt.Sprintf("ps | grep -q %q", exe+" [r]un ")
|
||||
shellCmd := fmt.Sprintf("ps | grep -q %q", bin+" [r]un ")
|
||||
if err := exec.Command("sh", "-c", shellCmd).Run(); err != nil {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
return service.StatusRunning, nil
|
||||
}
|
||||
|
||||
// procd wraps a service.Service, and provide status command to
|
||||
// systemd wraps a service.Service, and provide status command to
|
||||
// report the status correctly.
|
||||
type systemd struct {
|
||||
service.Service
|
||||
@@ -121,20 +134,101 @@ func (s *systemd) Status() (service.Status, error) {
|
||||
return s.Service.Status()
|
||||
}
|
||||
|
||||
func (s *systemd) Start() error {
|
||||
const systemdUnitFile = "/etc/systemd/system/ctrld.service"
|
||||
f, err := os.Open(systemdUnitFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if opts, change := ensureSystemdKillMode(f); change {
|
||||
mode := os.FileMode(0644)
|
||||
buf, err := io.ReadAll(unit.Serialize(opts))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(systemdUnitFile, buf, mode); err != nil {
|
||||
return err
|
||||
}
|
||||
if out, err := exec.Command("systemctl", "daemon-reload").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("systemctl daemon-reload failed: %w\n%s", err, string(out))
|
||||
}
|
||||
mainLog.Load().Debug().Msg("set KillMode=process successfully")
|
||||
}
|
||||
return s.Service.Start()
|
||||
}
|
||||
|
||||
// ensureSystemdKillMode ensure systemd unit file is configured with KillMode=process.
|
||||
// This is necessary for running self-upgrade flow.
|
||||
func ensureSystemdKillMode(r io.Reader) (opts []*unit.UnitOption, change bool) {
|
||||
opts, err := unit.DeserializeOptions(r)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to deserialize options")
|
||||
return
|
||||
}
|
||||
change = true
|
||||
needKillModeOpt := true
|
||||
killModeOpt := unit.NewUnitOption("Service", "KillMode", "process")
|
||||
for _, opt := range opts {
|
||||
if opt.Match(killModeOpt) {
|
||||
needKillModeOpt = false
|
||||
change = false
|
||||
break
|
||||
}
|
||||
if opt.Section == killModeOpt.Section && opt.Name == killModeOpt.Name {
|
||||
opt.Value = killModeOpt.Value
|
||||
needKillModeOpt = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if needKillModeOpt {
|
||||
opts = append(opts, killModeOpt)
|
||||
}
|
||||
return opts, change
|
||||
}
|
||||
|
||||
func newLaunchd(s service.Service) *launchd {
|
||||
return &launchd{
|
||||
Service: s,
|
||||
statusErrMsg: "Permission denied",
|
||||
}
|
||||
}
|
||||
|
||||
// launchd wraps a service.Service, and provide status command to
|
||||
// report the status correctly when not running as root on Darwin.
|
||||
//
|
||||
// TODO: remove this wrapper once https://github.com/kardianos/service/issues/400 fixed.
|
||||
type launchd struct {
|
||||
service.Service
|
||||
statusErrMsg string
|
||||
}
|
||||
|
||||
func (l *launchd) Status() (service.Status, error) {
|
||||
if os.Geteuid() != 0 {
|
||||
return service.StatusUnknown, errors.New(l.statusErrMsg)
|
||||
}
|
||||
return l.Service.Status()
|
||||
}
|
||||
|
||||
type task struct {
|
||||
f func() error
|
||||
abortOnError bool
|
||||
Name string
|
||||
}
|
||||
|
||||
func doTasks(tasks []task) bool {
|
||||
var prevErr error
|
||||
for _, task := range tasks {
|
||||
mainLog.Load().Debug().Msgf("Running task %s", task.Name)
|
||||
if err := task.f(); err != nil {
|
||||
if task.abortOnError {
|
||||
mainLog.Load().Error().Msg(errors.Join(prevErr, err).Error())
|
||||
mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err)
|
||||
return false
|
||||
}
|
||||
prevErr = err
|
||||
// if this is darwin stop command, dont print debug
|
||||
// since launchctl complains on every start
|
||||
if runtime.GOOS != "darwin" || task.Name != "Stop" {
|
||||
mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
@@ -155,6 +249,13 @@ func checkHasElevatedPrivilege() {
|
||||
func unixSystemVServiceStatus() (service.Status, error) {
|
||||
out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput()
|
||||
if err != nil {
|
||||
// Specific case for openwrt >= 24.10, it returns non-success code
|
||||
// for above status command, which may not right.
|
||||
if router.Name() == openwrt.Name {
|
||||
if string(bytes.ToLower(bytes.TrimSpace(out))) == "inactive" {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
}
|
||||
return service.StatusUnknown, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,3 +9,14 @@ import (
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
return os.Geteuid() == 0, nil
|
||||
}
|
||||
|
||||
func openLogFile(path string, flags int) (*os.File, error) {
|
||||
return os.OpenFile(path, flags, os.FileMode(0o600))
|
||||
}
|
||||
|
||||
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
||||
func hasLocalDnsServerRunning() bool { return false }
|
||||
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
|
||||
|
||||
func isRunningOnDomainControllerWindows() (bool, int) { return false, 0 }
|
||||
|
||||
28
cmd/cli/service_test.go
Normal file
28
cmd/cli/service_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_ensureSystemdKillMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
unitFile string
|
||||
wantChange bool
|
||||
}{
|
||||
{"no KillMode", "[Service]\nExecStart=/bin/sleep 1", true},
|
||||
{"not KillMode=process", "[Service]\nExecStart=/bin/sleep 1\nKillMode=mixed", true},
|
||||
{"KillMode=process", "[Service]\nExecStart=/bin/sleep 1\nKillMode=process", false},
|
||||
{"invalid unit file", "[Service\nExecStart=/bin/sleep 1\nKillMode=process", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if _, change := ensureSystemdKillMode(strings.NewReader(tc.unitFile)); tc.wantChange != change {
|
||||
t.Errorf("ensureSystemdKillMode(%q) = %v, want %v", tc.unitFile, change, tc.wantChange)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,22 @@
|
||||
package cli
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
"github.com/microsoft/wmi/pkg/base/instance"
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
var sid *windows.SID
|
||||
@@ -22,3 +38,190 @@ func hasElevatedPrivilege() (bool, error) {
|
||||
token := windows.Token(0)
|
||||
return token.IsMember(sid)
|
||||
}
|
||||
|
||||
// ConfigureWindowsServiceFailureActions checks if the given service
|
||||
// has the correct failure actions configured, and updates them if not.
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil // no-op on non-Windows
|
||||
}
|
||||
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(serviceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// 1. Retrieve the current config
|
||||
cfg, err := s.Config()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Update the Description
|
||||
cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy"
|
||||
|
||||
// 3. Apply the updated config
|
||||
if err := s.UpdateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then proceed with existing actions, e.g. setting failure actions
|
||||
actions := []mgr.RecoveryAction{
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
}
|
||||
|
||||
// Set the recovery actions (3 restarts, reset period = 120).
|
||||
err = s.SetRecoveryActions(actions, 120)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that failure actions are NOT triggered on user-initiated stops.
|
||||
var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG
|
||||
failureActionsFlag.FailureActionsOnNonCrashFailures = 0
|
||||
|
||||
if err := windows.ChangeServiceConfig2(
|
||||
s.Handle,
|
||||
windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG,
|
||||
(*byte)(unsafe.Pointer(&failureActionsFlag)),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func openLogFile(path string, mode int) (*os.File, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
|
||||
}
|
||||
|
||||
pathP, err := syscall.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var access uint32
|
||||
switch mode & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR) {
|
||||
case os.O_RDONLY:
|
||||
access = windows.GENERIC_READ
|
||||
case os.O_WRONLY:
|
||||
access = windows.GENERIC_WRITE
|
||||
case os.O_RDWR:
|
||||
access = windows.GENERIC_READ | windows.GENERIC_WRITE
|
||||
}
|
||||
if mode&os.O_CREATE != 0 {
|
||||
access |= windows.GENERIC_WRITE
|
||||
}
|
||||
if mode&os.O_APPEND != 0 {
|
||||
access &^= windows.GENERIC_WRITE
|
||||
access |= windows.FILE_APPEND_DATA
|
||||
}
|
||||
|
||||
shareMode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
|
||||
|
||||
var sa *syscall.SecurityAttributes
|
||||
|
||||
var createMode uint32
|
||||
switch {
|
||||
case mode&(os.O_CREATE|os.O_EXCL) == (os.O_CREATE | os.O_EXCL):
|
||||
createMode = windows.CREATE_NEW
|
||||
case mode&(os.O_CREATE|os.O_TRUNC) == (os.O_CREATE | os.O_TRUNC):
|
||||
createMode = windows.CREATE_ALWAYS
|
||||
case mode&os.O_CREATE == os.O_CREATE:
|
||||
createMode = windows.OPEN_ALWAYS
|
||||
case mode&os.O_TRUNC == os.O_TRUNC:
|
||||
createMode = windows.TRUNCATE_EXISTING
|
||||
default:
|
||||
createMode = windows.OPEN_EXISTING
|
||||
}
|
||||
|
||||
handle, err := syscall.CreateFile(pathP, access, shareMode, sa, createMode, syscall.FILE_ATTRIBUTE_NORMAL, 0)
|
||||
if err != nil {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: err}
|
||||
}
|
||||
|
||||
return os.NewFile(uintptr(handle), path), nil
|
||||
}
|
||||
|
||||
const processEntrySize = uint32(unsafe.Sizeof(windows.ProcessEntry32{}))
|
||||
|
||||
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
||||
func hasLocalDnsServerRunning() bool {
|
||||
h, e := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0)
|
||||
if e != nil {
|
||||
return false
|
||||
}
|
||||
p := windows.ProcessEntry32{Size: processEntrySize}
|
||||
for {
|
||||
e := windows.Process32Next(h, &p)
|
||||
if e != nil {
|
||||
return false
|
||||
}
|
||||
if strings.ToLower(windows.UTF16ToString(p.ExeFile[:])) == "dns.exe" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isRunningOnDomainControllerWindows() (bool, int) {
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("Win32_ComputerSystem")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("WMI query failed")
|
||||
return false, 0
|
||||
}
|
||||
if instances == nil {
|
||||
mainLog.Load().Debug().Msg("WMI query returned nil instances")
|
||||
return false, 0
|
||||
}
|
||||
defer instances.Close()
|
||||
|
||||
if len(instances) == 0 {
|
||||
mainLog.Load().Debug().Msg("no rows returned from Win32_ComputerSystem")
|
||||
return false, 0
|
||||
}
|
||||
|
||||
val, err := instances[0].GetProperty("DomainRole")
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get DomainRole property")
|
||||
return false, 0
|
||||
}
|
||||
if val == nil {
|
||||
mainLog.Load().Debug().Msg("DomainRole property is nil")
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// Safely handle varied types: string or integer
|
||||
var roleInt int
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
// "4", "5", etc.
|
||||
parsed, parseErr := strconv.Atoi(v)
|
||||
if parseErr != nil {
|
||||
mainLog.Load().Debug().Err(parseErr).Msgf("failed to parse DomainRole value %q", v)
|
||||
return false, 0
|
||||
}
|
||||
roleInt = parsed
|
||||
case int8, int16, int32, int64:
|
||||
roleInt = int(reflect.ValueOf(v).Int())
|
||||
case uint8, uint16, uint32, uint64:
|
||||
roleInt = int(reflect.ValueOf(v).Uint())
|
||||
default:
|
||||
mainLog.Load().Debug().Msgf("unexpected DomainRole type: %T value=%v", v, v)
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// Check if role indicates a domain controller
|
||||
isDC := roleInt == BackupDomainController || roleInt == PrimaryDomainController
|
||||
return isDC, roleInt
|
||||
}
|
||||
|
||||
25
cmd/cli/service_windows_test.go
Normal file
25
cmd/cli/service_windows_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_hasLocalDnsServerRunning(t *testing.T) {
|
||||
start := time.Now()
|
||||
hasDns := hasLocalDnsServerRunning()
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
hasDnsPowershell := hasLocalDnsServerRunningPowershell()
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
if hasDns != hasDnsPowershell {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", hasDnsPowershell, hasDns)
|
||||
}
|
||||
}
|
||||
|
||||
func hasLocalDnsServerRunningPowershell() bool {
|
||||
_, err := powershell("Get-Process -Name DNS")
|
||||
return err == nil
|
||||
}
|
||||
126
cmd/cli/upstream_monitor.go
Normal file
126
cmd/cli/upstream_monitor.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down.
|
||||
maxFailureRequest = 50
|
||||
// checkUpstreamBackoffSleep is the time interval between each upstream checks.
|
||||
checkUpstreamBackoffSleep = 2 * time.Second
|
||||
)
|
||||
|
||||
// upstreamMonitor performs monitoring upstreams health.
|
||||
type upstreamMonitor struct {
|
||||
cfg *ctrld.Config
|
||||
|
||||
mu sync.RWMutex
|
||||
checking map[string]bool
|
||||
down map[string]bool
|
||||
failureReq map[string]uint64
|
||||
recovered map[string]bool
|
||||
|
||||
// failureTimerActive tracks if a timer is already running for a given upstream.
|
||||
failureTimerActive map[string]bool
|
||||
}
|
||||
|
||||
func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
um := &upstreamMonitor{
|
||||
cfg: cfg,
|
||||
checking: make(map[string]bool),
|
||||
down: make(map[string]bool),
|
||||
failureReq: make(map[string]uint64),
|
||||
recovered: make(map[string]bool),
|
||||
failureTimerActive: make(map[string]bool),
|
||||
}
|
||||
for n := range cfg.Upstream {
|
||||
upstream := upstreamPrefix + n
|
||||
um.reset(upstream)
|
||||
}
|
||||
um.reset(upstreamOS)
|
||||
return um
|
||||
}
|
||||
|
||||
// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information.
|
||||
// It uses a timer to debounce failure detection, ensuring that an upstream is marked as down
|
||||
// within 10 seconds if failures persist, without spawning duplicate goroutines.
|
||||
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
if um.recovered[upstream] {
|
||||
mainLog.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream)
|
||||
return
|
||||
}
|
||||
|
||||
um.failureReq[upstream] += 1
|
||||
failedCount := um.failureReq[upstream]
|
||||
|
||||
// Log the updated failure count.
|
||||
mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount)
|
||||
|
||||
// If this is the first failure and no timer is running, start a 10-second timer.
|
||||
if failedCount == 1 && !um.failureTimerActive[upstream] {
|
||||
um.failureTimerActive[upstream] = true
|
||||
go func(upstream string) {
|
||||
time.Sleep(10 * time.Second)
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
// If no success occurred during the 10-second window (i.e. counter remains > 0)
|
||||
// and the upstream is not in a recovered state, mark it as down.
|
||||
if um.failureReq[upstream] > 0 && !um.recovered[upstream] {
|
||||
um.down[upstream] = true
|
||||
mainLog.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream])
|
||||
}
|
||||
// Reset the timer flag so that a new timer can be spawned if needed.
|
||||
um.failureTimerActive[upstream] = false
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
// If the failure count quickly reaches the threshold, mark the upstream as down immediately.
|
||||
if failedCount >= maxFailureRequest {
|
||||
um.down[upstream] = true
|
||||
mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount)
|
||||
}
|
||||
}
|
||||
|
||||
// isDown reports whether the given upstream is being marked as down.
|
||||
func (um *upstreamMonitor) isDown(upstream string) bool {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
return um.down[upstream]
|
||||
}
|
||||
|
||||
// reset marks an upstream as up and set failed queries counter to zero.
|
||||
func (um *upstreamMonitor) reset(upstream string) {
|
||||
um.mu.Lock()
|
||||
um.failureReq[upstream] = 0
|
||||
um.down[upstream] = false
|
||||
um.recovered[upstream] = true
|
||||
um.mu.Unlock()
|
||||
go func() {
|
||||
// debounce the recovery to avoid incrementing failure counts already in flight
|
||||
time.Sleep(1 * time.Second)
|
||||
um.mu.Lock()
|
||||
um.recovered[upstream] = false
|
||||
um.mu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// countHealthy returns the number of upstreams in the provided map that are considered healthy.
|
||||
func (um *upstreamMonitor) countHealthy(upstreams []string) int {
|
||||
var count int
|
||||
um.mu.RLock()
|
||||
for _, upstream := range upstreams {
|
||||
if !um.down[upstream] {
|
||||
count++
|
||||
}
|
||||
}
|
||||
um.mu.RUnlock()
|
||||
return count
|
||||
}
|
||||
20
cmd/cli/winres/winres.json
Normal file
20
cmd/cli/winres/winres.json
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"RT_VERSION": {
|
||||
"#1": {
|
||||
"0000": {
|
||||
"fixed": {
|
||||
"file_version": "0.0.0.1"
|
||||
},
|
||||
"info": {
|
||||
"0409": {
|
||||
"CompanyName": "ControlD Inc",
|
||||
"FileDescription": "Control D DNS daemon",
|
||||
"ProductName": "ctrld",
|
||||
"InternalName": "ctrld",
|
||||
"LegalCopyright": "ControlD Inc 2024"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
4
cmd/cli/winres_windows.go
Normal file
4
cmd/cli/winres_windows.go
Normal file
@@ -0,0 +1,4 @@
|
||||
//go:generate go-winres make --product-version=git-tag --file-version=git-tag
|
||||
package cli
|
||||
|
||||
// Placeholder file for windows builds.
|
||||
@@ -1,7 +1,13 @@
|
||||
package main
|
||||
|
||||
import "github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cli.Main()
|
||||
// make sure we exit with 0 if there are no errors
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
82
cmd/ctrld_library/main.go
Normal file
82
cmd/ctrld_library/main.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package ctrld_library
|
||||
|
||||
import (
|
||||
"github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
)
|
||||
|
||||
// Controller holds global state
|
||||
type Controller struct {
|
||||
stopCh chan struct{}
|
||||
AppCallback AppCallback
|
||||
Config cli.AppConfig
|
||||
}
|
||||
|
||||
// NewController provides reference to global state to be managed by android vpn service and iOS network extension.
|
||||
// reference is not safe for concurrent use.
|
||||
func NewController(appCallback AppCallback) *Controller {
|
||||
return &Controller{AppCallback: appCallback}
|
||||
}
|
||||
|
||||
// AppCallback provides access to app instance.
|
||||
type AppCallback interface {
|
||||
Hostname() string
|
||||
LanIp() string
|
||||
MacAddress() string
|
||||
Exit(error string)
|
||||
}
|
||||
|
||||
// Start configures utility with config.toml from provided directory.
|
||||
// This function will block until Stop is called
|
||||
// Check port availability prior to calling it.
|
||||
func (c *Controller) Start(CdUID string, ProvisionID string, CustomHostname string, HomeDir string, UpstreamProto string, logLevel int, logPath string) {
|
||||
if c.stopCh == nil {
|
||||
c.stopCh = make(chan struct{})
|
||||
c.Config = cli.AppConfig{
|
||||
CdUID: CdUID,
|
||||
ProvisionID: ProvisionID,
|
||||
CustomHostname: CustomHostname,
|
||||
HomeDir: HomeDir,
|
||||
UpstreamProto: UpstreamProto,
|
||||
Verbose: logLevel,
|
||||
LogPath: logPath,
|
||||
}
|
||||
appCallback := mapCallback(c.AppCallback)
|
||||
cli.RunMobile(&c.Config, &appCallback, c.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
// As workaround to avoid circular dependency between cli and ctrld_library module
|
||||
func mapCallback(callback AppCallback) cli.AppCallback {
|
||||
return cli.AppCallback{
|
||||
HostName: func() string {
|
||||
return callback.Hostname()
|
||||
},
|
||||
LanIp: func() string {
|
||||
return callback.LanIp()
|
||||
},
|
||||
MacAddress: func() string {
|
||||
return callback.MacAddress()
|
||||
},
|
||||
Exit: func(err string) {
|
||||
callback.Exit(err)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) Stop(restart bool, pin int64) int {
|
||||
var errorCode = 0
|
||||
// Force disconnect without checking pin.
|
||||
// In iOS restart is required if vpn detects no connectivity after network change.
|
||||
if !restart {
|
||||
errorCode = cli.CheckDeactivationPin(pin, c.stopCh)
|
||||
}
|
||||
if errorCode == 0 && c.stopCh != nil {
|
||||
close(c.stopCh)
|
||||
c.stopCh = nil
|
||||
}
|
||||
return errorCode
|
||||
}
|
||||
|
||||
func (c *Controller) IsRunning() bool {
|
||||
return c.stopCh != nil
|
||||
}
|
||||
585
cmd/ctrld_library/netstack/README.md
Normal file
585
cmd/ctrld_library/netstack/README.md
Normal file
@@ -0,0 +1,585 @@
|
||||
# Netstack - Full Packet Capture for Mobile VPN
|
||||
|
||||
Complete TCP/UDP/DNS packet capture implementation using gVisor netstack for Android and iOS.
|
||||
|
||||
## Overview
|
||||
|
||||
Provides full packet capture for mobile VPN applications:
|
||||
- **DNS filtering** through ControlD proxy
|
||||
- **IP whitelisting** - only allows connections to DNS-resolved IPs
|
||||
- **TCP forwarding** for all TCP traffic (with whitelist enforcement)
|
||||
- **UDP forwarding** with session tracking (with whitelist enforcement)
|
||||
- **QUIC blocking** for better content filtering
|
||||
|
||||
## Master Architecture Diagram
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ MOBILE APP (Android/iOS) │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
||||
│ │ VPN Configuration │ │
|
||||
│ │ │ │
|
||||
│ │ Android: iOS: │ │
|
||||
│ │ ┌──────────────────────┐ ┌──────────────────────┐ │ │
|
||||
│ │ │ Builder() │ │ NEIPv4Settings │ │ │
|
||||
│ │ │ .addAddress( │ │ addresses: [ │ │ │
|
||||
│ │ │ "10.0.0.2", 24) │ │ "10.0.0.2"] │ │ │
|
||||
│ │ │ .addDnsServer( │ │ │ │ │
|
||||
│ │ │ "10.0.0.1") │ │ NEDNSSettings │ │ │
|
||||
│ │ │ │ │ servers: [ │ │ │
|
||||
│ │ │ FIREWALL MODE: │ │ "10.0.0.1"] │ │ │
|
||||
│ │ │ .addRoute( │ │ │ │ │
|
||||
│ │ │ "0.0.0.0", 0) │ │ FIREWALL MODE: │ │ │
|
||||
│ │ │ │ │ includedRoutes: │ │ │
|
||||
│ │ │ DNS-ONLY MODE: │ │ [.default()] │ │ │
|
||||
│ │ │ .addRoute( │ │ │ │ │
|
||||
│ │ │ "10.0.0.1", 32) │ │ DNS-ONLY MODE: │ │ │
|
||||
│ │ │ │ │ includedRoutes: │ │ │
|
||||
│ │ │ .addDisallowedApp( │ │ [10.0.0.1/32] │ │ │
|
||||
│ │ │ "com.controld.*") │ │ │ │ │
|
||||
│ │ └──────────────────────┘ └──────────────────────┘ │ │
|
||||
│ │ │ │
|
||||
│ │ Result: │ │
|
||||
│ │ • Firewall: ALL traffic → VPN │ │
|
||||
│ │ • DNS-only: ONLY DNS (port 53) → VPN │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────┘ │
|
||||
└──────────────────────────┬───────────────────────────────────────────────────┘
|
||||
│ Packets
|
||||
↓
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ GOMOBILE LIBRARY (ctrld_library) │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
||||
│ │ PacketCaptureController.StartWithPacketCapture() │ │
|
||||
│ │ │ │
|
||||
│ │ Parameters: │ │
|
||||
│ │ • tunAddress: "10.0.0.1" (gateway) │ │
|
||||
│ │ • deviceAddress: "10.0.0.2" (device IP) │ │
|
||||
│ │ • dnsProxyAddress: "127.0.0.1:5354" (Android) / ":53" (iOS) │ │
|
||||
│ │ • cdUID, upstreamProto, etc. │ │
|
||||
│ └──────────────────────────┬──────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ↓ │
|
||||
│ ┌──────────────────────────────────────────────────────────────────┐ │
|
||||
│ │ NETSTACK CONTROLLER │ │
|
||||
│ │ │ │
|
||||
│ │ Components: │ │
|
||||
│ │ ┌────────────────┐ ┌─────────────┐ ┌──────────────┐ │ │
|
||||
│ │ │ DNS Filter │ │ IP Tracker │ │ TCP Forwarder│ │ │
|
||||
│ │ │ (port 53) │ │ (5min TTL) │ │ (firewall) │ │ │
|
||||
│ │ └────────────────┘ └─────────────┘ └──────────────┘ │ │
|
||||
│ │ ┌────────────────┐ │ │
|
||||
│ │ │ UDP Forwarder │ │ │
|
||||
│ │ │ (firewall) │ │ │
|
||||
│ │ └────────────────┘ │ │
|
||||
│ └──────────────────────────┬───────────────────────────────────────┘ │
|
||||
└─────────────────────────────┼───────────────────────────────────────────────┘
|
||||
│
|
||||
↓
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ PACKET FLOW DETAILS │
|
||||
│ │
|
||||
│ INCOMING PACKET (from TUN) │
|
||||
│ │ │
|
||||
│ ├──→ Is DNS? (port 53) │
|
||||
│ │ ├─ YES → DNS Filter │
|
||||
│ │ │ ├─→ Forward to ControlD DNS Proxy │
|
||||
│ │ │ │ (127.0.0.1:5354 or 127.0.0.1:53) │
|
||||
│ │ │ ├─→ Get DNS response │
|
||||
│ │ │ ├─→ Extract A/AAAA records │
|
||||
│ │ │ ├─→ TrackIP() for each resolved IP │
|
||||
│ │ │ │ • Store: resolvedIPs["93.184.216.34"] = now+5min │
|
||||
│ │ │ └─→ Return DNS response to app │
|
||||
│ │ │ │
|
||||
│ │ └─ NO → Is TCP/UDP? │
|
||||
│ │ │ │
|
||||
│ │ ├──→ TCP Packet │
|
||||
│ │ │ ├─→ Extract destination IP │
|
||||
│ │ │ ├─→ Check: ipTracker.IsTracked(destIP) │
|
||||
│ │ │ │ ├─ NOT TRACKED → BLOCK │
|
||||
│ │ │ │ │ Log: "BLOCKED hardcoded IP" │
|
||||
│ │ │ │ │ Return (connection reset) │
|
||||
│ │ │ │ │ │
|
||||
│ │ │ │ └─ TRACKED → ALLOW │
|
||||
│ │ │ │ net.Dial("tcp", destIP) │
|
||||
│ │ │ │ Bidirectional copy (app ↔ internet) │
|
||||
│ │ │ │ │
|
||||
│ │ └──→ UDP Packet │
|
||||
│ │ ├─→ Is QUIC? (port 443/80) │
|
||||
│ │ │ └─ YES → BLOCK (force TCP fallback) │
|
||||
│ │ │ │
|
||||
│ │ ├─→ Extract destination IP │
|
||||
│ │ ├─→ Check: ipTracker.IsTracked(destIP) │
|
||||
│ │ │ ├─ NOT TRACKED → BLOCK │
|
||||
│ │ │ │ Log: "BLOCKED hardcoded IP" │
|
||||
│ │ │ │ Return (drop packet) │
|
||||
│ │ │ │ │
|
||||
│ │ │ └─ TRACKED → ALLOW │
|
||||
│ │ │ net.Dial("udp", destIP) │
|
||||
│ │ │ Forward packets (app ↔ internet) │
|
||||
│ │ │ 30s timeout per session │
|
||||
│ │ │
|
||||
│ IP TRACKER STATE (in-memory map): │
|
||||
│ ┌────────────────────────────────────────────────────────────┐ │
|
||||
│ │ resolvedIPs map: │ │
|
||||
│ │ │ │
|
||||
│ │ "93.184.216.34" → expires: 2026-03-20 23:35:00 │ │
|
||||
│ │ "2606:2800:220::1" → expires: 2026-03-20 23:36:15 │ │
|
||||
│ │ "8.8.8.8" → expires: 2026-03-20 23:37:42 │ │
|
||||
│ │ │ │
|
||||
│ │ Cleanup: Every 30 seconds, remove expired entries │ │
|
||||
│ │ TTL: 5 minutes (configurable) │ │
|
||||
│ └────────────────────────────────────────────────────────────┘ │
|
||||
│ │
|
||||
│ EXAMPLE SCENARIO: │
|
||||
│ ─────────────────────────────────────────────────────────────────────── │
|
||||
│ │
|
||||
│ T=0s: App tries: connect(1.2.3.4:443) │
|
||||
│ → IsTracked(1.2.3.4)? NO │
|
||||
│ → ❌ BLOCKED │
|
||||
│ │
|
||||
│ T=1s: App queries: DNS "example.com" │
|
||||
│ → Response: A 93.184.216.34 │
|
||||
│ → TrackIP(93.184.216.34) with TTL=5min │
|
||||
│ │
|
||||
│ T=2s: App tries: connect(93.184.216.34:443) │
|
||||
│ → IsTracked(93.184.216.34)? YES (expires T+301s) │
|
||||
│ → ✅ ALLOWED │
|
||||
│ │
|
||||
│ T=302s: App tries: connect(93.184.216.34:443) │
|
||||
│ → IsTracked(93.184.216.34)? NO (expired) │
|
||||
│ → ❌ BLOCKED (must do DNS again) │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ MODE COMPARISON (Firewall vs DNS-only) │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────┬─────────────────────────────────┐ │
|
||||
│ │ FIREWALL MODE │ DNS-ONLY MODE │ │
|
||||
│ │ (Default Routes Configured) │ (Only DNS Route Configured) │ │
|
||||
│ ├─────────────────────────────────┼─────────────────────────────────┤ │
|
||||
│ │ Routes (Android): │ Routes (Android): │ │
|
||||
│ │ • addRoute("0.0.0.0", 0) │ • addRoute("10.0.0.1", 32) │ │
|
||||
│ │ │ │ │
|
||||
│ │ Routes (iOS): │ Routes (iOS): │ │
|
||||
│ │ • includedRoutes: [.default()] │ • includedRoutes: │ │
|
||||
│ │ │ [10.0.0.1/32] │ │
|
||||
│ ├─────────────────────────────────┼─────────────────────────────────┤ │
|
||||
│ │ Traffic Sent to VPN: │ Traffic Sent to VPN: │ │
|
||||
│ │ ✅ DNS (port 53) │ ✅ DNS (port 53) │ │
|
||||
│ │ ✅ TCP (all ports) │ ❌ TCP (bypasses VPN) │ │
|
||||
│ │ ✅ UDP (all ports) │ ❌ UDP (bypasses VPN) │ │
|
||||
│ ├─────────────────────────────────┼─────────────────────────────────┤ │
|
||||
│ │ IP Tracker Behavior: │ IP Tracker Behavior: │ │
|
||||
│ │ • Tracks DNS-resolved IPs │ • Tracks DNS-resolved IPs │ │
|
||||
│ │ • Blocks hardcoded TCP/UDP IPs │ • No TCP/UDP to block │ │
|
||||
│ │ • Enforces DNS-first policy │ • N/A (no non-DNS traffic) │ │
|
||||
│ ├─────────────────────────────────┼─────────────────────────────────┤ │
|
||||
│ │ Use Case: │ Use Case: │ │
|
||||
│ │ • Full content filtering │ • DNS filtering only │ │
|
||||
│ │ • Block DNS bypass attempts │ • Minimal battery impact │ │
|
||||
│ │ • Enforce ControlD policies │ • Fast web browsing │ │
|
||||
│ └─────────────────────────────────┴─────────────────────────────────┘ │
|
||||
│ │
|
||||
│ MODE SWITCHING: │
|
||||
│ • Android: VpnController.setFirewallMode(enabled) → recreates VPN │
|
||||
│ • iOS: sendProviderMessage("set_firewall_mode") → updates routes │
|
||||
│ • Both: No app restart needed │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ DETAILED PACKET FLOW (Firewall Mode) │
|
||||
│ │
|
||||
│ 1. APP MAKES REQUEST │
|
||||
│ ─────────────────────────────────────────────────────────────────────── │
|
||||
│ App: connect("example.com", 443) │
|
||||
│ ↓ │
|
||||
│ OS: Perform DNS lookup for "example.com" │
|
||||
│ ↓ │
|
||||
│ OS: Send DNS query to VPN DNS server (10.0.0.1) │
|
||||
│ │
|
||||
│ 2. DNS PACKET FLOW │
|
||||
│ ─────────────────────────────────────────────────────────────────────── │
|
||||
│ [DNS Query Packet: 10.0.0.2:12345 → 10.0.0.1:53] │
|
||||
│ ↓ │
|
||||
│ TUN Interface → readPacket() │
|
||||
│ ↓ │
|
||||
│ DNSFilter.ProcessPacket() │
|
||||
│ ├─ Detect port 53 (DNS) │
|
||||
│ ├─ Extract DNS payload │
|
||||
│ ├─ Forward to ControlD DNS proxy (127.0.0.1:5354 or :53) │
|
||||
│ │ ↓ │
|
||||
│ │ ControlD DNS Proxy │
|
||||
│ │ ├─ Apply filtering rules │
|
||||
│ │ ├─ Query upstream DNS (DoH/DoT/DoQ) │
|
||||
│ │ └─ Return response: A 93.184.216.34 │
|
||||
│ │ ↓ │
|
||||
│ ├─ Parse DNS response │
|
||||
│ ├─ extractAndTrackIPs() │
|
||||
│ │ └─ IPTracker.TrackIP(93.184.216.34) │
|
||||
│ │ • Store: resolvedIPs["93.184.216.34"] = now + 5min │
|
||||
│ ├─ Build DNS response packet │
|
||||
│ └─ writePacket() → TUN → App │
|
||||
│ │
|
||||
│ OS receives DNS response → resolves "example.com" to 93.184.216.34 │
|
||||
│ │
|
||||
│ 3. TCP CONNECTION FLOW │
|
||||
│ ─────────────────────────────────────────────────────────────────────── │
|
||||
│ OS: connect(93.184.216.34:443) │
|
||||
│ ↓ │
|
||||
│ [TCP SYN Packet: 10.0.0.2:54321 → 93.184.216.34:443] │
|
||||
│ ↓ │
|
||||
│ TUN Interface → readPacket() │
|
||||
│ ↓ │
|
||||
│ gVisor Netstack → TCPForwarder.handleConnection() │
|
||||
│ ├─ Extract destination IP: 93.184.216.34 │
|
||||
│ ├─ Check internal VPN subnet (10.0.0.0/24)? │
|
||||
│ │ └─ NO (skip check) │
|
||||
│ ├─ ipTracker.IsTracked(93.184.216.34)? │
|
||||
│ │ ├─ Check resolvedIPs map │
|
||||
│ │ ├─ Found: expires at T+300s │
|
||||
│ │ ├─ Not expired yet │
|
||||
│ │ └─ YES ✅ │
|
||||
│ ├─ ALLOWED - create upstream connection │
|
||||
│ ├─ net.Dial("tcp", "93.184.216.34:443") │
|
||||
│ │ ↓ │
|
||||
│ │ [Real Network Connection] │
|
||||
│ │ ↓ │
|
||||
│ └─ Bidirectional copy (TUN ↔ Internet) │
|
||||
│ │
|
||||
│ 4. BLOCKED SCENARIO (Hardcoded IP) │
|
||||
│ ─────────────────────────────────────────────────────────────────────── │
|
||||
│ App: connect(1.2.3.4:443) // Hardcoded IP, no DNS! │
|
||||
│ ↓ │
|
||||
│ [TCP SYN Packet: 10.0.0.2:54322 → 1.2.3.4:443] │
|
||||
│ ↓ │
|
||||
│ TUN Interface → readPacket() │
|
||||
│ ↓ │
|
||||
│ gVisor Netstack → TCPForwarder.handleConnection() │
|
||||
│ ├─ Extract destination IP: 1.2.3.4 │
|
||||
│ ├─ ipTracker.IsTracked(1.2.3.4)? │
|
||||
│ │ └─ Check resolvedIPs map → NOT FOUND │
|
||||
│ │ └─ NO ❌ │
|
||||
│ ├─ BLOCKED │
|
||||
│ ├─ Log: "[TCP] BLOCKED hardcoded IP: 10.0.0.2:54322 → 1.2.3.4:443" │
|
||||
│ └─ Return (send TCP RST to app) │
|
||||
│ │
|
||||
│ App receives connection refused/reset │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
||||
│ PLATFORM-SPECIFIC DETAILS │
|
||||
│ │
|
||||
│ ANDROID │
|
||||
│ ──────────────────────────────────────────────────────────────────────── │
|
||||
│ • VPN Config: ControlDService.kt │
|
||||
│ • Packet I/O: FileInputStream/FileOutputStream on VPN fd │
|
||||
│ • DNS Proxy: Listens on 0.0.0.0:5354 (connects via 127.0.0.1:5354) │
|
||||
│ • Self-Exclusion: addDisallowedApplication(packageName) │
|
||||
│ • Mode Switch: Recreates VPN interface with new routes │
|
||||
│ • No routing loops: App traffic bypasses VPN │
|
||||
│ │
|
||||
│ IOS │
|
||||
│ ──────────────────────────────────────────────────────────────────────── │
|
||||
│ • VPN Config: PacketTunnelProvider.swift │
|
||||
│ • Packet I/O: NEPacketTunnelFlow (async → blocking via PacketQueue) │
|
||||
│ • DNS Proxy: Listens on 127.0.0.1:53 │
|
||||
│ • Self-Exclusion: Network Extension sockets auto-bypass │
|
||||
│ • Mode Switch: setTunnelNetworkSettings() with new routes │
|
||||
│ • Write Batching: 16 packets per batch, 5ms flush timer │
|
||||
│ • No routing loops: Extension traffic bypasses VPN │
|
||||
└──────────────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
## Components
|
||||
|
||||
### DNS Filter (`dns_filter.go`)
|
||||
- Detects DNS packets on port 53 (UDP/TCP)
|
||||
- Forwards to ControlD DNS proxy (via DNS bridge)
|
||||
- Parses DNS responses to extract A/AAAA records
|
||||
- Automatically tracks resolved IPs via IP Tracker
|
||||
- Builds DNS response packets and sends back to TUN
|
||||
|
||||
### DNS Bridge (`dns_bridge.go`)
|
||||
- Bridges between netstack and ControlD DNS proxy
|
||||
- Tracks DNS queries by transaction ID
|
||||
- 5-second timeout per query
|
||||
- Returns responses to DNS filter
|
||||
|
||||
### IP Tracker (`ip_tracker.go`)
|
||||
- **Always enabled** - tracks all DNS-resolved IPs
|
||||
- In-memory whitelist with 5-minute TTL per IP
|
||||
- Background cleanup every 30 seconds (removes expired IPs)
|
||||
- Thread-safe with RWMutex (optimized for read-heavy workload)
|
||||
- Used by TCP/UDP forwarders to enforce DNS-first policy
|
||||
|
||||
### TCP Forwarder (`tcp_forwarder.go`)
|
||||
- Handles TCP connections via gVisor's `tcp.NewForwarder()`
|
||||
- Checks `ipTracker != nil` (always true) for firewall enforcement
|
||||
- Allows internal VPN subnet (10.0.0.0/24) without checks
|
||||
- Blocks connections to non-tracked IPs (logs: "BLOCKED hardcoded IP")
|
||||
- Forwards allowed connections via `net.Dial("tcp")` to real network
|
||||
- Bidirectional copy between TUN and internet
|
||||
|
||||
### UDP Forwarder (`udp_forwarder.go`)
|
||||
- Handles UDP packets via gVisor's `udp.NewForwarder()`
|
||||
- Session tracking with 30-second read timeout
|
||||
- Checks `ipTracker != nil` (always true) for firewall enforcement
|
||||
- Blocks QUIC (UDP/443, UDP/80) to force TCP fallback
|
||||
- Blocks connections to non-tracked IPs (logs: "BLOCKED hardcoded IP")
|
||||
- Forwards allowed packets via `net.Dial("udp")` to real network
|
||||
|
||||
### Packet Handler (`packet_handler.go`)
|
||||
- Interface for TUN I/O operations (read, write, close)
|
||||
- `MobilePacketHandler` wraps mobile platform callbacks
|
||||
- Bridges gomobile interface with netstack
|
||||
|
||||
### Netstack Controller (`netstack.go`)
|
||||
- Manages gVisor TCP/IP stack
|
||||
- Coordinates DNS Filter, IP Tracker, TCP/UDP Forwarders
|
||||
- Always creates IP Tracker (firewall always on)
|
||||
- Reads packets from TUN → injects into netstack
|
||||
- Writes packets from netstack → sends to TUN
|
||||
- Filters outbound packets (source = 10.0.0.x)
|
||||
- Blocks QUIC before injection into netstack
|
||||
|
||||
## Platform Configuration
|
||||
|
||||
### Android
|
||||
|
||||
```kotlin
|
||||
// Base VPN configuration (same for both modes)
|
||||
Builder()
|
||||
.addAddress("10.0.0.2", 24)
|
||||
.addDnsServer("10.0.0.1")
|
||||
.setMtu(1500)
|
||||
.setBlocking(true)
|
||||
.addDisallowedApplication(packageName) // Exclude self from VPN!
|
||||
|
||||
// Firewall mode - route ALL traffic
|
||||
if (isFirewallMode) {
|
||||
vpnBuilder.addRoute("0.0.0.0", 0)
|
||||
}
|
||||
// DNS-only mode - route ONLY DNS server IP
|
||||
else {
|
||||
vpnBuilder.addRoute("10.0.0.1", 32)
|
||||
}
|
||||
|
||||
vpnInterface = vpnBuilder.establish()
|
||||
|
||||
// DNS Proxy listens on: 0.0.0.0:5354
|
||||
// Library connects to: 127.0.0.1:5354
|
||||
```
|
||||
|
||||
**Important:**
|
||||
- App MUST exclude itself using `addDisallowedApplication()` to prevent routing loops
|
||||
- Mode switching: Call `setFirewallMode(enabled)` to recreate VPN interface with new routes
|
||||
|
||||
### iOS
|
||||
|
||||
```swift
|
||||
// Base configuration (same for both modes)
|
||||
let ipv4Settings = NEIPv4Settings(
|
||||
addresses: ["10.0.0.2"],
|
||||
subnetMasks: ["255.255.255.0"]
|
||||
)
|
||||
|
||||
// Firewall mode - route ALL traffic
|
||||
if isFirewallMode {
|
||||
ipv4Settings.includedRoutes = [NEIPv4Route.default()]
|
||||
}
|
||||
// DNS-only mode - route ONLY DNS server IP
|
||||
else {
|
||||
ipv4Settings.includedRoutes = [
|
||||
NEIPv4Route(destinationAddress: "10.0.0.1", subnetMask: "255.255.255.255")
|
||||
]
|
||||
}
|
||||
|
||||
let dnsSettings = NEDNSSettings(servers: ["10.0.0.1"])
|
||||
dnsSettings.matchDomains = [""]
|
||||
|
||||
networkSettings.ipv4Settings = ipv4Settings
|
||||
networkSettings.dnsSettings = dnsSettings
|
||||
networkSettings.mtu = 1500
|
||||
|
||||
setTunnelNetworkSettings(networkSettings)
|
||||
|
||||
// DNS Proxy listens on: 127.0.0.1:53
|
||||
// Library connects to: 127.0.0.1:53
|
||||
```
|
||||
|
||||
**Note:**
|
||||
- Network Extension sockets automatically bypass VPN - no routing loops
|
||||
- Mode switching: Send message `{"action": "set_firewall_mode", "enabled": "true"}` to extension
|
||||
|
||||
## Protocol Support
|
||||
|
||||
| Protocol | Support |
|
||||
|----------|---------|
|
||||
| DNS (UDP/TCP port 53) | ✅ Full |
|
||||
| TCP (all ports) | ✅ Full |
|
||||
| UDP (except 53, 80, 443) | ✅ Full |
|
||||
| QUIC (UDP/443, UDP/80) | 🚫 Blocked |
|
||||
| ICMP | ⚠️ Partial |
|
||||
| IPv4 | ✅ Full |
|
||||
| IPv6 | ✅ Full |
|
||||
|
||||
## QUIC Blocking
|
||||
|
||||
Blocks UDP packets on ports 443 and 80 to force TCP fallback.
|
||||
|
||||
**Where it's blocked:**
|
||||
- `netstack.go:354-369` - Blocks QUIC **before** injection into gVisor stack
|
||||
- Early blocking (pre-netstack) for efficiency
|
||||
- Checks destination port (UDP/443, UDP/80) in raw packet
|
||||
|
||||
**Why:**
|
||||
- QUIC/HTTP3 can use cached IPs, bypassing DNS filtering entirely
|
||||
- TCP/TLS provides visible SNI for content filtering
|
||||
- Ensures consistent ControlD policy enforcement
|
||||
- IP tracker alone isn't enough (apps cache QUIC IPs aggressively)
|
||||
|
||||
**Result:**
|
||||
- Apps automatically fallback to TCP/TLS (HTTP/2, HTTP/1.1)
|
||||
- No user-visible errors (fallback is seamless)
|
||||
- Slightly slower initial connection, then normal performance
|
||||
|
||||
**Note:** IP tracker ALSO blocks hardcoded IPs, but QUIC blocking provides additional layer of protection since QUIC apps often cache IPs longer than 5 minutes.
|
||||
|
||||
## IP Blocking (DNS Bypass Prevention)
|
||||
|
||||
**Firewall is ALWAYS enabled.** The IP tracker runs in all modes and tracks all DNS-resolved IPs.
|
||||
|
||||
**How it works:**
|
||||
1. DNS responses are parsed to extract A and AAAA records
|
||||
2. Resolved IPs are tracked in memory whitelist for 5 minutes (TTL)
|
||||
3. In **firewall mode**: TCP/UDP connections to **non-whitelisted** IPs are **BLOCKED**
|
||||
4. In **DNS-only mode**: Only DNS traffic reaches the VPN, so IP blocking is inactive
|
||||
|
||||
**Mode Behavior:**
|
||||
- **Firewall mode** (default routes): OS sends ALL traffic to VPN
|
||||
- DNS queries → tracked IPs
|
||||
- TCP/UDP connections → checked against tracker → blocked if not tracked
|
||||
|
||||
- **DNS-only mode** (DNS route only): OS sends ONLY DNS to VPN
|
||||
- DNS queries → tracked IPs
|
||||
- TCP/UDP connections → bypass VPN entirely (never reach tracker)
|
||||
|
||||
**Why IP tracker is always on:**
|
||||
- Simplifies implementation (no enable/disable logic)
|
||||
- Ready for mode switching at runtime
|
||||
- In DNS-only mode, tracker tracks IPs but never blocks (no TCP/UDP traffic)
|
||||
|
||||
**Example (Firewall Mode):**
|
||||
```
|
||||
T=0s: App connects to 1.2.3.4 directly
|
||||
→ ❌ BLOCKED (not in tracker)
|
||||
|
||||
T=1s: App queries "example.com" → DNS returns 93.184.216.34
|
||||
→ Tracker stores: 93.184.216.34 (expires in 5min)
|
||||
|
||||
T=2s: App connects to 93.184.216.34
|
||||
→ ✅ ALLOWED (found in tracker, not expired)
|
||||
|
||||
T=302s: App connects to 93.184.216.34
|
||||
→ ❌ BLOCKED (expired, must query DNS again)
|
||||
```
|
||||
|
||||
**Components:**
|
||||
- `ip_tracker.go` - Always-on whitelist with 5min TTL, 30s cleanup
|
||||
- `dns_filter.go` - Extracts A/AAAA records, tracks IPs automatically
|
||||
- `tcp_forwarder.go` - Checks `ipTracker != nil` (always true)
|
||||
- `udp_forwarder.go` - Checks `ipTracker != nil` (always true)
|
||||
|
||||
## Usage (Android)
|
||||
|
||||
```kotlin
|
||||
// Create callback
|
||||
val callback = object : PacketAppCallback {
|
||||
override fun readPacket(): ByteArray { ... }
|
||||
override fun writePacket(packet: ByteArray) { ... }
|
||||
override fun closePacketIO() { ... }
|
||||
override fun exit(s: String) { ... }
|
||||
override fun hostname(): String = "android-device"
|
||||
override fun lanIp(): String = "10.0.0.2"
|
||||
override fun macAddress(): String = "00:00:00:00:00:00"
|
||||
}
|
||||
|
||||
// Create controller
|
||||
val controller = Ctrld_library.newPacketCaptureController(callback)
|
||||
|
||||
// Start with all parameters
|
||||
controller.startWithPacketCapture(
|
||||
callback, // PacketAppCallback
|
||||
"10.0.0.1", // TUN address (gateway)
|
||||
"10.0.0.2", // Device address
|
||||
1500, // MTU
|
||||
"127.0.0.1:5354", // DNS proxy address
|
||||
"your-cd-uid", // ControlD UID
|
||||
"", // Provision ID (optional)
|
||||
"", // Custom hostname (optional)
|
||||
filesDir.absolutePath, // Home directory
|
||||
"doh", // Upstream protocol (doh/dot/doq)
|
||||
2, // Log level (0-3)
|
||||
"$filesDir/ctrld.log" // Log path
|
||||
)
|
||||
|
||||
// Stop
|
||||
controller.stop(false, 0)
|
||||
|
||||
// Runtime mode switching (no restart needed)
|
||||
VpnController.instance?.setFirewallMode(context, isFirewallMode = true)
|
||||
```
|
||||
|
||||
## Usage (iOS)
|
||||
|
||||
```swift
|
||||
// Start LocalProxy with all parameters
|
||||
let proxy = LocalProxy()
|
||||
proxy.mode = .firewall // or .dnsOnly
|
||||
|
||||
proxy.start(
|
||||
tunAddress: "10.0.0.1", // TUN address (gateway)
|
||||
deviceAddress: "10.0.0.2", // Device address
|
||||
mtu: 1500, // MTU
|
||||
dnsProxyAddress: "127.0.0.1:53", // DNS proxy address
|
||||
cUID: cdUID, // ControlD UID
|
||||
provisionID: "", // Provision ID (optional)
|
||||
customHostname: "", // Custom hostname (optional)
|
||||
homeDir: FileManager().temporaryDirectory.path, // Home directory
|
||||
upstreamProto: "doh", // Upstream protocol
|
||||
logLevel: 2, // Log level (0-3)
|
||||
logPath: FileManager().temporaryDirectory.appendingPathComponent("ctrld.log").path,
|
||||
deviceName: UIDevice.current.name, // Device name
|
||||
packetFlow: packetFlow // NEPacketTunnelFlow
|
||||
)
|
||||
|
||||
// Stop
|
||||
proxy.stop()
|
||||
|
||||
// Runtime mode switching (no restart needed)
|
||||
// Send message from main app to extension:
|
||||
let message = ["action": "set_firewall_mode", "enabled": "true"]
|
||||
session.sendProviderMessage(JSONEncoder().encode(message)) { response in }
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- **Android**: API 24+ (Android 7.0+)
|
||||
- **iOS**: iOS 12.0+
|
||||
- **Go**: 1.23+
|
||||
- **gVisor**: v0.0.0-20240722211153-64c016c92987
|
||||
|
||||
## Files
|
||||
|
||||
- `packet_handler.go` - TUN I/O interface
|
||||
- `netstack.go` - gVisor controller
|
||||
- `dns_filter.go` - DNS packet detection and IP extraction
|
||||
- `dns_bridge.go` - Transaction tracking
|
||||
- `ip_tracker.go` - DNS-resolved IP whitelist with TTL
|
||||
- `tcp_forwarder.go` - TCP forwarding with whitelist enforcement
|
||||
- `udp_forwarder.go` - UDP forwarding with whitelist enforcement
|
||||
|
||||
## License
|
||||
|
||||
Same as parent ctrld project.
|
||||
228
cmd/ctrld_library/netstack/dns_bridge.go
Normal file
228
cmd/ctrld_library/netstack/dns_bridge.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// DNSBridge provides a bridge between the netstack DNS filter and the existing ctrld DNS proxy.
|
||||
// It allows DNS queries captured from packets to be processed by the same logic as traditional DNS queries.
|
||||
type DNSBridge struct {
|
||||
// Channel for sending DNS queries
|
||||
queryCh chan *DNSQuery
|
||||
|
||||
// Channel for receiving DNS responses
|
||||
responseCh chan *DNSResponse
|
||||
|
||||
// Map to track pending queries by transaction ID
|
||||
pendingQueries map[uint16]*PendingQuery
|
||||
mu sync.RWMutex
|
||||
|
||||
// Timeout for DNS queries
|
||||
queryTimeout time.Duration
|
||||
|
||||
// Running state
|
||||
running bool
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// DNSQuery represents a DNS query to be processed
|
||||
type DNSQuery struct {
|
||||
ID uint16 // Transaction ID for matching response
|
||||
Query []byte // Raw DNS query bytes
|
||||
RespCh chan []byte // Response channel
|
||||
SrcIP string // Source IP for logging
|
||||
SrcPort uint16 // Source port
|
||||
}
|
||||
|
||||
// DNSResponse represents a DNS response
|
||||
type DNSResponse struct {
|
||||
ID uint16
|
||||
Response []byte
|
||||
}
|
||||
|
||||
// PendingQuery tracks a query waiting for response
|
||||
type PendingQuery struct {
|
||||
Query *DNSQuery
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// NewDNSBridge creates a new DNS bridge
|
||||
func NewDNSBridge() *DNSBridge {
|
||||
return &DNSBridge{
|
||||
queryCh: make(chan *DNSQuery, 100),
|
||||
responseCh: make(chan *DNSResponse, 100),
|
||||
pendingQueries: make(map[uint16]*PendingQuery),
|
||||
queryTimeout: 5 * time.Second,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the DNS bridge
|
||||
func (b *DNSBridge) Start() {
|
||||
b.mu.Lock()
|
||||
if b.running {
|
||||
b.mu.Unlock()
|
||||
return
|
||||
}
|
||||
b.running = true
|
||||
b.mu.Unlock()
|
||||
|
||||
// Start response handler
|
||||
b.wg.Add(1)
|
||||
go b.handleResponses()
|
||||
|
||||
// Start timeout checker
|
||||
b.wg.Add(1)
|
||||
go b.checkTimeouts()
|
||||
}
|
||||
|
||||
// Stop stops the DNS bridge
|
||||
func (b *DNSBridge) Stop() {
|
||||
b.mu.Lock()
|
||||
if !b.running {
|
||||
b.mu.Unlock()
|
||||
return
|
||||
}
|
||||
b.running = false
|
||||
b.mu.Unlock()
|
||||
|
||||
close(b.stopCh)
|
||||
b.wg.Wait()
|
||||
|
||||
// Clean up pending queries
|
||||
b.mu.Lock()
|
||||
for _, pending := range b.pendingQueries {
|
||||
close(pending.Query.RespCh)
|
||||
}
|
||||
b.pendingQueries = make(map[uint16]*PendingQuery)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
// ProcessQuery processes a DNS query and waits for response
|
||||
func (b *DNSBridge) ProcessQuery(query []byte, srcIP string, srcPort uint16) ([]byte, error) {
|
||||
if len(query) < 12 {
|
||||
return nil, fmt.Errorf("invalid DNS query: too short")
|
||||
}
|
||||
|
||||
// Parse DNS message to get transaction ID
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(query); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse DNS query: %v", err)
|
||||
}
|
||||
|
||||
// Create response channel
|
||||
respCh := make(chan []byte, 1)
|
||||
|
||||
// Create query
|
||||
dnsQuery := &DNSQuery{
|
||||
ID: msg.Id,
|
||||
Query: query,
|
||||
RespCh: respCh,
|
||||
SrcIP: srcIP,
|
||||
SrcPort: srcPort,
|
||||
}
|
||||
|
||||
// Store as pending
|
||||
b.mu.Lock()
|
||||
b.pendingQueries[msg.Id] = &PendingQuery{
|
||||
Query: dnsQuery,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
// Send query
|
||||
select {
|
||||
case b.queryCh <- dnsQuery:
|
||||
case <-time.After(time.Second):
|
||||
b.mu.Lock()
|
||||
delete(b.pendingQueries, msg.Id)
|
||||
b.mu.Unlock()
|
||||
return nil, fmt.Errorf("query channel full")
|
||||
}
|
||||
|
||||
// Wait for response with timeout
|
||||
select {
|
||||
case response := <-respCh:
|
||||
b.mu.Lock()
|
||||
delete(b.pendingQueries, msg.Id)
|
||||
b.mu.Unlock()
|
||||
return response, nil
|
||||
|
||||
case <-time.After(b.queryTimeout):
|
||||
b.mu.Lock()
|
||||
delete(b.pendingQueries, msg.Id)
|
||||
b.mu.Unlock()
|
||||
return nil, fmt.Errorf("DNS query timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// GetQueryChannel returns the channel for receiving DNS queries
|
||||
func (b *DNSBridge) GetQueryChannel() <-chan *DNSQuery {
|
||||
return b.queryCh
|
||||
}
|
||||
|
||||
// SendResponse sends a DNS response back to the waiting query
|
||||
func (b *DNSBridge) SendResponse(id uint16, response []byte) error {
|
||||
b.mu.RLock()
|
||||
pending, exists := b.pendingQueries[id]
|
||||
b.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("no pending query for ID %d", id)
|
||||
}
|
||||
|
||||
select {
|
||||
case pending.Query.RespCh <- response:
|
||||
return nil
|
||||
case <-time.After(time.Second):
|
||||
return fmt.Errorf("failed to send response: channel blocked")
|
||||
}
|
||||
}
|
||||
|
||||
// handleResponses handles incoming responses
|
||||
func (b *DNSBridge) handleResponses() {
|
||||
defer b.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.stopCh:
|
||||
return
|
||||
|
||||
case resp := <-b.responseCh:
|
||||
if err := b.SendResponse(resp.ID, resp.Response); err != nil {
|
||||
// Log error but continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkTimeouts periodically checks for and removes timed out queries
|
||||
func (b *DNSBridge) checkTimeouts() {
|
||||
defer b.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.stopCh:
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
b.mu.Lock()
|
||||
for id, pending := range b.pendingQueries {
|
||||
if now.Sub(pending.Timestamp) > b.queryTimeout {
|
||||
close(pending.Query.RespCh)
|
||||
delete(b.pendingQueries, id)
|
||||
}
|
||||
}
|
||||
b.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
365
cmd/ctrld_library/netstack/dns_filter.go
Normal file
365
cmd/ctrld_library/netstack/dns_filter.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
)
|
||||
|
||||
// DNSFilter intercepts and processes DNS packets.
|
||||
type DNSFilter struct {
|
||||
dnsHandler func([]byte) ([]byte, error)
|
||||
ipTracker *IPTracker
|
||||
}
|
||||
|
||||
// NewDNSFilter creates a new DNS filter with the given handler.
|
||||
func NewDNSFilter(handler func([]byte) ([]byte, error), ipTracker *IPTracker) *DNSFilter {
|
||||
return &DNSFilter{
|
||||
dnsHandler: handler,
|
||||
ipTracker: ipTracker,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessPacket checks if a packet is a DNS query and processes it.
|
||||
// Returns:
|
||||
// - isDNS: true if this is a DNS packet
|
||||
// - response: DNS response packet (if handled), nil otherwise
|
||||
// - error: any error that occurred
|
||||
func (df *DNSFilter) ProcessPacket(packet []byte) (isDNS bool, response []byte, err error) {
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Parse IP version
|
||||
ipVersion := packet[0] >> 4
|
||||
|
||||
switch ipVersion {
|
||||
case 4:
|
||||
return df.processIPv4(packet)
|
||||
case 6:
|
||||
return df.processIPv6(packet)
|
||||
default:
|
||||
return false, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// processIPv4 processes an IPv4 packet and checks if it's DNS.
|
||||
func (df *DNSFilter) processIPv4(packet []byte) (bool, []byte, error) {
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Parse IPv4 header
|
||||
ipHdr := header.IPv4(packet)
|
||||
if !ipHdr.IsValid(len(packet)) {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Check if it's UDP
|
||||
if ipHdr.TransportProtocol() != header.UDPProtocolNumber {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Get IP header length
|
||||
ihl := int(ipHdr.HeaderLength())
|
||||
if len(packet) < ihl+header.UDPMinimumSize {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Parse UDP header
|
||||
udpHdr := header.UDP(packet[ihl:])
|
||||
srcPort := udpHdr.SourcePort()
|
||||
dstPort := udpHdr.DestinationPort()
|
||||
|
||||
// Check if destination port is 53 (DNS)
|
||||
if dstPort != 53 {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
srcIP := ipHdr.SourceAddress()
|
||||
dstIP := ipHdr.DestinationAddress()
|
||||
|
||||
// Extract DNS payload
|
||||
udpPayloadOffset := ihl + header.UDPMinimumSize
|
||||
if len(packet) <= udpPayloadOffset {
|
||||
return true, nil, fmt.Errorf("invalid UDP packet length")
|
||||
}
|
||||
|
||||
dnsQuery := packet[udpPayloadOffset:]
|
||||
if len(dnsQuery) == 0 {
|
||||
return true, nil, fmt.Errorf("empty DNS query")
|
||||
}
|
||||
|
||||
// Process DNS query
|
||||
if df.dnsHandler == nil {
|
||||
return true, nil, fmt.Errorf("no DNS handler configured")
|
||||
}
|
||||
|
||||
dnsResponse, err := df.dnsHandler(dnsQuery)
|
||||
if err != nil {
|
||||
return true, nil, fmt.Errorf("DNS handler error: %v", err)
|
||||
}
|
||||
|
||||
// Track IPs from DNS response
|
||||
if df.ipTracker != nil {
|
||||
df.extractAndTrackIPs(dnsResponse)
|
||||
}
|
||||
|
||||
// Build response packet
|
||||
responsePacket := df.buildIPv4UDPPacket(
|
||||
dstIP.As4(), // Swap src/dst
|
||||
srcIP.As4(),
|
||||
dstPort, // Swap ports
|
||||
srcPort,
|
||||
dnsResponse,
|
||||
)
|
||||
|
||||
return true, responsePacket, nil
|
||||
}
|
||||
|
||||
// processIPv6 processes an IPv6 packet and checks if it's DNS.
|
||||
func (df *DNSFilter) processIPv6(packet []byte) (bool, []byte, error) {
|
||||
if len(packet) < header.IPv6MinimumSize {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Parse IPv6 header
|
||||
ipHdr := header.IPv6(packet)
|
||||
if !ipHdr.IsValid(len(packet)) {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Check if it's UDP
|
||||
if ipHdr.TransportProtocol() != header.UDPProtocolNumber {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// IPv6 header is fixed size
|
||||
if len(packet) < header.IPv6MinimumSize+header.UDPMinimumSize {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Parse UDP header
|
||||
udpHdr := header.UDP(packet[header.IPv6MinimumSize:])
|
||||
srcPort := udpHdr.SourcePort()
|
||||
dstPort := udpHdr.DestinationPort()
|
||||
|
||||
// Check if destination port is 53 (DNS)
|
||||
if dstPort != 53 {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
// Extract DNS payload
|
||||
udpPayloadOffset := header.IPv6MinimumSize + header.UDPMinimumSize
|
||||
if len(packet) <= udpPayloadOffset {
|
||||
return true, nil, fmt.Errorf("invalid UDP packet length")
|
||||
}
|
||||
|
||||
dnsQuery := packet[udpPayloadOffset:]
|
||||
if len(dnsQuery) == 0 {
|
||||
return true, nil, fmt.Errorf("empty DNS query")
|
||||
}
|
||||
|
||||
// Process DNS query
|
||||
if df.dnsHandler == nil {
|
||||
return true, nil, fmt.Errorf("no DNS handler configured")
|
||||
}
|
||||
|
||||
dnsResponse, err := df.dnsHandler(dnsQuery)
|
||||
if err != nil {
|
||||
return true, nil, fmt.Errorf("DNS handler error: %v", err)
|
||||
}
|
||||
|
||||
// Track IPs from DNS response
|
||||
if df.ipTracker != nil {
|
||||
df.extractAndTrackIPs(dnsResponse)
|
||||
}
|
||||
|
||||
// Build response packet
|
||||
srcIP := ipHdr.SourceAddress()
|
||||
dstIP := ipHdr.DestinationAddress()
|
||||
|
||||
responsePacket := df.buildIPv6UDPPacket(
|
||||
dstIP.As16(), // Swap src/dst
|
||||
srcIP.As16(),
|
||||
dstPort, // Swap ports
|
||||
srcPort,
|
||||
dnsResponse,
|
||||
)
|
||||
|
||||
return true, responsePacket, nil
|
||||
}
|
||||
|
||||
// buildIPv4UDPPacket builds a complete IPv4/UDP packet with the given payload.
|
||||
func (df *DNSFilter) buildIPv4UDPPacket(srcIP, dstIP [4]byte, srcPort, dstPort uint16, payload []byte) []byte {
|
||||
// Calculate lengths
|
||||
udpLen := header.UDPMinimumSize + len(payload)
|
||||
ipLen := header.IPv4MinimumSize + udpLen
|
||||
packet := make([]byte, ipLen)
|
||||
|
||||
// Build IPv4 header
|
||||
ipHdr := header.IPv4(packet)
|
||||
ipHdr.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(ipLen),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.UDPProtocolNumber),
|
||||
SrcAddr: tcpip.AddrFrom4(srcIP),
|
||||
DstAddr: tcpip.AddrFrom4(dstIP),
|
||||
})
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
|
||||
// Build UDP header
|
||||
udpHdr := header.UDP(packet[header.IPv4MinimumSize:])
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
Length: uint16(udpLen),
|
||||
})
|
||||
|
||||
// Copy payload
|
||||
copy(packet[header.IPv4MinimumSize+header.UDPMinimumSize:], payload)
|
||||
|
||||
// Calculate UDP checksum
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.UDPProtocolNumber,
|
||||
tcpip.AddrFrom4(srcIP),
|
||||
tcpip.AddrFrom4(dstIP),
|
||||
uint16(udpLen),
|
||||
)
|
||||
xsum = checksum(payload, xsum)
|
||||
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
|
||||
|
||||
return packet
|
||||
}
|
||||
|
||||
// buildIPv6UDPPacket builds a complete IPv6/UDP packet with the given payload.
|
||||
func (df *DNSFilter) buildIPv6UDPPacket(srcIP, dstIP [16]byte, srcPort, dstPort uint16, payload []byte) []byte {
|
||||
// Calculate lengths
|
||||
udpLen := header.UDPMinimumSize + len(payload)
|
||||
ipLen := header.IPv6MinimumSize + udpLen
|
||||
packet := make([]byte, ipLen)
|
||||
|
||||
// Build IPv6 header
|
||||
ipHdr := header.IPv6(packet)
|
||||
ipHdr.Encode(&header.IPv6Fields{
|
||||
PayloadLength: uint16(udpLen),
|
||||
TransportProtocol: header.UDPProtocolNumber,
|
||||
HopLimit: 64,
|
||||
SrcAddr: tcpip.AddrFrom16(srcIP),
|
||||
DstAddr: tcpip.AddrFrom16(dstIP),
|
||||
})
|
||||
|
||||
// Build UDP header
|
||||
udpHdr := header.UDP(packet[header.IPv6MinimumSize:])
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
Length: uint16(udpLen),
|
||||
})
|
||||
|
||||
// Copy payload
|
||||
copy(packet[header.IPv6MinimumSize+header.UDPMinimumSize:], payload)
|
||||
|
||||
// Calculate UDP checksum
|
||||
xsum := header.PseudoHeaderChecksum(
|
||||
header.UDPProtocolNumber,
|
||||
tcpip.AddrFrom16(srcIP),
|
||||
tcpip.AddrFrom16(dstIP),
|
||||
uint16(udpLen),
|
||||
)
|
||||
xsum = checksum(payload, xsum)
|
||||
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
|
||||
|
||||
return packet
|
||||
}
|
||||
|
||||
// checksum calculates the checksum for the given data.
|
||||
func checksum(buf []byte, initial uint16) uint16 {
|
||||
v := uint32(initial)
|
||||
l := len(buf)
|
||||
if l&1 != 0 {
|
||||
l--
|
||||
v += uint32(buf[l]) << 8
|
||||
}
|
||||
for i := 0; i < l; i += 2 {
|
||||
v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
|
||||
}
|
||||
return reduceChecksum(v)
|
||||
}
|
||||
|
||||
// reduceChecksum reduces a 32-bit checksum to 16 bits.
|
||||
func reduceChecksum(v uint32) uint16 {
|
||||
v = (v >> 16) + (v & 0xffff)
|
||||
v = (v >> 16) + (v & 0xffff)
|
||||
return uint16(v)
|
||||
}
|
||||
|
||||
// IPv4Address is a helper to create an IPv4 address from a byte array.
|
||||
func IPv4Address(b [4]byte) net.IP {
|
||||
return net.IPv4(b[0], b[1], b[2], b[3])
|
||||
}
|
||||
|
||||
// IPv6Address is a helper to create an IPv6 address from a byte array.
|
||||
func IPv6Address(b [16]byte) net.IP {
|
||||
return net.IP(b[:])
|
||||
}
|
||||
|
||||
// parseIPv4 extracts source and destination IPs from an IPv4 packet.
|
||||
func parseIPv4(packet []byte) (srcIP, dstIP [4]byte, ok bool) {
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
return
|
||||
}
|
||||
ipHdr := header.IPv4(packet)
|
||||
if !ipHdr.IsValid(len(packet)) {
|
||||
return
|
||||
}
|
||||
srcAddr := ipHdr.SourceAddress().As4()
|
||||
dstAddr := ipHdr.DestinationAddress().As4()
|
||||
copy(srcIP[:], srcAddr[:])
|
||||
copy(dstIP[:], dstAddr[:])
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
// parseUDP extracts UDP header information.
|
||||
func parseUDP(udpHeader []byte) (srcPort, dstPort uint16, ok bool) {
|
||||
if len(udpHeader) < header.UDPMinimumSize {
|
||||
return
|
||||
}
|
||||
srcPort = binary.BigEndian.Uint16(udpHeader[0:2])
|
||||
dstPort = binary.BigEndian.Uint16(udpHeader[2:4])
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
// extractAndTrackIPs parses DNS response and tracks resolved IP addresses
|
||||
func (df *DNSFilter) extractAndTrackIPs(dnsResponse []byte) {
|
||||
if len(dnsResponse) < 12 {
|
||||
return // Invalid DNS response
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(dnsResponse); err != nil {
|
||||
return // Failed to parse DNS response
|
||||
}
|
||||
|
||||
// Extract IPs from answer section
|
||||
for _, answer := range msg.Answer {
|
||||
switch rr := answer.(type) {
|
||||
case *dns.A:
|
||||
// IPv4 address
|
||||
if rr.A != nil {
|
||||
df.ipTracker.TrackIP(rr.A)
|
||||
}
|
||||
case *dns.AAAA:
|
||||
// IPv6 address
|
||||
if rr.AAAA != nil {
|
||||
df.ipTracker.TrackIP(rr.AAAA)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
150
cmd/ctrld_library/netstack/ip_tracker.go
Normal file
150
cmd/ctrld_library/netstack/ip_tracker.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IPTracker tracks IP addresses that have been resolved through DNS.
|
||||
// This allows blocking direct IP connections that bypass DNS filtering.
|
||||
type IPTracker struct {
|
||||
// Map of IP address string -> expiration time
|
||||
resolvedIPs map[string]time.Time
|
||||
mu sync.RWMutex
|
||||
|
||||
// TTL for tracked IPs (how long to remember them)
|
||||
ttl time.Duration
|
||||
|
||||
// Running state
|
||||
running bool
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewIPTracker creates a new IP tracker with the specified TTL
|
||||
func NewIPTracker(ttl time.Duration) *IPTracker {
|
||||
if ttl == 0 {
|
||||
ttl = 5 * time.Minute // Default 5 minutes
|
||||
}
|
||||
|
||||
return &IPTracker{
|
||||
resolvedIPs: make(map[string]time.Time),
|
||||
ttl: ttl,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the IP tracker cleanup routine
|
||||
func (t *IPTracker) Start() {
|
||||
t.mu.Lock()
|
||||
if t.running {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
t.running = true
|
||||
t.mu.Unlock()
|
||||
|
||||
// Start cleanup goroutine to remove expired IPs
|
||||
t.wg.Add(1)
|
||||
go t.cleanupExpiredIPs()
|
||||
}
|
||||
|
||||
// Stop stops the IP tracker
|
||||
func (t *IPTracker) Stop() {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
if !t.running {
|
||||
t.mu.Unlock()
|
||||
return
|
||||
}
|
||||
t.running = false
|
||||
t.mu.Unlock()
|
||||
|
||||
// Close stop channel (protected against double close)
|
||||
select {
|
||||
case <-t.stopCh:
|
||||
// Already closed
|
||||
default:
|
||||
close(t.stopCh)
|
||||
}
|
||||
|
||||
t.wg.Wait()
|
||||
|
||||
// Clear all tracked IPs
|
||||
t.mu.Lock()
|
||||
t.resolvedIPs = make(map[string]time.Time)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// TrackIP adds an IP address to the tracking list
|
||||
func (t *IPTracker) TrackIP(ip net.IP) {
|
||||
if ip == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Normalize to string format
|
||||
ipStr := ip.String()
|
||||
|
||||
t.mu.Lock()
|
||||
t.resolvedIPs[ipStr] = time.Now().Add(t.ttl)
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// IsTracked checks if an IP address is in the tracking list
|
||||
// Optimized to minimize lock contention by avoiding write locks in the hot path
|
||||
func (t *IPTracker) IsTracked(ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ipStr := ip.String()
|
||||
|
||||
t.mu.RLock()
|
||||
expiration, exists := t.resolvedIPs[ipStr]
|
||||
t.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if expired - but DON'T delete here to avoid write lock
|
||||
// Let the cleanup goroutine handle expired entries
|
||||
// This keeps IsTracked fast with only read locks
|
||||
return !time.Now().After(expiration)
|
||||
}
|
||||
|
||||
// GetTrackedCount returns the number of currently tracked IPs
|
||||
func (t *IPTracker) GetTrackedCount() int {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return len(t.resolvedIPs)
|
||||
}
|
||||
|
||||
// cleanupExpiredIPs periodically removes expired IP entries
|
||||
func (t *IPTracker) cleanupExpiredIPs() {
|
||||
defer t.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.stopCh:
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
t.mu.Lock()
|
||||
for ip, expiration := range t.resolvedIPs {
|
||||
if now.After(expiration) {
|
||||
delete(t.resolvedIPs, ip)
|
||||
}
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
417
cmd/ctrld_library/netstack/netstack.go
Normal file
417
cmd/ctrld_library/netstack/netstack.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
// Default MTU for the TUN interface
|
||||
defaultMTU = 1500
|
||||
|
||||
// NICID is the ID of the network interface
|
||||
NICID = 1
|
||||
|
||||
// Channel capacity for packet buffers
|
||||
channelCapacity = 512
|
||||
)
|
||||
|
||||
// NetstackController manages the gVisor netstack integration for mobile packet capture.
|
||||
type NetstackController struct {
|
||||
stack *stack.Stack
|
||||
linkEP *channel.Endpoint
|
||||
packetHandler PacketHandler
|
||||
dnsFilter *DNSFilter
|
||||
ipTracker *IPTracker
|
||||
tcpForwarder *TCPForwarder
|
||||
udpForwarder *UDPForwarder
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
started bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Config holds configuration for NetstackController.
|
||||
type Config struct {
|
||||
// MTU is the maximum transmission unit
|
||||
MTU uint32
|
||||
|
||||
// TUNIPv4 is the IPv4 address assigned to the TUN interface
|
||||
TUNIPv4 netip.Addr
|
||||
|
||||
// TUNIPv6 is the IPv6 address assigned to the TUN interface (optional)
|
||||
TUNIPv6 netip.Addr
|
||||
|
||||
// DNSHandler is the function to process DNS queries
|
||||
DNSHandler func([]byte) ([]byte, error)
|
||||
|
||||
// UpstreamInterface is the real network interface for routing non-DNS traffic
|
||||
UpstreamInterface *net.Interface
|
||||
}
|
||||
|
||||
// NewNetstackController creates a new netstack controller.
|
||||
func NewNetstackController(handler PacketHandler, cfg *Config) (*NetstackController, error) {
|
||||
if handler == nil {
|
||||
return nil, fmt.Errorf("packet handler cannot be nil")
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
cfg = &Config{
|
||||
MTU: defaultMTU,
|
||||
TUNIPv4: netip.MustParseAddr("10.0.0.1"),
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.MTU == 0 {
|
||||
cfg.MTU = defaultMTU
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create gVisor stack
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
ipv4.NewProtocol,
|
||||
ipv6.NewProtocol,
|
||||
},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
udp.NewProtocol,
|
||||
},
|
||||
})
|
||||
|
||||
// Create link endpoint
|
||||
linkEP := channel.New(channelCapacity, cfg.MTU, "")
|
||||
|
||||
// Always create IP tracker (5 minute TTL for tracked IPs)
|
||||
// In firewall mode (default routes): blocks direct IP connections
|
||||
// In DNS-only mode: no non-DNS traffic to block
|
||||
ipTracker := NewIPTracker(5 * time.Minute)
|
||||
|
||||
// Create DNS filter with IP tracker
|
||||
dnsFilter := NewDNSFilter(cfg.DNSHandler, ipTracker)
|
||||
|
||||
// Create TCP forwarder with IP tracker
|
||||
tcpForwarder := NewTCPForwarder(s, ctx, ipTracker)
|
||||
|
||||
// Create UDP forwarder with IP tracker
|
||||
udpForwarder := NewUDPForwarder(s, ctx, ipTracker)
|
||||
|
||||
// Create NIC
|
||||
if err := s.CreateNIC(NICID, linkEP); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
// Enable spoofing to allow packets with any source IP
|
||||
if err := s.SetSpoofing(NICID, true); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to enable spoofing: %v", err)
|
||||
}
|
||||
|
||||
// Enable promiscuous mode to accept all packets
|
||||
if err := s.SetPromiscuousMode(NICID, true); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to enable promiscuous mode: %v", err)
|
||||
}
|
||||
|
||||
// Add IPv4 address
|
||||
protocolAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(cfg.TUNIPv4.AsSlice()),
|
||||
PrefixLen: 24,
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(NICID, protocolAddr, stack.AddressProperties{}); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to add IPv4 address: %v", err)
|
||||
}
|
||||
|
||||
// Add IPv6 address if provided
|
||||
if cfg.TUNIPv6.IsValid() {
|
||||
protocolAddr6 := tcpip.ProtocolAddress{
|
||||
Protocol: ipv6.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(cfg.TUNIPv6.AsSlice()),
|
||||
PrefixLen: 64,
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(NICID, protocolAddr6, stack.AddressProperties{}); err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("failed to add IPv6 address: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add default routes
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: NICID,
|
||||
},
|
||||
{
|
||||
Destination: header.IPv6EmptySubnet,
|
||||
NIC: NICID,
|
||||
},
|
||||
})
|
||||
|
||||
// Register forwarders with the stack
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.forwarder.HandlePacket)
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.forwarder.HandlePacket)
|
||||
|
||||
nc := &NetstackController{
|
||||
stack: s,
|
||||
linkEP: linkEP,
|
||||
packetHandler: handler,
|
||||
dnsFilter: dnsFilter,
|
||||
ipTracker: ipTracker,
|
||||
tcpForwarder: tcpForwarder,
|
||||
udpForwarder: udpForwarder,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
started: false,
|
||||
}
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Controller created with TCP/UDP forwarders")
|
||||
|
||||
return nc, nil
|
||||
}
|
||||
|
||||
// Start starts the netstack controller and begins processing packets.
|
||||
func (nc *NetstackController) Start() error {
|
||||
nc.mu.Lock()
|
||||
defer nc.mu.Unlock()
|
||||
|
||||
if nc.started {
|
||||
return fmt.Errorf("netstack controller already started")
|
||||
}
|
||||
|
||||
nc.started = true
|
||||
|
||||
// Start IP tracker
|
||||
nc.ipTracker.Start()
|
||||
|
||||
// Start packet reader goroutine (TUN -> netstack)
|
||||
nc.wg.Add(1)
|
||||
go nc.readPackets()
|
||||
|
||||
// Start packet writer goroutine (netstack -> TUN)
|
||||
nc.wg.Add(1)
|
||||
go nc.writePackets()
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Packet processing started (read/write goroutines + IP tracker)")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the netstack controller and waits for all goroutines to finish.
|
||||
func (nc *NetstackController) Stop() error {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() called - starting shutdown")
|
||||
|
||||
nc.mu.Lock()
|
||||
if !nc.started {
|
||||
nc.mu.Unlock()
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - already stopped, returning")
|
||||
return nil
|
||||
}
|
||||
nc.mu.Unlock()
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - canceling context")
|
||||
nc.cancel()
|
||||
|
||||
// Close packet handler FIRST to unblock all pending reads
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - closing packet handler to unblock goroutines")
|
||||
if err := nc.packetHandler.Close(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Msgf("[Netstack] Stop() - failed to close packet handler: %v", err)
|
||||
// Continue shutdown even if close fails
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - packet handler closed")
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - waiting for goroutines (max 2 seconds)")
|
||||
|
||||
// Wait for goroutines with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
nc.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - all goroutines finished")
|
||||
case <-time.After(2 * time.Second):
|
||||
ctrld.ProxyLogger.Load().Warn().Msg("[Netstack] Stop() - timeout waiting for goroutines, proceeding anyway")
|
||||
}
|
||||
|
||||
// Stop IP tracker
|
||||
if nc.ipTracker != nil {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - stopping IP tracker")
|
||||
nc.ipTracker.Stop()
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - IP tracker stopped")
|
||||
}
|
||||
|
||||
// Close UDP forwarder
|
||||
if nc.udpForwarder != nil {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - closing UDP forwarder")
|
||||
nc.udpForwarder.Close()
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - UDP forwarder closed")
|
||||
}
|
||||
|
||||
nc.mu.Lock()
|
||||
nc.started = false
|
||||
nc.mu.Unlock()
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[Netstack] Stop() - shutdown complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// readPackets reads packets from the TUN interface and injects them into the netstack.
|
||||
func (nc *NetstackController) readPackets() {
|
||||
defer nc.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-nc.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read packet from TUN
|
||||
packet, err := nc.packetHandler.ReadPacket()
|
||||
if err != nil {
|
||||
if nc.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(packet) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is a DNS packet
|
||||
isDNS, response, err := nc.dnsFilter.ProcessPacket(packet)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if isDNS && response != nil {
|
||||
// DNS packet was handled, send response back to TUN
|
||||
nc.packetHandler.WritePacket(response)
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("[Netstack] DNS response sent (%d bytes)", len(response))
|
||||
continue
|
||||
}
|
||||
|
||||
if isDNS {
|
||||
continue
|
||||
}
|
||||
|
||||
// Not a DNS packet - check if it's an OUTBOUND packet (source = 10.0.0.x)
|
||||
// We should ONLY inject outbound packets, not return packets
|
||||
if len(packet) >= 20 {
|
||||
// Check if source is in our VPN subnet (10.0.0.x)
|
||||
isOutbound := packet[12] == 10 && packet[13] == 0 && packet[14] == 0
|
||||
|
||||
if !isOutbound {
|
||||
// This is a return packet (server -> mobile)
|
||||
// Drop it - return packets come through forwarder's upstream connection
|
||||
continue
|
||||
}
|
||||
|
||||
// Block QUIC protocol (UDP on port 443)
|
||||
// QUIC runs over UDP and bypasses DNS, so we block it to force HTTP/2 or HTTP/3 over TCP
|
||||
protocol := packet[9]
|
||||
if protocol == 17 { // UDP
|
||||
// Get IP header length
|
||||
ihl := int(packet[0]&0x0f) * 4
|
||||
if len(packet) >= ihl+4 {
|
||||
// Parse UDP destination port (bytes 2-3 of UDP header)
|
||||
dstPort := uint16(packet[ihl+2])<<8 | uint16(packet[ihl+3])
|
||||
if dstPort == 443 || dstPort == 80 {
|
||||
// Block QUIC (UDP/443) and HTTP/3 (UDP/80)
|
||||
// Apps will fallback to TCP automatically
|
||||
dstIP := net.IPv4(packet[16], packet[17], packet[18], packet[19])
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("[Netstack] Blocked QUIC packet to %s:%d", dstIP, dstPort)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create packet buffer
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
|
||||
// Determine protocol number
|
||||
var proto tcpip.NetworkProtocolNumber
|
||||
if len(packet) > 0 {
|
||||
version := packet[0] >> 4
|
||||
switch version {
|
||||
case 4:
|
||||
proto = header.IPv4ProtocolNumber
|
||||
case 6:
|
||||
proto = header.IPv6ProtocolNumber
|
||||
default:
|
||||
pkt.DecRef()
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
pkt.DecRef()
|
||||
continue
|
||||
}
|
||||
|
||||
// Inject into netstack - TCP/UDP forwarders will handle it
|
||||
nc.linkEP.InjectInbound(proto, pkt)
|
||||
}
|
||||
}
|
||||
|
||||
// writePackets reads packets from netstack and writes them to the TUN interface.
|
||||
func (nc *NetstackController) writePackets() {
|
||||
defer nc.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-nc.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read packet from netstack
|
||||
pkt := nc.linkEP.ReadContext(nc.ctx)
|
||||
if pkt == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert packet to bytes
|
||||
vv := pkt.ToView()
|
||||
packet := vv.AsSlice()
|
||||
|
||||
// Write to TUN
|
||||
if err := nc.packetHandler.WritePacket(packet); err != nil {
|
||||
// Log error
|
||||
continue
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
97
cmd/ctrld_library/netstack/packet_handler.go
Normal file
97
cmd/ctrld_library/netstack/packet_handler.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PacketHandler defines the interface for reading and writing raw IP packets
|
||||
// from/to the mobile TUN interface.
|
||||
type PacketHandler interface {
|
||||
// ReadPacket reads a raw IP packet from the TUN interface.
|
||||
// This should be a blocking call.
|
||||
ReadPacket() ([]byte, error)
|
||||
|
||||
// WritePacket writes a raw IP packet back to the TUN interface.
|
||||
WritePacket(packet []byte) error
|
||||
|
||||
// Close closes the packet handler and releases resources.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// MobilePacketHandler implements PacketHandler using callbacks from mobile platforms.
|
||||
// This bridges Go Mobile interface with the netstack implementation.
|
||||
type MobilePacketHandler struct {
|
||||
readFunc func() ([]byte, error)
|
||||
writeFunc func([]byte) error
|
||||
closeFunc func() error
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewMobilePacketHandler creates a new packet handler with mobile callbacks.
|
||||
func NewMobilePacketHandler(
|
||||
readFunc func() ([]byte, error),
|
||||
writeFunc func([]byte) error,
|
||||
closeFunc func() error,
|
||||
) *MobilePacketHandler {
|
||||
return &MobilePacketHandler{
|
||||
readFunc: readFunc,
|
||||
writeFunc: writeFunc,
|
||||
closeFunc: closeFunc,
|
||||
closed: false,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadPacket reads a packet from mobile TUN interface.
|
||||
func (m *MobilePacketHandler) ReadPacket() ([]byte, error) {
|
||||
m.mu.Lock()
|
||||
closed := m.closed
|
||||
m.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return nil, fmt.Errorf("packet handler is closed")
|
||||
}
|
||||
|
||||
if m.readFunc == nil {
|
||||
return nil, fmt.Errorf("read function not set")
|
||||
}
|
||||
|
||||
return m.readFunc()
|
||||
}
|
||||
|
||||
// WritePacket writes a packet back to mobile TUN interface.
|
||||
func (m *MobilePacketHandler) WritePacket(packet []byte) error {
|
||||
m.mu.Lock()
|
||||
closed := m.closed
|
||||
m.mu.Unlock()
|
||||
|
||||
if closed {
|
||||
return fmt.Errorf("packet handler is closed")
|
||||
}
|
||||
|
||||
if m.writeFunc == nil {
|
||||
return fmt.Errorf("write function not set")
|
||||
}
|
||||
|
||||
return m.writeFunc(packet)
|
||||
}
|
||||
|
||||
// Close closes the packet handler.
|
||||
func (m *MobilePacketHandler) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.closed = true
|
||||
|
||||
if m.closeFunc != nil {
|
||||
return m.closeFunc()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
130
cmd/ctrld_library/netstack/tcp_forwarder.go
Normal file
130
cmd/ctrld_library/netstack/tcp_forwarder.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
// TCPForwarder handles TCP connections from the TUN interface
|
||||
type TCPForwarder struct {
|
||||
ctx context.Context
|
||||
forwarder *tcp.Forwarder
|
||||
ipTracker *IPTracker
|
||||
}
|
||||
|
||||
// NewTCPForwarder creates a new TCP forwarder
|
||||
func NewTCPForwarder(s *stack.Stack, ctx context.Context, ipTracker *IPTracker) *TCPForwarder {
|
||||
f := &TCPForwarder{
|
||||
ctx: ctx,
|
||||
ipTracker: ipTracker,
|
||||
}
|
||||
|
||||
// Create gVisor TCP forwarder with handler callback
|
||||
// rcvWnd=0 (use default), maxInFlight=1024
|
||||
f.forwarder = tcp.NewForwarder(s, 0, 1024, f.handleRequest)
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
// GetForwarder returns the underlying gVisor forwarder
|
||||
func (f *TCPForwarder) GetForwarder() *tcp.Forwarder {
|
||||
return f.forwarder
|
||||
}
|
||||
|
||||
// handleRequest handles an incoming TCP connection request
|
||||
func (f *TCPForwarder) handleRequest(req *tcp.ForwarderRequest) {
|
||||
// Get the endpoint ID
|
||||
id := req.ID()
|
||||
|
||||
// Create waiter queue
|
||||
var wq waiter.Queue
|
||||
|
||||
// Create endpoint from request
|
||||
ep, err := req.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
req.Complete(true) // Send RST
|
||||
return
|
||||
}
|
||||
|
||||
// Accept the connection
|
||||
req.Complete(false)
|
||||
|
||||
// Cast to TCP endpoint
|
||||
tcpEP, ok := ep.(*tcp.Endpoint)
|
||||
if !ok {
|
||||
ep.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Handle in goroutine
|
||||
go f.handleConnection(tcpEP, &wq, id)
|
||||
}
|
||||
|
||||
func (f *TCPForwarder) handleConnection(ep *tcp.Endpoint, wq *waiter.Queue, id stack.TransportEndpointID) {
|
||||
// Convert endpoint to Go net.Conn
|
||||
tunConn := gonet.NewTCPConn(wq, ep)
|
||||
defer tunConn.Close()
|
||||
|
||||
// In gVisor's TransportEndpointID for an inbound connection:
|
||||
// - LocalAddress/LocalPort = the destination (where packet is going TO)
|
||||
// - RemoteAddress/RemotePort = the source (where packet is coming FROM)
|
||||
// We want to dial the DESTINATION (LocalAddress/LocalPort)
|
||||
dstIP := net.IP(id.LocalAddress.AsSlice())
|
||||
dstAddr := net.TCPAddr{
|
||||
IP: dstIP,
|
||||
Port: int(id.LocalPort),
|
||||
}
|
||||
|
||||
// Check if IP blocking is enabled (firewall mode only)
|
||||
// Skip blocking for internal VPN subnet (10.0.0.0/24)
|
||||
if f.ipTracker != nil {
|
||||
// Allow internal VPN traffic (10.0.0.0/24)
|
||||
if !(dstIP[0] == 10 && dstIP[1] == 0 && dstIP[2] == 0) {
|
||||
// Check if destination IP was resolved through ControlD DNS
|
||||
// ONLY allow connections to IPs that went through DNS (whitelist approach)
|
||||
if !f.ipTracker.IsTracked(dstIP) {
|
||||
srcAddr := net.IP(id.RemoteAddress.AsSlice())
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[TCP] BLOCKED hardcoded IP: %s:%d -> %s:%d (not resolved via DNS)",
|
||||
srcAddr, id.RemotePort, dstIP, id.LocalPort)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create outbound connection
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
upstreamConn, err := dialer.DialContext(f.ctx, "tcp", dstAddr.String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer upstreamConn.Close()
|
||||
|
||||
// Log successful TCP connection
|
||||
srcAddr := net.IP(id.RemoteAddress.AsSlice())
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("[TCP] %s:%d -> %s:%d", srcAddr, id.RemotePort, dstAddr.IP, dstAddr.Port)
|
||||
|
||||
// Bidirectional copy
|
||||
done := make(chan struct{}, 2)
|
||||
go func() {
|
||||
io.Copy(upstreamConn, tunConn)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
io.Copy(tunConn, upstreamConn)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
// Wait for one direction to finish
|
||||
<-done
|
||||
}
|
||||
238
cmd/ctrld_library/netstack/udp_forwarder.go
Normal file
238
cmd/ctrld_library/netstack/udp_forwarder.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package netstack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
// UDPForwarder handles UDP packets from the TUN interface
|
||||
type UDPForwarder struct {
|
||||
ctx context.Context
|
||||
forwarder *udp.Forwarder
|
||||
ipTracker *IPTracker
|
||||
|
||||
// Track UDP "connections" (address pairs)
|
||||
connections map[string]*udpConn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type udpConn struct {
|
||||
tunEP *gonet.UDPConn
|
||||
upstreamConn *net.UDPConn
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewUDPForwarder creates a new UDP forwarder
|
||||
func NewUDPForwarder(s *stack.Stack, ctx context.Context, ipTracker *IPTracker) *UDPForwarder {
|
||||
f := &UDPForwarder{
|
||||
ctx: ctx,
|
||||
ipTracker: ipTracker,
|
||||
connections: make(map[string]*udpConn),
|
||||
}
|
||||
|
||||
// Create gVisor UDP forwarder with handler callback
|
||||
f.forwarder = udp.NewForwarder(s, f.handlePacket)
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
// GetForwarder returns the underlying gVisor forwarder
|
||||
func (f *UDPForwarder) GetForwarder() *udp.Forwarder {
|
||||
return f.forwarder
|
||||
}
|
||||
|
||||
// handlePacket handles an incoming UDP packet
|
||||
func (f *UDPForwarder) handlePacket(req *udp.ForwarderRequest) {
|
||||
// Get the endpoint ID
|
||||
id := req.ID()
|
||||
|
||||
// Create connection key (source -> destination)
|
||||
connKey := fmt.Sprintf("%s:%d->%s:%d",
|
||||
net.IP(id.RemoteAddress.AsSlice()),
|
||||
id.RemotePort,
|
||||
net.IP(id.LocalAddress.AsSlice()),
|
||||
id.LocalPort,
|
||||
)
|
||||
|
||||
f.mu.Lock()
|
||||
conn, exists := f.connections[connKey]
|
||||
if !exists {
|
||||
// Create new connection
|
||||
conn = f.createConnection(req, connKey)
|
||||
if conn == nil {
|
||||
f.mu.Unlock()
|
||||
return
|
||||
}
|
||||
f.connections[connKey] = conn
|
||||
|
||||
// Log new UDP session
|
||||
srcAddr := net.IP(id.RemoteAddress.AsSlice())
|
||||
dstAddr := net.IP(id.LocalAddress.AsSlice())
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("[UDP] New session: %s:%d -> %s:%d (total: %d)",
|
||||
srcAddr, id.RemotePort, dstAddr, id.LocalPort, len(f.connections))
|
||||
}
|
||||
f.mu.Unlock()
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) createConnection(req *udp.ForwarderRequest, connKey string) *udpConn {
|
||||
id := req.ID()
|
||||
|
||||
// Create waiter queue
|
||||
var wq waiter.Queue
|
||||
|
||||
// Create endpoint from request
|
||||
ep, err := req.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert to Go UDP conn
|
||||
tunConn := gonet.NewUDPConn(&wq, ep)
|
||||
|
||||
// Extract destination address
|
||||
// LocalAddress/LocalPort = destination (where packet is going TO)
|
||||
// RemoteAddress/RemotePort = source (where packet is coming FROM)
|
||||
dstIP := net.IP(id.LocalAddress.AsSlice())
|
||||
dstAddr := &net.UDPAddr{
|
||||
IP: dstIP,
|
||||
Port: int(id.LocalPort),
|
||||
}
|
||||
|
||||
// Check if IP blocking is enabled (firewall mode only)
|
||||
// Skip blocking for internal VPN subnet (10.0.0.0/24)
|
||||
if f.ipTracker != nil {
|
||||
// Allow internal VPN traffic (10.0.0.0/24)
|
||||
if !(dstIP[0] == 10 && dstIP[1] == 0 && dstIP[2] == 0) {
|
||||
// Check if destination IP was resolved through ControlD DNS
|
||||
// ONLY allow connections to IPs that went through DNS (whitelist approach)
|
||||
if !f.ipTracker.IsTracked(dstIP) {
|
||||
srcAddr := net.IP(id.RemoteAddress.AsSlice())
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[UDP] BLOCKED hardcoded IP: %s:%d -> %s:%d (not resolved via DNS)",
|
||||
srcAddr, id.RemotePort, dstIP, id.LocalPort)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create dialer
|
||||
dialer := &net.Dialer{}
|
||||
|
||||
// Create outbound UDP connection
|
||||
dialConn, dialErr := dialer.Dial("udp", dstAddr.String())
|
||||
if dialErr != nil {
|
||||
tunConn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
upstreamConn, ok := dialConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
dialConn.Close()
|
||||
tunConn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create connection context
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
|
||||
udpConnection := &udpConn{
|
||||
tunEP: tunConn,
|
||||
upstreamConn: upstreamConn,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Start forwarding goroutines
|
||||
go f.forwardTunToUpstream(udpConnection, ctx)
|
||||
go f.forwardUpstreamToTun(udpConnection, ctx, connKey)
|
||||
|
||||
return udpConnection
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) forwardTunToUpstream(conn *udpConn, ctx context.Context) {
|
||||
buffer := make([]byte, 65535)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read from TUN
|
||||
n, err := conn.tunEP.Read(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Write to upstream
|
||||
_, err = conn.upstreamConn.Write(buffer[:n])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) forwardUpstreamToTun(conn *udpConn, ctx context.Context, connKey string) {
|
||||
defer func() {
|
||||
conn.tunEP.Close()
|
||||
conn.upstreamConn.Close()
|
||||
|
||||
f.mu.Lock()
|
||||
delete(f.connections, connKey)
|
||||
f.mu.Unlock()
|
||||
}()
|
||||
|
||||
buffer := make([]byte, 65535)
|
||||
|
||||
// Set read timeout
|
||||
conn.upstreamConn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read from upstream
|
||||
n, err := conn.upstreamConn.Read(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Reset read deadline
|
||||
conn.upstreamConn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
// Write to TUN
|
||||
_, err = conn.tunEP.Write(buffer[:n])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes all UDP connections
|
||||
func (f *UDPForwarder) Close() {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[UDP] Close() called - closing all connections")
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[UDP] Close() - closing %d connections", len(f.connections))
|
||||
for key, conn := range f.connections {
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("[UDP] Close() - closing connection: %s", key)
|
||||
conn.cancel()
|
||||
conn.tunEP.Close()
|
||||
conn.upstreamConn.Close()
|
||||
}
|
||||
f.connections = make(map[string]*udpConn)
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[UDP] Close() - all connections closed")
|
||||
}
|
||||
277
cmd/ctrld_library/packet_capture.go
Normal file
277
cmd/ctrld_library/packet_capture.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package ctrld_library
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
"github.com/Control-D-Inc/ctrld/cmd/ctrld_library/netstack"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// PacketAppCallback extends AppCallback with packet read/write capabilities.
|
||||
// Mobile platforms implementing full packet capture should use this interface.
|
||||
type PacketAppCallback interface {
|
||||
AppCallback
|
||||
|
||||
// ReadPacket reads a raw IP packet from the TUN interface.
|
||||
// This should be a blocking call that returns when a packet is available.
|
||||
ReadPacket() ([]byte, error)
|
||||
|
||||
// WritePacket writes a raw IP packet back to the TUN interface.
|
||||
WritePacket(packet []byte) error
|
||||
|
||||
// ClosePacketIO closes packet I/O resources.
|
||||
ClosePacketIO() error
|
||||
}
|
||||
|
||||
// PacketCaptureController holds state for packet capture mode
|
||||
type PacketCaptureController struct {
|
||||
baseController *Controller
|
||||
|
||||
// Packet capture mode fields
|
||||
netstackCtrl *netstack.NetstackController
|
||||
dnsBridge *netstack.DNSBridge
|
||||
packetStopCh chan struct{}
|
||||
dnsProxyAddress string
|
||||
}
|
||||
|
||||
// NewPacketCaptureController creates a new packet capture controller
|
||||
func NewPacketCaptureController(appCallback PacketAppCallback) *PacketCaptureController {
|
||||
return &PacketCaptureController{
|
||||
baseController: &Controller{AppCallback: appCallback},
|
||||
packetStopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// StartWithPacketCapture starts ctrld in full packet capture mode for mobile.
|
||||
// This method enables full IP packet processing with DNS filtering and upstream routing.
|
||||
// It requires a PacketAppCallback that provides packet read/write capabilities.
|
||||
func (pc *PacketCaptureController) StartWithPacketCapture(
|
||||
packetCallback PacketAppCallback,
|
||||
tunAddress string,
|
||||
deviceAddress string,
|
||||
mtu int64,
|
||||
dnsProxyAddress string,
|
||||
CdUID string,
|
||||
ProvisionID string,
|
||||
CustomHostname string,
|
||||
HomeDir string,
|
||||
UpstreamProto string,
|
||||
logLevel int,
|
||||
logPath string,
|
||||
) error {
|
||||
if pc.baseController.stopCh != nil {
|
||||
return fmt.Errorf("controller already running")
|
||||
}
|
||||
|
||||
// Store DNS proxy address for handleDNSQuery
|
||||
pc.dnsProxyAddress = dnsProxyAddress
|
||||
|
||||
// Set defaults
|
||||
if mtu == 0 {
|
||||
mtu = 1500
|
||||
}
|
||||
|
||||
// Set up configuration
|
||||
pc.baseController.Config = cli.AppConfig{
|
||||
CdUID: CdUID,
|
||||
ProvisionID: ProvisionID,
|
||||
CustomHostname: CustomHostname,
|
||||
HomeDir: HomeDir,
|
||||
UpstreamProto: UpstreamProto,
|
||||
Verbose: logLevel,
|
||||
LogPath: logPath,
|
||||
}
|
||||
pc.baseController.AppCallback = packetCallback
|
||||
|
||||
// Create DNS bridge for communication between netstack and DNS proxy
|
||||
pc.dnsBridge = netstack.NewDNSBridge()
|
||||
pc.dnsBridge.Start()
|
||||
|
||||
// Create packet handler that wraps the mobile callbacks
|
||||
packetHandler := netstack.NewMobilePacketHandler(
|
||||
packetCallback.ReadPacket,
|
||||
packetCallback.WritePacket,
|
||||
packetCallback.ClosePacketIO,
|
||||
)
|
||||
|
||||
// Create DNS handler that uses the bridge
|
||||
dnsHandler := func(query []byte) ([]byte, error) {
|
||||
// Use device address as the source of DNS queries
|
||||
return pc.dnsBridge.ProcessQuery(query, deviceAddress, 0)
|
||||
}
|
||||
|
||||
// Parse TUN IP address
|
||||
tunIPv4, err := netip.ParseAddr(tunAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse TUN IPv4 address '%s': %v", tunAddress, err)
|
||||
}
|
||||
|
||||
netstackCfg := &netstack.Config{
|
||||
MTU: uint32(mtu),
|
||||
TUNIPv4: tunIPv4,
|
||||
DNSHandler: dnsHandler,
|
||||
UpstreamInterface: nil, // Will use default interface
|
||||
}
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Network config - TUN: %s, Device: %s, MTU: %d, DNS Proxy: %s",
|
||||
tunAddress, deviceAddress, mtu, dnsProxyAddress)
|
||||
|
||||
// Create netstack controller
|
||||
netstackCtrl, err := netstack.NewNetstackController(packetHandler, netstackCfg)
|
||||
if err != nil {
|
||||
pc.dnsBridge.Stop()
|
||||
return fmt.Errorf("failed to create netstack controller: %v", err)
|
||||
}
|
||||
|
||||
pc.netstackCtrl = netstackCtrl
|
||||
|
||||
// Start netstack processing
|
||||
if err := pc.netstackCtrl.Start(); err != nil {
|
||||
pc.dnsBridge.Stop()
|
||||
return fmt.Errorf("failed to start netstack: %v", err)
|
||||
}
|
||||
|
||||
// Start regular ctrld DNS processing in background
|
||||
// This allows us to use existing DNS filtering logic
|
||||
pc.baseController.stopCh = make(chan struct{})
|
||||
|
||||
// Start DNS query processor that receives queries from the bridge
|
||||
// and sends them to the ctrld DNS proxy
|
||||
go pc.processDNSQueries()
|
||||
|
||||
// Start the main ctrld mobile runner
|
||||
go func() {
|
||||
appCallback := mapCallback(pc.baseController.AppCallback)
|
||||
cli.RunMobile(&pc.baseController.Config, &appCallback, pc.baseController.stopCh)
|
||||
}()
|
||||
|
||||
// BLOCK here until stopped (critical - Swift expects this to block!)
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Blocking until stop signal...")
|
||||
<-pc.baseController.stopCh
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop signal received, exiting")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processDNSQueries processes DNS queries from the bridge using the ctrld DNS proxy
|
||||
func (pc *PacketCaptureController) processDNSQueries() {
|
||||
queryCh := pc.dnsBridge.GetQueryChannel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pc.packetStopCh:
|
||||
return
|
||||
case <-pc.baseController.stopCh:
|
||||
return
|
||||
case query := <-queryCh:
|
||||
go pc.handleDNSQuery(query)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleDNSQuery handles a single DNS query
|
||||
func (pc *PacketCaptureController) handleDNSQuery(query *netstack.DNSQuery) {
|
||||
// Parse DNS message
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(query.Query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Send query to actual DNS proxy using configured address
|
||||
client := &dns.Client{
|
||||
Net: "udp",
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
response, _, err := client.Exchange(msg, pc.dnsProxyAddress)
|
||||
if err != nil {
|
||||
// Create SERVFAIL response
|
||||
response = new(dns.Msg)
|
||||
response.SetReply(msg)
|
||||
response.Rcode = dns.RcodeServerFailure
|
||||
}
|
||||
|
||||
// Pack response
|
||||
responseBytes, err := response.Pack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Send response back through bridge
|
||||
pc.dnsBridge.SendResponse(query.ID, responseBytes)
|
||||
}
|
||||
|
||||
// Stop stops the packet capture controller
|
||||
func (pc *PacketCaptureController) Stop(restart bool, pin int64) int {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() called - starting shutdown")
|
||||
var errorCode = 0
|
||||
|
||||
// Stop DNS bridge
|
||||
if pc.dnsBridge != nil {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - stopping DNS bridge")
|
||||
pc.dnsBridge.Stop()
|
||||
pc.dnsBridge = nil
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - DNS bridge stopped")
|
||||
}
|
||||
|
||||
// Stop netstack
|
||||
if pc.netstackCtrl != nil {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - stopping netstack controller")
|
||||
if err := pc.netstackCtrl.Stop(); err != nil {
|
||||
// Log error but continue shutdown
|
||||
ctrld.ProxyLogger.Load().Error().Msgf("[PacketCapture] Stop() - error stopping netstack: %v", err)
|
||||
}
|
||||
pc.netstackCtrl = nil
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - netstack controller stopped")
|
||||
}
|
||||
|
||||
// Close packet stop channel
|
||||
if pc.packetStopCh != nil {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - closing packet stop channel")
|
||||
select {
|
||||
case <-pc.packetStopCh:
|
||||
// Already closed
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - packet stop channel already closed")
|
||||
default:
|
||||
close(pc.packetStopCh)
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - packet stop channel closed")
|
||||
}
|
||||
pc.packetStopCh = make(chan struct{})
|
||||
}
|
||||
|
||||
// Stop base controller
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Stop() - stopping base controller (restart=%v, pin=%d)", restart, pin)
|
||||
if !restart {
|
||||
errorCode = cli.CheckDeactivationPin(pin, pc.baseController.stopCh)
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Stop() - deactivation pin check returned: %d", errorCode)
|
||||
}
|
||||
if errorCode == 0 && pc.baseController.stopCh != nil {
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - closing base controller stop channel")
|
||||
select {
|
||||
case <-pc.baseController.stopCh:
|
||||
// Already closed
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - base controller stop channel already closed")
|
||||
default:
|
||||
close(pc.baseController.stopCh)
|
||||
ctrld.ProxyLogger.Load().Info().Msg("[PacketCapture] Stop() - base controller stop channel closed")
|
||||
}
|
||||
pc.baseController.stopCh = nil
|
||||
}
|
||||
|
||||
ctrld.ProxyLogger.Load().Info().Msgf("[PacketCapture] Stop() - shutdown complete, errorCode=%d", errorCode)
|
||||
return errorCode
|
||||
}
|
||||
|
||||
// IsRunning returns true if the controller is running
|
||||
func (pc *PacketCaptureController) IsRunning() bool {
|
||||
return pc.baseController.stopCh != nil
|
||||
}
|
||||
|
||||
// IsPacketMode returns true (always in packet mode for this controller)
|
||||
func (pc *PacketCaptureController) IsPacketMode() bool {
|
||||
return true
|
||||
}
|
||||
453
config.go
453
config.go
@@ -2,13 +2,17 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
@@ -17,13 +21,17 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ameshkov/dnsstamps"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/tsaddr"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnsrcode"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
@@ -42,9 +50,40 @@ const (
|
||||
// depending on the record type of the DNS query.
|
||||
IpStackSplit = "split"
|
||||
|
||||
// FreeDnsDomain is the domain name of free ControlD service.
|
||||
FreeDnsDomain = "freedns.controld.com"
|
||||
// FreeDNSBoostrapIP is the IP address of freedns.controld.com.
|
||||
FreeDNSBoostrapIP = "76.76.2.11"
|
||||
// FreeDNSBoostrapIPv6 is the IPv6 address of freedns.controld.com.
|
||||
FreeDNSBoostrapIPv6 = "2606:1a40::11"
|
||||
// PremiumDnsDomain is the domain name of premium ControlD service.
|
||||
PremiumDnsDomain = "dns.controld.com"
|
||||
// PremiumDNSBoostrapIP is the IP address of dns.controld.com.
|
||||
PremiumDNSBoostrapIP = "76.76.2.22"
|
||||
// PremiumDNSBoostrapIPv6 is the IPv6 address of dns.controld.com.
|
||||
PremiumDNSBoostrapIPv6 = "2606:1a40::22"
|
||||
|
||||
// freeDnsDomainDev is the domain name of free ControlD service on dev env.
|
||||
freeDnsDomainDev = "freedns.controld.dev"
|
||||
// freeDNSBoostrapIP is the IP address of freedns.controld.dev.
|
||||
freeDNSBoostrapIP = "176.125.239.11"
|
||||
// freeDNSBoostrapIPv6 is the IPv6 address of freedns.controld.com.
|
||||
freeDNSBoostrapIPv6 = "2606:1a40:f000::11"
|
||||
// premiumDnsDomainDev is the domain name of premium ControlD service on dev env.
|
||||
premiumDnsDomainDev = "dns.controld.dev"
|
||||
// premiumDNSBoostrapIP is the IP address of dns.controld.dev.
|
||||
premiumDNSBoostrapIP = "176.125.239.22"
|
||||
// premiumDNSBoostrapIPv6 is the IPv6 address of dns.controld.dev.
|
||||
premiumDNSBoostrapIPv6 = "2606:1a40:f000::22"
|
||||
|
||||
controlDComDomain = "controld.com"
|
||||
controlDNetDomain = "controld.net"
|
||||
controlDDevDomain = "controld.dev"
|
||||
|
||||
endpointPrefixHTTPS = "https://"
|
||||
endpointPrefixQUIC = "quic://"
|
||||
endpointPrefixH3 = "h3://"
|
||||
endpointPrefixSdns = "sdns://"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -78,8 +117,18 @@ func SetConfigNameWithPath(v *viper.Viper, name, configPath string) {
|
||||
func InitConfig(v *viper.Viper, name string) {
|
||||
v.SetDefault("listener", map[string]*ListenerConfig{
|
||||
"0": {
|
||||
IP: "127.0.0.1",
|
||||
Port: 53,
|
||||
IP: "",
|
||||
Port: 0,
|
||||
Policy: &ListenerPolicyConfig{
|
||||
Name: "Main Policy",
|
||||
Networks: []Rule{
|
||||
{"network.0": []string{"upstream.0"}},
|
||||
},
|
||||
Rules: []Rule{
|
||||
{"example.com": []string{"upstream.0"}},
|
||||
{"*.ads.com": []string{"upstream.1"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
v.SetDefault("network", map[string]*NetworkConfig{
|
||||
@@ -90,14 +139,14 @@ func InitConfig(v *viper.Viper, name string) {
|
||||
})
|
||||
v.SetDefault("upstream", map[string]*UpstreamConfig{
|
||||
"0": {
|
||||
BootstrapIP: "76.76.2.11",
|
||||
BootstrapIP: FreeDNSBoostrapIP,
|
||||
Name: "Control D - Anti-Malware",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: "https://freedns.controld.com/p1",
|
||||
Timeout: 5000,
|
||||
},
|
||||
"1": {
|
||||
BootstrapIP: "76.76.2.11",
|
||||
BootstrapIP: FreeDNSBoostrapIP,
|
||||
Name: "Control D - No Ads",
|
||||
Type: ResolverTypeDOQ,
|
||||
Endpoint: "p2.freedns.controld.com",
|
||||
@@ -165,21 +214,32 @@ func (c *Config) FirstUpstream() *UpstreamConfig {
|
||||
|
||||
// ServiceConfig specifies the general ctrld config.
|
||||
type ServiceConfig struct {
|
||||
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
|
||||
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
|
||||
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
|
||||
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
|
||||
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
|
||||
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
|
||||
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
||||
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
|
||||
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
||||
Daemon bool `mapstructure:"-" toml:"-"`
|
||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
|
||||
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
|
||||
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
|
||||
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
|
||||
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
|
||||
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
|
||||
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
|
||||
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
||||
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp kea-dhcp4"`
|
||||
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
||||
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
|
||||
DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"`
|
||||
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
|
||||
MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"`
|
||||
MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"`
|
||||
DnsWatchdogEnabled *bool `mapstructure:"dns_watchdog_enabled" toml:"dns_watchdog_enabled,omitempty"`
|
||||
DnsWatchdogInvterval *time.Duration `mapstructure:"dns_watchdog_interval" toml:"dns_watchdog_interval,omitempty"`
|
||||
RefetchTime *int `mapstructure:"refetch_time" toml:"refetch_time,omitempty"`
|
||||
ForceRefetchWaitTime *int `mapstructure:"force_refetch_wait_time" toml:"force_refetch_wait_time,omitempty"`
|
||||
LeakOnUpstreamFailure *bool `mapstructure:"leak_on_upstream_failure" toml:"leak_on_upstream_failure,omitempty"`
|
||||
Daemon bool `mapstructure:"-" toml:"-"`
|
||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
|
||||
// NetworkConfig specifies configuration for networks where ctrld will handle requests.
|
||||
@@ -192,7 +252,7 @@ type NetworkConfig struct {
|
||||
// UpstreamConfig specifies configuration for upstreams that ctrld will forward requests to.
|
||||
type UpstreamConfig struct {
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy"`
|
||||
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy sdns ''"`
|
||||
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty"`
|
||||
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"`
|
||||
Domain string `mapstructure:"-" toml:"-"`
|
||||
@@ -201,6 +261,9 @@ type UpstreamConfig struct {
|
||||
// The caller should not access this field directly.
|
||||
// Use UpstreamSendClientInfo instead.
|
||||
SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"`
|
||||
// The caller should not access this field directly.
|
||||
// Use IsDiscoverable instead.
|
||||
Discoverable *bool `mapstructure:"discoverable" toml:"discoverable"`
|
||||
|
||||
g singleflight.Group
|
||||
rebootstrap atomic.Bool
|
||||
@@ -216,14 +279,17 @@ type UpstreamConfig struct {
|
||||
http3RoundTripper6 http.RoundTripper
|
||||
certPool *x509.CertPool
|
||||
u *url.URL
|
||||
fallbackOnce sync.Once
|
||||
uid string
|
||||
}
|
||||
|
||||
// ListenerConfig specifies the networks configuration that ctrld will run on.
|
||||
type ListenerConfig struct {
|
||||
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"`
|
||||
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"`
|
||||
Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"`
|
||||
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"`
|
||||
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"`
|
||||
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"`
|
||||
Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"`
|
||||
AllowWanClients bool `mapstructure:"allow_wan_clients" toml:"allow_wan_clients,omitempty"`
|
||||
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"`
|
||||
}
|
||||
|
||||
// IsDirectDnsListener reports whether ctrld can be a direct listener on port 53.
|
||||
@@ -249,6 +315,7 @@ type ListenerPolicyConfig struct {
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"`
|
||||
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
@@ -260,8 +327,13 @@ type Rule map[string][]string
|
||||
|
||||
// Init initialized necessary values for an UpstreamConfig.
|
||||
func (uc *UpstreamConfig) Init() {
|
||||
if err := uc.initDnsStamps(); err != nil {
|
||||
ProxyLogger.Load().Fatal().Err(err).Msg("invalid DNS Stamps")
|
||||
}
|
||||
uc.initDoHScheme()
|
||||
uc.uid = upstreamUID()
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
uc.Domain = u.Host
|
||||
uc.Domain = u.Hostname()
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
uc.u = u
|
||||
@@ -279,7 +351,7 @@ func (uc *UpstreamConfig) Init() {
|
||||
}
|
||||
}
|
||||
if uc.IPStack == "" {
|
||||
if uc.isControlD() {
|
||||
if uc.IsControlD() {
|
||||
uc.IPStack = IpStackSplit
|
||||
} else {
|
||||
uc.IPStack = IpStackBoth
|
||||
@@ -287,6 +359,15 @@ func (uc *UpstreamConfig) Init() {
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyMsg creates and returns a new DNS message could be used for testing upstream health.
|
||||
func (uc *UpstreamConfig) VerifyMsg() *dns.Msg {
|
||||
msg := new(dns.Msg)
|
||||
msg.RecursionDesired = true
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
msg.SetEdns0(4096, false) // ensure handling of large DNS response
|
||||
return msg
|
||||
}
|
||||
|
||||
// VerifyDomain returns the domain name that could be resolved by the upstream endpoint.
|
||||
// It returns empty for non-ControlD upstream endpoint.
|
||||
func (uc *UpstreamConfig) VerifyDomain() string {
|
||||
@@ -317,13 +398,28 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
if uc.isControlD() {
|
||||
if uc.IsControlD() || uc.isNextDNS() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsDiscoverable reports whether the upstream can be used for PTR discovery.
|
||||
// The caller must ensure uc.Init() was called before calling this.
|
||||
func (uc *UpstreamConfig) IsDiscoverable() bool {
|
||||
if uc.Discoverable != nil {
|
||||
return *uc.Discoverable
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate, ResolverTypeLocal:
|
||||
if ip, err := netip.ParseAddr(uc.Domain); err == nil {
|
||||
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// BootstrapIPs returns the bootstrap IPs list of upstreams.
|
||||
func (uc *UpstreamConfig) BootstrapIPs() []string {
|
||||
return uc.bootstrapIPs
|
||||
@@ -334,19 +430,26 @@ func (uc *UpstreamConfig) SetCertPool(cp *x509.CertPool) {
|
||||
uc.certPool = cp
|
||||
}
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
uc.setupBootstrapIP(true)
|
||||
// UID returns the unique identifier of the upstream.
|
||||
func (uc *UpstreamConfig) UID() string {
|
||||
return uc.uid
|
||||
}
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
// The upstream domain will be looked up using following orders:
|
||||
//
|
||||
// - Current system DNS settings.
|
||||
// - Direct IPs table for ControlD upstreams.
|
||||
// - ControlD Bootstrap DNS 76.76.2.22
|
||||
//
|
||||
// The setup process will block until there's usable IPs found.
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second)
|
||||
isControlD := uc.isControlD()
|
||||
isControlD := uc.IsControlD()
|
||||
nss := initDefaultOsResolver()
|
||||
for {
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, nss)
|
||||
// For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses,
|
||||
// filtering them out here to prevent weird behavior.
|
||||
if isControlD {
|
||||
@@ -359,6 +462,15 @@ func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
}
|
||||
}
|
||||
uc.bootstrapIPs = uc.bootstrapIPs[:n]
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
uc.bootstrapIPs = bootstrapIPsFromControlDDomain(uc.Domain)
|
||||
ProxyLogger.Load().Warn().Msgf("no record found for %q, lookup from direct IP table", uc.Domain)
|
||||
}
|
||||
}
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
ProxyLogger.Load().Warn().Msgf("no record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP)
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")})
|
||||
|
||||
}
|
||||
if len(uc.bootstrapIPs) > 0 {
|
||||
break
|
||||
@@ -384,8 +496,9 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
return
|
||||
}
|
||||
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
|
||||
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
|
||||
uc.rebootstrap.Store(true)
|
||||
if uc.rebootstrap.CompareAndSwap(false, true) {
|
||||
ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
@@ -411,7 +524,7 @@ func (uc *UpstreamConfig) setupDOHTransport() {
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
|
||||
if hasIPv6() {
|
||||
if HasIPv6() {
|
||||
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.transport6 = uc.transport4
|
||||
@@ -428,6 +541,13 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
ClientSessionCache: tls.NewLRUClientSessionCache(0),
|
||||
}
|
||||
|
||||
// Prevent bad tcp connection hanging the requests for too long.
|
||||
// See: https://github.com/golang/go/issues/36026
|
||||
if t2, err := http2.ConfigureTransports(transport); err == nil {
|
||||
t2.ReadIdleTimeout = 10 * time.Second
|
||||
t2.PingTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
dialerTimeoutMs := 2000
|
||||
if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs {
|
||||
dialerTimeoutMs = uc.Timeout
|
||||
@@ -436,7 +556,24 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
if uc.BootstrapIP != "" {
|
||||
dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout}
|
||||
// Create custom dialer with socket protection - matches working example pattern
|
||||
dialer := &net.Dialer{
|
||||
Timeout: dialerTimeout,
|
||||
KeepAlive: dialerTimeout,
|
||||
}
|
||||
// Access underlying socket fd before connecting to it
|
||||
dialer.Control = func(network, address string, c syscall.RawConn) error {
|
||||
return c.Control(func(fd uintptr) {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "Received DoH socket fd %d for %s", fd, address)
|
||||
i := int(fd)
|
||||
// Protect socket from VPN routing
|
||||
if err := ProtectSocket(i); err != nil {
|
||||
Log(ctx, ProxyLogger.Load().Warn(), "Failed to protect DoH socket fd=%d: %v", i, err)
|
||||
} else {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "Protected DoH socket fd=%d", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
addr := net.JoinHostPort(uc.BootstrapIP, port)
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", addr)
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
@@ -448,10 +585,25 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
for i := range addrs {
|
||||
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
||||
}
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs)
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs, ProxyLogger.Load())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Protect DoH socket from VPN routing
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
if rawConn, err := tcpConn.SyscallConn(); err == nil {
|
||||
rawConn.Control(func(fd uintptr) {
|
||||
i := int(fd)
|
||||
if err := ProtectSocket(i); err != nil {
|
||||
Log(ctx, ProxyLogger.Load().Warn(), "Failed to protect DoH socket fd=%d: %v", i, err)
|
||||
} else {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "Protected DoH socket fd=%d", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", conn.RemoteAddr())
|
||||
return conn, nil
|
||||
}
|
||||
@@ -463,38 +615,61 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
|
||||
// Ping warms up the connection to DoH/DoH3 upstream.
|
||||
func (uc *UpstreamConfig) Ping() {
|
||||
if err := uc.ping(); err != nil {
|
||||
ProxyLogger.Load().Debug().Err(err).Msgf("upstream ping failed: %s", uc.Endpoint)
|
||||
_ = uc.FallbackToDirectIP()
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorPing is like Ping, but return an error if any.
|
||||
func (uc *UpstreamConfig) ErrorPing() error {
|
||||
return uc.ping()
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) ping() error {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
default:
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
ping := func(t http.RoundTripper) {
|
||||
ping := func(t http.RoundTripper) error {
|
||||
if t == nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
req, _ := http.NewRequestWithContext(ctx, "HEAD", uc.Endpoint, nil)
|
||||
resp, _ := t.RoundTrip(req)
|
||||
if resp == nil {
|
||||
return
|
||||
req, err := http.NewRequestWithContext(ctx, "HEAD", uc.Endpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := t.RoundTrip(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH:
|
||||
ping(uc.dohTransport(typ))
|
||||
if err := ping(uc.dohTransport(typ)); err != nil {
|
||||
return err
|
||||
}
|
||||
case ResolverTypeDOH3:
|
||||
ping(uc.doh3Transport(typ))
|
||||
if err := ping(uc.doh3Transport(typ)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) isControlD() bool {
|
||||
// IsControlD reports whether this is a ControlD upstream.
|
||||
func (uc *UpstreamConfig) IsControlD() bool {
|
||||
domain := uc.Domain
|
||||
if domain == "" {
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
@@ -509,6 +684,16 @@ func (uc *UpstreamConfig) isControlD() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) isNextDNS() bool {
|
||||
domain := uc.Domain
|
||||
if domain == "" {
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
domain = u.Hostname()
|
||||
}
|
||||
}
|
||||
return domain == "dns.nextdns.io"
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport()
|
||||
@@ -543,7 +728,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
|
||||
case dns.TypeA:
|
||||
return pick(uc.bootstrapIPs4)
|
||||
default:
|
||||
if hasIPv6() {
|
||||
if HasIPv6() {
|
||||
return pick(uc.bootstrapIPs6)
|
||||
}
|
||||
return pick(uc.bootstrapIPs4)
|
||||
@@ -565,7 +750,7 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
|
||||
case dns.TypeA:
|
||||
return "tcp4-tls", "udp4"
|
||||
default:
|
||||
if hasIPv6() {
|
||||
if HasIPv6() {
|
||||
return "tcp6-tls", "udp6"
|
||||
}
|
||||
return "tcp4-tls", "udp4"
|
||||
@@ -574,6 +759,104 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
|
||||
return "tcp-tls", "udp"
|
||||
}
|
||||
|
||||
// initDoHScheme initializes the endpoint scheme for DoH/DoH3 upstream if not present.
|
||||
func (uc *UpstreamConfig) initDoHScheme() {
|
||||
if strings.HasPrefix(uc.Endpoint, endpointPrefixH3) && uc.Type == "" {
|
||||
uc.Type = ResolverTypeDOH3
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH:
|
||||
case ResolverTypeDOH3:
|
||||
if after, found := strings.CutPrefix(uc.Endpoint, endpointPrefixH3); found {
|
||||
uc.Endpoint = endpointPrefixHTTPS + after
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(uc.Endpoint, endpointPrefixHTTPS) {
|
||||
uc.Endpoint = endpointPrefixHTTPS + uc.Endpoint
|
||||
}
|
||||
}
|
||||
|
||||
// initDnsStamps initializes upstream config based on encoded DNS Stamps Endpoint.
|
||||
func (uc *UpstreamConfig) initDnsStamps() error {
|
||||
if strings.HasPrefix(uc.Endpoint, endpointPrefixSdns) && uc.Type == "" {
|
||||
uc.Type = ResolverTypeSDNS
|
||||
}
|
||||
if uc.Type != ResolverTypeSDNS {
|
||||
return nil
|
||||
}
|
||||
sdns, err := dnsstamps.NewServerStampFromString(uc.Endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ip, port, _ := net.SplitHostPort(sdns.ServerAddrStr)
|
||||
providerName, port2, _ := net.SplitHostPort(sdns.ProviderName)
|
||||
if port2 != "" {
|
||||
port = port2
|
||||
}
|
||||
if providerName == "" {
|
||||
providerName = sdns.ProviderName
|
||||
}
|
||||
switch sdns.Proto {
|
||||
case dnsstamps.StampProtoTypeDoH:
|
||||
uc.Type = ResolverTypeDOH
|
||||
host := sdns.ProviderName
|
||||
if port != "" && port != defaultPortFor(uc.Type) {
|
||||
host = net.JoinHostPort(providerName, port)
|
||||
}
|
||||
uc.Endpoint = "https://" + host + sdns.Path
|
||||
case dnsstamps.StampProtoTypeTLS:
|
||||
uc.Type = ResolverTypeDOT
|
||||
uc.Endpoint = net.JoinHostPort(providerName, port)
|
||||
case dnsstamps.StampProtoTypeDoQ:
|
||||
uc.Type = ResolverTypeDOQ
|
||||
uc.Endpoint = net.JoinHostPort(providerName, port)
|
||||
case dnsstamps.StampProtoTypePlain:
|
||||
uc.Type = ResolverTypeLegacy
|
||||
uc.Endpoint = sdns.ServerAddrStr
|
||||
default:
|
||||
return fmt.Errorf("unsupported stamp protocol %q", sdns.Proto)
|
||||
}
|
||||
uc.BootstrapIP = ip
|
||||
return nil
|
||||
}
|
||||
|
||||
// Context returns a new context with timeout set from upstream config.
|
||||
func (uc *UpstreamConfig) Context(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if uc.Timeout > 0 {
|
||||
return context.WithTimeout(ctx, time.Millisecond*time.Duration(uc.Timeout))
|
||||
}
|
||||
return context.WithCancel(ctx)
|
||||
}
|
||||
|
||||
// FallbackToDirectIP changes ControlD upstream endpoint to use direct IP instead of domain.
|
||||
func (uc *UpstreamConfig) FallbackToDirectIP() bool {
|
||||
if !uc.IsControlD() {
|
||||
return false
|
||||
}
|
||||
if uc.u == nil || uc.Domain == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
done := false
|
||||
uc.fallbackOnce.Do(func() {
|
||||
var ip string
|
||||
switch {
|
||||
case dns.IsSubDomain(PremiumDnsDomain, uc.Domain):
|
||||
ip = PremiumDNSBoostrapIP
|
||||
case dns.IsSubDomain(FreeDnsDomain, uc.Domain):
|
||||
ip = FreeDNSBoostrapIP
|
||||
default:
|
||||
return
|
||||
}
|
||||
ProxyLogger.Load().Warn().Msgf("using direct IP for %q: %s", uc.Endpoint, ip)
|
||||
uc.u.Host = ip
|
||||
done = true
|
||||
})
|
||||
return done
|
||||
}
|
||||
|
||||
// Init initialized necessary values for an ListenerConfig.
|
||||
func (lc *ListenerConfig) Init() {
|
||||
if lc.Policy != nil {
|
||||
@@ -626,6 +909,24 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) {
|
||||
return
|
||||
}
|
||||
|
||||
// Empty type is ok only for endpoints starts with "h3://" and "sdns://".
|
||||
if uc.Type == "" && !strings.HasPrefix(uc.Endpoint, endpointPrefixH3) && !strings.HasPrefix(uc.Endpoint, endpointPrefixSdns) {
|
||||
sl.ReportError(uc.Endpoint, "type", "type", "oneof", "doh doh3 dot doq os legacy sdns")
|
||||
return
|
||||
}
|
||||
|
||||
// initDoHScheme/initDnsStamps may change upstreams information,
|
||||
// so restoring changed values after validation to keep original one.
|
||||
defer func(ep, typ string) {
|
||||
uc.Endpoint = ep
|
||||
uc.Type = typ
|
||||
}(uc.Endpoint, uc.Type)
|
||||
|
||||
if err := uc.initDnsStamps(); err != nil {
|
||||
sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "")
|
||||
return
|
||||
}
|
||||
uc.initDoHScheme()
|
||||
// DoH/DoH3 requires endpoint is an HTTP url.
|
||||
if uc.Type == ResolverTypeDOH || uc.Type == ResolverTypeDOH3 {
|
||||
u, err := url.Parse(uc.Endpoint)
|
||||
@@ -633,10 +934,6 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) {
|
||||
sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "")
|
||||
return
|
||||
}
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -658,13 +955,19 @@ func defaultPortFor(typ string) string {
|
||||
// - If endpoint is an IP address -> ResolverTypeLegacy
|
||||
// - If endpoint starts with "https://" -> ResolverTypeDOH
|
||||
// - If endpoint starts with "quic://" -> ResolverTypeDOQ
|
||||
// - If endpoint starts with "h3://" -> ResolverTypeDOH3
|
||||
// - If endpoint starts with "sdns://" -> ResolverTypeSDNS
|
||||
// - For anything else -> ResolverTypeDOT
|
||||
func ResolverTypeFromEndpoint(endpoint string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(endpoint, "https://"):
|
||||
case strings.HasPrefix(endpoint, endpointPrefixHTTPS):
|
||||
return ResolverTypeDOH
|
||||
case strings.HasPrefix(endpoint, "quic://"):
|
||||
case strings.HasPrefix(endpoint, endpointPrefixQUIC):
|
||||
return ResolverTypeDOQ
|
||||
case strings.HasPrefix(endpoint, endpointPrefixH3):
|
||||
return ResolverTypeDOH3
|
||||
case strings.HasPrefix(endpoint, endpointPrefixSdns):
|
||||
return ResolverTypeSDNS
|
||||
}
|
||||
host := endpoint
|
||||
if strings.Contains(endpoint, ":") {
|
||||
@@ -679,3 +982,39 @@ func ResolverTypeFromEndpoint(endpoint string) string {
|
||||
func pick(s []string) string {
|
||||
return s[rand.Intn(len(s))]
|
||||
}
|
||||
|
||||
// upstreamUID generates an unique identifier for an upstream.
|
||||
func upstreamUID() string {
|
||||
b := make([]byte, 4)
|
||||
for {
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...")
|
||||
continue
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the UpstreamConfig for logging.
|
||||
func (uc *UpstreamConfig) String() string {
|
||||
if uc == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
return fmt.Sprintf("{name: %q, type: %q, endpoint: %q, bootstrap_ip: %q, domain: %q, ip_stack: %q}",
|
||||
uc.Name, uc.Type, uc.Endpoint, uc.BootstrapIP, uc.Domain, uc.IPStack)
|
||||
}
|
||||
|
||||
// bootstrapIPsFromControlDDomain returns bootstrap IPs for ControlD domain.
|
||||
func bootstrapIPsFromControlDDomain(domain string) []string {
|
||||
switch {
|
||||
case dns.IsSubDomain(PremiumDnsDomain, domain):
|
||||
return []string{PremiumDNSBoostrapIP, PremiumDNSBoostrapIPv6}
|
||||
case dns.IsSubDomain(FreeDnsDomain, domain):
|
||||
return []string{FreeDNSBoostrapIP, FreeDNSBoostrapIPv6}
|
||||
case dns.IsSubDomain(premiumDnsDomainDev, domain):
|
||||
return []string{premiumDNSBoostrapIP, premiumDNSBoostrapIPv6}
|
||||
case dns.IsSubDomain(freeDnsDomainDev, domain):
|
||||
return []string{freeDNSBoostrapIP, freeDNSBoostrapIPv6}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,24 +8,49 @@ import (
|
||||
)
|
||||
|
||||
func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
|
||||
uc := &UpstreamConfig{
|
||||
Name: "test",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: "https://freedns.controld.com/p2",
|
||||
Timeout: 5000,
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
}{
|
||||
{
|
||||
name: "doh/doh3",
|
||||
uc: &UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: "https://freedns.controld.com/p2",
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "doq/dot",
|
||||
uc: &UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: ResolverTypeDOT,
|
||||
Endpoint: "p2.freedns.controld.com",
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
}
|
||||
uc.Init()
|
||||
uc.setupBootstrapIP(false)
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
t.Log(nameservers())
|
||||
t.Fatal("could not bootstrap ip without bootstrap DNS")
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Enable parallel tests once https://github.com/microsoft/wmi/issues/165 fixed.
|
||||
// t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.SetupBootstrapIP()
|
||||
if len(tc.uc.bootstrapIPs) == 0 {
|
||||
t.Log(defaultNameservers())
|
||||
t.Fatalf("could not bootstrap ip: %s", tc.uc.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
t.Log(uc)
|
||||
|
||||
}
|
||||
|
||||
func TestUpstreamConfig_Init(t *testing.T) {
|
||||
u1, _ := url.Parse("https://example.com")
|
||||
u2, _ := url.Parse("https://example.com?k=v")
|
||||
u3, _ := url.Parse("https://freedns.controld.com/p1")
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
@@ -178,6 +203,152 @@ func TestUpstreamConfig_Init(t *testing.T) {
|
||||
u: u2,
|
||||
},
|
||||
},
|
||||
{
|
||||
"h3",
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Type: "doh3",
|
||||
Endpoint: "h3://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Type: "doh3",
|
||||
Endpoint: "https://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "example.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
u: u1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"h3 without type",
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Endpoint: "h3://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Type: "doh3",
|
||||
Endpoint: "https://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "example.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
u: u1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns -> doh",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "sdns",
|
||||
Endpoint: "sdns://AgMAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29tAy9wMQ",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "doh",
|
||||
Endpoint: "https://freedns.controld.com/p1",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
u: u3,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns -> dot",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "sdns",
|
||||
Endpoint: "sdns://AwcAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29t",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "dot",
|
||||
Endpoint: "freedns.controld.com:843",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns -> doq",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "sdns",
|
||||
Endpoint: "sdns://BAcAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29t",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "doq",
|
||||
Endpoint: "freedns.controld.com:784",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns -> legacy",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "sdns",
|
||||
Endpoint: "sdns://AAcAAAAAAAAACjc2Ljc2LjIuMTE",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "legacy",
|
||||
Endpoint: "76.76.2.11:53",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "76.76.2.11",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns without type",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Endpoint: "sdns://AAcAAAAAAAAACjc2Ljc2LjIuMTE",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "legacy",
|
||||
Endpoint: "76.76.2.11:53",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "76.76.2.11",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -185,6 +356,7 @@ func TestUpstreamConfig_Init(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.uid = "" // we don't care about the uid.
|
||||
assert.Equal(t, tc.expected, tc.uc)
|
||||
})
|
||||
}
|
||||
@@ -278,6 +450,61 @@ func TestUpstreamConfig_UpstreamSendClientInfo(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamConfig_IsDiscoverable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
discoverable bool
|
||||
}{
|
||||
{
|
||||
"loopback",
|
||||
&UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"rfc1918",
|
||||
&UpstreamConfig{Endpoint: "192.168.1.1", Type: ResolverTypeLegacy},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"CGNAT",
|
||||
&UpstreamConfig{Endpoint: "100.66.67.68", Type: ResolverTypeLegacy},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Public IP",
|
||||
&UpstreamConfig{Endpoint: "8.8.8.8", Type: ResolverTypeLegacy},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"override discoverable",
|
||||
&UpstreamConfig{Endpoint: "127.0.0.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(false)},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"override non-public",
|
||||
&UpstreamConfig{Endpoint: "1.1.1.1", Type: ResolverTypeLegacy, Discoverable: ptrBool(true)},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"non-legacy upstream",
|
||||
&UpstreamConfig{Endpoint: "https://192.168.1.1/custom-doh", Type: ResolverTypeDOH},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
if got := tc.uc.IsDiscoverable(); got != tc.discoverable {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ptrBool(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build !qf
|
||||
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
@@ -8,14 +6,12 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
@@ -28,9 +24,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if ctrldnet.IPv6Available(ctx) {
|
||||
if HasIPv6() {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.http3RoundTripper6 = uc.http3RoundTripper4
|
||||
@@ -40,10 +34,9 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
rt := &http3.RoundTripper{}
|
||||
rt := &http3.Transport{}
|
||||
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
domain := addr
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
||||
if uc.BootstrapIP != "" {
|
||||
@@ -57,20 +50,23 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
|
||||
return quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
|
||||
}
|
||||
dialAddrs := make([]string, len(addrs))
|
||||
for i := range addrs {
|
||||
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
||||
}
|
||||
pd := &quicParallelDialer{}
|
||||
conn, err := pd.Dial(ctx, domain, dialAddrs, tlsCfg, cfg)
|
||||
conn, err := pd.Dial(ctx, dialAddrs, tlsCfg, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
|
||||
return conn, err
|
||||
}
|
||||
runtime.SetFinalizer(rt, func(rt *http3.Transport) {
|
||||
rt.CloseIdleConnections()
|
||||
})
|
||||
return rt
|
||||
}
|
||||
|
||||
@@ -100,20 +96,22 @@ func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
|
||||
// - quic dialer is different with net.Dialer
|
||||
// - simplification for quic free version
|
||||
type parallelDialerResult struct {
|
||||
conn quic.EarlyConnection
|
||||
conn *quic.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
type quicParallelDialer struct{}
|
||||
|
||||
// Dial performs parallel dialing to the given address list.
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
||||
if len(addrs) == 0 {
|
||||
return nil, errors.New("empty addresses")
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
ch := make(chan *parallelDialerResult, len(addrs))
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(addrs))
|
||||
@@ -122,11 +120,6 @@ func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []st
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
go func(addr string) {
|
||||
defer wg.Done()
|
||||
@@ -135,9 +128,22 @@ func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []st
|
||||
ch <- ¶llelDialerResult{conn: nil, err: err}
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
|
||||
ch <- ¶llelDialerResult{conn: conn, err: err}
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
ch <- ¶llelDialerResult{conn: nil, err: err}
|
||||
return
|
||||
}
|
||||
conn, err := quic.DialEarly(ctx, udpConn, remoteAddr, tlsCfg, cfg)
|
||||
select {
|
||||
case ch <- ¶llelDialerResult{conn: conn, err: err}:
|
||||
case <-done:
|
||||
if conn != nil {
|
||||
conn.CloseWithError(quic.ApplicationErrorCode(http3.ErrCodeNoError), "")
|
||||
}
|
||||
if udpConn != nil {
|
||||
udpConn.Close()
|
||||
}
|
||||
}
|
||||
}(addr)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build qf
|
||||
|
||||
package ctrld
|
||||
|
||||
import "net/http"
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3Transport() {}
|
||||
|
||||
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper { return nil }
|
||||
106
config_test.go
106
config_test.go
@@ -1,9 +1,11 @@
|
||||
package ctrld_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/spf13/viper"
|
||||
@@ -21,6 +23,8 @@ func TestLoadConfig(t *testing.T) {
|
||||
|
||||
assert.Equal(t, "info", cfg.Service.LogLevel)
|
||||
assert.Equal(t, "/path/to/log.log", cfg.Service.LogPath)
|
||||
assert.Equal(t, false, *cfg.Service.DnsWatchdogEnabled)
|
||||
assert.Equal(t, time.Duration(20*time.Second), *cfg.Service.DnsWatchdogInvterval)
|
||||
|
||||
assert.Len(t, cfg.Network, 2)
|
||||
assert.Contains(t, cfg.Network, "0")
|
||||
@@ -54,7 +58,12 @@ func TestLoadDefaultConfig(t *testing.T) {
|
||||
cfg := defaultConfig(t)
|
||||
validate := validator.New()
|
||||
require.NoError(t, ctrld.ValidateConfig(validate, cfg))
|
||||
assert.Len(t, cfg.Listener, 1)
|
||||
if assert.Len(t, cfg.Listener, 1) {
|
||||
l0 := cfg.Listener["0"]
|
||||
require.NotNil(t, l0.Policy)
|
||||
assert.Len(t, l0.Policy.Networks, 1)
|
||||
assert.Len(t, l0.Policy.Rules, 2)
|
||||
}
|
||||
assert.Len(t, cfg.Upstream, 2)
|
||||
}
|
||||
|
||||
@@ -96,6 +105,13 @@ func TestConfigValidation(t *testing.T) {
|
||||
{"lease file format required if lease file exist", configWithExistedLeaseFile(t), true},
|
||||
{"invalid lease file format", configWithInvalidLeaseFileFormat(t), true},
|
||||
{"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true},
|
||||
{"invalid client id pref", configWithInvalidClientIDPref(t), true},
|
||||
{"doh endpoint without scheme", dohUpstreamEndpointWithoutScheme(t), false},
|
||||
{"doh endpoint without type", dohUpstreamEndpointWithoutType(t), true},
|
||||
{"doh3 endpoint without type", doh3UpstreamEndpointWithoutType(t), false},
|
||||
{"sdns endpoint without type", sdnsUpstreamEndpointWithoutType(t), false},
|
||||
{"maximum number of flush cache domains", configWithInvalidFlushCacheDomain(t), true},
|
||||
{"kea dhcp4 format", configWithDhcp4KeaFormat(t), false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -115,6 +131,44 @@ func TestConfigValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidationDoNotChangeEndpoint(t *testing.T) {
|
||||
cfg := configWithInvalidDoHEndpoint(t)
|
||||
endpointMap := map[string]struct{}{}
|
||||
for _, uc := range cfg.Upstream {
|
||||
endpointMap[uc.Endpoint] = struct{}{}
|
||||
}
|
||||
validate := validator.New()
|
||||
_ = ctrld.ValidateConfig(validate, cfg)
|
||||
for _, uc := range cfg.Upstream {
|
||||
if _, ok := endpointMap[uc.Endpoint]; !ok {
|
||||
t.Fatalf("expected endpoint '%s' to exist", uc.Endpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigDiscoverOverride(t *testing.T) {
|
||||
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
ctrld.InitConfig(v, "test_config_discover_override")
|
||||
v.SetConfigType("toml")
|
||||
configStr := `
|
||||
[service]
|
||||
discover_arp = false
|
||||
discover_dhcp = false
|
||||
discover_hosts = false
|
||||
discover_mdns = false
|
||||
discover_ptr = false
|
||||
`
|
||||
require.NoError(t, v.ReadConfig(strings.NewReader(configStr)))
|
||||
cfg := ctrld.Config{}
|
||||
require.NoError(t, v.Unmarshal(&cfg))
|
||||
|
||||
require.False(t, *cfg.Service.DiscoverARP)
|
||||
require.False(t, *cfg.Service.DiscoverDHCP)
|
||||
require.False(t, *cfg.Service.DiscoverHosts)
|
||||
require.False(t, *cfg.Service.DiscoverMDNS)
|
||||
require.False(t, *cfg.Service.DiscoverPtr)
|
||||
}
|
||||
|
||||
func defaultConfig(t *testing.T) *ctrld.Config {
|
||||
v := viper.New()
|
||||
ctrld.InitConfig(v, "test_load_default_config")
|
||||
@@ -138,6 +192,33 @@ func invalidUpstreamType(t *testing.T) *ctrld.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func dohUpstreamEndpointWithoutScheme(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "freedns.controld.com/p1"
|
||||
return cfg
|
||||
}
|
||||
|
||||
func dohUpstreamEndpointWithoutType(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "https://freedns.controld.com/p1"
|
||||
cfg.Upstream["0"].Type = ""
|
||||
return cfg
|
||||
}
|
||||
|
||||
func doh3UpstreamEndpointWithoutType(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "h3://freedns.controld.com/p1"
|
||||
cfg.Upstream["0"].Type = ""
|
||||
return cfg
|
||||
}
|
||||
|
||||
func sdnsUpstreamEndpointWithoutType(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "sdns://AgMAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29tAy9wMQ"
|
||||
cfg.Upstream["0"].Type = ""
|
||||
return cfg
|
||||
}
|
||||
|
||||
func invalidUpstreamTimeout(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Timeout = -1
|
||||
@@ -227,9 +308,30 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithDhcp4KeaFormat(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Service.DHCPLeaseFileFormat = "kea-dhcp4"
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "1.1.1.1"
|
||||
cfg.Upstream["0"].Endpoint = "/1.1.1.1"
|
||||
cfg.Upstream["0"].Type = ctrld.ResolverTypeDOH
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithInvalidClientIDPref(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Service.ClientIDPref = "foo"
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithInvalidFlushCacheDomain(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Service.CacheFlushDomains = make([]string, 257)
|
||||
for i := range cfg.Service.CacheFlushDomains {
|
||||
cfg.Service.CacheFlushDomains[i] = fmt.Sprintf("%d.com", i)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
7
desktop_darwin.go
Normal file
7
desktop_darwin.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return true
|
||||
}
|
||||
9
desktop_others.go
Normal file
9
desktop_others.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return false
|
||||
}
|
||||
7
desktop_windows.go
Normal file
7
desktop_windows.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return isWindowsWorkStation()
|
||||
}
|
||||
30
dns.go
Normal file
30
dns.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// SetCacheReply extracts and stores the necessary data from the message for a cached answer.
|
||||
func SetCacheReply(answer, msg *dns.Msg, code int) {
|
||||
answer.SetRcode(msg, code)
|
||||
cCookie := getEdns0Cookie(msg.IsEdns0())
|
||||
sCookie := getEdns0Cookie(answer.IsEdns0())
|
||||
if cCookie != nil && sCookie != nil {
|
||||
// Client cookie is fixed size 8 bytes, Server cookie is variable size 8 -> 32 bytes.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc7873#section-4
|
||||
sCookie.Cookie = cCookie.Cookie[:16] + sCookie.Cookie[16:]
|
||||
}
|
||||
}
|
||||
|
||||
// getEdns0Cookie returns Edns0 cookie from *dns.OPT if present.
|
||||
func getEdns0Cookie(opt *dns.OPT) *dns.EDNS0_COOKIE {
|
||||
if opt == nil {
|
||||
return nil
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
if e, ok := o.(*dns.EDNS0_COOKIE); ok {
|
||||
return e
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -8,7 +8,7 @@
|
||||
# - Non-cgo ctrld binary.
|
||||
#
|
||||
# CI_COMMIT_TAG is used to set the version of ctrld binary.
|
||||
FROM golang:bullseye as base
|
||||
FROM golang:1.20-bullseye as base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
32
docker/Dockerfile.debug
Normal file
32
docker/Dockerfile.debug
Normal file
@@ -0,0 +1,32 @@
|
||||
# Using Debian bullseye for building regular image.
|
||||
# Using scratch image for minimal image size.
|
||||
# The final image has:
|
||||
#
|
||||
# - Timezone info file.
|
||||
# - CA certs file.
|
||||
# - /etc/{passwd,group} file.
|
||||
# - Non-cgo ctrld binary.
|
||||
#
|
||||
# CI_COMMIT_TAG is used to set the version of ctrld binary.
|
||||
FROM golang:bullseye as base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y upx-ucl
|
||||
|
||||
COPY . .
|
||||
|
||||
ARG tag=master
|
||||
ENV CI_COMMIT_TAG=$tag
|
||||
RUN CTRLD_NO_QF=yes CGO_ENABLED=0 ./scripts/build.sh
|
||||
|
||||
FROM alpine
|
||||
|
||||
COPY --from=base /usr/share/zoneinfo /usr/share/zoneinfo
|
||||
COPY --from=base /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||
COPY --from=base /etc/passwd /etc/passwd
|
||||
COPY --from=base /etc/group /etc/group
|
||||
|
||||
COPY --from=base /app/ctrld-linux-*-nocgo ctrld
|
||||
|
||||
ENTRYPOINT ["./ctrld", "run"]
|
||||
178
docs/config.md
178
docs/config.md
@@ -14,7 +14,7 @@ The config file allows for advanced configuration of the `ctrld` utility to cove
|
||||
|
||||
|
||||
## Config Location
|
||||
`ctrld` uses [TOML](toml_link) format for its configuration file. Default configuration file is `ctrld.toml` found in following order:
|
||||
`ctrld` uses [TOML][toml_link] format for its configuration file. Default configuration file is `ctrld.toml` found in following order:
|
||||
|
||||
- `/etc/controld` on *nix.
|
||||
- User's home directory on Windows.
|
||||
@@ -157,9 +157,15 @@ stale cached records (regardless of their TTLs) until upstream comes online.
|
||||
- Required: no
|
||||
- Default: false
|
||||
|
||||
### cache_flush_domains
|
||||
When `ctrld` receives query with domain name in `cache_flush_domains`, the local cache will be discarded
|
||||
before serving the query.
|
||||
|
||||
- Type: array of strings
|
||||
- Required: no
|
||||
|
||||
### max_concurrent_requests
|
||||
The number of concurrent requests that will be handled, must be a non-negative integer.
|
||||
Tweaking this value depends on the capacity of your system.
|
||||
|
||||
- Type: number
|
||||
- Required: no
|
||||
@@ -172,6 +178,8 @@ Perform LAN client discovery using mDNS. This will spawn a listener on port 5353
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_arp
|
||||
Perform LAN client discovery using ARP.
|
||||
|
||||
@@ -179,6 +187,8 @@ Perform LAN client discovery using ARP.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_dhcp
|
||||
Perform LAN client discovery using DHCP leases files. Common file locations are auto-discovered.
|
||||
|
||||
@@ -186,6 +196,8 @@ Perform LAN client discovery using DHCP leases files. Common file locations are
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_ptr
|
||||
Perform LAN client discovery using PTR queries.
|
||||
|
||||
@@ -193,6 +205,25 @@ Perform LAN client discovery using PTR queries.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_hosts
|
||||
Perform LAN client discovery using hosts file.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_refresh_interval
|
||||
Time in seconds between each discovery refresh loop to update new client information data.
|
||||
The default value is 120 seconds, lower this value to make the discovery process run more aggressively.
|
||||
|
||||
- Type: integer
|
||||
- Required: no
|
||||
- Default: 120
|
||||
|
||||
### dhcp_lease_file_path
|
||||
Relative or absolute path to a custom DHCP leases file location.
|
||||
|
||||
@@ -205,9 +236,65 @@ DHCP leases file format.
|
||||
|
||||
- Type: string
|
||||
- Required: no
|
||||
- Valid values: `dnsmasq`, `isc-dhcp`
|
||||
- Valid values: `dnsmasq`, `isc-dhcp`, `kea-dhcp4`
|
||||
- Default: ""
|
||||
|
||||
### client_id_preference
|
||||
Decide how the client ID is generated. By default client ID will use both MAC address and Hostname i.e. `hash(mac + host)`. To override this behavior, select one of the 2 allowed values to scope client ID to just MAC address OR Hostname.
|
||||
|
||||
- Type: string
|
||||
- Required: no
|
||||
- Valid values: `mac`, `host`
|
||||
- Default: ""
|
||||
|
||||
### metrics_query_stats
|
||||
If set to `true`, collect and export the query counters, and show them in `clients list` command.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default: false
|
||||
|
||||
### metrics_listener
|
||||
Specifying the `ip` and `port` of the Prometheus metrics server. The Prometheus metrics will be available on: `http://ip:port/metrics`. You can also append `/metrics/json` to get the same data in json format.
|
||||
|
||||
- Type: string
|
||||
- Required: no
|
||||
- Default: ""
|
||||
|
||||
### dns_watchdog_enabled
|
||||
Watches all physical interfaces for DNS changes and reverts them to ctrld's settings.The DNS watchdog process only runs on Windows and MacOS.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
### dns_watchdog_interval
|
||||
Time duration between each DNS watchdog iteration.
|
||||
|
||||
A duration string is a possibly signed sequence of decimal numbers, each with optional fraction and a unit suffix,
|
||||
such as "300ms", "-1.5h" or "2h45m". Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
|
||||
If the time duration is non-positive, default value will be used.
|
||||
|
||||
- Type: time duration string
|
||||
- Required: no
|
||||
- Default: 20s
|
||||
|
||||
### refetch_time
|
||||
Time in seconds between each iteration that reloads custom config from the API.
|
||||
|
||||
The value must be a positive number, any invalid value will be ignored and default value will be used.
|
||||
- Type: number
|
||||
- Required: no
|
||||
- Default: 3600
|
||||
|
||||
### leak_on_upstream_failure
|
||||
If a remote upstream fails to resolve a query or is unreachable, `ctrld` will forward the queries to the default DNS resolver on the network. If failures persist, `ctrld` will remove itself from all networking interfaces until connectivity is restored.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default: true on Windows, MacOS and non-router Linux.
|
||||
|
||||
## Upstream
|
||||
The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.
|
||||
|
||||
@@ -292,7 +379,7 @@ The protocol that `ctrld` will use to send DNS requests to upstream.
|
||||
|
||||
- Type: string
|
||||
- Required: yes
|
||||
- Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`, `os`
|
||||
- Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`
|
||||
|
||||
### ip_stack
|
||||
Specifying what kind of ip stack that `ctrld` will use to connect to upstream.
|
||||
@@ -312,6 +399,24 @@ If `ip_stack` is empty, or undefined:
|
||||
- Default value is `both` for non-Control D resolvers.
|
||||
- Default value is `split` for Control D resolvers.
|
||||
|
||||
### send_client_info
|
||||
Specifying whether to include client info when sending query to upstream. **This will only work with `doh` or `doh3` type upstreams.**
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default:
|
||||
- `true` for ControlD upstreams.
|
||||
- `false` for other upstreams.
|
||||
|
||||
### discoverable
|
||||
Specifying whether the upstream can be used for PTR discovery.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default:
|
||||
- `true` for loopback/RFC1918/CGNAT IP address.
|
||||
- `false` for public IP address.
|
||||
|
||||
## Network
|
||||
The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs.
|
||||
|
||||
@@ -369,7 +474,14 @@ Port number that the listener will listen on for incoming requests. If `port` is
|
||||
- Default: 0 or 53 or 5354 (depending on platform)
|
||||
|
||||
### restricted
|
||||
If set to `true` makes the listener `REFUSE` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`.
|
||||
If set to `true`, makes the listener `REFUSED` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`.
|
||||
|
||||
- Type: bool
|
||||
- Required: no
|
||||
- Default: false
|
||||
|
||||
### allow_wan_clients
|
||||
The listener will refuse DNS queries from WAN IPs using `REFUSED` RCODE by default. Set to `true` to disable this behavior, but this is not recommended.
|
||||
|
||||
- Type: bool
|
||||
- Required: no
|
||||
@@ -379,7 +491,15 @@ If set to `true` makes the listener `REFUSE` DNS queries from all source IP addr
|
||||
Allows `ctrld` to set policy rules to determine which upstreams the requests will be forwarded to.
|
||||
If no `policy` is defined or the requests do not match any policy rules, it will be forwarded to corresponding upstream of the listener. For example, the request to `listener.0` will be forwarded to `upstream.0`.
|
||||
|
||||
The policy `rule` syntax is a simple `toml` inline table with exactly one key/value pair per rule. `key` is either the `network` or a domain. Value is the list of the upstreams. For example:
|
||||
The policy `rule` syntax is a simple `toml` inline table with exactly one key/value pair per rule. `key` is either:
|
||||
|
||||
- Network.
|
||||
- Domain.
|
||||
- Mac Address.
|
||||
|
||||
Value is the list of the upstreams.
|
||||
|
||||
For example:
|
||||
|
||||
```toml
|
||||
[listener.0.policy]
|
||||
@@ -393,12 +513,18 @@ rules = [
|
||||
{"*.local" = ["upstream.1"]},
|
||||
{"test.com" = ["upstream.2", "upstream.1"]},
|
||||
]
|
||||
|
||||
macs = [
|
||||
{"14:54:4a:8e:08:2d" = ["upstream.3"]},
|
||||
]
|
||||
```
|
||||
|
||||
Above policy will:
|
||||
- Forward requests on `listener.0` from `network.0` to `upstream.1`.
|
||||
|
||||
- Forward requests on `listener.0` for `.local` suffixed domains to `upstream.1`.
|
||||
- Forward requests on `listener.0` for `test.com` to `upstream.2`. If timeout is reached, retry on `upstream.1`.
|
||||
- Forward requests on `listener.0` from client with Mac `14:54:4a:8e:08:2d` to `upstream.3`.
|
||||
- Forward requests on `listener.0` from `network.0` to `upstream.1`.
|
||||
- All other requests on `listener.0` that do not match above conditions will be forwarded to `upstream.0`.
|
||||
|
||||
An empty upstream would not route the request to any defined upstreams, and use the OS default resolver.
|
||||
@@ -412,6 +538,27 @@ rules = [
|
||||
]
|
||||
```
|
||||
|
||||
If there is no explicitly defined rules, LAN queries will be handled solely by the OS resolver.
|
||||
|
||||
These following domains are considered LAN queries:
|
||||
|
||||
- Queries does not have dot `.` in domain name, like `machine1`, `example`, ... (1)
|
||||
- Queries have domain ends with: `.domain`, `.lan`, `.local`. (2)
|
||||
- All `SRV` queries of LAN hostname (1) + (2).
|
||||
- `PTR` queries with private IPs.
|
||||
|
||||
---
|
||||
|
||||
Note that the order of matching preference:
|
||||
|
||||
```
|
||||
rules => macs => networks
|
||||
```
|
||||
|
||||
And within each policy, the rules are processed from top to bottom.
|
||||
|
||||
---
|
||||
|
||||
#### name
|
||||
`name` is the name for the policy.
|
||||
|
||||
@@ -433,10 +580,23 @@ rules = [
|
||||
- Required: no
|
||||
- Default: []
|
||||
|
||||
---
|
||||
|
||||
Note that the domain comparisons are done in case in-sensitive manner following [RFC 1034](https://datatracker.ietf.org/doc/html/rfc1034#section-3.1)
|
||||
|
||||
---
|
||||
|
||||
### macs:
|
||||
`macs` is the list of mac rules within the policy. Mac address value is case-insensitive.
|
||||
|
||||
- Type: array of macs
|
||||
- Required: no
|
||||
- Default: []
|
||||
|
||||
### failover_rcodes
|
||||
For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`.
|
||||
|
||||
- Type: array of string
|
||||
- Type: array of strings
|
||||
- Required: no
|
||||
- Default: []
|
||||
-
|
||||
@@ -453,7 +613,7 @@ networks = [
|
||||
|
||||
If `upstream.0` returns a NXDOMAIN response, the request will be forwarded to `upstream.1` instead of returning immediately to the client.
|
||||
|
||||
See all available DNS Rcodes value [here](rcode_link).
|
||||
See all available DNS Rcodes value [here][rcode_link].
|
||||
|
||||
[toml_link]: https://toml.io/en
|
||||
[rcode_link]: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-6
|
||||
|
||||
BIN
docs/ctrldsplash.png
Normal file
BIN
docs/ctrldsplash.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 458 KiB |
42
docs/known-issues.md
Normal file
42
docs/known-issues.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Known Issues
|
||||
|
||||
This document outlines known issues with ctrld and their current status, workarounds, and recommendations.
|
||||
|
||||
## macOS (Darwin) Issues
|
||||
|
||||
### Self-Upgrade Issue on Darwin 15.5
|
||||
|
||||
**Issue**: ctrld self-upgrading functionality may not work on macOS Darwin 15.5.
|
||||
|
||||
**Status**: Under investigation
|
||||
|
||||
**Description**: Users on macOS Darwin 15.5 may experience issues when ctrld attempts to perform automatic self-upgrades. The upgrade process would be triggered, but ctrld won't be upgraded.
|
||||
|
||||
**Workarounds**:
|
||||
1. **Recommended**: Upgrade your macOS system to Darwin 15.6 or later, which has been tested and verified to work correctly with ctrld self-upgrade functionality.
|
||||
2. **Alternative**: Run `ctrld upgrade prod` directly to manually upgrade ctrld to the latest version on Darwin 15.5.
|
||||
|
||||
**Affected Versions**: ctrld v1.4.2 and later on macOS Darwin 15.5
|
||||
|
||||
**Last Updated**: 05/09/2025
|
||||
|
||||
---
|
||||
|
||||
## Contributing to Known Issues
|
||||
|
||||
If you encounter an issue not listed here, please:
|
||||
|
||||
1. Check the [GitHub Issues](https://github.com/Control-D-Inc/ctrld/issues) to see if it's already reported
|
||||
2. If not reported, create a new issue with:
|
||||
- Detailed description of the problem
|
||||
- Steps to reproduce
|
||||
- Expected vs actual behavior
|
||||
- System information (OS, version, architecture)
|
||||
- ctrld version
|
||||
|
||||
## Issue Status Legend
|
||||
|
||||
- **Under investigation**: Issue is confirmed and being analyzed
|
||||
- **Workaround available**: Temporary solution exists while permanent fix is developed
|
||||
- **Fixed**: Issue has been resolved in a specific version
|
||||
- **Won't fix**: Issue is acknowledged but will not be addressed due to technical limitations or design decisions
|
||||
187
doh.go
187
doh.go
@@ -2,29 +2,75 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cuonglm/osinfo"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
dohMacHeader = "x-cd-mac"
|
||||
dohIPHeader = "x-cd-ip"
|
||||
dohHostHeader = "x-cd-host"
|
||||
headerApplicationDNS = "application/dns-message"
|
||||
dohMacHeader = "x-cd-mac"
|
||||
dohIPHeader = "x-cd-ip"
|
||||
dohHostHeader = "x-cd-host"
|
||||
dohOsHeader = "x-cd-os"
|
||||
dohClientIDPrefHeader = "x-cd-cpref"
|
||||
headerApplicationDNS = "application/dns-message"
|
||||
)
|
||||
|
||||
// EncodeOsNameMap provides mapping from OS name to a shorter string, used for encoding x-cd-os value.
|
||||
var EncodeOsNameMap = map[string]string{
|
||||
"windows": "1",
|
||||
"darwin": "2",
|
||||
"linux": "3",
|
||||
"freebsd": "4",
|
||||
}
|
||||
|
||||
// DecodeOsNameMap provides mapping from encoded OS name to real value, used for decoding x-cd-os value.
|
||||
var DecodeOsNameMap = map[string]string{}
|
||||
|
||||
// EncodeArchNameMap provides mapping from OS arch to a shorter string, used for encoding x-cd-os value.
|
||||
var EncodeArchNameMap = map[string]string{
|
||||
"amd64": "1",
|
||||
"arm64": "2",
|
||||
"arm": "3",
|
||||
"386": "4",
|
||||
"mips": "5",
|
||||
"mipsle": "6",
|
||||
"mips64": "7",
|
||||
}
|
||||
|
||||
// DecodeArchNameMap provides mapping from encoded OS arch to real value, used for decoding x-cd-os value.
|
||||
var DecodeArchNameMap = map[string]string{}
|
||||
|
||||
func init() {
|
||||
for k, v := range EncodeOsNameMap {
|
||||
DecodeOsNameMap[v] = k
|
||||
}
|
||||
for k, v := range EncodeArchNameMap {
|
||||
DecodeArchNameMap[v] = k
|
||||
}
|
||||
}
|
||||
|
||||
var dohOsHeaderValue = sync.OnceValue(func() string {
|
||||
oi := osinfo.New()
|
||||
return strings.Join([]string{EncodeOsNameMap[runtime.GOOS], EncodeArchNameMap[runtime.GOARCH], oi.Dist}, "-")
|
||||
})()
|
||||
|
||||
func newDohResolver(uc *UpstreamConfig) *dohResolver {
|
||||
r := &dohResolver{
|
||||
endpoint: uc.u,
|
||||
isDoH3: uc.Type == ResolverTypeDOH3,
|
||||
http3RoundTripper: uc.http3RoundTripper,
|
||||
sendClientInfo: uc.UpstreamSendClientInfo(),
|
||||
uc: uc,
|
||||
}
|
||||
return r
|
||||
@@ -35,9 +81,9 @@ type dohResolver struct {
|
||||
endpoint *url.URL
|
||||
isDoH3 bool
|
||||
http3RoundTripper http.RoundTripper
|
||||
sendClientInfo bool
|
||||
}
|
||||
|
||||
// Resolve performs DNS query with given DNS message using DOH protocol.
|
||||
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
@@ -54,7 +100,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create request: %w", err)
|
||||
}
|
||||
addHeader(ctx, req, r.sendClientInfo)
|
||||
addHeader(ctx, req, r.uc)
|
||||
dnsTyp := uint16(0)
|
||||
if len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
@@ -68,7 +114,14 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
c.Transport = transport
|
||||
}
|
||||
resp, err := c.Do(req)
|
||||
if err != nil && r.uc.FallbackToDirectIP() {
|
||||
retryCtx, cancel := r.uc.Context(context.WithoutCancel(ctx))
|
||||
defer cancel()
|
||||
Log(ctx, ProxyLogger.Load().Warn().Err(err), "retrying request after fallback to direct ip")
|
||||
resp, err = c.Do(req.Clone(retryCtx))
|
||||
}
|
||||
if err != nil {
|
||||
err = wrapUrlError(err)
|
||||
if r.isDoH3 {
|
||||
if closer, ok := c.Transport.(io.Closer); ok {
|
||||
closer.Close()
|
||||
@@ -94,21 +147,115 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
return answer, nil
|
||||
}
|
||||
|
||||
func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {
|
||||
req.Header.Set("Content-Type", headerApplicationDNS)
|
||||
req.Header.Set("Accept", headerApplicationDNS)
|
||||
if sendClientInfo {
|
||||
// addHeader adds necessary HTTP header to request based on upstream config.
|
||||
func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) {
|
||||
printed := false
|
||||
dohHeader := make(http.Header)
|
||||
if uc.UpstreamSendClientInfo() {
|
||||
if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil {
|
||||
if ci.Mac != "" {
|
||||
req.Header.Set(dohMacHeader, ci.Mac)
|
||||
}
|
||||
if ci.IP != "" {
|
||||
req.Header.Set(dohIPHeader, ci.IP)
|
||||
}
|
||||
if ci.Hostname != "" {
|
||||
req.Header.Set(dohHostHeader, ci.Hostname)
|
||||
printed = ci.Mac != "" || ci.IP != "" || ci.Hostname != ""
|
||||
switch {
|
||||
case uc.IsControlD():
|
||||
dohHeader = newControlDHeaders(ci)
|
||||
case uc.isNextDNS():
|
||||
dohHeader = newNextDNSHeaders(ci)
|
||||
}
|
||||
}
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug().Interface("header", req.Header), "sending request header")
|
||||
if printed {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "sending request header: %v", dohHeader)
|
||||
}
|
||||
dohHeader.Set("Content-Type", headerApplicationDNS)
|
||||
dohHeader.Set("Accept", headerApplicationDNS)
|
||||
req.Header = dohHeader
|
||||
}
|
||||
|
||||
// newControlDHeaders returns DoH/Doh3 HTTP request headers for ControlD upstream.
|
||||
func newControlDHeaders(ci *ClientInfo) http.Header {
|
||||
header := make(http.Header)
|
||||
if ci.Mac != "" {
|
||||
header.Set(dohMacHeader, ci.Mac)
|
||||
}
|
||||
if ci.IP != "" {
|
||||
header.Set(dohIPHeader, ci.IP)
|
||||
}
|
||||
if ci.Hostname != "" {
|
||||
header.Set(dohHostHeader, ci.Hostname)
|
||||
}
|
||||
if ci.Self {
|
||||
header.Set(dohOsHeader, dohOsHeaderValue)
|
||||
}
|
||||
switch ci.ClientIDPref {
|
||||
case "mac":
|
||||
header.Set(dohClientIDPrefHeader, "1")
|
||||
case "host":
|
||||
header.Set(dohClientIDPrefHeader, "2")
|
||||
}
|
||||
return header
|
||||
}
|
||||
|
||||
// newNextDNSHeaders returns DoH/Doh3 HTTP request headers for nextdns upstream.
|
||||
// https://github.com/nextdns/nextdns/blob/v1.41.0/resolver/doh.go#L100
|
||||
func newNextDNSHeaders(ci *ClientInfo) http.Header {
|
||||
header := make(http.Header)
|
||||
if ci.Mac != "" {
|
||||
// https: //github.com/nextdns/nextdns/blob/v1.41.0/run.go#L543
|
||||
header.Set("X-Device-Model", "mac:"+ci.Mac[:8])
|
||||
}
|
||||
if ci.IP != "" {
|
||||
header.Set("X-Device-Ip", ci.IP)
|
||||
}
|
||||
if ci.Hostname != "" {
|
||||
header.Set("X-Device-Name", ci.Hostname)
|
||||
}
|
||||
return header
|
||||
}
|
||||
|
||||
// wrapCertificateVerificationError wraps a certificate verification error with additional context about the certificate issuer.
|
||||
// It extracts information like the issuer, organization, and subject from the certificate for a more descriptive error output.
|
||||
// If no certificate-related information is available, it simply returns the original error unmodified.
|
||||
func wrapCertificateVerificationError(err error) error {
|
||||
var tlsErr *tls.CertificateVerificationError
|
||||
if errors.As(err, &tlsErr) {
|
||||
if len(tlsErr.UnverifiedCertificates) > 0 {
|
||||
cert := tlsErr.UnverifiedCertificates[0]
|
||||
// Extract a more user-friendly issuer name
|
||||
var issuer string
|
||||
var organization string
|
||||
if len(cert.Issuer.Organization) > 0 {
|
||||
organization = cert.Issuer.Organization[0]
|
||||
issuer = organization
|
||||
} else if cert.Issuer.CommonName != "" {
|
||||
issuer = cert.Issuer.CommonName
|
||||
} else {
|
||||
issuer = cert.Issuer.String()
|
||||
}
|
||||
|
||||
// Get the organization from the subject field as well
|
||||
if len(cert.Subject.Organization) > 0 {
|
||||
organization = cert.Subject.Organization[0]
|
||||
}
|
||||
|
||||
// Extract the subject information
|
||||
subjectCN := cert.Subject.CommonName
|
||||
if subjectCN == "" && len(cert.Subject.Organization) > 0 {
|
||||
subjectCN = cert.Subject.Organization[0]
|
||||
}
|
||||
return fmt.Errorf("%w: %s, %s, %s", tlsErr, subjectCN, organization, issuer)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// wrapUrlError inspects and wraps a URL error, focusing on certificate verification errors for detailed context.
|
||||
func wrapUrlError(err error) error {
|
||||
var urlErr *url.Error
|
||||
if errors.As(err, &urlErr) {
|
||||
var tlsErr *tls.CertificateVerificationError
|
||||
if errors.As(urlErr.Err, &tlsErr) {
|
||||
urlErr.Err = wrapCertificateVerificationError(tlsErr)
|
||||
return urlErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
266
doh_test.go
Normal file
266
doh_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
func Test_dohOsHeaderValue(t *testing.T) {
|
||||
val := dohOsHeaderValue
|
||||
if val == "" {
|
||||
t.Fatalf("empty %s", dohOsHeader)
|
||||
}
|
||||
t.Log(val)
|
||||
|
||||
encodedOs := EncodeOsNameMap[runtime.GOOS]
|
||||
if encodedOs == "" {
|
||||
t.Fatalf("missing encoding value for: %q", runtime.GOOS)
|
||||
}
|
||||
decodedOs := DecodeOsNameMap[encodedOs]
|
||||
if decodedOs == "" {
|
||||
t.Fatalf("missing decoding value for: %q", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapUrlError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "No wrapping for non-URL errors",
|
||||
err: errors.New("plain error"),
|
||||
wantErr: "plain error",
|
||||
},
|
||||
{
|
||||
name: "URL error without TLS error",
|
||||
err: &url.Error{
|
||||
Op: "Get",
|
||||
URL: "https://example.com",
|
||||
Err: errors.New("underlying error"),
|
||||
},
|
||||
wantErr: "Get \"https://example.com\": underlying error",
|
||||
},
|
||||
{
|
||||
name: "TLS error with missing unverified certificate data",
|
||||
err: &url.Error{
|
||||
Op: "Get",
|
||||
URL: "https://example.com",
|
||||
Err: &tls.CertificateVerificationError{
|
||||
UnverifiedCertificates: nil,
|
||||
Err: &x509.UnknownAuthorityError{},
|
||||
},
|
||||
},
|
||||
wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority`,
|
||||
},
|
||||
{
|
||||
name: "TLS error with valid certificate data",
|
||||
err: &url.Error{
|
||||
Op: "Get",
|
||||
URL: "https://example.com",
|
||||
Err: &tls.CertificateVerificationError{
|
||||
UnverifiedCertificates: []*x509.Certificate{
|
||||
{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "BadSubjectCN",
|
||||
Organization: []string{"BadSubjectOrg"},
|
||||
},
|
||||
Issuer: pkix.Name{
|
||||
CommonName: "BadIssuerCN",
|
||||
Organization: []string{"BadIssuerOrg"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Err: &x509.UnknownAuthorityError{},
|
||||
},
|
||||
},
|
||||
wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority: BadSubjectCN, BadSubjectOrg, BadIssuerOrg`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotErr := wrapUrlError(tt.err)
|
||||
if gotErr.Error() != tt.wantErr {
|
||||
t.Errorf("wrapCertificateVerificationError() error = %v, want %v", gotErr, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ClientCertificateVerificationError(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/dns-message")
|
||||
})
|
||||
tlsServer, cert := testTLSServer(t, handler)
|
||||
tlsServerUrl, err := url.Parse(tlsServer.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
quicServer := newTestQUICServer(t)
|
||||
http3Server := newTestHTTP3Server(t, handler)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
}{
|
||||
{
|
||||
"doh",
|
||||
&UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: tlsServer.URL,
|
||||
Timeout: 1000,
|
||||
},
|
||||
},
|
||||
{
|
||||
"doh3",
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Type: ResolverTypeDOH3,
|
||||
Endpoint: http3Server.addr,
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
"doq",
|
||||
&UpstreamConfig{
|
||||
Name: "doq",
|
||||
Type: ResolverTypeDOQ,
|
||||
Endpoint: quicServer.addr,
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
"dot",
|
||||
&UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: ResolverTypeDOT,
|
||||
Endpoint: net.JoinHostPort(tlsServerUrl.Hostname(), tlsServerUrl.Port()),
|
||||
Timeout: 1000,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.SetupBootstrapIP()
|
||||
r, err := NewResolver(tc.uc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("verify.controld.com.", dns.TypeA)
|
||||
msg.RecursionDesired = true
|
||||
_, err = r.Resolve(context.Background(), msg)
|
||||
// Verify the error contains the expected certificate information
|
||||
if err == nil {
|
||||
t.Fatal("expected certificate verification error, got nil")
|
||||
}
|
||||
|
||||
// You can check the error contains information about the test certificate
|
||||
if !strings.Contains(err.Error(), cert.Issuer.CommonName) {
|
||||
t.Fatalf("error should contain issuer information %q, got: %v", cert.Issuer.CommonName, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTLSServer creates an HTTPS test server with a self-signed certificate
|
||||
// returns the server and its certificate for verification testing
|
||||
// testTLSServer creates an HTTPS test server with a self-signed certificate
|
||||
func testTLSServer(t *testing.T, handler http.Handler) (*httptest.Server, *x509.Certificate) {
|
||||
t.Helper()
|
||||
|
||||
testCert := generateTestCertificate(t)
|
||||
|
||||
// Create a test server
|
||||
server := httptest.NewUnstartedServer(handler)
|
||||
server.TLS = &tls.Config{
|
||||
Certificates: []tls.Certificate{testCert.tlsCert},
|
||||
}
|
||||
server.StartTLS()
|
||||
|
||||
// Add cleanup
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
return server, testCert.cert
|
||||
}
|
||||
|
||||
// testHTTP3Server represents a structure for an HTTP/3 test server with its server instance, TLS certificate, and address.
|
||||
type testHTTP3Server struct {
|
||||
server *http3.Server
|
||||
cert *x509.Certificate
|
||||
addr string
|
||||
}
|
||||
|
||||
// newTestHTTP3Server creates and starts a test HTTP/3 server with a given handler and returns the server instance.
|
||||
func newTestHTTP3Server(t *testing.T, handler http.Handler) *testHTTP3Server {
|
||||
t.Helper()
|
||||
|
||||
testCert := generateTestCertificate(t)
|
||||
|
||||
// First create a listener to get the actual port
|
||||
udpAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create UDP listener: %v", err)
|
||||
}
|
||||
|
||||
// Get the actual address
|
||||
actualAddr := udpConn.LocalAddr().String()
|
||||
|
||||
// Create TLS config
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{testCert.tlsCert},
|
||||
NextProtos: []string{"h3"}, // HTTP/3 protocol identifier
|
||||
}
|
||||
|
||||
// Create HTTP/3 server
|
||||
server := &http3.Server{
|
||||
Handler: handler,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
|
||||
// Start the server with the existing UDP connection
|
||||
go func() {
|
||||
if err := server.Serve(udpConn); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Logf("HTTP/3 server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
h3Server := &testHTTP3Server{
|
||||
server: server,
|
||||
cert: testCert.cert,
|
||||
addr: actualAddr,
|
||||
}
|
||||
|
||||
// Add cleanup
|
||||
t.Cleanup(func() {
|
||||
server.Close()
|
||||
udpConn.Close()
|
||||
})
|
||||
|
||||
// Wait a bit for the server to be ready
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return h3Server
|
||||
}
|
||||
4
doq.go
4
doq.go
@@ -43,7 +43,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, wrapCertificateVerificationError(err)
|
||||
}
|
||||
return answer, nil
|
||||
}
|
||||
@@ -51,7 +51,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.
|
||||
}
|
||||
|
||||
func doResolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.Config) (*dns.Msg, error) {
|
||||
session, err := quic.DialAddr(endpoint, tlsConfig, nil)
|
||||
session, err := quic.DialAddr(ctx, endpoint, tlsConfig, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
//go:build qf
|
||||
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type doqResolver struct {
|
||||
uc *UpstreamConfig
|
||||
}
|
||||
|
||||
func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
return nil, errors.New("DoQ is not supported")
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user