mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Compare commits
379 Commits
v1.3.6
...
release-br
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e53fa4274 | ||
|
|
d0e66b83d0 | ||
|
|
34fef77ff7 | ||
|
|
7006e967e4 | ||
|
|
f9d026334a | ||
|
|
36d4192c05 | ||
|
|
90eddb8268 | ||
|
|
c13a3c3c17 | ||
|
|
d42a78cba9 | ||
|
|
92f32ba16e | ||
|
|
4c838f6a5e | ||
|
|
adc0e1a51e | ||
|
|
3afdaef6e6 | ||
|
|
ef7432df55 | ||
|
|
fb807d7c37 | ||
|
|
f7c124d99d | ||
|
|
ed826f7a95 | ||
|
|
56f8113bb0 | ||
|
|
a04babbbc3 | ||
|
|
59b98245d3 | ||
|
|
f6be1ab1fb | ||
|
|
54f58cc2e5 | ||
|
|
eb8c5bc3fa | ||
|
|
3bcad10f92 | ||
|
|
d87a0a69c8 | ||
|
|
b7202f8469 | ||
|
|
a084c87370 | ||
|
|
5d87bd07ca | ||
|
|
3412d1f8b9 | ||
|
|
a72ff1e769 | ||
|
|
2c98b2c545 | ||
|
|
4792183c0d | ||
|
|
d88c860cac | ||
|
|
8b605da861 | ||
|
|
7cda5d7646 | ||
|
|
0cd873a88f | ||
|
|
1ff5d1f05a | ||
|
|
954395fa29 | ||
|
|
ea98a59aba | ||
|
|
6971d392b7 | ||
|
|
a2f8313668 | ||
|
|
af05cb2d94 | ||
|
|
5f0b9a24b9 | ||
|
|
37523fdc45 | ||
|
|
ca505f1140 | ||
|
|
a22f0579d5 | ||
|
|
9f656269ac | ||
|
|
13de41d854 | ||
|
|
42ea5f7fed | ||
|
|
af9386568f | ||
|
|
13b15e642d | ||
|
|
d4df2e7f72 | ||
|
|
5b8ed3a72f | ||
|
|
59fe94112d | ||
|
|
0ab51cdad7 | ||
|
|
0a1d6fa4db | ||
|
|
6e10bba7fe | ||
|
|
fc8268b70a | ||
|
|
b5f101f667 | ||
|
|
69b192c6fa | ||
|
|
ec85b1621d | ||
|
|
ddbb0f0db4 | ||
|
|
2996a161cd | ||
|
|
35e2a20019 | ||
|
|
65a300a807 | ||
|
|
a67aea88be | ||
|
|
84d4491a18 | ||
|
|
05d183c94b | ||
|
|
41282d0f51 | ||
|
|
f7fb555c89 | ||
|
|
2e63624f6c | ||
|
|
b2a54db4b5 | ||
|
|
0ef02bc15e | ||
|
|
b18cd7ee83 | ||
|
|
a16b25ad1d | ||
|
|
59ece456b1 | ||
|
|
d5cb327620 | ||
|
|
7a2277bc18 | ||
|
|
c736f4c1e9 | ||
|
|
f0cb810dd6 | ||
|
|
64632fa640 | ||
|
|
b9b9cfcade | ||
|
|
fc527dbdfb | ||
|
|
5641aab5bd | ||
|
|
31517ce750 | ||
|
|
51e58b64a5 | ||
|
|
3ca559e5a4 | ||
|
|
0e3f764299 | ||
|
|
e52402eb0c | ||
|
|
2133f31854 | ||
|
|
a198a5cd65 | ||
|
|
eb2b231bd2 | ||
|
|
7af29cfbc0 | ||
|
|
ce1a165348 | ||
|
|
fd48e6d795 | ||
|
|
d71d1341b6 | ||
|
|
21855df4af | ||
|
|
66e2d3a40a | ||
|
|
26257cf24a | ||
|
|
36a7423634 | ||
|
|
e616091249 | ||
|
|
0948161529 | ||
|
|
ce29b5d217 | ||
|
|
de24fa293e | ||
|
|
6663925c4d | ||
|
|
b9ece6d7b9 | ||
|
|
c4efa1ab97 | ||
|
|
7cea5305e1 | ||
|
|
a20fbf95de | ||
|
|
628c4302aa | ||
|
|
8dc34f8bf5 | ||
|
|
b4faf82f76 | ||
|
|
a983dfaee2 | ||
|
|
62f73bcaa2 | ||
|
|
00e9d2bdd3 | ||
|
|
ace3b1e66e | ||
|
|
d1ea1ba08c | ||
|
|
c06c8aa859 | ||
|
|
0c2cc00c4f | ||
|
|
8d6ea91f35 | ||
|
|
7dfb77228f | ||
|
|
24910f1fa6 | ||
|
|
433a61d2ee | ||
|
|
3937e885f0 | ||
|
|
c651003cc4 | ||
|
|
b7ccfcb8b4 | ||
|
|
a9ed70200b | ||
|
|
c6365e6b74 | ||
|
|
dacc67e50f | ||
|
|
c60cf33af3 | ||
|
|
f27cbe3525 | ||
|
|
2de1b9929a | ||
|
|
8bf654aece | ||
|
|
84376ed719 | ||
|
|
7a136b8874 | ||
|
|
58c0e4f15a | ||
|
|
e0d35d8ba2 | ||
|
|
3b2e48761e | ||
|
|
b27064008e | ||
|
|
1ad63827e1 | ||
|
|
20e61550c2 | ||
|
|
020b814402 | ||
|
|
e578867118 | ||
|
|
46a1039f21 | ||
|
|
cc9e27de5f | ||
|
|
6ab3ab9faf | ||
|
|
e68bfa795a | ||
|
|
e60a92e93e | ||
|
|
62fe14f76b | ||
|
|
a0c5062e3a | ||
|
|
49eb152d02 | ||
|
|
b05056423a | ||
|
|
c7168739c7 | ||
|
|
5b1faf1ce3 | ||
|
|
513a6f9ec7 | ||
|
|
8db6fa4232 | ||
|
|
5036de2602 | ||
|
|
332f8ccc37 | ||
|
|
a582195cec | ||
|
|
9fe36ae984 | ||
|
|
54cb455522 | ||
|
|
8bd3b9e474 | ||
|
|
eff5ff580b | ||
|
|
c45f863ed8 | ||
|
|
414d4e356d | ||
|
|
ef697eb781 | ||
|
|
0631ffe831 | ||
|
|
7444d8517a | ||
|
|
3480043e40 | ||
|
|
619b6e7516 | ||
|
|
0123ca44fb | ||
|
|
7929aafe2a | ||
|
|
dc433f8dc9 | ||
|
|
8ccaeeab60 | ||
|
|
043a28eb33 | ||
|
|
c329402f5d | ||
|
|
23e6ad6e1f | ||
|
|
e6de78c1fa | ||
|
|
a670708f93 | ||
|
|
4ebe2fb5f4 | ||
|
|
3403b2039d | ||
|
|
e30ad31e0f | ||
|
|
81e0bad739 | ||
|
|
7d07d738dc | ||
|
|
0fae584e65 | ||
|
|
9e83085f2a | ||
|
|
41a00c68ac | ||
|
|
e3b99bf339 | ||
|
|
5007a87d3a | ||
|
|
60e65a37a6 | ||
|
|
d37d0e942c | ||
|
|
98042d8dbd | ||
|
|
af4b826b68 | ||
|
|
253a57ca01 | ||
|
|
caf98b4dfe | ||
|
|
398f71fd00 | ||
|
|
e1301ade96 | ||
|
|
7a23f82192 | ||
|
|
715bcc4aa1 | ||
|
|
0c74838740 | ||
|
|
4b05b6da7b | ||
|
|
375844ff1a | ||
|
|
1d207379cb | ||
|
|
fb49cb71e3 | ||
|
|
9618efbcde | ||
|
|
bb2210b06a | ||
|
|
917052723d | ||
|
|
fef85cadeb | ||
|
|
4a05fb6b28 | ||
|
|
6644ce53f2 | ||
|
|
72f0b89fdc | ||
|
|
41a97a6609 | ||
|
|
38064d6ad5 | ||
|
|
ae6945cedf | ||
|
|
3132d1b032 | ||
|
|
2716ae29bd | ||
|
|
1c50c2b6af | ||
|
|
cf6d16b439 | ||
|
|
60686f55ff | ||
|
|
47d7ace3a7 | ||
|
|
2d3779ec27 | ||
|
|
595071b608 | ||
|
|
57ef717080 | ||
|
|
eb27d1482b | ||
|
|
f57972ead7 | ||
|
|
168eaf538b | ||
|
|
1560455ca3 | ||
|
|
028475a193 | ||
|
|
f7a6dbe39b | ||
|
|
e573a490c9 | ||
|
|
ce3281e70d | ||
|
|
0fbfd160c9 | ||
|
|
20759017e6 | ||
|
|
69e0aab73e | ||
|
|
7ed6733fb7 | ||
|
|
9718ab8579 | ||
|
|
2687a4a018 | ||
|
|
2d9c60dea1 | ||
|
|
841be069b7 | ||
|
|
7833132917 | ||
|
|
e9e63b0983 | ||
|
|
4df470b869 | ||
|
|
89600f6091 | ||
|
|
f986a575e8 | ||
|
|
9c2fe8d21f | ||
|
|
8bcbb9249e | ||
|
|
a95d50c0af | ||
|
|
5db7d3577b | ||
|
|
c53a0ca1c4 | ||
|
|
6fd3d1788a | ||
|
|
087c1975e5 | ||
|
|
3713cbecc3 | ||
|
|
6046789fa4 | ||
|
|
3ea69b180c | ||
|
|
db6e977e3a | ||
|
|
a5c776c846 | ||
|
|
5a566c028a | ||
|
|
ff43c74d8d | ||
|
|
3c7255569c | ||
|
|
4a92ec4d2d | ||
|
|
9bbccb4082 | ||
|
|
4f62314646 | ||
|
|
cb49d0d947 | ||
|
|
89f7874fc6 | ||
|
|
221917e80b | ||
|
|
37d41bd215 | ||
|
|
8a96b8bec4 | ||
|
|
02ee113b95 | ||
|
|
f71dd78915 | ||
|
|
cd5619a05b | ||
|
|
a63a30c76b | ||
|
|
f5ba8be182 | ||
|
|
a9f76322bd | ||
|
|
ed39269c80 | ||
|
|
09426dcd36 | ||
|
|
17941882a9 | ||
|
|
70ab8032a0 | ||
|
|
8360bdc50a | ||
|
|
6837176ec7 | ||
|
|
5e9b4244e7 | ||
|
|
9b6a308958 | ||
|
|
71e327653a | ||
|
|
a56711796f | ||
|
|
09495f2a7c | ||
|
|
484643e114 | ||
|
|
da91aabc35 | ||
|
|
c654398981 | ||
|
|
47a90ec2a1 | ||
|
|
2875e22d0b | ||
|
|
c5d14e0075 | ||
|
|
84e06c363c | ||
|
|
5b9ccc5065 | ||
|
|
6ca1a7ccc7 | ||
|
|
9d666be5d4 | ||
|
|
65de7edcde | ||
|
|
0cdff0d368 | ||
|
|
f87220a908 | ||
|
|
30ea0c6499 | ||
|
|
9501e35c60 | ||
|
|
5ac9d17bdf | ||
|
|
cb14992ddc | ||
|
|
e88372fc8c | ||
|
|
b320662d67 | ||
|
|
ce353cd4d9 | ||
|
|
4befd33866 | ||
|
|
4b36e3ac44 | ||
|
|
f507bc8f9e | ||
|
|
14c88f4a6d | ||
|
|
3e388c2857 | ||
|
|
cfe1209d61 | ||
|
|
5a88a7c22c | ||
|
|
8c661c4401 | ||
|
|
e6f256d640 | ||
|
|
ede354166b | ||
|
|
282a8ce78e | ||
|
|
08fe04f1ee | ||
|
|
082d14a9ba | ||
|
|
617674ce43 | ||
|
|
7088df58dd | ||
|
|
9cbd9b3e44 | ||
|
|
e6586fd360 | ||
|
|
33a6db2599 | ||
|
|
70b0c4f7b9 | ||
|
|
5af3ec4f7b | ||
|
|
79476add12 | ||
|
|
1634a06330 | ||
|
|
a007394f60 | ||
|
|
62a0ba8731 | ||
|
|
e8d3ed1acd | ||
|
|
8b98faa441 | ||
|
|
30320ec9c7 | ||
|
|
5f4a399850 | ||
|
|
82e0d4b0c4 | ||
|
|
95a9df826d | ||
|
|
3b71d26cf3 | ||
|
|
c233ad9b1b | ||
|
|
12d6484b1c | ||
|
|
bc7b1cc6d8 | ||
|
|
ec684348ed | ||
|
|
18a19a3aa2 | ||
|
|
905f2d08c5 | ||
|
|
04947b4d87 | ||
|
|
72bf80533e | ||
|
|
9ddedf926e | ||
|
|
139dd62ff3 | ||
|
|
50ef00526e | ||
|
|
80cf79b9cb | ||
|
|
e6ad39b070 | ||
|
|
56f9c72569 | ||
|
|
dc48c908b8 | ||
|
|
9b0f0e792a | ||
|
|
b3eebb19b6 | ||
|
|
c24589a5be | ||
|
|
1e1c5a4dc8 | ||
|
|
339023421a | ||
|
|
a00d2a431a | ||
|
|
5aca118dbb | ||
|
|
411f7434f4 | ||
|
|
34801382f5 | ||
|
|
b9f2259ae4 | ||
|
|
19020a96bf | ||
|
|
96085147ff | ||
|
|
f3dd344026 | ||
|
|
486096416f | ||
|
|
5710f2e984 | ||
|
|
09936f1f07 | ||
|
|
0d6ca57536 | ||
|
|
3ddcb84db8 | ||
|
|
1012bf063f | ||
|
|
b8155e6182 | ||
|
|
9a34df61bb | ||
|
|
fbb879edf9 | ||
|
|
ac97c88876 | ||
|
|
a1fda2c0de | ||
|
|
f499770d45 | ||
|
|
4769da4ef4 | ||
|
|
c2556a8e39 | ||
|
|
29bf329f6a | ||
|
|
1dee4305bc |
8
.github/workflows/ci.yml
vendored
8
.github/workflows/ci.yml
vendored
@@ -9,18 +9,18 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ["windows-latest", "ubuntu-latest", "macOS-latest"]
|
||||
go: ["1.21.x"]
|
||||
go: ["1.24.x"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- uses: WillAbides/setup-go-faster@v1.8.0
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
- run: "go test -race ./..."
|
||||
- uses: dominikh/staticcheck-action@v1.2.0
|
||||
- uses: dominikh/staticcheck-action@v1.4.0
|
||||
with:
|
||||
version: "2023.1.2"
|
||||
version: "2025.1.1"
|
||||
install-go: false
|
||||
cache-key: ${{ matrix.go }}
|
||||
|
||||
195
README.md
195
README.md
@@ -4,12 +4,13 @@
|
||||
[](https://pkg.go.dev/github.com/Control-D-Inc/ctrld)
|
||||
[](https://goreportcard.com/report/github.com/Control-D-Inc/ctrld)
|
||||
|
||||

|
||||
|
||||
A highly configurable DNS forwarding proxy with support for:
|
||||
- Multiple listeners for incoming queries
|
||||
- Multiple upstreams with fallbacks
|
||||
- Multiple network policy driven DNS query steering
|
||||
- Multiple network policy driven DNS query steering (via network cidr, MAC address or FQDN)
|
||||
- Policy driven domain based "split horizon" DNS with wildcard support
|
||||
- Integrations with common router vendors and firmware
|
||||
- LAN client discovery via DHCP, mDNS, ARP, NDP, hosts file parsing
|
||||
- Prometheus metrics exporter
|
||||
|
||||
@@ -24,61 +25,58 @@ All DNS protocols are supported, including:
|
||||
- `DNS-over-QUIC`
|
||||
|
||||
# Use Cases
|
||||
1. Use secure DNS protocols on networks and devices that don't natively support them (legacy routers, legacy OSes, TVs, smart toasters).
|
||||
1. Use secure DNS protocols on networks and devices that don't natively support them (legacy OSes, TVs, smart toasters).
|
||||
2. Create source IP based DNS routing policies with variable secure DNS upstreams. Subnet 1 (admin) uses upstream resolver A, while Subnet 2 (employee) uses upstream resolver B.
|
||||
3. Create destination IP based DNS routing policies with variable secure DNS upstreams. Listener 1 uses upstream resolver C, while Listener 2 uses upstream resolver D.
|
||||
4. Create domain level "split horizon" DNS routing policies to send internal domains (*.company.int) to a local DNS server, while everything else goes to another upstream.
|
||||
5. Deploy on a router and create LAN client specific DNS routing policies from a web GUI (When using ControlD.com).
|
||||
|
||||
|
||||
## OS Support
|
||||
- Windows (386, amd64, arm)
|
||||
- Mac (amd64, arm64)
|
||||
- Windows Desktop (386, amd64, arm)
|
||||
- MacOS (amd64, arm64)
|
||||
- Linux (386, amd64, arm, mips)
|
||||
- FreeBSD
|
||||
- Common routers (See Router Mode below)
|
||||
- FreeBSD (386, amd64, arm)
|
||||
|
||||
# Install
|
||||
There are several ways to download and install `ctrld.
|
||||
There are several ways to download and install `ctrld`.
|
||||
|
||||
## Quick Install
|
||||
The simplest way to download and install `ctrld` is to use the following installer command on any UNIX-like platform:
|
||||
|
||||
```shell
|
||||
sh -c 'sh -c "$(curl -sL https://api.controld.com/dl)"'
|
||||
sh -c 'sh -c "$(curl -sL https://api.controld.com/dl?version=2)"'
|
||||
```
|
||||
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative cmd:
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative PowerShell:
|
||||
```shell
|
||||
powershell -Command "(Invoke-WebRequest -Uri 'https://api.controld.com/dl' -UseBasicParsing).Content | Set-Content 'ctrld_install.bat'" && ctrld_install.bat
|
||||
(Invoke-WebRequest -Uri 'https://api.controld.com/dl/ps1?version=2' -UseBasicParsing).Content | Set-Content "$env:TEMPctrld_install.ps1"; Invoke-Expression "& '$env:TEMPctrld_install.ps1'"
|
||||
```
|
||||
|
||||
Or you can pull and run a Docker container from [Docker Hub](https://hub.docker.com/r/controldns/ctrld)
|
||||
```
|
||||
$ docker pull controldns/ctrld
|
||||
```shell
|
||||
docker run -d --name=ctrld -p 127.0.0.1:53:53/tcp -p 127.0.0.1:53:53/udp controldns/ctrld:latest
|
||||
```
|
||||
|
||||
## Download Manually
|
||||
Alternatively, if you know what you're doing you can download pre-compiled binaries from the [Releases](https://github.com/Control-D-Inc/ctrld/releases) section for the appropriate platform.
|
||||
|
||||
## Build
|
||||
Lastly, you can build `ctrld` from source which requires `go1.21+`:
|
||||
Lastly, you can build `ctrld` from source which requires `go1.23+`:
|
||||
|
||||
```shell
|
||||
$ go build ./cmd/ctrld
|
||||
go build ./cmd/ctrld
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
$ go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
$ docker build -t controldns/ctrld . -f docker/Dockerfile
|
||||
$ docker run -d --name=ctrld -p 53:53/tcp -p 53:53/udp controldns/ctrld --cd=RESOLVER_ID_GOES_HERE -vv
|
||||
```shell
|
||||
docker build -t controldns/ctrld . -f docker/Dockerfile
|
||||
```
|
||||
|
||||
|
||||
@@ -99,99 +97,118 @@ Usage:
|
||||
|
||||
Available Commands:
|
||||
run Run the DNS proxy server
|
||||
service Manage ctrld service
|
||||
start Quick start service and configure DNS on interface
|
||||
stop Quick stop service and remove DNS from interface
|
||||
restart Restart the ctrld service
|
||||
reload Reload the ctrld service
|
||||
status Show status of the ctrld service
|
||||
uninstall Stop and uninstall the ctrld service
|
||||
service Manage ctrld service
|
||||
clients Manage clients
|
||||
upgrade Upgrading ctrld to latest version
|
||||
log Manage runtime debug logs
|
||||
|
||||
Flags:
|
||||
-h, --help help for ctrld
|
||||
-s, --silent do not write any log output
|
||||
-v, --verbose count verbose log output, "-v" basic logging, "-vv" debug level logging
|
||||
-v, --verbose count verbose log output, "-v" basic logging, "-vv" debug logging
|
||||
--version version for ctrld
|
||||
|
||||
Use "ctrld [command] --help" for more information about a command.
|
||||
```
|
||||
|
||||
## Basic Run Mode
|
||||
To start the server with default configuration, simply run: `./ctrld run`. This will create a generic `ctrld.toml` file in the **working directory** and start the application in foreground.
|
||||
1. Start the server
|
||||
```
|
||||
$ sudo ./ctrld run
|
||||
This is the most basic way to run `ctrld`, in foreground mode. Unless you already have a config file, a default one will be generated.
|
||||
|
||||
### Command
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe run
|
||||
```
|
||||
|
||||
2. Run a test query using a DNS client, for example, `dig`:
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld run
|
||||
```
|
||||
|
||||
You can then run a test query using a DNS client, for example, `dig`:
|
||||
```
|
||||
$ dig verify.controld.com @127.0.0.1 +short
|
||||
api.controld.com.
|
||||
147.185.34.1
|
||||
```
|
||||
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file and go nuts with it. To enforce a new config, restart the server.
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file that was generated. To enforce a new config, restart the server.
|
||||
|
||||
## Service Mode
|
||||
To run the application in service mode on any Windows, MacOS, Linux distibution or supported router, simply run: `./ctrld start` as system/root user. This will create a generic `ctrld.toml` file in the **user home** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and configure the listener on the default network interface. Service will start on OS boot.
|
||||
This mode will run the application as a background system service on any Windows, MacOS, Linux or FreeBSD distribution. This will create a generic `ctrld.toml` file in the **C:\ControlD** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and **configure the listener on all physical network interface**. Service will start on OS boot.
|
||||
|
||||
When Control D upstreams are used, `ctrld` willl [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them.
|
||||
### Command
|
||||
|
||||
In order to stop the service, and restore your DNS to original state, simply run `./ctrld stop`. If you wish to stop and uninstall the service permanently, run `./ctrld uninstall`.
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start
|
||||
```
|
||||
|
||||
Linux or Macos
|
||||
```
|
||||
sudo ctrld start
|
||||
```
|
||||
|
||||
### Supported Routers
|
||||
You can run `ctrld` on any supported router, which will function similarly to the Service Mode mentioned above. The list of supported routers and firmware includes:
|
||||
- Asus Merlin
|
||||
- DD-WRT
|
||||
- Firewalla
|
||||
- FreshTomato
|
||||
- GL.iNet
|
||||
- OpenWRT
|
||||
- pfSense / OPNsense
|
||||
- Synology
|
||||
- Ubiquiti (UniFi, EdgeOS)
|
||||
If `ctrld` is not in your system path (you installed it manually), you will need to run the above commands from the directory where you installed `ctrld`.
|
||||
|
||||
`ctrld` will attempt to interface with dnsmasq whenever possible and set itself as the upstream, while running on port 5354. On FreeBSD based OSes, `ctrld` will terminate dnsmasq and unbound in order to be able to listen on port 53 directly.
|
||||
In order to stop the service, and restore your DNS to original state, simply run `ctrld stop`. If you wish to stop and uninstall the service permanently, run `ctrld uninstall`.
|
||||
|
||||
## Unmanaged Service Mode
|
||||
This mode functions similarly to the "Service Mode" above except it will simply start a system service and the config defined listeners, but **will not make any changes to any network interfaces**. You can then set the `ctrld` listener(s) IP on the desired network interfaces manually.
|
||||
|
||||
### Control D Auto Configuration
|
||||
Application can be started with a specific resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `run` (foreground) or `start` (service) modes.
|
||||
### Command
|
||||
|
||||
The following command will start the application in foreground mode, using the free "p2" resolver, which blocks Ads & Trackers.
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe service start
|
||||
```
|
||||
|
||||
```shell
|
||||
./ctrld run --cd p2
|
||||
```
|
||||
|
||||
Alternatively, you can use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Device.
|
||||
|
||||
```shell
|
||||
./ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above commands (in service mode only), the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `uninstall` sub-commands
|
||||
- Your default network interface will be updated to use the listener started by the service
|
||||
- All OS DNS queries will be sent to the listener
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld service start
|
||||
```
|
||||
|
||||
# Configuration
|
||||
See [Configuration Docs](docs/config.md).
|
||||
`ctrld` can be configured in variety of different ways, which include: API, local config file or via cli launch args.
|
||||
|
||||
## Example
|
||||
- Start `listener.0` on 127.0.0.1:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `upstream.0` via DoH protocol
|
||||
## API Based Auto Configuration
|
||||
Application can be started with a specific Control D resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `start` (service) mode. This mode is used when the 1 liner installer command from the Control D onboarding guide is executed.
|
||||
|
||||
### Default Config
|
||||
The following command will use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Endpoint.
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start --cd abcd1234
|
||||
```
|
||||
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above command, the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `uninstall` sub-commands
|
||||
- All physical network interface will be updated to use the listener started by the service
|
||||
- All DNS queries will be sent to the listener
|
||||
|
||||
## Manual Configuration
|
||||
`ctrld` is entirely config driven and can be configured in many different ways, please see [Configuration Docs](docs/config.md).
|
||||
|
||||
### Example
|
||||
```toml
|
||||
[listener]
|
||||
|
||||
[listener.0]
|
||||
ip = ""
|
||||
port = 0
|
||||
restricted = false
|
||||
ip = '0.0.0.0'
|
||||
port = 53
|
||||
|
||||
[network]
|
||||
|
||||
@@ -199,10 +216,6 @@ See [Configuration Docs](docs/config.md).
|
||||
cidrs = ["0.0.0.0/0"]
|
||||
name = "Network 0"
|
||||
|
||||
[service]
|
||||
log_level = "info"
|
||||
log_path = ""
|
||||
|
||||
[upstream]
|
||||
|
||||
[upstream.0]
|
||||
@@ -211,28 +224,26 @@ See [Configuration Docs](docs/config.md).
|
||||
name = "Control D - Anti-Malware"
|
||||
timeout = 5000
|
||||
type = "doh"
|
||||
|
||||
[upstream.1]
|
||||
bootstrap_ip = "76.76.2.11"
|
||||
endpoint = "p2.freedns.controld.com"
|
||||
name = "Control D - No Ads"
|
||||
timeout = 3000
|
||||
type = "doq"
|
||||
|
||||
```
|
||||
|
||||
`ctrld` will pick a working config for `listener.0` then writing the default config to disk for the first run.
|
||||
The above basic config will:
|
||||
- Start listener on 0.0.0.0:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `https://freedns.controld.com/p1` using DoH protocol
|
||||
|
||||
## Advanced Configuration
|
||||
The above is the most basic example, which will work out of the box. If you're looking to do advanced configurations using policies, see [Configuration Docs](docs/config.md) for complete documentation of the config file.
|
||||
## CLI Args
|
||||
If you're unable to use a config file, `ctrld` can be be supplied with basic configuration via launch arguments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
|
||||
You can also supply configuration via launch argeuments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
### Example
|
||||
```
|
||||
ctrld run --listen=127.0.0.1:53 --primary_upstream=https://freedns.controld.com/p2 --secondary_upstream=10.0.10.1:53 --domains=*.company.int,very-secure.local --log /path/to/log.log
|
||||
```
|
||||
|
||||
The above will start a foreground process and:
|
||||
- Listen on `127.0.0.1:53` for DNS queries
|
||||
- Forward all queries to `https://freedns.controld.com/p2` using DoH protocol, while...
|
||||
- Excluding `*.company.int` and `very-secure.local` matching queries, that are forwarded to `10.0.10.1:53`
|
||||
- Write a debug log to `/path/to/log.log`
|
||||
|
||||
## Contributing
|
||||
See [Contribution Guideline](./docs/contributing.md)
|
||||
|
||||
## Roadmap
|
||||
The following functionality is on the roadmap and will be available in future releases.
|
||||
- DNS intercept mode
|
||||
- Direct listener mode
|
||||
- Support for more routers (let us know which ones)
|
||||
|
||||
10
cmd/cli/ad_others.go
Normal file
10
cmd/cli/ad_others.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// addExtraSplitDnsRule adds split DNS rule if present.
|
||||
func addExtraSplitDnsRule(_ *ctrld.Config) bool { return false }
|
||||
73
cmd/cli/ad_windows.go
Normal file
73
cmd/cli/ad_windows.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
hh "github.com/microsoft/wmi/pkg/hardware/host"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// addExtraSplitDnsRule adds split DNS rule for domain if it's part of active directory.
|
||||
func addExtraSplitDnsRule(cfg *ctrld.Config) bool {
|
||||
domain, err := getActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Msgf("Unable to get active directory domain: %v", err)
|
||||
return false
|
||||
}
|
||||
if domain == "" {
|
||||
mainLog.Load().Debug().Msg("No active directory domain found")
|
||||
return false
|
||||
}
|
||||
// Network rules are lowercase during toml config marshaling,
|
||||
// lowercase the domain here too for consistency.
|
||||
domain = strings.ToLower(domain)
|
||||
domainRuleAdded := addSplitDnsRule(cfg, domain)
|
||||
wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, "."))
|
||||
return domainRuleAdded || wildcardDomainRuleRuleAdded
|
||||
}
|
||||
|
||||
// addSplitDnsRule adds split-rule for given domain if there's no existed rule.
|
||||
// The return value indicates whether the split-rule was added or not.
|
||||
func addSplitDnsRule(cfg *ctrld.Config, domain string) bool {
|
||||
for n, lc := range cfg.Listener {
|
||||
if lc.Policy == nil {
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{}
|
||||
}
|
||||
for _, rule := range lc.Policy.Rules {
|
||||
if _, ok := rule[domain]; ok {
|
||||
mainLog.Load().Debug().Msgf("Split-rule %q already existed for listener.%s", domain, n)
|
||||
return false
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Adding split-rule %q for listener.%s", domain, n)
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// getActiveDirectoryDomain returns AD domain name of this computer.
|
||||
func getActiveDirectoryDomain() (string, error) {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
whost := host.NewWmiLocalHost()
|
||||
cs, err := hh.GetComputerSystem(whost)
|
||||
if cs != nil {
|
||||
defer cs.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pod, err := cs.GetPropertyPartOfDomain()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if pod {
|
||||
return cs.GetPropertyDomain()
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
71
cmd/cli/ad_windows_test.go
Normal file
71
cmd/cli/ad_windows_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_getActiveDirectoryDomain(t *testing.T) {
|
||||
start := time.Now()
|
||||
domain, err := getActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
domainPowershell, err := getActiveDirectoryDomainPowershell()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
if domain != domainPowershell {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", domainPowershell, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func getActiveDirectoryDomainPowershell() (string, error) {
|
||||
cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }"
|
||||
output, err := powershell(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output))
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
func Test_addSplitDnsRule(t *testing.T) {
|
||||
newCfg := func(domains ...string) *ctrld.Config {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
lc := cfg.Listener["0"]
|
||||
for _, domain := range domains {
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *ctrld.Config
|
||||
domain string
|
||||
added bool
|
||||
}{
|
||||
{"added", newCfg(), "example.com", true},
|
||||
{"TLD existed", newCfg("example.com"), "*.example.com", true},
|
||||
{"wildcard existed", newCfg("*.example.com"), "example.com", true},
|
||||
{"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false},
|
||||
{"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
added := addSplitDnsRule(tc.cfg, tc.domain)
|
||||
assert.Equal(t, tc.added, added)
|
||||
})
|
||||
}
|
||||
}
|
||||
5
cmd/cli/cgo.go
Normal file
5
cmd/cli/cgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = true
|
||||
1906
cmd/cli/cli.go
1906
cmd/cli/cli.go
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,7 @@ func Test_writeConfigFile(t *testing.T) {
|
||||
_, err := os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
assert.NoError(t, writeConfigFile())
|
||||
assert.NoError(t, writeConfigFile(&cfg))
|
||||
|
||||
_, err = os.Stat(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
141
cmd/cli/commands_clients.go
Normal file
141
cmd/cli/commands_clients.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
|
||||
)
|
||||
|
||||
// ClientsCommand handles clients-related operations
|
||||
type ClientsCommand struct {
|
||||
controlClient *controlClient
|
||||
}
|
||||
|
||||
// NewClientsCommand creates a new clients command handler
|
||||
func NewClientsCommand() (*ClientsCommand, error) {
|
||||
dir, err := socketDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find ctrld home dir: %w", err)
|
||||
}
|
||||
|
||||
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
|
||||
return &ClientsCommand{
|
||||
controlClient: cc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListClients lists all connected clients
|
||||
func (cc *ClientsCommand) ListClients(cmd *cobra.Command, args []string) error {
|
||||
// Check service status first
|
||||
sc := NewServiceCommand()
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
mainLog.Load().Warn().Msg("Service is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := cc.controlClient.post(listClientsPath, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get clients: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var clients []*clientinfo.Client
|
||||
if err := json.NewDecoder(resp.Body).Decode(&clients); err != nil {
|
||||
return fmt.Errorf("failed to decode clients result: %w", err)
|
||||
}
|
||||
|
||||
map2Slice := func(m map[string]struct{}) []string {
|
||||
s := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
if k == "" { // skip empty source from output.
|
||||
continue
|
||||
}
|
||||
s = append(s, k)
|
||||
}
|
||||
sort.Strings(s)
|
||||
return s
|
||||
}
|
||||
|
||||
// If metrics is enabled, server set this for all clients, so we can check only the first one.
|
||||
// Ideally, we may have a field in response to indicate that query count should be shown, but
|
||||
// it would break earlier version of ctrld, which only look list of clients in response.
|
||||
withQueryCount := len(clients) > 0 && clients[0].IncludeQueryCount
|
||||
data := make([][]string, len(clients))
|
||||
for i, c := range clients {
|
||||
row := []string{
|
||||
c.IP.String(),
|
||||
c.Hostname,
|
||||
c.Mac,
|
||||
strings.Join(map2Slice(c.Source), ","),
|
||||
}
|
||||
if withQueryCount {
|
||||
row = append(row, strconv.FormatInt(c.QueryCount, 10))
|
||||
}
|
||||
data[i] = row
|
||||
}
|
||||
|
||||
table := tablewriter.NewWriter(os.Stdout)
|
||||
headers := []string{"IP", "Hostname", "Mac", "Discovered"}
|
||||
if withQueryCount {
|
||||
headers = append(headers, "Queries")
|
||||
}
|
||||
table.SetHeader(headers)
|
||||
table.SetAutoFormatHeaders(false)
|
||||
table.AppendBulk(data)
|
||||
table.Render()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitClientsCmd creates the clients command with proper logic
|
||||
func InitClientsCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
listClientsCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List clients that ctrld discovered",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cc, err := NewClientsCommand()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cc.ListClients(cmd, args)
|
||||
},
|
||||
}
|
||||
|
||||
clientsCmd := &cobra.Command{
|
||||
Use: "clients",
|
||||
Short: "Manage clients",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
listClientsCmd.Use,
|
||||
},
|
||||
}
|
||||
clientsCmd.AddCommand(listClientsCmd)
|
||||
rootCmd.AddCommand(clientsCmd)
|
||||
|
||||
return clientsCmd
|
||||
}
|
||||
87
cmd/cli/commands_interfaces.go
Normal file
87
cmd/cli/commands_interfaces.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// InterfacesCommand handles interfaces-related operations
|
||||
type InterfacesCommand struct{}
|
||||
|
||||
// NewInterfacesCommand creates a new interfaces command handler
|
||||
func NewInterfacesCommand() (*InterfacesCommand, error) {
|
||||
return &InterfacesCommand{}, nil
|
||||
}
|
||||
|
||||
// ListInterfaces lists all network interfaces
|
||||
func (ic *InterfacesCommand) ListInterfaces(cmd *cobra.Command, args []string) error {
|
||||
withEachPhysicalInterfaces("", "Interface list", func(i *net.Interface) error {
|
||||
fmt.Printf("Index : %d\n", i.Index)
|
||||
fmt.Printf("Name : %s\n", i.Name)
|
||||
var status string
|
||||
if i.Flags&net.FlagUp != 0 {
|
||||
status = "Up"
|
||||
} else {
|
||||
status = "Down"
|
||||
}
|
||||
fmt.Printf("Status: %s\n", status)
|
||||
addrs, _ := i.Addrs()
|
||||
for i, ipaddr := range addrs {
|
||||
if i == 0 {
|
||||
fmt.Printf("Addrs : %v\n", ipaddr)
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" %v\n", ipaddr)
|
||||
}
|
||||
nss, err := currentStaticDNS(i)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("Failed to get DNS")
|
||||
}
|
||||
if len(nss) == 0 {
|
||||
nss = currentDNS(i)
|
||||
}
|
||||
for i, dns := range nss {
|
||||
if i == 0 {
|
||||
fmt.Printf("DNS : %s\n", dns)
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" : %s\n", dns)
|
||||
}
|
||||
println()
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitInterfacesCmd creates the interfaces command with proper logic
|
||||
func InitInterfacesCmd(_ *cobra.Command) *cobra.Command {
|
||||
listInterfacesCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List network interfaces",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ic, err := NewInterfacesCommand()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ic.ListInterfaces(cmd, args)
|
||||
},
|
||||
}
|
||||
|
||||
interfacesCmd := &cobra.Command{
|
||||
Use: "interfaces",
|
||||
Short: "Manage network interfaces",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
listInterfacesCmd.Use,
|
||||
},
|
||||
}
|
||||
interfacesCmd.AddCommand(listInterfacesCmd)
|
||||
|
||||
return interfacesCmd
|
||||
}
|
||||
175
cmd/cli/commands_log.go
Normal file
175
cmd/cli/commands_log.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/docker/go-units"
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// LogCommand handles log-related operations
|
||||
type LogCommand struct {
|
||||
controlClient *controlClient
|
||||
}
|
||||
|
||||
// NewLogCommand creates a new log command handler
|
||||
func NewLogCommand() (*LogCommand, error) {
|
||||
dir, err := socketDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find ctrld home dir: %w", err)
|
||||
}
|
||||
|
||||
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
|
||||
return &LogCommand{
|
||||
controlClient: cc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// warnRuntimeLoggingNotEnabled logs a warning about runtime logging not being enabled
|
||||
func (lc *LogCommand) warnRuntimeLoggingNotEnabled() {
|
||||
mainLog.Load().Warn().Msg("Runtime debug logging is not enabled")
|
||||
mainLog.Load().Warn().Msg(`ctrld may be running without "--cd" flag or logging is already enabled`)
|
||||
}
|
||||
|
||||
// SendLogs sends runtime debug logs to ControlD
|
||||
func (lc *LogCommand) SendLogs(cmd *cobra.Command, args []string) error {
|
||||
sc := NewServiceCommand()
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
mainLog.Load().Warn().Msg("Service is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := lc.controlClient.post(sendLogsPath, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send logs: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusServiceUnavailable:
|
||||
mainLog.Load().Warn().Msg("Runtime logs could only be sent once per minute")
|
||||
return nil
|
||||
case http.StatusMovedPermanently:
|
||||
lc.warnRuntimeLoggingNotEnabled()
|
||||
return nil
|
||||
}
|
||||
|
||||
var logs logSentResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil {
|
||||
return fmt.Errorf("failed to decode sent logs result: %w", err)
|
||||
}
|
||||
|
||||
if logs.Error != "" {
|
||||
return fmt.Errorf("failed to send logs: %s", logs.Error)
|
||||
}
|
||||
|
||||
mainLog.Load().Notice().Msgf("Sent %s of runtime logs", units.BytesSize(float64(logs.Size)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ViewLogs views current runtime debug logs
|
||||
func (lc *LogCommand) ViewLogs(cmd *cobra.Command, args []string) error {
|
||||
sc := NewServiceCommand()
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
mainLog.Load().Warn().Msg("Service is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := lc.controlClient.post(viewLogsPath, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get logs: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusMovedPermanently:
|
||||
lc.warnRuntimeLoggingNotEnabled()
|
||||
return nil
|
||||
case http.StatusBadRequest:
|
||||
mainLog.Load().Warn().Msg("Runtime debug logs are not available")
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to read response body")
|
||||
}
|
||||
mainLog.Load().Warn().Msgf("ctrld process response:\n\n%s\n", string(buf))
|
||||
return nil
|
||||
case http.StatusOK:
|
||||
}
|
||||
|
||||
var logs logViewResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil {
|
||||
return fmt.Errorf("failed to decode view logs result: %w", err)
|
||||
}
|
||||
|
||||
fmt.Print(logs.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitLogCmd creates the log command with proper logic
|
||||
func InitLogCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
lc, err := NewLogCommand()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create log command: %v", err))
|
||||
}
|
||||
|
||||
logSendCmd := &cobra.Command{
|
||||
Use: "send",
|
||||
Short: "Send runtime debug logs to ControlD",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: lc.SendLogs,
|
||||
}
|
||||
|
||||
logViewCmd := &cobra.Command{
|
||||
Use: "view",
|
||||
Short: "View current runtime debug logs",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: lc.ViewLogs,
|
||||
}
|
||||
|
||||
logCmd := &cobra.Command{
|
||||
Use: "log",
|
||||
Short: "Manage runtime debug logs",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
logSendCmd.Use,
|
||||
logViewCmd.Use,
|
||||
},
|
||||
}
|
||||
logCmd.AddCommand(logSendCmd)
|
||||
logCmd.AddCommand(logViewCmd)
|
||||
rootCmd.AddCommand(logCmd)
|
||||
|
||||
return logCmd
|
||||
}
|
||||
59
cmd/cli/commands_run.go
Normal file
59
cmd/cli/commands_run.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// RunCommand handles run-related operations
|
||||
type RunCommand struct {
|
||||
// Add any dependencies here if needed in the future
|
||||
}
|
||||
|
||||
// NewRunCommand creates a new run command handler
|
||||
func NewRunCommand() *RunCommand {
|
||||
return &RunCommand{}
|
||||
}
|
||||
|
||||
// Run implements the logic for the run command
|
||||
func (rc *RunCommand) Run(cmd *cobra.Command, args []string) {
|
||||
RunCobraCommand(cmd)
|
||||
}
|
||||
|
||||
// InitRunCmd creates the run command with proper logic
|
||||
func InitRunCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
rc := NewRunCommand()
|
||||
|
||||
runCmd := &cobra.Command{
|
||||
Use: "run",
|
||||
Short: "Run the DNS proxy server",
|
||||
Args: cobra.NoArgs,
|
||||
Run: rc.Run,
|
||||
}
|
||||
runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon")
|
||||
runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid")
|
||||
runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token")
|
||||
runCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API")
|
||||
runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = runCmd.Flags().MarkHidden("dev")
|
||||
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
|
||||
_ = runCmd.Flags().MarkHidden("homedir")
|
||||
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
_ = runCmd.Flags().MarkHidden("iface")
|
||||
runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
|
||||
runCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener")
|
||||
|
||||
runCmd.FParseErrWhitelist = cobra.FParseErrWhitelist{UnknownFlags: true}
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
return runCmd
|
||||
}
|
||||
256
cmd/cli/commands_service.go
Normal file
256
cmd/cli/commands_service.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// filterEmptyStrings removes empty strings from a slice
|
||||
// This is used to clean up command line arguments and configuration values
|
||||
func filterEmptyStrings(slice []string) []string {
|
||||
var result []string
|
||||
for _, s := range slice {
|
||||
if s != "" {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ServiceCommand handles service-related operations
|
||||
// This encapsulates all service management functionality for the CLI
|
||||
type ServiceCommand struct {
|
||||
serviceManager *ServiceManager
|
||||
}
|
||||
|
||||
// initializeServiceManager creates a service manager with default configuration
|
||||
// This sets up the basic service infrastructure needed for all service operations
|
||||
func (sc *ServiceCommand) initializeServiceManager() (service.Service, *prog, error) {
|
||||
svcConfig := sc.createServiceConfig()
|
||||
return sc.initializeServiceManagerWithServiceConfig(svcConfig)
|
||||
}
|
||||
|
||||
// initializeServiceManagerWithServiceConfig creates a service manager with the given configuration
|
||||
// This allows for custom service configuration while maintaining the same initialization pattern
|
||||
func (sc *ServiceCommand) initializeServiceManagerWithServiceConfig(svcConfig *service.Config) (service.Service, *prog, error) {
|
||||
p := &prog{}
|
||||
|
||||
s, err := sc.newService(p, svcConfig)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
sc.serviceManager = &ServiceManager{prog: p, svc: s}
|
||||
return s, p, nil
|
||||
}
|
||||
|
||||
// newService creates a new service instance using the provided program and configuration.
|
||||
// This abstracts the service creation process for different operating systems
|
||||
func (sc *ServiceCommand) newService(p *prog, svcConfig *service.Config) (service.Service, error) {
|
||||
s, err := newService(p, svcConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewServiceCommand creates a new service command handler
|
||||
// This provides a clean factory method for creating service command instances
|
||||
func NewServiceCommand() *ServiceCommand {
|
||||
return &ServiceCommand{}
|
||||
}
|
||||
|
||||
// createServiceConfig creates a properly initialized service configuration
|
||||
// This ensures consistent service naming and description across all platforms
|
||||
func (sc *ServiceCommand) createServiceConfig() *service.Config {
|
||||
return &service.Config{
|
||||
Name: ctrldServiceName,
|
||||
DisplayName: "Control-D Helper Service",
|
||||
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
|
||||
Option: service.KeyValue{},
|
||||
}
|
||||
}
|
||||
|
||||
// InitServiceCmd creates the service command with proper logic and aliases
|
||||
// This sets up all service-related subcommands with appropriate permissions and flags
|
||||
func InitServiceCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
// Create service command handlers
|
||||
sc := NewServiceCommand()
|
||||
|
||||
startCmd, startCmdAlias := createStartCommands(sc)
|
||||
rootCmd.AddCommand(startCmdAlias)
|
||||
|
||||
// Stop command
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop",
|
||||
Short: "Stop the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Stop,
|
||||
}
|
||||
stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
stopCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`)
|
||||
_ = stopCmd.Flags().MarkHidden("pin")
|
||||
|
||||
// Restart command
|
||||
restartCmd := &cobra.Command{
|
||||
Use: "restart",
|
||||
Short: "Restart the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Restart,
|
||||
}
|
||||
|
||||
// Status command
|
||||
statusCmd := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: sc.Status,
|
||||
}
|
||||
if runtime.GOOS == "darwin" {
|
||||
// On darwin, running status command without privileges may return wrong information.
|
||||
statusCmd.PreRun = func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
}
|
||||
}
|
||||
|
||||
// Reload command
|
||||
reloadCmd := &cobra.Command{
|
||||
Use: "reload",
|
||||
Short: "Reload the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Reload,
|
||||
}
|
||||
|
||||
// Uninstall command
|
||||
uninstallCmd := &cobra.Command{
|
||||
Use: "uninstall",
|
||||
Short: "Stop and uninstall the ctrld service",
|
||||
Long: `Stop and uninstall the ctrld service.
|
||||
|
||||
NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Uninstall,
|
||||
}
|
||||
uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`)
|
||||
_ = uninstallCmd.Flags().MarkHidden("pin")
|
||||
uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`)
|
||||
|
||||
// Interfaces command - use the existing InitInterfacesCmd function
|
||||
interfacesCmd := InitInterfacesCmd(rootCmd)
|
||||
|
||||
stopCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "stop",
|
||||
Short: "Quick stop service and remove DNS from interface",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
return stopCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags())
|
||||
rootCmd.AddCommand(stopCmdAlias)
|
||||
|
||||
// Create aliases for other service commands
|
||||
restartCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "restart",
|
||||
Short: "Restart the ctrld service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return restartCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
rootCmd.AddCommand(restartCmdAlias)
|
||||
|
||||
reloadCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "reload",
|
||||
Short: "Reload the ctrld service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return reloadCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
rootCmd.AddCommand(reloadCmdAlias)
|
||||
|
||||
statusCmdAlias := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: statusCmd.RunE,
|
||||
}
|
||||
rootCmd.AddCommand(statusCmdAlias)
|
||||
|
||||
uninstallCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "uninstall",
|
||||
Short: "Stop and uninstall the ctrld service",
|
||||
Long: `Stop and uninstall the ctrld service.
|
||||
|
||||
NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
return uninstallCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags())
|
||||
rootCmd.AddCommand(uninstallCmdAlias)
|
||||
|
||||
// Create service command
|
||||
serviceCmd := &cobra.Command{
|
||||
Use: "service",
|
||||
Short: "Manage ctrld service",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
}
|
||||
serviceCmd.ValidArgs = make([]string, 7)
|
||||
serviceCmd.ValidArgs[0] = startCmd.Use
|
||||
serviceCmd.ValidArgs[1] = stopCmd.Use
|
||||
serviceCmd.ValidArgs[2] = restartCmd.Use
|
||||
serviceCmd.ValidArgs[3] = reloadCmd.Use
|
||||
serviceCmd.ValidArgs[4] = statusCmd.Use
|
||||
serviceCmd.ValidArgs[5] = uninstallCmd.Use
|
||||
serviceCmd.ValidArgs[6] = interfacesCmd.Use
|
||||
|
||||
serviceCmd.AddCommand(startCmd)
|
||||
serviceCmd.AddCommand(stopCmd)
|
||||
serviceCmd.AddCommand(restartCmd)
|
||||
serviceCmd.AddCommand(reloadCmd)
|
||||
serviceCmd.AddCommand(statusCmd)
|
||||
serviceCmd.AddCommand(uninstallCmd)
|
||||
serviceCmd.AddCommand(interfacesCmd)
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
|
||||
return serviceCmd
|
||||
}
|
||||
41
cmd/cli/commands_service_manager.go
Normal file
41
cmd/cli/commands_service_manager.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
// dialSocketControlServerTimeout is the default timeout to wait when ping control server.
|
||||
const dialSocketControlServerTimeout = 30 * time.Second
|
||||
|
||||
// ServiceManager handles service operations
|
||||
type ServiceManager struct {
|
||||
prog *prog
|
||||
svc service.Service
|
||||
}
|
||||
|
||||
// NewServiceManager creates a new service manager
|
||||
func NewServiceManager() (*ServiceManager, error) {
|
||||
p := &prog{}
|
||||
|
||||
// Create a proper service configuration
|
||||
svcConfig := &service.Config{
|
||||
Name: ctrldServiceName,
|
||||
DisplayName: "Control-D Helper Service",
|
||||
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
|
||||
Option: service.KeyValue{},
|
||||
}
|
||||
|
||||
s, err := newService(p, svcConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
return &ServiceManager{prog: p, svc: s}, nil
|
||||
}
|
||||
|
||||
// Status returns the current service status
|
||||
func (sm *ServiceManager) Status() (service.Status, error) {
|
||||
return sm.svc.Status()
|
||||
}
|
||||
67
cmd/cli/commands_service_reload.go
Normal file
67
cmd/cli/commands_service_reload.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Reload implements the logic from cmdReload.Run
|
||||
func (sc *ServiceCommand) Reload(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service reload command started")
|
||||
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
logger.Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
logger.Warn().Msg("Service is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
dir, err := socketDir()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("Failed to find ctrld home dir")
|
||||
}
|
||||
|
||||
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
|
||||
resp, err := cc.post(reloadPath, nil)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("Failed to send reload signal to ctrld")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
logger.Notice().Msg("Service reloaded")
|
||||
case http.StatusCreated:
|
||||
logger.Warn().Msg("Service was reloaded, but new config requires service restart.")
|
||||
logger.Warn().Msg("Restarting service")
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
logger.Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
return sc.Restart(cmd, args)
|
||||
default:
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("Could not read response from control server")
|
||||
}
|
||||
logger.Error().Err(err).Msgf("Failed to reload ctrld: %s", string(buf))
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service reload command completed")
|
||||
return nil
|
||||
}
|
||||
111
cmd/cli/commands_service_restart.go
Normal file
111
cmd/cli/commands_service_restart.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Restart implements the logic from cmdRestart.Run
|
||||
func (sc *ServiceCommand) Restart(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service restart command started")
|
||||
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
cdUID = curCdUID()
|
||||
cdMode := cdUID != ""
|
||||
|
||||
s, p, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
logger.Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
|
||||
p.cfg = &cfg
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
|
||||
var validateConfigErr error
|
||||
if cdMode {
|
||||
logger.Debug().Msg("Validating ControlD remote config")
|
||||
validateConfigErr = doValidateCdRemoteConfig(cdUID, false)
|
||||
if validateConfigErr != nil {
|
||||
logger.Warn().Err(validateConfigErr).Msg("ControlD remote config validation failed")
|
||||
}
|
||||
}
|
||||
|
||||
if ir := runningIface(s); ir != nil {
|
||||
iface = ir.Name
|
||||
}
|
||||
|
||||
doRestart := func() bool {
|
||||
logger.Debug().Msg("Starting service restart sequence")
|
||||
|
||||
tasks := []task{
|
||||
{s.Stop, true, "Stop"},
|
||||
{func() error {
|
||||
// restore static DNS settings or DHCP
|
||||
p.resetDNS(false, true)
|
||||
return nil
|
||||
}, false, "Cleanup"},
|
||||
{func() error {
|
||||
time.Sleep(time.Second * 1)
|
||||
return nil
|
||||
}, false, "Waiting for service to stop"},
|
||||
}
|
||||
if !doTasks(tasks) {
|
||||
logger.Error().Msg("Service stop tasks failed")
|
||||
return false
|
||||
}
|
||||
tasks = []task{
|
||||
{s.Start, true, "Start"},
|
||||
}
|
||||
success := doTasks(tasks)
|
||||
if success {
|
||||
logger.Debug().Msg("Service restart sequence completed successfully")
|
||||
} else {
|
||||
logger.Error().Msg("Service restart sequence failed")
|
||||
}
|
||||
return success
|
||||
}
|
||||
|
||||
if doRestart() {
|
||||
if dir, err := socketDir(); err == nil {
|
||||
timeout := dialSocketControlServerTimeout
|
||||
if validateConfigErr != nil {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
if cc := newSocketControlClientWithTimeout(context.TODO(), s, dir, timeout); cc != nil {
|
||||
_, _ = cc.post(ifacePath, nil)
|
||||
logger.Debug().Msg("Control server ping successful")
|
||||
} else {
|
||||
logger.Warn().Err(err).Msg("Service was restarted, but ctrld process may not be ready yet")
|
||||
}
|
||||
} else {
|
||||
logger.Warn().Err(err).Msg("Service was restarted, but could not ping the control server")
|
||||
}
|
||||
logger.Notice().Msg("Service restarted")
|
||||
} else {
|
||||
logger.Error().Msg("Service restart failed")
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service restart command completed")
|
||||
return nil
|
||||
}
|
||||
387
cmd/cli/commands_service_start.go
Normal file
387
cmd/cli/commands_service_start.go
Normal file
@@ -0,0 +1,387 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// Start implements the logic from cmdStart.Run
|
||||
func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service start command started")
|
||||
|
||||
checkStrFlagEmpty(cmd, cdUidFlagName)
|
||||
checkStrFlagEmpty(cmd, cdOrgFlagName)
|
||||
validateCdAndNextDNSFlags()
|
||||
|
||||
svcConfig := sc.createServiceConfig()
|
||||
osArgs := os.Args[2:]
|
||||
osArgs = filterEmptyStrings(osArgs)
|
||||
if os.Args[1] == "service" {
|
||||
osArgs = os.Args[3:]
|
||||
}
|
||||
setDependencies(svcConfig)
|
||||
svcConfig.Arguments = append([]string{"run"}, osArgs...)
|
||||
|
||||
// Initialize service manager with proper configuration
|
||||
s, p, err := sc.initializeServiceManagerWithServiceConfig(svcConfig)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
p.cfg = &cfg
|
||||
p.preRun()
|
||||
|
||||
status, err := s.Status()
|
||||
isCtrldRunning := status == service.StatusRunning
|
||||
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
|
||||
|
||||
// Get current running iface, if any.
|
||||
var currentIface *ifaceResponse
|
||||
|
||||
// If pin code was set, do not allow running start command.
|
||||
if isCtrldRunning {
|
||||
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
||||
logger.Error().Msg("Deactivation pin check failed")
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
currentIface = runningIface(s)
|
||||
logger.Debug().Msgf("Current interface on start: %v", currentIface)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
reportSetDnsOk := func(sockDir string) {
|
||||
if cc := newSocketControlClient(ctx, s, sockDir); cc != nil {
|
||||
if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK {
|
||||
if iface == "auto" {
|
||||
iface = defaultIfaceName()
|
||||
}
|
||||
res := &ifaceResponse{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(res); err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get iface info")
|
||||
return
|
||||
}
|
||||
if res.OK {
|
||||
name := res.Name
|
||||
if iff, err := net.InterfaceByName(name); err == nil {
|
||||
_, _ = patchNetIfaceName(iff)
|
||||
name = iff.Name
|
||||
}
|
||||
logger := logger.With().Str("iface", name)
|
||||
logger.Debug().Msg("Setting DNS successfully")
|
||||
if res.All {
|
||||
// Log that DNS is set for other interfaces.
|
||||
withEachPhysicalInterfaces(
|
||||
name,
|
||||
"set DNS",
|
||||
func(i *net.Interface) error { return nil },
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No config path, generating config in HOME directory.
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
|
||||
logServerStarted := make(chan struct{})
|
||||
stopLogCh := make(chan struct{})
|
||||
ud, err := userHomeDir()
|
||||
sockDir := ud
|
||||
var logServerSocketPath string
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get user home directory")
|
||||
logger.Warn().Msg("Log server did not start")
|
||||
close(logServerStarted)
|
||||
} else {
|
||||
setWorkingDirectory(svcConfig, ud)
|
||||
if configPath == "" && writeDefaultConfig {
|
||||
defaultConfigFile = filepath.Join(ud, defaultConfigFile)
|
||||
}
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--homedir="+ud)
|
||||
if d, err := socketDir(); err == nil {
|
||||
sockDir = d
|
||||
}
|
||||
logServerSocketPath = filepath.Join(sockDir, ctrldLogUnixSock)
|
||||
_ = os.Remove(logServerSocketPath)
|
||||
go func() {
|
||||
defer os.Remove(logServerSocketPath)
|
||||
|
||||
close(logServerStarted)
|
||||
|
||||
// Start HTTP log server
|
||||
if err := httpLogServer(logServerSocketPath, stopLogCh); err != nil && err != http.ErrServerClosed {
|
||||
logger.Warn().Err(err).Msg("Failed to serve HTTP log server")
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
<-logServerStarted
|
||||
|
||||
if !startOnly {
|
||||
startOnly = len(osArgs) == 0
|
||||
}
|
||||
// If user run "ctrld start" and ctrld is already installed, starting existing service.
|
||||
if startOnly && isCtrldInstalled {
|
||||
tryReadingConfigWithNotice(false, true)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
logger.Fatal().Msgf("Failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
// if already running, dont restart
|
||||
if isCtrldRunning {
|
||||
logger.Notice().Msg("Service is already running")
|
||||
return nil
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
tasks := []task{
|
||||
{func() error {
|
||||
// Save current DNS so we can restore later.
|
||||
withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error {
|
||||
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}, false, "Save current DNS"},
|
||||
{func() error {
|
||||
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
|
||||
}, false, "Configure service failure actions"},
|
||||
{s.Start, true, "Start"},
|
||||
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
|
||||
}
|
||||
logger.Notice().Msg("Starting existing ctrld service")
|
||||
if doTasks(tasks) {
|
||||
logger.Notice().Msg("Service started")
|
||||
sockDir, err := socketDir()
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get socket directory")
|
||||
os.Exit(1)
|
||||
}
|
||||
reportSetDnsOk(sockDir)
|
||||
} else {
|
||||
logger.Error().Err(err).Msg("Failed to start existing ctrld service")
|
||||
os.Exit(1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if cdUID != "" {
|
||||
_ = doValidateCdRemoteConfig(cdUID, true)
|
||||
} else if uid := cdUIDFromProvToken(); uid != "" {
|
||||
cdUID = uid
|
||||
logger.Debug().Msg("Using uid from provision token")
|
||||
removeOrgFlagsFromArgs(svcConfig)
|
||||
// Pass --cd flag to "ctrld run" command, so the provision token takes no effect.
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--cd="+cdUID)
|
||||
}
|
||||
if cdUID != "" {
|
||||
validateCdUpstreamProtocol()
|
||||
}
|
||||
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
|
||||
tryReadingConfigWithNotice(writeDefaultConfig, true)
|
||||
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
logger.Fatal().Msgf("Failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
|
||||
if nextdns != "" {
|
||||
removeNextDNSFromArgs(svcConfig)
|
||||
}
|
||||
|
||||
// Explicitly passing config, so on system where home directory could not be obtained,
|
||||
// or sub-process env is different with the parent, we still behave correctly and use
|
||||
// the expected config file.
|
||||
if configPath == "" {
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--config="+defaultConfigFile)
|
||||
}
|
||||
|
||||
tasks := []task{
|
||||
{s.Stop, false, "Stop"},
|
||||
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"},
|
||||
{func() error { return ensureUninstall(s) }, false, "Ensure uninstall"},
|
||||
//resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
||||
{func() error {
|
||||
// Save current DNS so we can restore later.
|
||||
withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error {
|
||||
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}, false, "Save current DNS"},
|
||||
{s.Install, false, "Install"},
|
||||
{func() error {
|
||||
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
|
||||
}, false, "Configure Windows service failure actions"},
|
||||
{s.Start, true, "Start"},
|
||||
// Note that startCmd do not actually write ControlD config, but the config file was
|
||||
// generated after s.Start, so we notice users here for consistent with nextdns mode.
|
||||
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
|
||||
}
|
||||
logger.Notice().Msg("Starting service")
|
||||
if doTasks(tasks) {
|
||||
// add a small delay to ensure the service is started and did not crash
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
ok, status, err := selfCheckStatus(ctx, s, sockDir)
|
||||
switch {
|
||||
case ok && status == service.StatusRunning:
|
||||
logger.Notice().Msg("Service started")
|
||||
default:
|
||||
marker := append(bytes.Repeat([]byte("="), 32), '\n')
|
||||
// If ctrld service is not running, emitting log obtained from ctrld process.
|
||||
if status != service.StatusRunning || ctx.Err() != nil {
|
||||
logger.Error().Msg("Ctrld service may not have started due to an error or misconfiguration, service log:")
|
||||
_, _ = logger.Write(marker)
|
||||
|
||||
// Wait for log collection to complete
|
||||
<-stopLogCh
|
||||
|
||||
// Retrieve logs from HTTP server if available
|
||||
if logServerSocketPath != "" {
|
||||
hlc := newHTTPLogClient(logServerSocketPath)
|
||||
logs, err := hlc.GetLogs()
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get logs from HTTP log server")
|
||||
}
|
||||
if len(logs) == 0 {
|
||||
logger.Write([]byte("<no log output is obtained from ctrld process>\n"))
|
||||
} else {
|
||||
logger.Write(logs)
|
||||
logger.Write([]byte("\n"))
|
||||
}
|
||||
} else {
|
||||
logger.Write([]byte("<no log output from HTTP log server>\n"))
|
||||
}
|
||||
}
|
||||
// Report any error if occurred.
|
||||
if err != nil {
|
||||
_, _ = logger.Write(marker)
|
||||
msg := fmt.Sprintf("An error occurred while performing test query: %s\n", err)
|
||||
logger.Write([]byte(msg))
|
||||
}
|
||||
// If ctrld service is running but selfCheckStatus failed, it could be related
|
||||
// to user's system firewall configuration, notice users about it.
|
||||
if status == service.StatusRunning && err == nil {
|
||||
_, _ = logger.Write(marker)
|
||||
logger.Write([]byte("ctrld service was running, but a DNS query could not be sent to its listener\n"))
|
||||
logger.Write([]byte("Please check your system firewall if it is configured to block/intercept/redirect DNS queries\n"))
|
||||
}
|
||||
|
||||
_, _ = logger.Write(marker)
|
||||
uninstall(p, s)
|
||||
os.Exit(1)
|
||||
}
|
||||
reportSetDnsOk(sockDir)
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service start command completed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// createStartCommands creates the start command and its alias
|
||||
func createStartCommands(sc *ServiceCommand) (*cobra.Command, *cobra.Command) {
|
||||
// Start command
|
||||
startCmd := &cobra.Command{
|
||||
Use: "start",
|
||||
Short: "Install and start the ctrld service",
|
||||
Long: `Install and start the ctrld service
|
||||
|
||||
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
args = filterEmptyStrings(args)
|
||||
if len(args) > 0 {
|
||||
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
|
||||
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Start,
|
||||
}
|
||||
// Keep these flags in sync with runCmd above, except for "-d"/"--nextdns".
|
||||
startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid")
|
||||
startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token")
|
||||
startCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API")
|
||||
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = startCmd.Flags().MarkHidden("dev")
|
||||
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id")
|
||||
startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
|
||||
startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`)
|
||||
startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service")
|
||||
_ = startCmd.Flags().MarkHidden("start_only")
|
||||
startCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener")
|
||||
|
||||
// Start command alias
|
||||
startCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "start",
|
||||
Short: "Quick start service and configure DNS on interface",
|
||||
Long: `Quick start service and configure DNS on interface
|
||||
|
||||
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
args = filterEmptyStrings(args)
|
||||
if len(args) > 0 {
|
||||
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
|
||||
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if len(os.Args) == 2 {
|
||||
startOnly = true
|
||||
}
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
return startCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
|
||||
|
||||
return startCmd, startCmdAlias
|
||||
}
|
||||
41
cmd/cli/commands_service_status.go
Normal file
41
cmd/cli/commands_service_status.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Status implements the logic from cmdStatus.Run
|
||||
func (sc *ServiceCommand) Status(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service status command started")
|
||||
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
logger.Error().Msg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
switch status {
|
||||
case service.StatusUnknown:
|
||||
logger.Notice().Msg("Unknown status")
|
||||
os.Exit(2)
|
||||
case service.StatusRunning:
|
||||
logger.Notice().Msg("Service is running")
|
||||
os.Exit(0)
|
||||
case service.StatusStopped:
|
||||
logger.Notice().Msg("Service is stopped")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service status command completed")
|
||||
return nil
|
||||
}
|
||||
61
cmd/cli/commands_service_stop.go
Normal file
61
cmd/cli/commands_service_stop.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Stop implements the logic from cmdStop.Run
|
||||
func (sc *ServiceCommand) Stop(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service stop command started")
|
||||
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
|
||||
s, p, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
p.cfg = &cfg
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
logger.Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
logger.Warn().Msg("Service is already stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
||||
logger.Error().Msg("Deactivation pin check failed")
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Stopping service")
|
||||
if doTasks([]task{{s.Stop, true, "Stop"}}) {
|
||||
logger.Notice().Msg("Service stopped")
|
||||
} else {
|
||||
logger.Error().Msg("Service stop failed")
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service stop command completed")
|
||||
return nil
|
||||
}
|
||||
106
cmd/cli/commands_service_uninstall.go
Normal file
106
cmd/cli/commands_service_uninstall.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// Uninstall implements the logic from cmdUninstall.Run
|
||||
func (sc *ServiceCommand) Uninstall(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service uninstall command started")
|
||||
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
|
||||
s, p, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
p.cfg = &cfg
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
||||
logger.Error().Msg("Deactivation pin check failed")
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Starting service uninstall")
|
||||
uninstall(p, s)
|
||||
|
||||
if cleanup {
|
||||
logger.Debug().Msg("Performing cleanup operations")
|
||||
var files []string
|
||||
// Config file.
|
||||
files = append(files, v.ConfigFileUsed())
|
||||
// Log file and backup log file.
|
||||
// For safety, only process if log file path is absolute.
|
||||
if logFile := normalizeLogFilePath(cfg.Service.LogPath); filepath.IsAbs(logFile) {
|
||||
files = append(files, logFile)
|
||||
oldLogFile := logFile + oldLogSuffix
|
||||
if _, err := os.Stat(oldLogFile); err == nil {
|
||||
files = append(files, oldLogFile)
|
||||
}
|
||||
}
|
||||
// Socket files.
|
||||
if dir, _ := socketDir(); dir != "" {
|
||||
files = append(files, filepath.Join(dir, ctrldControlUnixSock))
|
||||
files = append(files, filepath.Join(dir, ctrldLogUnixSock))
|
||||
}
|
||||
// Static DNS settings files.
|
||||
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
|
||||
file := ctrld.SavedStaticDnsSettingsFilePath(i)
|
||||
files = append(files, file)
|
||||
return nil
|
||||
})
|
||||
bin, err := os.Executable()
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get executable path")
|
||||
}
|
||||
if bin != "" && supportedSelfDelete {
|
||||
files = append(files, bin)
|
||||
}
|
||||
// Backup file after upgrading.
|
||||
oldBin := bin + oldBinSuffix
|
||||
if _, err := os.Stat(oldBin); err == nil {
|
||||
files = append(files, oldBin)
|
||||
}
|
||||
for _, file := range files {
|
||||
if file == "" {
|
||||
continue
|
||||
}
|
||||
if err := os.Remove(file); err == nil {
|
||||
logger.Notice().Str("file", file).Msg("File removed during cleanup")
|
||||
} else {
|
||||
logger.Debug().Err(err).Str("file", file).Msg("Failed to remove file during cleanup")
|
||||
}
|
||||
}
|
||||
// Self-delete the ctrld binary if supported
|
||||
if err := selfDeleteExe(); err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to delete ctrld binary")
|
||||
} else {
|
||||
if !supportedSelfDelete {
|
||||
logger.Debug().Msgf("File removed: %s", bin)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Cleanup operations completed")
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service uninstall command completed")
|
||||
return nil
|
||||
}
|
||||
197
cmd/cli/commands_test.go
Normal file
197
cmd/cli/commands_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestBasicCommandStructure tests the actual root command structure
|
||||
func TestBasicCommandStructure(t *testing.T) {
|
||||
// Test the actual root command that's returned from initCLI()
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command has basic properties
|
||||
assert.Equal(t, "ctrld", rootCmd.Use)
|
||||
assert.NotEmpty(t, rootCmd.Short, "Root command should have a short description")
|
||||
|
||||
// Test that root command has subcommands
|
||||
commands := rootCmd.Commands()
|
||||
assert.NotNil(t, commands, "Root command should have subcommands")
|
||||
assert.Greater(t, len(commands), 0, "Root command should have at least one subcommand")
|
||||
|
||||
// Test that expected commands exist
|
||||
expectedCommands := []string{"run", "service", "clients", "upgrade", "log"}
|
||||
for _, cmdName := range expectedCommands {
|
||||
found := false
|
||||
for _, cmd := range commands {
|
||||
if cmd.Name() == cmdName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected command %s not found in root command", cmdName)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceCommandCreation tests service command creation
|
||||
func TestServiceCommandCreation(t *testing.T) {
|
||||
sc := NewServiceCommand()
|
||||
require.NotNil(t, sc, "ServiceCommand should be created")
|
||||
|
||||
// Test service config creation
|
||||
config := sc.createServiceConfig()
|
||||
require.NotNil(t, config, "Service config should be created")
|
||||
assert.Equal(t, ctrldServiceName, config.Name)
|
||||
assert.Equal(t, "Control-D Helper Service", config.DisplayName)
|
||||
assert.Equal(t, "A highly configurable, multi-protocol DNS forwarding proxy", config.Description)
|
||||
}
|
||||
|
||||
// TestServiceCommandSubCommands tests service command sub commands
|
||||
func TestServiceCommandSubCommands(t *testing.T) {
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "ctrld",
|
||||
Short: "DNS forwarding proxy",
|
||||
}
|
||||
|
||||
serviceCmd := InitServiceCmd(rootCmd)
|
||||
require.NotNil(t, serviceCmd, "Service command should be created")
|
||||
|
||||
// Test that service command has subcommands
|
||||
subcommands := serviceCmd.Commands()
|
||||
assert.Greater(t, len(subcommands), 0, "Service command should have subcommands")
|
||||
|
||||
// Test specific subcommands exist
|
||||
expectedCommands := []string{"start", "stop", "restart", "reload", "status", "uninstall", "interfaces"}
|
||||
|
||||
for _, cmdName := range expectedCommands {
|
||||
found := false
|
||||
for _, cmd := range subcommands {
|
||||
if cmd.Name() == cmdName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected service subcommand %s not found", cmdName)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommandHelp tests basic help functionality
|
||||
func TestCommandHelp(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test help command execution
|
||||
var buf bytes.Buffer
|
||||
rootCmd.SetOut(&buf)
|
||||
rootCmd.SetErr(&buf)
|
||||
|
||||
rootCmd.SetArgs([]string{"--help"})
|
||||
err := rootCmd.Execute()
|
||||
assert.NoError(t, err, "Help command should execute without error")
|
||||
assert.Contains(t, buf.String(), "dns forwarding proxy", "Help output should contain description")
|
||||
}
|
||||
|
||||
// TestCommandVersion tests version command
|
||||
func TestCommandVersion(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
var buf bytes.Buffer
|
||||
rootCmd.SetOut(&buf)
|
||||
rootCmd.SetErr(&buf)
|
||||
|
||||
// Test version command
|
||||
rootCmd.SetArgs([]string{"--version"})
|
||||
err := rootCmd.Execute()
|
||||
assert.NoError(t, err, "Version command should execute without error")
|
||||
assert.Contains(t, buf.String(), "version", "Version output should contain version information")
|
||||
}
|
||||
|
||||
// TestCommandErrorHandling tests error handling
|
||||
func TestCommandErrorHandling(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test invalid flag instead of invalid command
|
||||
rootCmd.SetArgs([]string{"--invalid-flag"})
|
||||
err := rootCmd.Execute()
|
||||
assert.Error(t, err, "Invalid flag should return error")
|
||||
}
|
||||
|
||||
// TestCommandFlags tests flag functionality
|
||||
func TestCommandFlags(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command has expected flags
|
||||
verboseFlag := rootCmd.PersistentFlags().Lookup("verbose")
|
||||
assert.NotNil(t, verboseFlag, "Verbose flag should exist")
|
||||
assert.Equal(t, "v", verboseFlag.Shorthand)
|
||||
|
||||
silentFlag := rootCmd.PersistentFlags().Lookup("silent")
|
||||
assert.NotNil(t, silentFlag, "Silent flag should exist")
|
||||
assert.Equal(t, "s", silentFlag.Shorthand)
|
||||
}
|
||||
|
||||
// TestCommandExecution tests basic command execution
|
||||
func TestCommandExecution(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command can be executed (help command)
|
||||
var buf bytes.Buffer
|
||||
rootCmd.SetOut(&buf)
|
||||
rootCmd.SetErr(&buf)
|
||||
|
||||
rootCmd.SetArgs([]string{"--help"})
|
||||
err := rootCmd.Execute()
|
||||
assert.NoError(t, err, "Root command should execute without error")
|
||||
assert.Contains(t, buf.String(), "dns forwarding proxy", "Help output should contain description")
|
||||
}
|
||||
|
||||
// TestCommandArgs tests argument handling
|
||||
func TestCommandArgs(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command can handle arguments properly
|
||||
// Test with no args (should succeed)
|
||||
err := rootCmd.Execute()
|
||||
assert.NoError(t, err, "Root command with no args should execute")
|
||||
|
||||
// Test with help flag (should succeed)
|
||||
rootCmd.SetArgs([]string{"--help"})
|
||||
err = rootCmd.Execute()
|
||||
assert.NoError(t, err, "Root command with help flag should execute")
|
||||
}
|
||||
|
||||
// TestCommandSubcommands tests subcommand functionality
|
||||
func TestCommandSubcommands(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command has subcommands
|
||||
commands := rootCmd.Commands()
|
||||
assert.Greater(t, len(commands), 0, "Root command should have subcommands")
|
||||
|
||||
// Test that specific subcommands exist and can be executed
|
||||
expectedSubcommands := []string{"run", "service", "clients", "upgrade", "log"}
|
||||
for _, subCmdName := range expectedSubcommands {
|
||||
// Find the subcommand
|
||||
var subCmd *cobra.Command
|
||||
for _, cmd := range commands {
|
||||
if cmd.Name() == subCmdName {
|
||||
subCmd = cmd
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, subCmd, "Subcommand %s should exist", subCmdName)
|
||||
|
||||
// Test that subcommand has help
|
||||
assert.NotEmpty(t, subCmd.Short, "Subcommand %s should have a short description", subCmdName)
|
||||
}
|
||||
}
|
||||
192
cmd/cli/commands_upgrade.go
Normal file
192
cmd/cli/commands_upgrade.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/minio/selfupdate"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
upgradeChannelDev = "dev"
|
||||
upgradeChannelProd = "prod"
|
||||
upgradeChannelDefault = "default"
|
||||
)
|
||||
|
||||
// UpgradeCommand handles upgrade-related operations
|
||||
type UpgradeCommand struct {
|
||||
}
|
||||
|
||||
// NewUpgradeCommand creates a new upgrade command handler
|
||||
func NewUpgradeCommand() (*UpgradeCommand, error) {
|
||||
return &UpgradeCommand{}, nil
|
||||
}
|
||||
|
||||
// Upgrade performs the upgrade operation
|
||||
func (uc *UpgradeCommand) Upgrade(cmd *cobra.Command, args []string) error {
|
||||
upgradeChannel := map[string]string{
|
||||
upgradeChannelDefault: "https://dl.controld.dev",
|
||||
upgradeChannelDev: "https://dl.controld.dev",
|
||||
upgradeChannelProd: "https://dl.controld.com",
|
||||
}
|
||||
if isStableVersion(curVersion()) {
|
||||
upgradeChannel[upgradeChannelDefault] = upgradeChannel[upgradeChannelProd]
|
||||
}
|
||||
|
||||
bin, err := os.Executable()
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to get current ctrld binary path")
|
||||
}
|
||||
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
svcCmd := NewServiceCommand()
|
||||
s, p, err := svcCmd.initializeServiceManager()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
svcInstalled := true
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
svcInstalled = false
|
||||
}
|
||||
|
||||
oldBin := bin + oldBinSuffix
|
||||
baseUrl := upgradeChannel[upgradeChannelDefault]
|
||||
if len(args) > 0 {
|
||||
channel := args[0]
|
||||
switch channel {
|
||||
case upgradeChannelProd, upgradeChannelDev: // ok
|
||||
default:
|
||||
mainLog.Load().Fatal().Msgf("Upgrade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev)
|
||||
}
|
||||
baseUrl = upgradeChannel[channel]
|
||||
}
|
||||
|
||||
dlUrl := upgradeUrl(baseUrl)
|
||||
mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl)
|
||||
|
||||
resp, err := getWithRetry(dlUrl, downloadServerIp)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to download binary")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
mainLog.Load().Fatal().Msgf("Could not download binary: %s", http.StatusText(resp.StatusCode))
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Msg("Updating current binary")
|
||||
if err := selfupdate.Apply(resp.Body, selfupdate.Options{OldSavePath: oldBin}); err != nil {
|
||||
if rerr := selfupdate.RollbackError(err); rerr != nil {
|
||||
mainLog.Load().Error().Err(rerr).Msg("Could not rollback old binary")
|
||||
}
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to update current binary")
|
||||
}
|
||||
|
||||
doRestart := func() bool {
|
||||
if !svcInstalled {
|
||||
return true
|
||||
}
|
||||
tasks := []task{
|
||||
{s.Stop, true, "Stop"},
|
||||
{func() error {
|
||||
// restore static DNS settings or DHCP
|
||||
p.resetDNS(false, true)
|
||||
return nil
|
||||
}, false, "Cleanup"},
|
||||
{func() error {
|
||||
time.Sleep(time.Second * 1)
|
||||
return nil
|
||||
}, false, "Waiting for service to stop"},
|
||||
}
|
||||
doTasks(tasks)
|
||||
|
||||
tasks = []task{
|
||||
{s.Start, true, "Start"},
|
||||
}
|
||||
if doTasks(tasks) {
|
||||
if dir, err := socketDir(); err == nil {
|
||||
if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil {
|
||||
_, _ = cc.post(ifacePath, nil)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if svcInstalled {
|
||||
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")
|
||||
}
|
||||
|
||||
if doRestart() {
|
||||
_ = os.Remove(oldBin)
|
||||
_ = os.Chmod(bin, 0755)
|
||||
ver := "unknown version"
|
||||
out, err := exec.Command(bin, "--version").CombinedOutput()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("Failed to get new binary version")
|
||||
}
|
||||
if after, found := strings.CutPrefix(string(out), "ctrld version "); found {
|
||||
ver = after
|
||||
}
|
||||
mainLog.Load().Notice().Msgf("Upgrade successful - %s", ver)
|
||||
return nil
|
||||
}
|
||||
|
||||
mainLog.Load().Warn().Msgf("Upgrade failed, restoring previous binary: %s", oldBin)
|
||||
if err := os.Remove(bin); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to remove new binary")
|
||||
}
|
||||
if err := os.Rename(oldBin, bin); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to restore old binary")
|
||||
}
|
||||
if doRestart() {
|
||||
mainLog.Load().Notice().Msg("Restored previous binary successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitUpgradeCmd creates the upgrade command with proper logic
|
||||
func InitUpgradeCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
upgradeCmd := &cobra.Command{
|
||||
Use: "upgrade",
|
||||
Short: "Upgrading ctrld to latest version",
|
||||
ValidArgs: []string{upgradeChannelDev, upgradeChannelProd},
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
uc, err := NewUpgradeCommand()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return uc.Upgrade(cmd, args)
|
||||
},
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(upgradeCmd)
|
||||
|
||||
return upgradeCmd
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// logConn wraps a net.Conn, override the Write behavior.
|
||||
// runCmd uses this wrapper, so as long as startCmd finished,
|
||||
// ctrld log won't be flushed with un-necessary write errors.
|
||||
type logConn struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (lc *logConn) Read(b []byte) (n int, err error) {
|
||||
return lc.conn.Read(b)
|
||||
}
|
||||
|
||||
func (lc *logConn) Close() error {
|
||||
return lc.conn.Close()
|
||||
}
|
||||
|
||||
func (lc *logConn) LocalAddr() net.Addr {
|
||||
return lc.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (lc *logConn) RemoteAddr() net.Addr {
|
||||
return lc.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (lc *logConn) SetDeadline(t time.Time) error {
|
||||
return lc.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (lc *logConn) SetReadDeadline(t time.Time) error {
|
||||
return lc.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (lc *logConn) SetWriteDeadline(t time.Time) error {
|
||||
return lc.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (lc *logConn) Write(b []byte) (int, error) {
|
||||
// Write performs writes with underlying net.Conn, ignore any errors happen.
|
||||
// "ctrld run" command use this wrapper to report errors to "ctrld start".
|
||||
// If no error occurred, "ctrld start" may finish before "ctrld run" attempt
|
||||
// to close the connection, so ignore errors conservatively here, prevent
|
||||
// un-necessary error "write to closed connection" flushed to ctrld log.
|
||||
_, _ = lc.conn.Write(b)
|
||||
return len(b), nil
|
||||
}
|
||||
@@ -8,10 +8,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// controlClient represents an HTTP client for communicating with the control server
|
||||
type controlClient struct {
|
||||
c *http.Client
|
||||
}
|
||||
|
||||
// newControlClient creates a new control client with Unix socket transport
|
||||
func newControlClient(addr string) *controlClient {
|
||||
return &controlClient{c: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
@@ -25,6 +27,10 @@ func newControlClient(addr string) *controlClient {
|
||||
}
|
||||
|
||||
func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) {
|
||||
// for log/send, set the timeout to 5 minutes
|
||||
if path == sendLogsPath {
|
||||
c.c.Timeout = time.Minute * 5
|
||||
}
|
||||
return c.c.Post("http://unix"+path, contentTypeJson, data)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -10,9 +12,11 @@ import (
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -22,14 +26,25 @@ const (
|
||||
reloadPath = "/reload"
|
||||
deactivationPath = "/deactivation"
|
||||
cdPath = "/cd"
|
||||
ifacePath = "/iface"
|
||||
viewLogsPath = "/log/view"
|
||||
sendLogsPath = "/log/send"
|
||||
)
|
||||
|
||||
type ifaceResponse struct {
|
||||
Name string `json:"name"`
|
||||
All bool `json:"all"`
|
||||
OK bool `json:"ok"`
|
||||
}
|
||||
|
||||
// controlServer represents an HTTP server for handling control requests
|
||||
type controlServer struct {
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
addr string
|
||||
}
|
||||
|
||||
// newControlServer creates a new control server instance
|
||||
func newControlServer(addr string) (*controlServer, error) {
|
||||
mux := http.NewServeMux()
|
||||
s := &controlServer{
|
||||
@@ -66,33 +81,81 @@ func (s *controlServer) register(pattern string, handler http.Handler) {
|
||||
|
||||
func (p *prog) registerControlServerHandler() {
|
||||
p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
p.Debug().Msg("Handling list clients request")
|
||||
|
||||
clients := p.ciTable.ListClients()
|
||||
p.Debug().Int("client_count", len(clients)).Msg("Retrieved clients list")
|
||||
|
||||
sort.Slice(clients, func(i, j int) bool {
|
||||
return clients[i].IP.Less(clients[j].IP)
|
||||
})
|
||||
if p.cfg.Service.MetricsQueryStats {
|
||||
for _, client := range clients {
|
||||
p.Debug().Msg("Sorted clients by IP address")
|
||||
|
||||
if p.metricsQueryStats.Load() {
|
||||
p.Debug().Msg("Metrics query stats enabled, collecting query counts")
|
||||
|
||||
for idx, client := range clients {
|
||||
p.Debug().
|
||||
Int("index", idx).
|
||||
Str("ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("Processing client metrics")
|
||||
|
||||
client.IncludeQueryCount = true
|
||||
dm := &dto.Metric{}
|
||||
|
||||
if statsClientQueriesCount.MetricVec == nil {
|
||||
p.Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("Skipping metrics collection: MetricVec is nil")
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := statsClientQueriesCount.MetricVec.GetMetricWithLabelValues(
|
||||
client.IP.String(),
|
||||
client.Mac,
|
||||
client.Hostname,
|
||||
)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("could not get metrics for client: %v", client)
|
||||
p.Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("Failed to get metrics for client")
|
||||
continue
|
||||
}
|
||||
if err := m.Write(dm); err == nil {
|
||||
|
||||
if err := m.Write(dm); err == nil && dm.Counter != nil {
|
||||
client.QueryCount = int64(dm.Counter.GetValue())
|
||||
p.Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Int64("query_count", client.QueryCount).
|
||||
Msg("Successfully collected query count")
|
||||
} else if err != nil {
|
||||
p.Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("Failed to write metric")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
p.Debug().Msg("Metrics query stats disabled, skipping query counts")
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&clients); err != nil {
|
||||
p.Error().
|
||||
Err(err).
|
||||
Int("client_count", len(clients)).
|
||||
Msg("Failed to encode clients response")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
p.Debug().
|
||||
Int("client_count", len(clients)).
|
||||
Msg("Successfully sent clients list response")
|
||||
}))
|
||||
p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
select {
|
||||
@@ -114,14 +177,14 @@ func (p *prog) registerControlServerHandler() {
|
||||
oldSvc := p.cfg.Service
|
||||
p.mu.Unlock()
|
||||
if err := p.sendReloadSignal(); err != nil {
|
||||
mainLog.Load().Err(err).Msg("could not send reload signal")
|
||||
p.Error().Err(err).Msg("Could not send reload signal")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-p.reloadDoneCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError)
|
||||
http.Error(w, "Timeout waiting for ctrld reload", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -149,8 +212,26 @@ func (p *prog) registerControlServerHandler() {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
p.cs.register(deactivationPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
// Non-cd mode or pin code not set, always allowing deactivation.
|
||||
if cdUID == "" || deactivationPinNotSet() {
|
||||
// Non-cd mode always allowing deactivation.
|
||||
if cdUID == "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
// Re-fetch pin code from API.
|
||||
if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev); rc != nil {
|
||||
if rc.DeactivationPin != nil {
|
||||
cdDeactivationPin.Store(*rc.DeactivationPin)
|
||||
} else {
|
||||
cdDeactivationPin.Store(defaultDeactivationPin)
|
||||
}
|
||||
} else {
|
||||
p.Warn().Err(err).Msg("Could not re-fetch deactivation pin code")
|
||||
}
|
||||
|
||||
// If pin code not set, allowing deactivation.
|
||||
if !deactivationPinSet() {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
@@ -158,14 +239,18 @@ func (p *prog) registerControlServerHandler() {
|
||||
var req deactivationRequest
|
||||
if err := json.NewDecoder(request.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusPreconditionFailed)
|
||||
mainLog.Load().Err(err).Msg("invalid deactivation request")
|
||||
p.Error().Err(err).Msg("Invalid deactivation request")
|
||||
return
|
||||
}
|
||||
|
||||
code := http.StatusForbidden
|
||||
switch req.Pin {
|
||||
case cdDeactivationPin:
|
||||
case cdDeactivationPin.Load():
|
||||
code = http.StatusOK
|
||||
select {
|
||||
case p.pinCodeValidCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
case defaultDeactivationPin:
|
||||
// If the pin code was set, but users do not provide --pin, return proper code to client.
|
||||
code = http.StatusBadRequest
|
||||
@@ -175,12 +260,87 @@ func (p *prog) registerControlServerHandler() {
|
||||
p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if cdUID != "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(cdUID))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
res := &ifaceResponse{Name: iface}
|
||||
// p.setDNS is only called when running as a service
|
||||
if !service.Interactive() {
|
||||
<-p.csSetDnsDone
|
||||
if p.csSetDnsOk {
|
||||
res.Name = p.runningIface
|
||||
res.All = p.requiredMultiNICsConfig
|
||||
res.OK = true
|
||||
}
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(res); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal iface data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
lr, err := p.logReaderRaw()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer lr.r.Close()
|
||||
if lr.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(lr.r)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("could not read log: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal log data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if time.Since(p.internalLogSent) < logWriterSentInterval {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
r, err := p.logReaderNoColor()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if r.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
req := &controld.LogsRequest{
|
||||
UID: cdUID,
|
||||
Data: r.r,
|
||||
}
|
||||
p.Debug().Msg("Sending log file to ControlD server")
|
||||
resp := logSentResponse{Size: r.size}
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
if err := controld.SendLogs(loggerCtx, req, cdDev); err != nil {
|
||||
p.Error().Msgf("Could not send log file to ControlD server: %v", err)
|
||||
resp.Error = err.Error()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
p.Debug().Msg("Sending log file successfully")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
p.internalLogSent = time.Now()
|
||||
}))
|
||||
}
|
||||
|
||||
// jsonResponse wraps an HTTP handler to set JSON content type
|
||||
func jsonResponse(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
package cli
|
||||
|
||||
//lint:ignore U1000 use in os_linux.go
|
||||
type getDNS func(iface string) []string
|
||||
1728
cmd/cli/dns_proxy.go
1728
cmd/cli/dns_proxy.go
File diff suppressed because it is too large
Load Diff
@@ -30,6 +30,7 @@ func Test_wildcardMatches(t *testing.T) {
|
||||
{"domain - suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
|
||||
{"domain - both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
|
||||
{"domain - both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
|
||||
{"domain - case-insensitive", "*.WINDSCRIBE.com", "anything.windscribe.com", true},
|
||||
{"mac - prefix", "*:98:05:b4:2b", "d4:67:98:05:b4:2b", true},
|
||||
{"mac - prefix not match other s", "*:98:05:b4:2b", "0d:ba:54:09:94:2c", false},
|
||||
{"mac - prefix not match s in name", "*:98:05:b4:2b", "e4:67:97:05:b4:2b", false},
|
||||
@@ -74,8 +75,10 @@ func Test_canonicalName(t *testing.T) {
|
||||
|
||||
func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
|
||||
p := &prog{cfg: cfg}
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
p.logger.Store(mainLog.Load())
|
||||
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
p.ptrLoopGuard = newLoopGuard()
|
||||
for _, nc := range p.cfg.Network {
|
||||
@@ -140,9 +143,94 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_upstreamForWithCustomMatching(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
prog.logger.Store(mainLog.Load())
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a custom policy with domain-first matching order
|
||||
customPolicy := &ctrld.ListenerPolicyConfig{
|
||||
Name: "Custom Policy",
|
||||
Networks: []ctrld.Rule{
|
||||
{"network.0": []string{"upstream.1", "upstream.0"}},
|
||||
},
|
||||
Macs: []ctrld.Rule{
|
||||
{"14:45:A0:67:83:0A": []string{"upstream.2"}},
|
||||
},
|
||||
Rules: []ctrld.Rule{
|
||||
{"*.ru": []string{"upstream.1"}},
|
||||
},
|
||||
Matching: &ctrld.MatchingConfig{
|
||||
Order: []string{"domain", "mac", "network"},
|
||||
},
|
||||
}
|
||||
|
||||
customListener := &ctrld.ListenerConfig{
|
||||
Policy: customPolicy,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
mac string
|
||||
domain string
|
||||
upstreams []string
|
||||
matched bool
|
||||
}{
|
||||
{
|
||||
name: "Domain rule should match first with custom order",
|
||||
ip: "192.168.0.1:0",
|
||||
mac: "14:45:A0:67:83:0A",
|
||||
domain: "example.ru",
|
||||
upstreams: []string{"upstream.1"},
|
||||
matched: true,
|
||||
},
|
||||
{
|
||||
name: "MAC rule should match when no domain rule",
|
||||
ip: "192.168.0.1:0",
|
||||
mac: "14:45:A0:67:83:0A",
|
||||
domain: "example.com",
|
||||
upstreams: []string{"upstream.2"},
|
||||
matched: true,
|
||||
},
|
||||
{
|
||||
name: "Network rule should match when no domain or MAC rule",
|
||||
ip: "192.168.0.1:0",
|
||||
mac: "00:11:22:33:44:55",
|
||||
domain: "example.com",
|
||||
upstreams: []string{"upstream.1", "upstream.0"},
|
||||
matched: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
addr, err := net.ResolveUDPAddr("udp", tc.ip)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, addr)
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
||||
ufr := prog.upstreamFor(ctx, "0", customListener, addr, tc.mac, tc.domain)
|
||||
|
||||
assert.Equal(t, tc.matched, ufr.matched)
|
||||
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
prog.logger.Store(mainLog.Load())
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
@@ -364,6 +452,9 @@ func Test_isLanHostnameQuery(t *testing.T) {
|
||||
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
|
||||
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
|
||||
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
|
||||
{".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true},
|
||||
{".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true},
|
||||
{".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
@@ -413,6 +504,27 @@ func Test_isPrivatePtrLookup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isSrvLanLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isSrvLookup bool
|
||||
}{
|
||||
{"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
|
||||
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
|
||||
{"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isWanClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -438,3 +550,254 @@ func Test_isWanClient(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_shouldStartRecovery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason RecoveryReason
|
||||
hasExistingRecovery bool
|
||||
expectedResult bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "network change with existing recovery",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
hasExistingRecovery: true,
|
||||
expectedResult: true,
|
||||
description: "should cancel existing recovery and start new one for network change",
|
||||
},
|
||||
{
|
||||
name: "network change without existing recovery",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
hasExistingRecovery: false,
|
||||
expectedResult: true,
|
||||
description: "should start new recovery for network change",
|
||||
},
|
||||
{
|
||||
name: "regular failure with existing recovery",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
hasExistingRecovery: true,
|
||||
expectedResult: false,
|
||||
description: "should skip duplicate recovery for regular failure",
|
||||
},
|
||||
{
|
||||
name: "regular failure without existing recovery",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
hasExistingRecovery: false,
|
||||
expectedResult: true,
|
||||
description: "should start new recovery for regular failure",
|
||||
},
|
||||
{
|
||||
name: "OS failure with existing recovery",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
hasExistingRecovery: true,
|
||||
expectedResult: false,
|
||||
description: "should skip duplicate recovery for OS failure",
|
||||
},
|
||||
{
|
||||
name: "OS failure without existing recovery",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
hasExistingRecovery: false,
|
||||
expectedResult: true,
|
||||
description: "should start new recovery for OS failure",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
// Setup existing recovery if needed
|
||||
if tc.hasExistingRecovery {
|
||||
p.recoveryCancelMu.Lock()
|
||||
p.recoveryCancel = func() {} // Mock cancel function
|
||||
p.recoveryCancelMu.Unlock()
|
||||
}
|
||||
|
||||
result := p.shouldStartRecovery(tc.reason)
|
||||
assert.Equal(t, tc.expectedResult, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_createRecoveryContext(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
ctx, cleanup := p.createRecoveryContext()
|
||||
|
||||
// Verify context is created
|
||||
assert.NotNil(t, ctx)
|
||||
assert.NotNil(t, cleanup)
|
||||
|
||||
// Verify recoveryCancel is set
|
||||
p.recoveryCancelMu.Lock()
|
||||
assert.NotNil(t, p.recoveryCancel)
|
||||
p.recoveryCancelMu.Unlock()
|
||||
|
||||
// Test cleanup function
|
||||
cleanup()
|
||||
|
||||
// Verify recoveryCancel is cleared
|
||||
p.recoveryCancelMu.Lock()
|
||||
assert.Nil(t, p.recoveryCancel)
|
||||
p.recoveryCancelMu.Unlock()
|
||||
}
|
||||
|
||||
func Test_prepareForRecovery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason RecoveryReason
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "regular failure",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "network change",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OS failure",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
err := p.prepareForRecovery(tc.reason)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify recoveryRunning is set to true
|
||||
assert.True(t, p.recoveryRunning.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_completeRecovery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason RecoveryReason
|
||||
recovered string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "regular failure recovery",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
recovered: "upstream1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "network change recovery",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
recovered: "upstream2",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OS failure recovery",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
recovered: "upstream3",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
err := p.completeRecovery(tc.reason, tc.recovered)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify recoveryRunning is set to false
|
||||
assert.False(t, p.recoveryRunning.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_reinitializeOSResolver(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
err := p.reinitializeOSResolver("Test message")
|
||||
|
||||
// This function should not return an error under normal circumstances
|
||||
// The actual behavior depends on the OS resolver implementation
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_handleRecovery_Integration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason RecoveryReason
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "network change recovery",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "regular failure recovery",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OS failure recovery",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
// This is an integration test that exercises the full recovery flow
|
||||
// In a real test environment, you would mock the dependencies
|
||||
// For now, we're just testing that the method doesn't panic
|
||||
// and that the recovery logic flows correctly
|
||||
assert.NotPanics(t, func() {
|
||||
// Test only the preparation phase to avoid actual upstream checking
|
||||
if !p.shouldStartRecovery(tc.reason) {
|
||||
return
|
||||
}
|
||||
|
||||
_, cleanup := p.createRecoveryContext()
|
||||
defer cleanup()
|
||||
|
||||
if err := p.prepareForRecovery(tc.reason); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip the actual upstream recovery check for this test
|
||||
// as it requires properly configured upstreams
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newTestProg creates a properly initialized *prog for testing.
|
||||
func newTestProg(t *testing.T) *prog {
|
||||
p := &prog{cfg: testhelper.SampleConfig(t)}
|
||||
p.logger.Store(mainLog.Load())
|
||||
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
|
||||
return p
|
||||
}
|
||||
|
||||
18
cmd/cli/hostname.go
Normal file
18
cmd/cli/hostname.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import "regexp"
|
||||
|
||||
// validHostname reports whether hostname is a valid hostname.
|
||||
// A valid hostname contains 3 -> 64 characters and conform to RFC1123.
|
||||
// This function validates hostnames to ensure they meet DNS naming standards
|
||||
// and prevents invalid hostnames from being used in DNS configurations
|
||||
func validHostname(hostname string) bool {
|
||||
hostnameLen := len(hostname)
|
||||
if hostnameLen < 3 || hostnameLen > 64 {
|
||||
return false
|
||||
}
|
||||
// RFC1123 regex pattern ensures hostnames follow DNS naming conventions
|
||||
// This prevents issues with DNS resolution and system compatibility
|
||||
validHostnameRfc1123 := regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
|
||||
return validHostnameRfc1123.MatchString(hostname)
|
||||
}
|
||||
35
cmd/cli/hostname_test.go
Normal file
35
cmd/cli/hostname_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_validHostname(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostname string
|
||||
valid bool
|
||||
}{
|
||||
{"localhost", "localhost", true},
|
||||
{"localdomain", "localhost.localdomain", true},
|
||||
{"localhost6", "localhost6.localdomain6", true},
|
||||
{"ip6", "ip6-localhost", true},
|
||||
{"non-domain", "controld", true},
|
||||
{"domain", "controld.com", true},
|
||||
{"empty", "", false},
|
||||
{"min length", "fo", false},
|
||||
{"max length", strings.Repeat("a", 65), false},
|
||||
{"special char", "foo!", false},
|
||||
{"non-ascii", "fooΩ", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.hostname, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, validHostname(tc.hostname) == tc.valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
172
cmd/cli/http_log.go
Normal file
172
cmd/cli/http_log.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HTTP log server endpoint constants
|
||||
const (
|
||||
httpLogEndpointPing = "/ping"
|
||||
httpLogEndpointLogs = "/logs"
|
||||
httpLogEndpointExit = "/exit"
|
||||
)
|
||||
|
||||
// httpLogClient sends logs to an HTTP server via POST requests.
|
||||
// This replaces the logConn functionality with HTTP-based communication.
|
||||
type httpLogClient struct {
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// newHTTPLogClient creates a new HTTP log client
|
||||
func newHTTPLogClient(sockPath string) *httpLogClient {
|
||||
return &httpLogClient{
|
||||
baseURL: "http://unix",
|
||||
client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Write sends log data to the HTTP server via POST request
|
||||
func (hlc *httpLogClient) Write(b []byte) (int, error) {
|
||||
// Send log data via HTTP POST to /logs endpoint
|
||||
resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointLogs, "text/plain", bytes.NewReader(b))
|
||||
if err != nil {
|
||||
// Ignore errors to prevent log pollution, just like the original logConn
|
||||
return len(b), nil
|
||||
}
|
||||
resp.Body.Close()
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Ping tests if the HTTP log server is available
|
||||
func (hlc *httpLogClient) Ping() error {
|
||||
resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointPing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close sends exit signal to the HTTP server
|
||||
func (hlc *httpLogClient) Close() error {
|
||||
// Send exit signal via HTTP POST with empty body
|
||||
resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLogs retrieves all collected logs from the HTTP server
|
||||
func (hlc *httpLogClient) GetLogs() ([]byte, error) {
|
||||
resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNoContent {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// httpLogServer starts an HTTP server listening on unix socket to collect logs from runCmd.
|
||||
func httpLogServer(sockPath string, stopLogCh chan struct{}) error {
|
||||
addr, err := net.ResolveUnixAddr("unix", sockPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid log sock path: %w", err)
|
||||
}
|
||||
|
||||
ln, err := net.ListenUnix("unix", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not listen log socket: %w", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
// Create a log writer to store all logs
|
||||
logWriter := newLogWriter()
|
||||
|
||||
// Use a sync.Once to ensure channel is only closed once
|
||||
var channelClosed sync.Once
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(httpLogEndpointPing, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
mux.HandleFunc(httpLogEndpointLogs, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
// POST /logs - Store log data
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store log data in log writer
|
||||
logWriter.Write(body)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
case http.MethodGet:
|
||||
// GET /logs - Retrieve all logs
|
||||
// Get all logs from the log writer
|
||||
logWriter.mu.Lock()
|
||||
logs := logWriter.buf.Bytes()
|
||||
logWriter.mu.Unlock()
|
||||
|
||||
if len(logs) == 0 {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(logs)
|
||||
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc(httpLogEndpointExit, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Close the stop channel to signal completion (only once)
|
||||
channelClosed.Do(func() {
|
||||
close(stopLogCh)
|
||||
})
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
server := &http.Server{Handler: mux}
|
||||
return server.Serve(ln)
|
||||
}
|
||||
747
cmd/cli/http_log_test.go
Normal file
747
cmd/cli/http_log_test.go
Normal file
@@ -0,0 +1,747 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/nettest"
|
||||
)
|
||||
|
||||
func unixDomainSocketPath(t *testing.T) string {
|
||||
t.Helper()
|
||||
sockPath, err := nettest.LocalPath()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
return sockPath
|
||||
}
|
||||
|
||||
func TestHTTPLogServer(t *testing.T) {
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait a bit for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Ping endpoint", func(t *testing.T) {
|
||||
resp, err := client.Get("http://unix" + httpLogEndpointPing)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to ping server: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Ping endpoint wrong method", func(t *testing.T) {
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointPing, "text/plain", bytes.NewReader([]byte("test")))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send POST to ping: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected status 405, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Log endpoint", func(t *testing.T) {
|
||||
testLog := "test log message"
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(testLog)))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send log: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check if log was stored by retrieving it
|
||||
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer logsResp.Body.Close()
|
||||
|
||||
if logsResp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(logsResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(body), testLog) {
|
||||
t.Errorf("Expected log '%s' not found in stored logs", testLog)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Log endpoint wrong method", func(t *testing.T) {
|
||||
// Test unsupported method (PUT) on /logs endpoint
|
||||
req, err := http.NewRequest("PUT", "http://unix"+httpLogEndpointLogs, bytes.NewReader([]byte("test")))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create PUT request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send PUT to logs: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected status 405, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Exit endpoint", func(t *testing.T) {
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send exit: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check if channel is closed by trying to read from it
|
||||
select {
|
||||
case _, ok := <-stopLogCh:
|
||||
if ok {
|
||||
t.Error("Expected channel to be closed, but it's still open")
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("Timeout waiting for channel closure")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Exit endpoint wrong method", func(t *testing.T) {
|
||||
resp, err := client.Get("http://unix" + httpLogEndpointExit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send GET to exit: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected status 405, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Multiple log messages", func(t *testing.T) {
|
||||
logs := []string{"log1", "log2", "log3"}
|
||||
|
||||
for _, log := range logs {
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n")))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send log '%s': %v", log, err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Check if all logs were stored by retrieving them
|
||||
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer logsResp.Body.Close()
|
||||
|
||||
if logsResp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(logsResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(body)
|
||||
for i, expectedLog := range logs {
|
||||
if !strings.Contains(logContent, expectedLog) {
|
||||
t.Errorf("Log %d: expected '%s' not found in stored logs", i, expectedLog)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Large log message", func(t *testing.T) {
|
||||
largeLog := strings.Repeat("a", 1024*10) // 10KB log message
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(largeLog)))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send large log: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check if large log was stored by retrieving it
|
||||
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer logsResp.Body.Close()
|
||||
|
||||
if logsResp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(logsResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(body), largeLog) {
|
||||
t.Error("Large log message was not stored correctly")
|
||||
}
|
||||
})
|
||||
|
||||
// Clean up
|
||||
os.Remove(sockPath)
|
||||
}
|
||||
|
||||
func TestHTTPLogServerInvalidSocketPath(t *testing.T) {
|
||||
// Test with invalid socket path
|
||||
invalidPath := "/invalid/path/that/does/not/exist.sock"
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
err := httpLogServer(invalidPath, stopLogCh)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid socket path")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "could not listen log socket") {
|
||||
t.Errorf("Expected 'could not listen log socket' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPLogServerSocketInUse(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create the first server
|
||||
stopLogCh1 := make(chan struct{})
|
||||
serverErr1 := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr1 <- httpLogServer(sockPath, stopLogCh1)
|
||||
}()
|
||||
|
||||
// Wait for first server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to create a second server on the same socket
|
||||
stopLogCh2 := make(chan struct{})
|
||||
err := httpLogServer(sockPath, stopLogCh2)
|
||||
if err == nil {
|
||||
t.Error("Expected error when socket is already in use")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "could not listen log socket") {
|
||||
t.Errorf("Expected 'could not listen log socket' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPLogServerConcurrentRequests(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Send concurrent requests
|
||||
numRequests := 10
|
||||
done := make(chan bool, numRequests)
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func(i int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
logMsg := fmt.Sprintf("concurrent log %d", i)
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg)))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to send concurrent log %d: %v", i, err)
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for request %d, got %d", i, resp.StatusCode)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all requests to complete
|
||||
for i := 0; i < numRequests; i++ {
|
||||
select {
|
||||
case <-done:
|
||||
// Request completed
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Errorf("Timeout waiting for concurrent request %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all logs were stored by retrieving them
|
||||
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer logsResp.Body.Close()
|
||||
|
||||
if logsResp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(logsResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(body)
|
||||
// Verify all logs were stored
|
||||
for i := 0; i < numRequests; i++ {
|
||||
expectedLog := fmt.Sprintf("concurrent log %d", i)
|
||||
if !strings.Contains(logContent, expectedLog) {
|
||||
t.Errorf("Log '%s' was not stored", expectedLog)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPLogServerErrorHandling(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Invalid request body", func(t *testing.T) {
|
||||
// Test with malformed request - this will fail at HTTP level, not server level
|
||||
// The server will return 400 Bad Request for invalid body
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", strings.NewReader(""))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Empty body should still be processed successfully
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkHTTPLogServer(b *testing.B) {
|
||||
// Create a temporary socket path
|
||||
tmpDir := b.TempDir()
|
||||
sockPath := filepath.Join(tmpDir, "bench.sock")
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
go func() {
|
||||
httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Benchmark log sending
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logMsg := fmt.Sprintf("benchmark log %d", i)
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg)))
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to send log: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Clean up
|
||||
os.Remove(sockPath)
|
||||
}
|
||||
|
||||
func TestHTTPLogClient(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP log client
|
||||
client := newHTTPLogClient(sockPath)
|
||||
|
||||
t.Run("Ping server", func(t *testing.T) {
|
||||
err := client.Ping()
|
||||
if err != nil {
|
||||
t.Errorf("Ping failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Write logs", func(t *testing.T) {
|
||||
testLog := "test log message from client"
|
||||
n, err := client.Write([]byte(testLog))
|
||||
if err != nil {
|
||||
t.Errorf("Write failed: %v", err)
|
||||
}
|
||||
if n != len(testLog) {
|
||||
t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n)
|
||||
}
|
||||
|
||||
// Check if log was stored by retrieving it
|
||||
logs, err := client.GetLogs()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(logs), testLog) {
|
||||
t.Errorf("Expected log '%s' not found in stored logs", testLog)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Close client", func(t *testing.T) {
|
||||
err := client.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
// Check if channel is closed (signaling completion)
|
||||
select {
|
||||
case _, ok := <-stopLogCh:
|
||||
if ok {
|
||||
t.Error("Expected channel to be closed, but it's still open")
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("Timeout waiting for channel closure")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPLogClientServerUnavailable(t *testing.T) {
|
||||
// Create client with non-existent socket
|
||||
sockPath := "/non/existent/socket.sock"
|
||||
client := newHTTPLogClient(sockPath)
|
||||
|
||||
t.Run("Ping unavailable server", func(t *testing.T) {
|
||||
err := client.Ping()
|
||||
if err == nil {
|
||||
t.Error("Expected ping to fail for unavailable server")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Write to unavailable server", func(t *testing.T) {
|
||||
testLog := "test log message"
|
||||
n, err := client.Write([]byte(testLog))
|
||||
if err != nil {
|
||||
t.Errorf("Write should not return error (ignores errors): %v", err)
|
||||
}
|
||||
if n != len(testLog) {
|
||||
t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Close unavailable server", func(t *testing.T) {
|
||||
err := client.Close()
|
||||
if err == nil {
|
||||
t.Error("Expected close to fail for unavailable server")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkHTTPLogClient(b *testing.B) {
|
||||
// Create a temporary socket path
|
||||
tmpDir := b.TempDir()
|
||||
sockPath := filepath.Join(tmpDir, "bench.sock")
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
go func() {
|
||||
httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP log client
|
||||
client := newHTTPLogClient(sockPath)
|
||||
|
||||
// Benchmark client writes
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logMsg := fmt.Sprintf("benchmark write %d", i)
|
||||
client.Write([]byte(logMsg))
|
||||
}
|
||||
|
||||
// Clean up
|
||||
os.Remove(sockPath)
|
||||
}
|
||||
|
||||
func TestHTTPLogServerWithLogWriter(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait a bit for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Store and retrieve logs", func(t *testing.T) {
|
||||
// Send multiple log messages
|
||||
logs := []string{"log message 1", "log message 2", "log message 3"}
|
||||
|
||||
for _, log := range logs {
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n")))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send log '%s': %v", log, err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Retrieve all logs
|
||||
resp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs response: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(body)
|
||||
for _, log := range logs {
|
||||
if !strings.Contains(logContent, log) {
|
||||
t.Errorf("Expected log '%s' not found in retrieved logs", log)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Empty logs endpoint", func(t *testing.T) {
|
||||
// Create a new server for this test
|
||||
sockPath2 := unixDomainSocketPath(t)
|
||||
stopLogCh2 := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
httpLogServer(sockPath2, stopLogCh2)
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client2 := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath2)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client2.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
t.Errorf("Expected status 204, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
os.Remove(sockPath2)
|
||||
})
|
||||
|
||||
t.Run("Channel closure on exit", func(t *testing.T) {
|
||||
// Send exit signal
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send exit: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check if channel is closed by trying to read from it
|
||||
select {
|
||||
case _, ok := <-stopLogCh:
|
||||
if ok {
|
||||
t.Error("Expected channel to be closed, but it's still open")
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("Timeout waiting for channel closure")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPLogClientGetLogs(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
go func() {
|
||||
httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait a bit for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP log client
|
||||
client := newHTTPLogClient(sockPath)
|
||||
|
||||
t.Run("Get logs from client", func(t *testing.T) {
|
||||
// Send some logs
|
||||
testLogs := []string{"client log 1", "client log 2", "client log 3"}
|
||||
for _, log := range testLogs {
|
||||
client.Write([]byte(log + "\n"))
|
||||
}
|
||||
|
||||
// Retrieve logs using client method
|
||||
logs, err := client.GetLogs()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(logs)
|
||||
for _, log := range testLogs {
|
||||
if !strings.Contains(logContent, log) {
|
||||
t.Errorf("Expected log '%s' not found in retrieved logs", log)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Get empty logs", func(t *testing.T) {
|
||||
// Create a new client for empty logs test
|
||||
sockPath2 := unixDomainSocketPath(t)
|
||||
stopLogCh2 := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
httpLogServer(sockPath2, stopLogCh2)
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client2 := newHTTPLogClient(sockPath2)
|
||||
logs, err := client2.GetLogs()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get empty logs: %v", err)
|
||||
}
|
||||
|
||||
if len(logs) != 0 {
|
||||
t.Errorf("Expected empty logs, got %d bytes", len(logs))
|
||||
}
|
||||
|
||||
os.Remove(sockPath2)
|
||||
})
|
||||
}
|
||||
@@ -1,7 +1,15 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AppCallback provides hooks for injecting certain functionalities
|
||||
// from mobile platforms to main ctrld cli.
|
||||
// This allows mobile applications to customize behavior without modifying core CLI code
|
||||
type AppCallback struct {
|
||||
HostName func() string
|
||||
LanIp func() string
|
||||
@@ -10,10 +18,94 @@ type AppCallback struct {
|
||||
}
|
||||
|
||||
// AppConfig allows overwriting ctrld cli flags from mobile platforms.
|
||||
// This provides a clean interface for mobile apps to configure ctrld behavior
|
||||
type AppConfig struct {
|
||||
CdUID string
|
||||
HomeDir string
|
||||
UpstreamProto string
|
||||
Verbose int
|
||||
LogPath string
|
||||
CdUID string
|
||||
ProvisionID string
|
||||
CustomHostname string
|
||||
HomeDir string
|
||||
UpstreamProto string
|
||||
Verbose int
|
||||
LogPath string
|
||||
}
|
||||
|
||||
// Network and HTTP configuration constants
|
||||
const (
|
||||
// defaultHTTPTimeout provides reasonable timeout for HTTP operations
|
||||
// This prevents hanging requests while allowing sufficient time for network delays
|
||||
defaultHTTPTimeout = 30 * time.Second
|
||||
|
||||
// defaultMaxRetries provides retry attempts for failed HTTP requests
|
||||
// This improves reliability in unstable network conditions
|
||||
defaultMaxRetries = 3
|
||||
|
||||
// downloadServerIp is the fallback IP for download operations
|
||||
// This ensures downloads work even when DNS resolution fails
|
||||
downloadServerIp = "23.171.240.151"
|
||||
)
|
||||
|
||||
// httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback
|
||||
// This ensures reliable HTTP operations by preferring IPv4 and handling timeouts gracefully
|
||||
func httpClientWithFallback(timeout time.Duration) *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
// Prefer IPv4 over IPv6
|
||||
// This improves compatibility with networks that have IPv6 issues
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
FallbackDelay: 1 * time.Millisecond, // Very small delay to prefer IPv4
|
||||
}).DialContext,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// doWithRetry performs an HTTP request with retries
|
||||
// This improves reliability by automatically retrying failed requests with exponential backoff
|
||||
func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, error) {
|
||||
var lastErr error
|
||||
client := httpClientWithFallback(defaultHTTPTimeout)
|
||||
var ipReq *http.Request
|
||||
if ip != "" {
|
||||
ipReq = req.Clone(req.Context())
|
||||
ipReq.Host = ip
|
||||
ipReq.URL.Host = ip
|
||||
}
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Linear backoff reduces server load and improves success rate
|
||||
time.Sleep(time.Second * time.Duration(attempt+1))
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
if ipReq != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("Dial to %q failed", req.Host)
|
||||
mainLog.Load().Warn().Msgf("Fallback to direct ip to download prod version: %q", ip)
|
||||
resp, err = client.Do(ipReq)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
mainLog.Load().Debug().Err(err).
|
||||
Str("method", req.Method).
|
||||
Str("url", req.URL.String()).
|
||||
Msgf("HTTP request attempt %d/%d failed", attempt+1, maxRetries)
|
||||
}
|
||||
return nil, fmt.Errorf("failed after %d attempts to %s %s: %v", maxRetries, req.Method, req.URL, lastErr)
|
||||
}
|
||||
|
||||
// Helper for making GET requests with retries
|
||||
// This provides a simplified interface for common GET operations with built-in retry logic
|
||||
func getWithRetry(url string, ip string) (*http.Response, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return doWithRetry(req, defaultMaxRetries, ip)
|
||||
}
|
||||
|
||||
407
cmd/cli/log_writer.go
Normal file
407
cmd/cli/log_writer.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// Log writer constants for buffer management and log formatting
|
||||
const (
|
||||
// logWriterSize is the default buffer size for log writers
|
||||
// This provides sufficient space for runtime logs without excessive memory usage
|
||||
logWriterSize = 1024 * 1024 * 5 // 5 MB
|
||||
|
||||
// logWriterSmallSize is used for memory-constrained environments
|
||||
// This reduces memory footprint while still maintaining log functionality
|
||||
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
|
||||
|
||||
// logWriterInitialSize is the initial buffer allocation
|
||||
// This provides immediate space for early log entries
|
||||
logWriterInitialSize = 32 * 1024 // 32 KB
|
||||
|
||||
// logWriterSentInterval controls how often logs are sent to external systems
|
||||
// This balances real-time logging with system performance
|
||||
logWriterSentInterval = time.Minute
|
||||
|
||||
// logWriterInitEndMarker marks the end of initialization logs
|
||||
// This helps separate startup logs from runtime logs
|
||||
logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n"
|
||||
|
||||
// logWriterLogEndMarker marks the end of log sections
|
||||
// This provides clear boundaries for log parsing and analysis
|
||||
logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n"
|
||||
)
|
||||
|
||||
// Custom level encoders that handle NOTICE level
|
||||
// Since NOTICE and WARN share the same numeric value (1), we handle them specially
|
||||
// in the encoder to display NOTICE messages with the "NOTICE" prefix.
|
||||
// Note: WARN messages will also display as "NOTICE" because they share the same level value.
|
||||
// This is the intended behavior for visual distinction.
|
||||
|
||||
// noticeLevelEncoder provides custom level encoding for NOTICE level
|
||||
// This ensures NOTICE messages are clearly distinguished from other log levels
|
||||
func noticeLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) {
|
||||
switch l {
|
||||
case ctrld.NoticeLevel:
|
||||
enc.AppendString("NOTICE")
|
||||
default:
|
||||
zapcore.CapitalLevelEncoder(l, enc)
|
||||
}
|
||||
}
|
||||
|
||||
// noticeColorLevelEncoder provides colored level encoding for NOTICE level
|
||||
// This uses cyan color to make NOTICE messages visually distinct in terminal output
|
||||
func noticeColorLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) {
|
||||
switch l {
|
||||
case ctrld.NoticeLevel:
|
||||
enc.AppendString("\x1b[36mNOTICE\x1b[0m") // Cyan color for NOTICE
|
||||
default:
|
||||
zapcore.CapitalColorLevelEncoder(l, enc)
|
||||
}
|
||||
}
|
||||
|
||||
// logViewResponse represents the response structure for log viewing requests
|
||||
// This provides a consistent JSON format for log data retrieval
|
||||
type logViewResponse struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// logSentResponse represents the response structure for log sending operations
|
||||
// This includes size information and error details for debugging
|
||||
type logSentResponse struct {
|
||||
Size int64 `json:"size"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// logReader provides read access to log data with size information.
|
||||
//
|
||||
// This struct encapsulates log reading functionality for external consumers,
|
||||
// providing both the log content and metadata about the log size. It supports
|
||||
// reading from both internal log buffers (when no external logging is configured)
|
||||
// and external log files (when logging to file is enabled).
|
||||
//
|
||||
// Fields:
|
||||
// - r: An io.ReadCloser that provides access to the log content
|
||||
// - size: The total size of the log data in bytes
|
||||
//
|
||||
// The logReader is used by the control server to serve log content to clients
|
||||
// and by various CLI commands that need to display or process log data.
|
||||
type logReader struct {
|
||||
r io.ReadCloser
|
||||
size int64
|
||||
}
|
||||
|
||||
// logWriter is an internal buffer to keep track of runtime log when no logging is enabled.
|
||||
// This provides in-memory log storage for debugging and monitoring purposes
|
||||
type logWriter struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
size int
|
||||
}
|
||||
|
||||
// newLogWriter creates an internal log writer.
|
||||
// This provides the default log writer with standard buffer size
|
||||
func newLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSize)
|
||||
}
|
||||
|
||||
// newSmallLogWriter creates an internal log writer with small buffer size.
|
||||
// This is used in memory-constrained environments or for temporary logging
|
||||
func newSmallLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSmallSize)
|
||||
}
|
||||
|
||||
// newLogWriterWithSize creates an internal log writer with a given buffer size.
|
||||
// This allows customization of log buffer size based on specific requirements
|
||||
func newLogWriterWithSize(size int) *logWriter {
|
||||
lw := &logWriter{size: size}
|
||||
return lw
|
||||
}
|
||||
|
||||
// Write implements io.Writer interface for logWriter
|
||||
// This manages buffer overflow by discarding old data while preserving important markers
|
||||
func (lw *logWriter) Write(p []byte) (int, error) {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
|
||||
// If writing p causes overflows, discard old data.
|
||||
// This prevents unbounded memory growth while maintaining recent logs
|
||||
if lw.buf.Len()+len(p) > lw.size {
|
||||
buf := lw.buf.Bytes()
|
||||
haveEndMarker := false
|
||||
// If there's init end marker already, preserve the data til the marker.
|
||||
// This ensures initialization logs are always available for debugging
|
||||
if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 {
|
||||
buf = buf[:idx+len(logWriterInitEndMarker)]
|
||||
haveEndMarker = true
|
||||
} else {
|
||||
// Otherwise, preserve the initial size data.
|
||||
buf = buf[:logWriterInitialSize]
|
||||
if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 {
|
||||
buf = buf[:idx]
|
||||
}
|
||||
}
|
||||
lw.buf.Reset()
|
||||
lw.buf.Write(buf)
|
||||
if !haveEndMarker {
|
||||
lw.buf.WriteString(logWriterInitEndMarker) // indicate that the log was truncated.
|
||||
}
|
||||
}
|
||||
// If p is bigger than buffer size, truncate p by half until its size is smaller.
|
||||
for len(p)+lw.buf.Len() > lw.size {
|
||||
p = p[len(p)/2:]
|
||||
}
|
||||
return lw.buf.Write(p)
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func (p *prog) initLogging(backup bool) {
|
||||
logCores := initLoggingWithBackup(backup)
|
||||
|
||||
// Initializing internal logging after global logging.
|
||||
p.initInternalLogging(logCores)
|
||||
p.logger.Store(mainLog.Load())
|
||||
}
|
||||
|
||||
// initInternalLogging performs internal logging if there's no log enabled.
|
||||
func (p *prog) initInternalLogging(externalCores []zapcore.Core) {
|
||||
if !p.needInternalLogging() {
|
||||
return
|
||||
}
|
||||
p.initInternalLogWriterOnce.Do(func() {
|
||||
p.Notice().Msg("Internal logging enabled")
|
||||
p.internalLogWriter = newLogWriter()
|
||||
p.internalLogSent = time.Now().Add(-logWriterSentInterval)
|
||||
p.internalWarnLogWriter = newSmallLogWriter()
|
||||
})
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
|
||||
// Create zap cores for different writers
|
||||
var cores []zapcore.Core
|
||||
cores = append(cores, externalCores...)
|
||||
|
||||
// Add core for internal log writer.
|
||||
// Run the internal logging at debug level, so we could
|
||||
// have enough information for troubleshooting.
|
||||
internalCore := newHumanReadableZapCore(lw, zapcore.DebugLevel)
|
||||
cores = append(cores, internalCore)
|
||||
|
||||
// Add core for internal warn log writer
|
||||
warnCore := newHumanReadableZapCore(wlw, zapcore.WarnLevel)
|
||||
cores = append(cores, warnCore)
|
||||
|
||||
// Create a multi-core logger
|
||||
multiCore := zapcore.NewTee(cores...)
|
||||
logger := zap.New(multiCore)
|
||||
|
||||
mainLog.Store(&ctrld.Logger{Logger: logger})
|
||||
}
|
||||
|
||||
// needInternalLogging reports whether prog needs to run internal logging.
|
||||
func (p *prog) needInternalLogging() bool {
|
||||
// Do not run in non-cd mode.
|
||||
if cdUID == "" {
|
||||
return false
|
||||
}
|
||||
// Do not run if there's already log file.
|
||||
if p.cfg.Service.LogPath != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// logReaderNoColor returns a logReader with ANSI color codes stripped from the log content.
|
||||
//
|
||||
// This method is useful when log content needs to be processed by tools that don't
|
||||
// handle ANSI escape sequences properly, or when storing logs in plain text format.
|
||||
// It internally calls logReader(true) to strip color codes.
|
||||
//
|
||||
// Returns:
|
||||
// - *logReader: A logReader instance with color codes removed, or nil if no logs available
|
||||
// - error: Any error encountered during log reading (e.g., empty logs, file access issues)
|
||||
//
|
||||
// Use cases:
|
||||
// - Log processing pipelines that require plain text
|
||||
// - Storing logs in databases or text files
|
||||
// - Displaying logs in environments that don't support color
|
||||
func (p *prog) logReaderNoColor() (*logReader, error) {
|
||||
return p.logReader(true)
|
||||
}
|
||||
|
||||
// logReaderRaw returns a logReader with ANSI color codes preserved in the log content.
|
||||
//
|
||||
// This method maintains the original formatting of log entries including color codes,
|
||||
// which is useful for displaying logs in terminals that support ANSI colors or when
|
||||
// the original visual formatting needs to be preserved. It internally calls logReader(false).
|
||||
//
|
||||
// Returns:
|
||||
// - *logReader: A logReader instance with color codes preserved, or nil if no logs available
|
||||
// - error: Any error encountered during log reading (e.g., empty logs, file access issues)
|
||||
//
|
||||
// Use cases:
|
||||
// - Terminal-based log viewers that support color
|
||||
// - Interactive debugging sessions
|
||||
// - Preserving original log formatting for display
|
||||
func (p *prog) logReaderRaw() (*logReader, error) {
|
||||
return p.logReader(false)
|
||||
}
|
||||
|
||||
// logReader creates a logReader instance for accessing log content with optional color stripping.
|
||||
//
|
||||
// This is the core method that handles log reading from different sources based on the
|
||||
// current logging configuration. It supports both internal logging (when no external
|
||||
// logging is configured) and external file logging (when logging to file is enabled).
|
||||
//
|
||||
// Behavior:
|
||||
// - Internal logging: Reads from internal log buffers (normal logs + warning logs)
|
||||
// and combines them with appropriate markers for separation
|
||||
// - External logging: Reads directly from the configured log file
|
||||
// - Empty logs: Returns appropriate error messages when no log content is available
|
||||
//
|
||||
// Parameters:
|
||||
// - stripColor: If true, removes ANSI color codes from log content; if false, preserves them
|
||||
//
|
||||
// Returns:
|
||||
// - *logReader: A logReader instance providing access to log content and size metadata
|
||||
// - error: Any error encountered during log reading, including:
|
||||
// - "nil internal log writer" - Internal logging not properly initialized
|
||||
// - "nil internal warn log writer" - Warning log writer not properly initialized
|
||||
// - "internal log is empty" - No content in internal log buffers
|
||||
// - "log file is empty" - External log file exists but contains no data
|
||||
// - File system errors when accessing external log files
|
||||
//
|
||||
// The method handles thread-safe access to internal log buffers and provides
|
||||
// comprehensive error handling for various edge cases.
|
||||
func (p *prog) logReader(stripColor bool) (*logReader, error) {
|
||||
if p.needInternalLogging() {
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
if lw == nil {
|
||||
return nil, errors.New("nil internal log writer")
|
||||
}
|
||||
if wlw == nil {
|
||||
return nil, errors.New("nil internal warn log writer")
|
||||
}
|
||||
|
||||
// Normal log content.
|
||||
lw.mu.Lock()
|
||||
lwReader := newLogReader(&lw.buf, stripColor)
|
||||
lwSize := lw.buf.Len()
|
||||
lw.mu.Unlock()
|
||||
// Warn log content.
|
||||
wlw.mu.Lock()
|
||||
wlwReader := newLogReader(&wlw.buf, stripColor)
|
||||
wlwSize := wlw.buf.Len()
|
||||
wlw.mu.Unlock()
|
||||
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logWriterLogEndMarker)), wlwReader)
|
||||
lr := &logReader{r: io.NopCloser(reader)}
|
||||
lr.size = int64(lwSize + wlwSize)
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("internal log is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
if p.cfg.Service.LogPath == "" {
|
||||
return &logReader{r: io.NopCloser(strings.NewReader(""))}, nil
|
||||
}
|
||||
f, err := os.Open(normalizeLogFilePath(p.cfg.Service.LogPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lr := &logReader{r: f}
|
||||
if st, err := f.Stat(); err == nil {
|
||||
lr.size = st.Size()
|
||||
} else {
|
||||
return nil, fmt.Errorf("f.Stat: %w", err)
|
||||
}
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("log file is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
|
||||
// newHumanReadableZapCore creates a zap core optimized for human-readable log output.
|
||||
//
|
||||
// Features:
|
||||
// - Uses development encoder configuration for enhanced readability
|
||||
// - Console encoding with colored log levels for easy visual scanning
|
||||
// - Millisecond precision timestamps in human-friendly format
|
||||
// - Structured field output with clear key-value pairs
|
||||
// - Ideal for development, debugging, and interactive terminal sessions
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The output writer (e.g., os.Stdout, file, buffer)
|
||||
// - level: Minimum log level to capture (e.g., Debug, Info, Warn, Error)
|
||||
//
|
||||
// Returns a zapcore.Core configured for human consumption.
|
||||
func newHumanReadableZapCore(w io.Writer, level zapcore.Level) zapcore.Core {
|
||||
encoderConfig := zap.NewDevelopmentEncoderConfig()
|
||||
encoderConfig.TimeKey = "time"
|
||||
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(time.StampMilli)
|
||||
encoderConfig.EncodeLevel = noticeColorLevelEncoder
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
return zapcore.NewCore(encoder, zapcore.AddSync(w), level)
|
||||
}
|
||||
|
||||
// newMachineFriendlyZapCore creates a zap core optimized for machine processing and log aggregation.
|
||||
//
|
||||
// Features:
|
||||
// - Uses production encoder configuration for consistent, parseable output
|
||||
// - Console encoding with non-colored log levels for log parsing tools
|
||||
// - Millisecond precision timestamps in ISO-like format
|
||||
// - Structured field output optimized for log aggregation systems
|
||||
// - Ideal for production environments, log shipping, and automated analysis
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The output writer (e.g., os.Stdout, file, buffer)
|
||||
// - level: Minimum log level to capture (e.g., Debug, Info, Warn, Error)
|
||||
//
|
||||
// Returns a zapcore.Core configured for machine consumption and log aggregation.
|
||||
func newMachineFriendlyZapCore(w io.Writer, level zapcore.Level) zapcore.Core {
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.TimeKey = "time"
|
||||
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(time.StampMilli)
|
||||
encoderConfig.EncodeLevel = noticeLevelEncoder
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
return zapcore.NewCore(encoder, zapcore.AddSync(w), level)
|
||||
}
|
||||
|
||||
// ansiRegex is a regular expression to match ANSI color codes.
|
||||
var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*m`)
|
||||
|
||||
// newLogReader creates a reader for log buffer content with optional ANSI color stripping.
|
||||
//
|
||||
// This function provides flexible log content access by allowing consumers to choose
|
||||
// between raw log data (with ANSI color codes) or stripped content (without color codes).
|
||||
// The color stripping is useful when logs need to be processed by tools that don't
|
||||
// handle ANSI escape sequences properly, or when storing logs in plain text format.
|
||||
//
|
||||
// Parameters:
|
||||
// - buf: The log buffer containing the log data to read
|
||||
// - stripColor: If true, strips ANSI color codes from the log content;
|
||||
// if false, returns raw log content with color codes preserved
|
||||
//
|
||||
// Returns an io.Reader that provides access to the processed log content.
|
||||
func newLogReader(buf *bytes.Buffer, stripColor bool) io.Reader {
|
||||
if stripColor {
|
||||
return strings.NewReader(ansiRegex.ReplaceAllString(buf.String(), ""))
|
||||
}
|
||||
return strings.NewReader(buf.String())
|
||||
}
|
||||
417
cmd/cli/log_writer_test.go
Normal file
417
cmd/cli/log_writer_test.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_logWriter_Write(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
data := strings.Repeat("A", size)
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
newData := "B"
|
||||
halfData := strings.Repeat("A", len(data)/2) + logWriterInitEndMarker
|
||||
lw.Write([]byte(newData))
|
||||
if lw.buf.String() != halfData+newData {
|
||||
t.Fatalf("unexpected new buf content: %v", lw.buf.String())
|
||||
}
|
||||
|
||||
bigData := strings.Repeat("B", 256*1024)
|
||||
expected := halfData + strings.Repeat("B", 16*1024)
|
||||
lw.Write([]byte(bigData))
|
||||
if lw.buf.String() != expected {
|
||||
t.Fatalf("unexpected big buf content: %v", lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_ConcurrentWrite(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
n := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
lw.Write([]byte(strings.Repeat("A", i)))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if lw.buf.Len() > lw.size {
|
||||
t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_MarkerInitEnd(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
|
||||
paddingSize := 10
|
||||
// Writing half of the size, minus len(end marker) and padding size.
|
||||
dataSize := size/2 - len(logWriterInitEndMarker) - paddingSize
|
||||
data := strings.Repeat("A", dataSize)
|
||||
// Inserting newline for making partial init data
|
||||
data += "\n"
|
||||
// Filling left over buffer to make the log full.
|
||||
// The data length: len(end marker) + padding size - 1 (for newline above) + size/2
|
||||
data += strings.Repeat("A", len(logWriterInitEndMarker)+paddingSize-1+(size/2))
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
lw.Write([]byte("B"))
|
||||
lw.Write([]byte(strings.Repeat("B", 256*1024)))
|
||||
firstIdx := strings.Index(lw.buf.String(), logWriterInitEndMarker)
|
||||
lastIdx := strings.LastIndex(lw.buf.String(), logWriterInitEndMarker)
|
||||
// Check if init end marker present.
|
||||
if firstIdx == -1 || lastIdx == -1 {
|
||||
t.Fatalf("missing init end marker: %s", lw.buf.String())
|
||||
}
|
||||
// Check if init end marker appears only once.
|
||||
if firstIdx != lastIdx {
|
||||
t.Fatalf("log init end marker appears more than once: %s", lw.buf.String())
|
||||
}
|
||||
// Ensure that we have the correct init log data.
|
||||
if !strings.Contains(lw.buf.String(), strings.Repeat("A", dataSize)+logWriterInitEndMarker) {
|
||||
t.Fatalf("unexpected log content: %s", lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestNoticeLevel tests that the custom NOTICE level works correctly
|
||||
func TestNoticeLevel(t *testing.T) {
|
||||
// Create a buffer to capture log output
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Create encoder config with custom NOTICE level support
|
||||
encoderConfig := zap.NewDevelopmentEncoderConfig()
|
||||
encoderConfig.TimeKey = "time"
|
||||
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout("15:04:05.000")
|
||||
encoderConfig.EncodeLevel = noticeLevelEncoder
|
||||
|
||||
// Test with NOTICE level
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
core := zapcore.NewCore(encoder, zapcore.AddSync(&buf), ctrld.NoticeLevel)
|
||||
logger := zap.New(core)
|
||||
ctrldLogger := &ctrld.Logger{Logger: logger}
|
||||
|
||||
// Log messages at different levels
|
||||
ctrldLogger.Debug().Msg("This is a DEBUG message")
|
||||
ctrldLogger.Info().Msg("This is an INFO message")
|
||||
ctrldLogger.Notice().Msg("This is a NOTICE message")
|
||||
ctrldLogger.Warn().Msg("This is a WARN message")
|
||||
ctrldLogger.Error().Msg("This is an ERROR message")
|
||||
|
||||
output := buf.String()
|
||||
|
||||
// Verify that DEBUG and INFO messages are NOT logged (filtered out)
|
||||
if strings.Contains(output, "DEBUG") {
|
||||
t.Error("DEBUG message should not be logged when level is NOTICE")
|
||||
}
|
||||
if strings.Contains(output, "INFO") {
|
||||
t.Error("INFO message should not be logged when level is NOTICE")
|
||||
}
|
||||
|
||||
// Verify that NOTICE, WARN, and ERROR messages ARE logged
|
||||
if !strings.Contains(output, "NOTICE") {
|
||||
t.Error("NOTICE message should be logged when level is NOTICE")
|
||||
}
|
||||
if !strings.Contains(output, "WARN") {
|
||||
t.Error("WARN message should be logged when level is NOTICE")
|
||||
}
|
||||
if !strings.Contains(output, "ERROR") {
|
||||
t.Error("ERROR message should be logged when level is NOTICE")
|
||||
}
|
||||
|
||||
// Verify the NOTICE message content
|
||||
if !strings.Contains(output, "This is a NOTICE message") {
|
||||
t.Error("NOTICE message content should be present")
|
||||
}
|
||||
|
||||
t.Logf("Log output with NOTICE level:\n%s", output)
|
||||
}
|
||||
|
||||
func TestNewLogReader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bufContent string
|
||||
stripColor bool
|
||||
expected string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "empty_buffer_no_color_strip",
|
||||
bufContent: "",
|
||||
stripColor: false,
|
||||
expected: "",
|
||||
description: "Empty buffer should return empty reader",
|
||||
},
|
||||
{
|
||||
name: "empty_buffer_with_color_strip",
|
||||
bufContent: "",
|
||||
stripColor: true,
|
||||
expected: "",
|
||||
description: "Empty buffer with color strip should return empty reader",
|
||||
},
|
||||
{
|
||||
name: "plain_text_no_color_strip",
|
||||
bufContent: "This is plain text without any color codes",
|
||||
stripColor: false,
|
||||
expected: "This is plain text without any color codes",
|
||||
description: "Plain text should be returned as-is when not stripping colors",
|
||||
},
|
||||
{
|
||||
name: "plain_text_with_color_strip",
|
||||
bufContent: "This is plain text without any color codes",
|
||||
stripColor: true,
|
||||
expected: "This is plain text without any color codes",
|
||||
description: "Plain text should be returned as-is when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "text_with_ansi_codes_no_strip",
|
||||
bufContent: "Normal text \x1b[31mred text\x1b[0m normal again",
|
||||
stripColor: false,
|
||||
expected: "Normal text \x1b[31mred text\x1b[0m normal again",
|
||||
description: "ANSI color codes should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "text_with_ansi_codes_with_strip",
|
||||
bufContent: "Normal text \x1b[31mred text\x1b[0m normal again",
|
||||
stripColor: true,
|
||||
expected: "Normal text red text normal again",
|
||||
description: "ANSI color codes should be removed when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "multiple_ansi_codes_no_strip",
|
||||
bufContent: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text",
|
||||
stripColor: false,
|
||||
expected: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text",
|
||||
description: "Multiple ANSI codes should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "multiple_ansi_codes_with_strip",
|
||||
bufContent: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text",
|
||||
stripColor: true,
|
||||
expected: "Bold Green Blue text",
|
||||
description: "Multiple ANSI codes should be removed when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "complex_ansi_sequences_no_strip",
|
||||
bufContent: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m",
|
||||
stripColor: false,
|
||||
expected: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m",
|
||||
description: "Complex ANSI sequences should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "complex_ansi_sequences_with_strip",
|
||||
bufContent: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m",
|
||||
stripColor: true,
|
||||
expected: "Bold red on green Orange",
|
||||
description: "Complex ANSI sequences should be removed when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "ansi_codes_with_newlines_no_strip",
|
||||
bufContent: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3",
|
||||
stripColor: false,
|
||||
expected: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3",
|
||||
description: "ANSI codes with newlines should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "ansi_codes_with_newlines_with_strip",
|
||||
bufContent: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3",
|
||||
stripColor: true,
|
||||
expected: "Line 1\nRed line\nLine 3",
|
||||
description: "ANSI codes with newlines should be removed when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "malformed_ansi_codes_no_strip",
|
||||
bufContent: "Text \x1b[invalidm \x1b[0m normal",
|
||||
stripColor: false,
|
||||
expected: "Text \x1b[invalidm \x1b[0m normal",
|
||||
description: "Malformed ANSI codes should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "malformed_ansi_codes_with_strip",
|
||||
bufContent: "Text \x1b[invalidm \x1b[0m normal",
|
||||
stripColor: true,
|
||||
expected: "Text \x1b[invalidm normal",
|
||||
description: "Non-matching ANSI sequences should be preserved when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "large_buffer_no_strip",
|
||||
bufContent: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m",
|
||||
stripColor: false,
|
||||
expected: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m",
|
||||
description: "Large buffer should handle ANSI codes correctly when not stripping",
|
||||
},
|
||||
{
|
||||
name: "large_buffer_with_strip",
|
||||
bufContent: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m",
|
||||
stripColor: true,
|
||||
expected: strings.Repeat("A", 10000) + strings.Repeat("B", 1000),
|
||||
description: "Large buffer should remove ANSI codes correctly when stripping",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a buffer with the test content
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString(tt.bufContent)
|
||||
|
||||
// Create the log reader
|
||||
reader := newLogReader(buf, tt.stripColor)
|
||||
|
||||
// Read all content from the reader
|
||||
content, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read from log reader: %v", err)
|
||||
}
|
||||
|
||||
// Verify the content matches expected
|
||||
actual := string(content)
|
||||
if actual != tt.expected {
|
||||
t.Errorf("Expected content: %q, got: %q", tt.expected, actual)
|
||||
t.Logf("Description: %s", tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLogReader_ReaderBehavior(t *testing.T) {
|
||||
// Test that the returned reader behaves correctly
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("Test content with \x1b[31mred\x1b[0m text")
|
||||
|
||||
// Test with color stripping
|
||||
reader := newLogReader(buf, true)
|
||||
|
||||
// Test reading in chunks
|
||||
chunk1 := make([]byte, 10)
|
||||
n1, err := reader.Read(chunk1)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("Unexpected error reading first chunk: %v", err)
|
||||
}
|
||||
if n1 != 10 {
|
||||
t.Errorf("Expected to read 10 bytes, got %d", n1)
|
||||
}
|
||||
|
||||
// Test reading remaining content
|
||||
remaining, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read remaining content: %v", err)
|
||||
}
|
||||
|
||||
// Verify total content
|
||||
totalContent := string(chunk1[:n1]) + string(remaining)
|
||||
expected := "Test content with red text"
|
||||
if totalContent != expected {
|
||||
t.Errorf("Expected total content: %q, got: %q", expected, totalContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLogReader_ConcurrentAccess(t *testing.T) {
|
||||
// Test concurrent access to the same buffer
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("Concurrent test with \x1b[32mgreen\x1b[0m text")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
results := make(chan string, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reader := newLogReader(buf, true)
|
||||
content, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to read content: %v", err)
|
||||
return
|
||||
}
|
||||
results <- string(content)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
// Verify all goroutines got the same result
|
||||
expected := "Concurrent test with green text"
|
||||
for result := range results {
|
||||
if result != expected {
|
||||
t.Errorf("Expected: %q, got: %q", expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLogReader_ANSIRegexEdgeCases(t *testing.T) {
|
||||
// Test edge cases for ANSI regex matching
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty_escape_sequence",
|
||||
input: "Text \x1b[m normal",
|
||||
expected: "Text normal",
|
||||
},
|
||||
{
|
||||
name: "multiple_semicolons",
|
||||
input: "Text \x1b[1;2;3;4m normal",
|
||||
expected: "Text normal",
|
||||
},
|
||||
{
|
||||
name: "numeric_only",
|
||||
input: "Text \x1b[123m normal",
|
||||
expected: "Text normal",
|
||||
},
|
||||
{
|
||||
name: "mixed_numeric_semicolon",
|
||||
input: "Text \x1b[1;23;456m normal",
|
||||
expected: "Text normal",
|
||||
},
|
||||
{
|
||||
name: "no_closing_bracket",
|
||||
input: "Text \x1b[31 normal",
|
||||
expected: "Text \x1b[31 normal",
|
||||
},
|
||||
{
|
||||
name: "no_opening_bracket",
|
||||
input: "Text 31m normal",
|
||||
expected: "Text 31m normal",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString(tt.input)
|
||||
|
||||
reader := newLogReader(buf, true)
|
||||
content, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read content: %v", err)
|
||||
}
|
||||
|
||||
actual := string(content)
|
||||
if actual != tt.expected {
|
||||
t.Errorf("Expected: %q, got: %q", tt.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -84,7 +84,7 @@ func (p *prog) detectLoop(msg *dns.Msg) {
|
||||
//
|
||||
// See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html
|
||||
func (p *prog) checkDnsLoop() {
|
||||
mainLog.Load().Debug().Msg("start checking DNS loop")
|
||||
p.Debug().Msg("Start checking DNS loop")
|
||||
upstream := make(map[string]*ctrld.UpstreamConfig)
|
||||
p.loopMu.Lock()
|
||||
for n, uc := range p.cfg.Upstream {
|
||||
@@ -93,7 +93,7 @@ func (p *prog) checkDnsLoop() {
|
||||
}
|
||||
// Do not send test query to external upstream.
|
||||
if !canBeLocalUpstream(uc.Domain) {
|
||||
mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n)
|
||||
p.Debug().Msgf("Skipping external: upstream.%s", n)
|
||||
continue
|
||||
}
|
||||
uid := uc.UID()
|
||||
@@ -102,19 +102,24 @@ func (p *prog) checkDnsLoop() {
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
for uid := range p.loop {
|
||||
msg := loopTestMsg(uid)
|
||||
uc := upstream[uid]
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
// Skipping upstream which is being marked as down.
|
||||
if uc == nil {
|
||||
continue
|
||||
}
|
||||
resolver, err := ctrld.NewResolver(loggerCtx, uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
p.Warn().Err(err).Msgf("Could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
continue
|
||||
}
|
||||
if _, err := resolver.Resolve(context.Background(), msg); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
p.Warn().Err(err).Msgf("Could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msg("end checking DNS loop")
|
||||
p.Debug().Msg("End checking DNS loop")
|
||||
}
|
||||
|
||||
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
|
||||
@@ -133,7 +138,7 @@ func (p *prog) checkDnsLoopTicker(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// loopTestMsg generates DNS message for checking loop.
|
||||
// loopTestMsg creates a DNS test message for loop detection
|
||||
func loopTestMsg(uid string) *dns.Msg {
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype)
|
||||
|
||||
156
cmd/cli/main.go
156
cmd/cli/main.go
@@ -5,14 +5,16 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/rs/zerolog"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// Global variables for CLI configuration and state management
|
||||
// These are used across multiple commands and need to persist throughout the application lifecycle
|
||||
var (
|
||||
configPath string
|
||||
configBase64 string
|
||||
@@ -29,38 +31,53 @@ var (
|
||||
silent bool
|
||||
cdUID string
|
||||
cdOrg string
|
||||
customHostname string
|
||||
cdDev bool
|
||||
iface string
|
||||
ifaceStartStop string
|
||||
nextdns string
|
||||
cdUpstreamProto string
|
||||
deactivationPin int64
|
||||
skipSelfChecks bool
|
||||
cleanup bool
|
||||
startOnly bool
|
||||
rfc1918 bool
|
||||
|
||||
mainLog atomic.Pointer[zerolog.Logger]
|
||||
consoleWriter zerolog.ConsoleWriter
|
||||
noConfigStart bool
|
||||
mainLog atomic.Pointer[ctrld.Logger]
|
||||
consoleWriter zapcore.Core
|
||||
consoleWriterLevel zapcore.Level
|
||||
noConfigStart bool
|
||||
)
|
||||
|
||||
// Flag name constants for consistent reference across the codebase
|
||||
// Using constants prevents typos and makes refactoring easier
|
||||
const (
|
||||
cdUidFlagName = "cd"
|
||||
cdOrgFlagName = "cd-org"
|
||||
nextdnsFlagName = "nextdns"
|
||||
cdUidFlagName = "cd"
|
||||
cdOrgFlagName = "cd-org"
|
||||
customHostnameFlagName = "custom-hostname"
|
||||
nextdnsFlagName = "nextdns"
|
||||
)
|
||||
|
||||
// init initializes the default logger before any CLI commands are executed
|
||||
// This ensures logging is available even during early initialization phases
|
||||
func init() {
|
||||
l := zerolog.New(io.Discard)
|
||||
mainLog.Store(&l)
|
||||
l := zap.NewNop()
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
}
|
||||
|
||||
// Main is the entry point for the CLI application
|
||||
// It initializes configuration, sets up the CLI structure, and executes the root command
|
||||
func Main() {
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
initCLI()
|
||||
rootCmd := initCLI()
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeLogFilePath converts relative log file paths to absolute paths
|
||||
// This ensures log files are created in predictable locations regardless of working directory
|
||||
func normalizeLogFilePath(logFilePath string) string {
|
||||
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
|
||||
return logFilePath
|
||||
@@ -76,29 +93,36 @@ func normalizeLogFilePath(logFilePath string) string {
|
||||
}
|
||||
|
||||
// initConsoleLogging initializes console logging, then storing to mainLog.
|
||||
// This sets up human-readable logging output for interactive use
|
||||
func initConsoleLogging() {
|
||||
consoleWriter = zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
|
||||
w.TimeFormat = time.StampMilli
|
||||
})
|
||||
multi := zerolog.MultiLevelWriter(consoleWriter)
|
||||
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
|
||||
mainLog.Store(&l)
|
||||
consoleWriterLevel = ctrld.NoticeLevel
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
// For silent mode, use a no-op logger to suppress all output
|
||||
l := zap.NewNop()
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
case verbose == 1:
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
// Info level provides basic operational information
|
||||
consoleWriterLevel = zapcore.InfoLevel
|
||||
case verbose > 1:
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
default:
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
// Debug level provides detailed diagnostic information
|
||||
consoleWriterLevel = zapcore.DebugLevel
|
||||
}
|
||||
consoleWriter = newHumanReadableZapCore(os.Stdout, consoleWriterLevel)
|
||||
l := zap.New(consoleWriter)
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func initLogging() {
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
initLoggingWithBackup(true)
|
||||
// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded
|
||||
// to be used for all interactive commands.
|
||||
//
|
||||
// Current log file config will also be ignored.
|
||||
// This prevents log file conflicts during interactive command execution
|
||||
func initInteractiveLogging() {
|
||||
old := cfg.Service.LogPath
|
||||
cfg.Service.LogPath = ""
|
||||
initLoggingWithBackup(false)
|
||||
cfg.Service.LogPath = old
|
||||
}
|
||||
|
||||
// initLoggingWithBackup initializes log setup base on current config.
|
||||
@@ -107,67 +131,101 @@ func initLogging() {
|
||||
// This is only used in runCmd for special handling in case of logging config
|
||||
// change in cd mode. Without special reason, the caller should use initLogging
|
||||
// wrapper instead of calling this function directly.
|
||||
func initLoggingWithBackup(doBackup bool) {
|
||||
writers := []io.Writer{io.Discard}
|
||||
func initLoggingWithBackup(doBackup bool) []zapcore.Core {
|
||||
var writers []io.Writer
|
||||
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
|
||||
// Create parent directory if necessary.
|
||||
// This ensures log files can be created even if the directory doesn't exist
|
||||
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
|
||||
mainLog.Load().Error().Msgf("failed to create log path: %v", err)
|
||||
mainLog.Load().Error().Msgf("Failed to create log path: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Default open log file in append mode.
|
||||
// This preserves existing log entries across restarts
|
||||
flags := os.O_CREATE | os.O_RDWR | os.O_APPEND
|
||||
if doBackup {
|
||||
// Backup old log file with .1 suffix.
|
||||
if err := os.Rename(logFilePath, logFilePath+".1"); err != nil && !os.IsNotExist(err) {
|
||||
mainLog.Load().Error().Msgf("could not backup old log file: %v", err)
|
||||
// This prevents log file corruption during rotation
|
||||
if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) {
|
||||
mainLog.Load().Error().Msgf("Could not backup old log file: %v", err)
|
||||
} else {
|
||||
// Backup was created, set flags for truncating old log file.
|
||||
// This ensures a clean start for the new log file
|
||||
flags = os.O_CREATE | os.O_RDWR
|
||||
}
|
||||
}
|
||||
logFile, err := os.OpenFile(logFilePath, flags, os.FileMode(0o600))
|
||||
logFile, err := openLogFile(logFilePath, flags)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msgf("failed to create log file: %v", err)
|
||||
mainLog.Load().Error().Msgf("Failed to create log file: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
writers = append(writers, logFile)
|
||||
}
|
||||
writers = append(writers, consoleWriter)
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
l := mainLog.Load().Output(multi).With().Logger()
|
||||
mainLog.Store(&l)
|
||||
// TODO: find a better way.
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
// Create zap cores for different writers
|
||||
// Multiple cores allow logging to both console and file simultaneously
|
||||
var cores []zapcore.Core
|
||||
cores = append(cores, consoleWriter)
|
||||
|
||||
// Determine log level based on verbosity and configuration
|
||||
// This provides flexible logging control for different use cases
|
||||
logLevel := cfg.Service.LogLevel
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
return
|
||||
// For silent mode, use a no-op logger to suppress all output
|
||||
l := zap.NewNop()
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
return cores
|
||||
case verbose == 1:
|
||||
logLevel = "info"
|
||||
case verbose > 1:
|
||||
logLevel = "debug"
|
||||
}
|
||||
if logLevel == "" {
|
||||
return
|
||||
|
||||
// Parse log level string to zapcore.Level
|
||||
// This provides human-readable log level configuration
|
||||
var level zapcore.Level
|
||||
switch logLevel {
|
||||
case "debug":
|
||||
level = zapcore.DebugLevel
|
||||
case "info":
|
||||
level = zapcore.InfoLevel
|
||||
case "notice":
|
||||
level = ctrld.NoticeLevel
|
||||
case "warn":
|
||||
level = zapcore.WarnLevel
|
||||
case "error":
|
||||
level = zapcore.ErrorLevel
|
||||
default:
|
||||
level = zapcore.InfoLevel // default level
|
||||
}
|
||||
level, err := zerolog.ParseLevel(logLevel)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set log level")
|
||||
return
|
||||
|
||||
consoleWriter.Enabled(level)
|
||||
// Add cores for all writers
|
||||
// This enables multi-destination logging (console + file)
|
||||
for _, writer := range writers {
|
||||
core := newMachineFriendlyZapCore(writer, level)
|
||||
cores = append(cores, core)
|
||||
}
|
||||
zerolog.SetGlobalLevel(level)
|
||||
|
||||
// Create a multi-core logger
|
||||
// This allows simultaneous logging to multiple destinations
|
||||
multiCore := zapcore.NewTee(cores...)
|
||||
logger := zap.New(multiCore)
|
||||
mainLog.Store(&ctrld.Logger{Logger: logger})
|
||||
|
||||
return cores
|
||||
}
|
||||
|
||||
// initCache initializes DNS cache configuration
|
||||
// This improves performance by caching frequently requested DNS responses
|
||||
func initCache() {
|
||||
if !cfg.Service.CacheEnable {
|
||||
return
|
||||
}
|
||||
if cfg.Service.CacheSize == 0 {
|
||||
// Default cache size provides good balance between memory usage and performance
|
||||
cfg.Service.CacheSize = 4096
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,13 +5,28 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
var logOutput strings.Builder
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
l := zerolog.New(&logOutput)
|
||||
mainLog.Store(&l)
|
||||
// Create a custom writer that writes to logOutput
|
||||
writer := zapcore.AddSync(&logOutput)
|
||||
|
||||
// Create zap encoder
|
||||
encoderConfig := zap.NewDevelopmentEncoderConfig()
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
|
||||
// Create core that writes to our string builder
|
||||
core := zapcore.NewCore(encoder, writer, zap.DebugLevel)
|
||||
|
||||
// Create logger
|
||||
l := zap.New(core)
|
||||
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
)
|
||||
|
||||
// metricsServer represents a server to expose Prometheus metrics via HTTP.
|
||||
// This provides monitoring and observability for the DNS proxy service
|
||||
type metricsServer struct {
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
@@ -24,6 +25,7 @@ type metricsServer struct {
|
||||
}
|
||||
|
||||
// newMetricsServer returns new metrics server.
|
||||
// This initializes the HTTP server for exposing Prometheus metrics
|
||||
func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, error) {
|
||||
mux := http.NewServeMux()
|
||||
ms := &metricsServer{
|
||||
@@ -37,11 +39,13 @@ func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, er
|
||||
}
|
||||
|
||||
// register adds handlers for given pattern.
|
||||
// This provides a clean interface for adding HTTP endpoints to the metrics server
|
||||
func (ms *metricsServer) register(pattern string, handler http.Handler) {
|
||||
ms.mux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
// registerMetricsServerHandler adds handlers for metrics server.
|
||||
// This sets up both Prometheus format and JSON format endpoints for metrics
|
||||
func (ms *metricsServer) registerMetricsServerHandler() {
|
||||
ms.register("/metrics", promhttp.HandlerFor(
|
||||
ms.reg,
|
||||
@@ -74,6 +78,7 @@ func (ms *metricsServer) registerMetricsServerHandler() {
|
||||
}
|
||||
|
||||
// start runs the metricsServer.
|
||||
// This starts the HTTP server for metrics exposure
|
||||
func (ms *metricsServer) start() error {
|
||||
listener, err := net.Listen("tcp", ms.addr)
|
||||
if err != nil {
|
||||
@@ -85,6 +90,7 @@ func (ms *metricsServer) start() error {
|
||||
}
|
||||
|
||||
// stop shutdowns the metricsServer within 2 seconds timeout.
|
||||
// This ensures graceful shutdown of the metrics server
|
||||
func (ms *metricsServer) stop() error {
|
||||
if !ms.started {
|
||||
return nil
|
||||
@@ -95,6 +101,7 @@ func (ms *metricsServer) stop() error {
|
||||
}
|
||||
|
||||
// runMetricsServer initializes metrics stats and runs the metrics server if enabled.
|
||||
// This sets up the complete metrics infrastructure including Prometheus collectors
|
||||
func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
if !p.metricsEnabled() {
|
||||
return
|
||||
@@ -107,7 +114,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
// Register queries count stats if enabled.
|
||||
if cfg.Service.MetricsQueryStats {
|
||||
if p.metricsQueryStats.Load() {
|
||||
reg.MustRegister(statsQueriesCount)
|
||||
reg.MustRegister(statsClientQueriesCount)
|
||||
}
|
||||
@@ -115,7 +122,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
addr := p.cfg.Service.MetricsListener
|
||||
ms, err := newMetricsServer(addr, reg)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create new metrics server")
|
||||
mainLog.Load().Warn().Err(err).Msg("Could not create new metrics server")
|
||||
return
|
||||
}
|
||||
// Only start listener address if defined.
|
||||
@@ -130,9 +137,9 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
statsVersion.WithLabelValues(commit, runtime.Version(), curVersion()).Inc()
|
||||
reg.MustRegister(statsTimeStart)
|
||||
statsTimeStart.Set(float64(time.Now().Unix()))
|
||||
mainLog.Load().Debug().Msgf("starting metrics server on: %s", addr)
|
||||
mainLog.Load().Debug().Msgf("Starting metrics server on: %s", addr)
|
||||
if err := ms.start(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not start metrics server")
|
||||
mainLog.Load().Warn().Err(err).Msg("Could not start metrics server")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -144,7 +151,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
}
|
||||
|
||||
if err := ms.stop(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not stop metrics server")
|
||||
mainLog.Load().Warn().Err(err).Msg("Could not stop metrics server")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
package cli
|
||||
|
||||
import "strings"
|
||||
|
||||
// Copied from https://gist.github.com/Ultraporing/fe52981f678be6831f747c206a4861cb
|
||||
|
||||
// Mac Address parts to look for, and identify non-physical devices. There may be more, update me!
|
||||
var macAddrPartsToFilter = []string{
|
||||
"00:03:FF", // Microsoft Hyper-V, Virtual Server, Virtual PC
|
||||
"0A:00:27", // VirtualBox
|
||||
"00:00:00:00:00", // Teredo Tunneling Pseudo-Interface
|
||||
"00:50:56", // VMware ESX 3, Server, Workstation, Player
|
||||
"00:1C:14", // VMware ESX 3, Server, Workstation, Player
|
||||
"00:0C:29", // VMware ESX 3, Server, Workstation, Player
|
||||
"00:05:69", // VMware ESX 3, Server, Workstation, Player
|
||||
"00:1C:42", // Microsoft Hyper-V, Virtual Server, Virtual PC
|
||||
"00:0F:4B", // Virtual Iron 4
|
||||
"00:16:3E", // Red Hat Xen, Oracle VM, XenSource, Novell Xen
|
||||
"08:00:27", // Sun xVM VirtualBox
|
||||
"7A:79", // Hamachi
|
||||
}
|
||||
|
||||
// Filters the possible physical interface address by comparing it to known popular VM Software addresses
|
||||
// and Teredo Tunneling Pseudo-Interface.
|
||||
//
|
||||
//lint:ignore U1000 use in net_windows.go
|
||||
func isPhysicalInterface(addr string) bool {
|
||||
for _, macPart := range macAddrPartsToFilter {
|
||||
if strings.HasPrefix(strings.ToLower(addr), strings.ToLower(macPart)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -9,17 +9,18 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error {
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output()
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
patched := false
|
||||
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
|
||||
patched = true
|
||||
iface.Name = name
|
||||
mainLog.Load().Debug().Str("network_service", name).Msg("found network service name for interface")
|
||||
}
|
||||
return nil
|
||||
return patched, nil
|
||||
}
|
||||
|
||||
func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
@@ -43,20 +44,8 @@ func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one, which includes:
|
||||
//
|
||||
// - en0: physical wireless
|
||||
// - en1: Thunderbolt 1
|
||||
// - en2: Thunderbolt 2
|
||||
// - en3: Thunderbolt 3
|
||||
// - en4: Thunderbolt 4
|
||||
//
|
||||
// For full list, see: https://unix.stackexchange.com/questions/603506/what-are-these-ifconfig-interfaces-on-macos
|
||||
func validInterface(iface *net.Interface) bool {
|
||||
switch iface.Name {
|
||||
case "en0", "en1", "en2", "en3", "en4":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
17
cmd/cli/net_linux.go
Normal file
17
cmd/cli/net_linux.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// patchNetIfaceName patches network interface names on Linux
|
||||
// This is a no-op on Linux as interface names don't need special handling
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// Only non-virtual interfaces are considered valid.
|
||||
// This prevents DNS configuration on virtual interfaces like docker, veth, etc.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
@@ -1,9 +1,13 @@
|
||||
//go:build !darwin && !windows
|
||||
//go:build !darwin && !windows && !linux
|
||||
|
||||
package cli
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error { return nil }
|
||||
// patchNetIfaceName patches network interface names on non-Linux/Darwin platforms
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
func validInterface(iface *net.Interface) bool { return true }
|
||||
// validInterface checks if an interface is valid on non-Linux/Darwin platforms
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
|
||||
|
||||
@@ -4,18 +4,13 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error {
|
||||
return nil
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// On Windows, only physical interfaces are considered valid.
|
||||
func validInterface(iface *net.Interface) bool {
|
||||
if iface == nil {
|
||||
return false
|
||||
}
|
||||
if isPhysicalInterface(iface.HardwareAddr.String()) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
47
cmd/cli/net_windows_test.go
Normal file
47
cmd/cli/net_windows_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_validInterfaces(t *testing.T) {
|
||||
verbose = 3
|
||||
initConsoleLogging()
|
||||
start := time.Now()
|
||||
im := ctrld.ValidInterfaces(ctrld.LoggerCtx(context.Background(), mainLog.Load()))
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
ifaces := slices.Collect(maps.Keys(im))
|
||||
|
||||
start = time.Now()
|
||||
ifacesPowershell := validInterfacesPowershell()
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(ifaces)
|
||||
slices.Sort(ifacesPowershell)
|
||||
if !slices.Equal(ifaces, ifacesPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", ifacesPowershell, ifaces)
|
||||
}
|
||||
}
|
||||
|
||||
func validInterfacesPowershell() []string {
|
||||
out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var res []string
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
ifaceName := strings.TrimSpace(scanner.Text())
|
||||
res = append(res, ifaceName)
|
||||
}
|
||||
return res
|
||||
}
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func (p *prog) watchLinkState(ctx context.Context) {
|
||||
@@ -12,7 +14,7 @@ func (p *prog) watchLinkState(ctx context.Context) {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
if err := netlink.LinkSubscribe(ch, done); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not subscribe link")
|
||||
p.Warn().Err(err).Msg("Could not subscribe link")
|
||||
return
|
||||
}
|
||||
for {
|
||||
@@ -24,9 +26,9 @@ func (p *prog) watchLinkState(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
|
||||
p.Debug().Msgf("Link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,66 +23,67 @@ systemd-resolved=false
|
||||
var networkManagerCtrldConfFile = filepath.Join(nmConfDir, nmCtrldConfFilename)
|
||||
|
||||
// hasNetworkManager reports whether NetworkManager executable found.
|
||||
// hasNetworkManager checks if NetworkManager is available on the system
|
||||
func hasNetworkManager() bool {
|
||||
exe, _ := exec.LookPath("NetworkManager")
|
||||
return exe != ""
|
||||
}
|
||||
|
||||
func setupNetworkManager() error {
|
||||
func (p *prog) setupNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent {
|
||||
mainLog.Load().Debug().Msg("NetworkManager already setup, nothing to do")
|
||||
p.Debug().Msg("NetworkManager already setup, nothing to do")
|
||||
return nil
|
||||
}
|
||||
err := os.WriteFile(networkManagerCtrldConfFile, []byte(nmCtrldConfContent), os.FileMode(0644))
|
||||
if os.IsNotExist(err) {
|
||||
mainLog.Load().Debug().Msg("NetworkManager is not available")
|
||||
p.Debug().Msg("NetworkManager is not available")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not write NetworkManager ctrld config file")
|
||||
p.Debug().Err(err).Msg("Could not write NetworkManager ctrld config file")
|
||||
return err
|
||||
}
|
||||
|
||||
reloadNetworkManager()
|
||||
mainLog.Load().Debug().Msg("setup NetworkManager done")
|
||||
p.reloadNetworkManager()
|
||||
p.Debug().Msg("Setup NetworkManager done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreNetworkManager() error {
|
||||
func (p *prog) restoreNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
err := os.Remove(networkManagerCtrldConfFile)
|
||||
if os.IsNotExist(err) {
|
||||
mainLog.Load().Debug().Msg("NetworkManager is not available")
|
||||
p.Debug().Msg("NetworkManager is not available")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not remove NetworkManager ctrld config file")
|
||||
p.Debug().Err(err).Msg("Could not remove NetworkManager ctrld config file")
|
||||
return err
|
||||
}
|
||||
|
||||
reloadNetworkManager()
|
||||
mainLog.Load().Debug().Msg("restore NetworkManager done")
|
||||
p.reloadNetworkManager()
|
||||
p.Debug().Msg("Restore NetworkManager done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func reloadNetworkManager() {
|
||||
func (p *prog) reloadNetworkManager() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
conn, err := dbus.NewSystemConnectionContext(ctx)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not create new system connection")
|
||||
p.Error().Err(err).Msg("Could not create new system connection")
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
waitCh := make(chan string)
|
||||
if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not reload NetworkManager")
|
||||
p.Debug().Err(err).Msg("Could not reload NetworkManager")
|
||||
return
|
||||
}
|
||||
<-waitCh
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
package cli
|
||||
|
||||
func setupNetworkManager() error {
|
||||
reloadNetworkManager()
|
||||
func (p *prog) setupNetworkManager() error {
|
||||
p.reloadNetworkManager()
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreNetworkManager() error {
|
||||
reloadNetworkManager()
|
||||
func (p *prog) restoreNetworkManager() error {
|
||||
p.reloadNetworkManager()
|
||||
return nil
|
||||
}
|
||||
|
||||
func reloadNetworkManager() {}
|
||||
func (p *prog) reloadNetworkManager() {}
|
||||
|
||||
@@ -8,11 +8,12 @@ import (
|
||||
|
||||
const nextdnsURL = "https://dns.nextdns.io"
|
||||
|
||||
// generateNextDNSConfig generates NextDNS configuration for the given UID
|
||||
func generateNextDNSConfig(uid string) {
|
||||
if uid == "" {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver")
|
||||
mainLog.Load().Info().Msg("Generating ctrld config for NextDNS resolver")
|
||||
cfg = ctrld.Config{
|
||||
Listener: map[string]*ctrld.ListenerConfig{
|
||||
"0": {
|
||||
|
||||
5
cmd/cli/nocgo.go
Normal file
5
cmd/cli/nocgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = false
|
||||
@@ -8,26 +8,31 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// allocate loopback ip
|
||||
// allocateIP allocates an IP address on the specified interface
|
||||
// sudo ifconfig lo0 alias 127.0.0.2 up
|
||||
func allocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address")
|
||||
cmd := exec.Command("ifconfig", "lo0", "alias", ip, "up")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("allocateIP failed")
|
||||
mainLog.Load().Error().Err(err).Msg("AllocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// deAllocateIP deallocates an IP address from the specified interface
|
||||
func deAllocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address")
|
||||
cmd := exec.Command("ifconfig", "lo0", "-alias", ip)
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
|
||||
mainLog.Load().Error().Err(err).Msg("DeAllocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -47,12 +52,19 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
|
||||
// networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1
|
||||
// TODO(cuonglm): use system API
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration")
|
||||
|
||||
// Note that networksetup won't modify search domains settings,
|
||||
// This assignment is just a placeholder to silent linter.
|
||||
_ = searchDomains
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name}
|
||||
args = append(args, nameservers...)
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%v: %w", string(out), err)
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -70,21 +82,30 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
|
||||
// TODO(cuonglm): use system API
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
if ns := savedStaticNameservers(iface); len(ns) > 0 {
|
||||
if err := setDNS(iface, ns); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration")
|
||||
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name, "empty"}
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%v: %w", string(out), err)
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration reset successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if ns := ctrld.SavedStaticNameservers(iface); len(ns) > 0 {
|
||||
err = setDNS(iface, ns)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// currentDNS returns the current DNS servers for the specified interface
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
return ctrld.CurrentNameserversFromResolvconf()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
|
||||
@@ -5,27 +5,36 @@ import (
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// allocate loopback ip
|
||||
// allocateIP allocates an IP address on the specified interface
|
||||
// sudo ifconfig lo0 127.0.0.53 alias
|
||||
func allocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address")
|
||||
cmd := exec.Command("ifconfig", "lo0", ip, "alias")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("allocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// deAllocateIP deallocates an IP address from the specified interface
|
||||
func deAllocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address")
|
||||
cmd := exec.Command("ifconfig", "lo0", ip, "-alias")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -36,9 +45,11 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
|
||||
|
||||
// set the dns server for the provided network interface
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
r, err := dns.NewOSConfigurator(logf, iface.Name)
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration")
|
||||
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to create DNS OS configurator")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -47,10 +58,22 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
ns = append(ns, netip.MustParseAddr(nameserver))
|
||||
}
|
||||
|
||||
if err := r.SetDNS(dns.OSConfig{Nameservers: ns}); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to set DNS")
|
||||
osConfig := dns.OSConfig{
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("Failed to get search domains list")
|
||||
}
|
||||
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to set DNS")
|
||||
return err
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -59,22 +82,34 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
// resetDNS resets DNS servers for the specified interface
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
r, err := dns.NewOSConfigurator(logf, iface.Name)
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration")
|
||||
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to create DNS OS configurator")
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.Close(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to rollback DNS setting")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to rollback DNS setting")
|
||||
return err
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration reset successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// currentDNS returns the current DNS servers for the specified interface
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
return ctrld.CurrentNameserversFromResolvconf()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
|
||||
@@ -17,30 +17,40 @@ import (
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/nclient4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6/client6"
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
||||
|
||||
type getDNS func(iface string) []string
|
||||
|
||||
// allocate loopback ip
|
||||
// sudo ip a add 127.0.0.2/24 dev lo
|
||||
func allocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address")
|
||||
cmd := exec.Command("ip", "a", "add", ip+"/24", "dev", "lo")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("allocateIP failed: %s", string(out))
|
||||
mainLog.Load().Error().Err(err).Msgf("AllocateIP failed: %s", string(out))
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func deAllocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address")
|
||||
cmd := exec.Command("ip", "a", "del", ip+"/24", "dev", "lo")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
|
||||
mainLog.Load().Error().Err(err).Msg("DeAllocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -52,9 +62,11 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
|
||||
}
|
||||
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
r, err := dns.NewOSConfigurator(logf, iface.Name)
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration")
|
||||
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to create dns os configurator")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -67,35 +79,31 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("Failed to get search domains list")
|
||||
}
|
||||
trySystemdResolve := false
|
||||
for i := 0; i < maxSetDNSAttempts; i++ {
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
|
||||
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
|
||||
trySystemdResolve = true
|
||||
break
|
||||
}
|
||||
// This error happens on read-only file system, which causes ctrld failed to create backup
|
||||
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
|
||||
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
|
||||
// The error format is controlled by us, so checking for error string is fine.
|
||||
// See: ../../internal/dns/direct.go:L278
|
||||
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
|
||||
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
|
||||
trySystemdResolve = true
|
||||
goto systemdResolve
|
||||
}
|
||||
if useSystemdResolved {
|
||||
if out, err := exec.Command("systemctl", "restart", "systemd-resolved").CombinedOutput(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not restart systemd-resolved: %s", string(out))
|
||||
}
|
||||
}
|
||||
currentNS := currentDNS(iface)
|
||||
if isSubSet(nameservers, currentNS) {
|
||||
// This error happens on read-only file system, which causes ctrld failed to create backup
|
||||
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
|
||||
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
|
||||
// The error format is controlled by us, so checking for error string is fine.
|
||||
// See: ../../internal/dns/direct.go:L278
|
||||
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
systemdResolve:
|
||||
if trySystemdResolve {
|
||||
// Stop systemd-networkd and retry setting DNS.
|
||||
if out, err := exec.Command("systemctl", "stop", "systemd-networkd").CombinedOutput(); err != nil {
|
||||
@@ -115,8 +123,10 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
mainLog.Load().Debug().Msg("DNS was not set for some reason")
|
||||
}
|
||||
mainLog.Load().Debug().Msg("DNS was not set for some reason")
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -126,6 +136,8 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
}
|
||||
|
||||
func resetDNS(iface *net.Interface) (err error) {
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration")
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
return
|
||||
@@ -134,10 +146,10 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
if exe, _ := exec.LookPath("/lib/systemd/systemd-networkd"); exe != "" {
|
||||
_ = exec.Command("systemctl", "start", "systemd-networkd").Run()
|
||||
}
|
||||
if r, oerr := dns.NewOSConfigurator(logf, iface.Name); oerr == nil {
|
||||
if r, oerr := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name); oerr == nil {
|
||||
_ = r.SetDNS(dns.OSConfig{})
|
||||
if err := r.Close(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to rollback DNS setting")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to rollback dns setting")
|
||||
return
|
||||
}
|
||||
err = nil
|
||||
@@ -165,17 +177,18 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
|
||||
// TODO(cuonglm): handle DHCPv6 properly.
|
||||
mainLog.Load().Debug().Msg("Checking for ipv6 availability")
|
||||
if ctrldnet.IPv6Available(ctx) {
|
||||
c := client6.NewClient()
|
||||
conversation, err := c.Exchange(iface.Name)
|
||||
if err != nil && !errAddrInUse(err) {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not exchange DHCPv6")
|
||||
mainLog.Load().Debug().Err(err).Msg("Could not exchange dhcpv6")
|
||||
}
|
||||
for _, packet := range conversation {
|
||||
if packet.Type() == dhcpv6.MessageTypeReply {
|
||||
msg, err := packet.GetInnerMessage()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not get inner DHCPv6 message")
|
||||
mainLog.Load().Debug().Err(err).Msg("Could not get inner dhcpv6 message")
|
||||
return nil
|
||||
}
|
||||
nameservers := msg.Options.DNS()
|
||||
@@ -184,6 +197,8 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("IPv6 is not available")
|
||||
}
|
||||
|
||||
return ignoringEINTR(func() error {
|
||||
@@ -191,8 +206,15 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
})
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} {
|
||||
resolvconfFunc := func(_ string) []string { return ctrld.CurrentNameserversFromResolvconf() }
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconfFunc} {
|
||||
if ns := fn(iface.Name); len(ns) > 0 {
|
||||
return ns
|
||||
}
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
|
||||
package cli
|
||||
|
||||
// TODO(cuonglm): implement.
|
||||
// allocateIP allocates an IP address on the specified interface
|
||||
func allocateIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(cuonglm): implement.
|
||||
// deAllocateIP deallocates an IP address from the specified interface
|
||||
func deAllocateIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,26 +4,21 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
const (
|
||||
forwardersFilename = ".forwarders.txt"
|
||||
v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
)
|
||||
|
||||
var (
|
||||
setDNSOnce sync.Once
|
||||
resetDNSOnce sync.Once
|
||||
v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
)
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
@@ -36,33 +31,41 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
if len(nameservers) == 0 {
|
||||
return errors.New("empty DNS nameservers")
|
||||
}
|
||||
setDNSOnce.Do(func() {
|
||||
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
||||
// Configuring the Dns server to forward queries to ctrld instead.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
file := absHomeDir(forwardersFilename)
|
||||
if data, _ := os.ReadFile(file); len(data) > 0 {
|
||||
if err := removeDnsServerForwarders(strings.Split(string(data), ",")); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not remove current forwarders settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("removed current forwarders settings.")
|
||||
}
|
||||
}
|
||||
if err := os.WriteFile(file, []byte(strings.Join(nameservers, ",")), 0600); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
||||
}
|
||||
if err := addDnsServerForwarders(nameservers); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
|
||||
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("setDNS: %w", err)
|
||||
}
|
||||
var (
|
||||
serversV4 []netip.Addr
|
||||
serversV6 []netip.Addr
|
||||
)
|
||||
for _, ns := range nameservers {
|
||||
if addr, err := netip.ParseAddr(ns); err == nil {
|
||||
if addr.Is4() {
|
||||
serversV4 = append(serversV4, addr)
|
||||
} else {
|
||||
serversV6 = append(serversV6, addr)
|
||||
}
|
||||
}
|
||||
})
|
||||
primaryDNS := nameservers[0]
|
||||
if err := setPrimaryDNS(iface, primaryDNS, true); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(nameservers) > 1 {
|
||||
secondaryDNS := nameservers[1]
|
||||
_ = addSecondaryDNS(iface, secondaryDNS)
|
||||
|
||||
// Note that Windows won't modify the current search domains if passing nil to luid.SetDNS function.
|
||||
// searchDomains is still implemented for Windows just in case Windows API changes in future versions.
|
||||
_ = searchDomains
|
||||
|
||||
if len(serversV4) == 0 && len(serversV6) == 0 {
|
||||
return errors.New("invalid DNS nameservers")
|
||||
}
|
||||
if len(serversV4) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv4: %w", err)
|
||||
}
|
||||
}
|
||||
if len(serversV6) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv6: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -72,38 +75,26 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
// TODO(cuonglm): should we use system API?
|
||||
// resetDNS resets DNS servers for the specified interface
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
resetDNSOnce.Do(func() {
|
||||
// See corresponding comment in setDNS.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
file := absHomeDir(forwardersFilename)
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not read forwarders settings")
|
||||
return
|
||||
}
|
||||
nameservers := strings.Split(string(content), ",")
|
||||
if err := removeDnsServerForwarders(nameservers); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not remove forwarders settings")
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Restoring ipv6 first.
|
||||
if ctrldnet.SupportsIPv6ListenLocal() {
|
||||
if output, err := netsh("interface", "ipv6", "set", "dnsserver", strconv.Itoa(iface.Index), "dhcp"); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("failed to reset ipv6 DNS: %s", string(output))
|
||||
}
|
||||
}
|
||||
// Restoring ipv4 DHCP.
|
||||
output, err := netsh("interface", "ipv4", "set", "dnsserver", strconv.Itoa(iface.Index), "dhcp")
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", string(output), err)
|
||||
return fmt.Errorf("resetDNS: %w", err)
|
||||
}
|
||||
// If there's static DNS saved, restoring it.
|
||||
if nss := savedStaticNameservers(iface); len(nss) > 0 {
|
||||
// Restoring DHCP settings.
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv4: %w", err)
|
||||
}
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv6: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if nss := ctrld.SavedStaticNameservers(iface); len(nss) > 0 {
|
||||
v4ns := make([]string, 0, 2)
|
||||
v6ns := make([]string, 0, 2)
|
||||
for _, ns := range nss {
|
||||
@@ -114,67 +105,48 @@ func resetDNS(iface *net.Interface) error {
|
||||
}
|
||||
}
|
||||
|
||||
for _, ns := range [][]string{v4ns, v6ns} {
|
||||
if len(ns) == 0 {
|
||||
continue
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("restoreDNS: %w", err)
|
||||
}
|
||||
|
||||
if len(v4ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("Restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns)
|
||||
if err := setDNS(iface, v4ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4): %w", err)
|
||||
}
|
||||
primaryDNS := ns[0]
|
||||
if err := setPrimaryDNS(iface, primaryDNS, false); err != nil {
|
||||
return err
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Restoring IPv4 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4 clear): %w", err)
|
||||
}
|
||||
if len(ns) > 1 {
|
||||
secondaryDNS := ns[1]
|
||||
_ = addSecondaryDNS(iface, secondaryDNS)
|
||||
}
|
||||
|
||||
if len(v6ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("Restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns)
|
||||
if err := setDNS(iface, v6ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6): %w", err)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Restoring IPv6 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6 clear): %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setPrimaryDNS(iface *net.Interface, dns string, disablev6 bool) error {
|
||||
ipVer := "ipv4"
|
||||
if ctrldnet.IsIPv6(dns) {
|
||||
ipVer = "ipv6"
|
||||
}
|
||||
idx := strconv.Itoa(iface.Index)
|
||||
output, err := netsh("interface", ipVer, "set", "dnsserver", idx, "static", dns)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("failed to set primary DNS: %s", string(output))
|
||||
return err
|
||||
}
|
||||
if disablev6 && ipVer == "ipv4" && ctrldnet.SupportsIPv6ListenLocal() {
|
||||
// Disable IPv6 DNS, so the query will be fallback to IPv4.
|
||||
_, _ = netsh("interface", "ipv6", "set", "dnsserver", idx, "static", "::1", "primary")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func addSecondaryDNS(iface *net.Interface, dns string) error {
|
||||
ipVer := "ipv4"
|
||||
if ctrldnet.IsIPv6(dns) {
|
||||
ipVer = "ipv6"
|
||||
}
|
||||
output, err := netsh("interface", ipVer, "add", "dns", strconv.Itoa(iface.Index), dns, "index=2")
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("failed to add secondary DNS: %s", string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func netsh(args ...string) ([]byte, error) {
|
||||
return exec.Command("netsh", args...).Output()
|
||||
return err
|
||||
}
|
||||
|
||||
// currentDNS returns the current DNS servers for the specified interface
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to get interface LUID")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to get interface LUID")
|
||||
return nil
|
||||
}
|
||||
nameservers, err := luid.DNS()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to get interface DNS")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to get interface DNS")
|
||||
return nil
|
||||
}
|
||||
ns := make([]string, 0, len(nameservers))
|
||||
@@ -184,53 +156,65 @@ func currentDNS(iface *net.Interface) []string {
|
||||
return ns
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
// currentStaticDNS checks both the IPv4 and IPv6 paths for static DNS values using keys
|
||||
// like "NameServer" and "ProfileNameServer".
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("fallback winipcfg.LUIDFromIndex: %w", err)
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("fallback luid.GUID: %w", err)
|
||||
}
|
||||
|
||||
var ns []string
|
||||
for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} {
|
||||
keyPaths := []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat}
|
||||
for _, path := range keyPaths {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
found = true
|
||||
ns = append(ns, strings.Split(string(out), ",")...)
|
||||
}
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("Failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name)
|
||||
continue
|
||||
}
|
||||
func() {
|
||||
defer k.Close()
|
||||
for _, keyName := range []string{"NameServer", "ProfileNameServer"} {
|
||||
value, _, err := k.GetStringValue(keyName)
|
||||
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
||||
mainLog.Load().Debug().Err(err).Msgf("Error reading %s registry key", keyName)
|
||||
continue
|
||||
}
|
||||
if len(value) > 0 {
|
||||
mainLog.Load().Debug().Msgf("Found static DNS for interface %q: %s", iface.Name, value)
|
||||
parsed := parseDNSServers(value)
|
||||
for _, pns := range parsed {
|
||||
if !slices.Contains(ns, pns) {
|
||||
ns = append(ns, pns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Debug().Msgf("No static DNS values found for interface %q", iface.Name)
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// addDnsServerForwarders adds given nameservers to DNS server forwarders list.
|
||||
func addDnsServerForwarders(nameservers []string) error {
|
||||
for _, ns := range nameservers {
|
||||
cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", ns)
|
||||
if out, err := powershell(cmd); err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
// parseDNSServers splits a DNS server string that may be comma- or space-separated,
|
||||
// and trims any extraneous whitespace or null characters.
|
||||
func parseDNSServers(val string) []string {
|
||||
fields := strings.FieldsFunc(val, func(r rune) bool {
|
||||
return r == ' ' || r == ','
|
||||
})
|
||||
var servers []string
|
||||
for _, f := range fields {
|
||||
trimmed := strings.TrimSpace(f)
|
||||
if len(trimmed) > 0 {
|
||||
servers = append(servers, trimmed)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeDnsServerForwarders removes given nameservers from DNS server forwarders list.
|
||||
func removeDnsServerForwarders(nameservers []string) error {
|
||||
for _, ns := range nameservers {
|
||||
cmd := fmt.Sprintf("Remove-DnsServerForwarder -IPAddress %s -Force", ns)
|
||||
if out, err := powershell(cmd); err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return servers
|
||||
}
|
||||
|
||||
76
cmd/cli/os_windows_test.go
Normal file
76
cmd/cli/os_windows_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func Test_currentStaticDNS(t *testing.T) {
|
||||
iface, err := net.InterfaceByName(defaultIfaceName())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
start := time.Now()
|
||||
staticDns, err := currentStaticDNS(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
staticDnsPowershell, err := currentStaticDnsPowershell(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(staticDns)
|
||||
slices.Sort(staticDnsPowershell)
|
||||
if !slices.Equal(staticDns, staticDnsPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns)
|
||||
}
|
||||
}
|
||||
|
||||
func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ns []string
|
||||
for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(string(out), ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// powershell runs the given powershell command.
|
||||
func powershell(cmd string) ([]byte, error) {
|
||||
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
|
||||
return bytes.TrimSpace(out), err
|
||||
}
|
||||
1112
cmd/cli/prog.go
1112
cmd/cli/prog.go
File diff suppressed because it is too large
Load Diff
@@ -4,8 +4,10 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
// setDependencies sets service dependencies for Darwin
|
||||
func setDependencies(svc *service.Config) {}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
// setDependencies sets service dependencies for FreeBSD
|
||||
func setDependencies(svc *service.Config) {
|
||||
// TODO(cuonglm): remove once https://github.com/kardianos/service/issues/359 fixed.
|
||||
_ = os.MkdirAll("/usr/local/etc/rc.d", 0755)
|
||||
}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {}
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo"); err == nil {
|
||||
if r, err := newLoopbackOSConfigurator(); err == nil {
|
||||
useSystemdResolved = r.Mode() == "systemd-resolved"
|
||||
}
|
||||
// Disable quic-go's ECN support by default, see https://github.com/quic-go/quic-go/issues/3911
|
||||
@@ -18,18 +21,42 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// setDependencies sets service dependencies for Linux
|
||||
func setDependencies(svc *service.Config) {
|
||||
svc.Dependencies = []string{
|
||||
"Wants=network-online.target",
|
||||
"After=network-online.target",
|
||||
"Wants=NetworkManager-wait-online.service",
|
||||
"After=NetworkManager-wait-online.service",
|
||||
"Wants=systemd-networkd-wait-online.service",
|
||||
"Wants=nss-lookup.target",
|
||||
"After=nss-lookup.target",
|
||||
}
|
||||
if out, _ := exec.Command("networkctl", "--no-pager").CombinedOutput(); len(out) > 0 {
|
||||
if wantsSystemDNetworkdWaitOnline(bytes.NewReader(out)) {
|
||||
svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
// wantsSystemDNetworkdWaitOnline reports whether "systemd-networkd-wait-online" service
|
||||
// is required to be added to ctrld dependencies services.
|
||||
// The input reader r is the output of "networkctl --no-pager" command.
|
||||
func wantsSystemDNetworkdWaitOnline(r io.Reader) bool {
|
||||
scanner := bufio.NewScanner(r)
|
||||
// Skip header
|
||||
scanner.Scan()
|
||||
configured := false
|
||||
for scanner.Scan() {
|
||||
fields := strings.Fields(scanner.Text())
|
||||
if len(fields) > 0 && fields[len(fields)-1] == "configured" {
|
||||
configured = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return configured
|
||||
}
|
||||
|
||||
48
cmd/cli/prog_linux_test.go
Normal file
48
cmd/cli/prog_linux_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
networkctlUnmanagedOutput = `IDX LINK TYPE OPERATIONAL SETUP
|
||||
1 lo loopback carrier unmanaged
|
||||
2 wlp0s20f3 wlan routable unmanaged
|
||||
3 tailscale0 none routable unmanaged
|
||||
4 br-9ac33145e060 bridge no-carrier unmanaged
|
||||
5 docker0 bridge no-carrier unmanaged
|
||||
|
||||
5 links listed.
|
||||
`
|
||||
networkctlManagedOutput = `IDX LINK TYPE OPERATIONAL SETUP
|
||||
1 lo loopback carrier unmanaged
|
||||
2 wlp0s20f3 wlan routable configured
|
||||
3 tailscale0 none routable unmanaged
|
||||
4 br-9ac33145e060 bridge no-carrier unmanaged
|
||||
5 docker0 bridge no-carrier unmanaged
|
||||
|
||||
5 links listed.
|
||||
`
|
||||
)
|
||||
|
||||
func Test_wantsSystemDNetworkdWaitOnline(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
r io.Reader
|
||||
required bool
|
||||
}{
|
||||
{"unmanaged", strings.NewReader(networkctlUnmanagedOutput), false},
|
||||
{"managed", strings.NewReader(networkctlManagedOutput), true},
|
||||
{"empty", strings.NewReader(""), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if required := wantsSystemDNetworkdWaitOnline(tc.r); required != tc.required {
|
||||
t.Errorf("wants %v got %v", tc.required, required)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
33
cmd/cli/prog_log.go
Normal file
33
cmd/cli/prog_log.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package cli
|
||||
|
||||
import "github.com/Control-D-Inc/ctrld"
|
||||
|
||||
// Debug starts a new message with debug level.
|
||||
func (p *prog) Debug() *ctrld.LogEvent {
|
||||
return p.logger.Load().Debug()
|
||||
}
|
||||
|
||||
// Warn starts a new message with warn level.
|
||||
func (p *prog) Warn() *ctrld.LogEvent {
|
||||
return p.logger.Load().Warn()
|
||||
}
|
||||
|
||||
// Info starts a new message with info level.
|
||||
func (p *prog) Info() *ctrld.LogEvent {
|
||||
return p.logger.Load().Info()
|
||||
}
|
||||
|
||||
// Fatal starts a new message with fatal level.
|
||||
func (p *prog) Fatal() *ctrld.LogEvent {
|
||||
return p.logger.Load().Fatal()
|
||||
}
|
||||
|
||||
// Error starts a new message with error level.
|
||||
func (p *prog) Error() *ctrld.LogEvent {
|
||||
return p.logger.Load().Error()
|
||||
}
|
||||
|
||||
// Notice starts a new message with notice level.
|
||||
func (p *prog) Notice() *ctrld.LogEvent {
|
||||
return p.logger.Load().Notice()
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
//go:build !linux && !freebsd && !darwin
|
||||
//go:build !linux && !freebsd && !darwin && !windows
|
||||
|
||||
package cli
|
||||
|
||||
import "github.com/kardianos/service"
|
||||
|
||||
// setDependencies sets service dependencies for other platforms
|
||||
func setDependencies(svc *service.Config) {}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
// WorkingDirectory is not supported on Windows.
|
||||
svc.WorkingDirectory = dir
|
||||
|
||||
266
cmd/cli/prog_test.go
Normal file
266
cmd/cli/prog_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_prog_dnsWatchdogEnabled(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{}}
|
||||
|
||||
// Default value is true.
|
||||
assert.True(t, p.dnsWatchdogEnabled())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
enabled bool
|
||||
}{
|
||||
{"enabled", true},
|
||||
{"disabled", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p.cfg.Service.DnsWatchdogEnabled = &tc.enabled
|
||||
assert.Equal(t, tc.enabled, p.dnsWatchdogEnabled())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_dnsWatchdogInterval(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{}}
|
||||
|
||||
// Default value is 20s.
|
||||
assert.Equal(t, dnsWatchdogDefaultInterval, p.dnsWatchdogDuration())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
duration time.Duration
|
||||
expected time.Duration
|
||||
}{
|
||||
{"valid", time.Minute, time.Minute},
|
||||
{"zero", 0, dnsWatchdogDefaultInterval},
|
||||
{"nagative", time.Duration(-1 * time.Minute), dnsWatchdogDefaultInterval},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p.cfg.Service.DnsWatchdogInvterval = &tc.duration
|
||||
assert.Equal(t, tc.expected, p.dnsWatchdogDuration())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_shouldUpgrade(t *testing.T) {
|
||||
// Helper function to create a version
|
||||
makeVersion := func(v string) *semver.Version {
|
||||
ver, err := semver.NewVersion(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create version %s: %v", v, err)
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
currentVersion *semver.Version
|
||||
shouldUpgrade bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "empty version target",
|
||||
versionTarget: "",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when version target is empty",
|
||||
},
|
||||
{
|
||||
name: "invalid version target",
|
||||
versionTarget: "invalid-version",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when version target is invalid",
|
||||
},
|
||||
{
|
||||
name: "same version",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when target version equals current version",
|
||||
},
|
||||
{
|
||||
name: "older version",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v1.1.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when target version is older than current version",
|
||||
},
|
||||
{
|
||||
name: "patch upgrade allowed",
|
||||
versionTarget: "v1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow patch version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "minor upgrade allowed",
|
||||
versionTarget: "v1.1.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow minor version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "major upgrade blocked",
|
||||
versionTarget: "v2.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block major version upgrade",
|
||||
},
|
||||
{
|
||||
name: "major downgrade blocked",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v2.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block major version downgrade",
|
||||
},
|
||||
{
|
||||
name: "version without v prefix",
|
||||
versionTarget: "1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should handle version target without v prefix",
|
||||
},
|
||||
{
|
||||
name: "complex version upgrade allowed",
|
||||
versionTarget: "v1.5.3",
|
||||
currentVersion: makeVersion("v1.4.2"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow complex version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "complex major upgrade blocked",
|
||||
versionTarget: "v3.1.0",
|
||||
currentVersion: makeVersion("v2.5.3"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block complex major version upgrade",
|
||||
},
|
||||
{
|
||||
name: "pre-release version upgrade allowed",
|
||||
versionTarget: "v1.0.1-beta.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow pre-release version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "pre-release major upgrade blocked",
|
||||
versionTarget: "v2.0.0-alpha.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block pre-release major version upgrade",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := &ctrld.Logger{Logger: zap.NewNop()}
|
||||
|
||||
// Call the function and capture the result
|
||||
result := shouldUpgrade(tc.versionTarget, tc.currentVersion, testLogger)
|
||||
|
||||
// Assert the expected result
|
||||
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_selfUpgradeCheck(t *testing.T) {
|
||||
// Helper function to create a version
|
||||
makeVersion := func(v string) *semver.Version {
|
||||
ver, err := semver.NewVersion(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create version %s: %v", v, err)
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
currentVersion *semver.Version
|
||||
shouldUpgrade bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "upgrade allowed",
|
||||
versionTarget: "v1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow upgrade and attempt to perform it",
|
||||
},
|
||||
{
|
||||
name: "upgrade blocked",
|
||||
versionTarget: "v2.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block upgrade and not attempt to perform it",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := &ctrld.Logger{Logger: zap.NewNop()}
|
||||
|
||||
// Call the function and capture the result
|
||||
result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, testLogger)
|
||||
|
||||
// Assert the expected result
|
||||
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_performUpgrade(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
expectedResult bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid version target",
|
||||
versionTarget: "v1.0.1",
|
||||
expectedResult: true,
|
||||
description: "should attempt to perform upgrade with valid version target",
|
||||
},
|
||||
{
|
||||
name: "empty version target",
|
||||
versionTarget: "",
|
||||
expectedResult: true,
|
||||
description: "should attempt to perform upgrade even with empty version target",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := &ctrld.Logger{Logger: zap.NewNop()}
|
||||
// Call the function and capture the result
|
||||
result := performUpgrade(tc.versionTarget, testLogger)
|
||||
assert.Equal(t, tc.expectedResult, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
12
cmd/cli/prog_windows.go
Normal file
12
cmd/cli/prog_windows.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package cli
|
||||
|
||||
import "github.com/kardianos/service"
|
||||
|
||||
// setDependencies sets service dependencies for Windows
|
||||
func setDependencies(svc *service.Config) {}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
// WorkingDirectory is not supported on Windows.
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package cli
|
||||
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
// Prometheus metrics label constants for consistent labeling across all metrics
|
||||
// These ensure standardized metric labeling for monitoring and alerting
|
||||
const (
|
||||
metricsLabelListener = "listener"
|
||||
metricsLabelClientSourceIP = "client_source_ip"
|
||||
@@ -13,17 +15,21 @@ const (
|
||||
)
|
||||
|
||||
// statsVersion represent ctrld version.
|
||||
// This metric provides version information for monitoring and debugging
|
||||
var statsVersion = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_build_info",
|
||||
Help: "Version of ctrld process.",
|
||||
}, []string{"gitref", "goversion", "version"})
|
||||
|
||||
// statsTimeStart represents start time of ctrld service.
|
||||
// This metric tracks service uptime and helps with monitoring service restarts
|
||||
var statsTimeStart = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "ctrld_time_seconds",
|
||||
Help: "Start time of the ctrld process since unix epoch in seconds.",
|
||||
})
|
||||
|
||||
// statsQueriesCountLabels defines the labels for query count metrics
|
||||
// These labels provide detailed breakdown of DNS query statistics
|
||||
var statsQueriesCountLabels = []string{
|
||||
metricsLabelListener,
|
||||
metricsLabelClientSourceIP,
|
||||
@@ -35,6 +41,7 @@ var statsQueriesCountLabels = []string{
|
||||
}
|
||||
|
||||
// statsQueriesCount counts total number of queries.
|
||||
// This provides comprehensive DNS query statistics for monitoring and alerting
|
||||
var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_queries_count",
|
||||
Help: "Total number of queries.",
|
||||
@@ -44,14 +51,16 @@ var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
//
|
||||
// The labels "client_source_ip", "client_mac", "client_hostname" are unbounded,
|
||||
// thus this stat is highly inefficient if there are many devices.
|
||||
// This metric should be used carefully in high-client environments
|
||||
var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_client_queries_count",
|
||||
Help: "Total number queries of a client.",
|
||||
}, []string{metricsLabelClientSourceIP, metricsLabelClientMac, metricsLabelClientHostname})
|
||||
|
||||
// WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled.
|
||||
// This provides conditional metric collection to avoid performance impact when metrics are disabled
|
||||
func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) {
|
||||
if p.cfg.Service.MetricsQueryStats {
|
||||
if p.metricsQueryStats.Load() {
|
||||
c.WithLabelValues(lvs...).Inc()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,10 +8,12 @@ import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// notifyReloadSigCh sends reload signal to the channel
|
||||
func notifyReloadSigCh(ch chan os.Signal) {
|
||||
signal.Notify(ch, syscall.SIGUSR1)
|
||||
}
|
||||
|
||||
// sendReloadSignal sends a reload signal to the current process
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// notifyReloadSigCh is a no-op on Windows platforms
|
||||
func notifyReloadSigCh(ch chan os.Signal) {}
|
||||
|
||||
// sendReloadSignal sends a reload signal to the program
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
select {
|
||||
case p.reloadCh <- struct{}{}:
|
||||
|
||||
@@ -4,37 +4,58 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
const (
|
||||
resolvConfPath = "/etc/resolv.conf"
|
||||
resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
||||
)
|
||||
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
|
||||
// Returns nil if no nameservers are found.
|
||||
// This function parses the system DNS configuration to understand current nameserver settings
|
||||
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
|
||||
return resolvconffile.NameserversFromFile(path)
|
||||
}
|
||||
|
||||
// watchResolvConf watches any changes to /etc/resolv.conf file,
|
||||
// and reverting to the original config set by ctrld.
|
||||
func watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||
mainLog.Load().Debug().Msg("start watching /etc/resolv.conf file")
|
||||
// This ensures that DNS settings are not overridden by other applications or system processes
|
||||
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||
resolvConfPath := "/etc/resolv.conf"
|
||||
// Evaluating symbolics link to watch the target file that /etc/resolv.conf point to.
|
||||
// This handles systems where resolv.conf is a symlink to another location
|
||||
if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" {
|
||||
resolvConfPath = rp
|
||||
}
|
||||
p.Debug().Msgf("Start watching %s file", resolvConfPath)
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
|
||||
p.Warn().Err(err).Msg("Could not create watcher for /etc/resolv.conf")
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
// We watch /etc instead of /etc/resolv.conf directly,
|
||||
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
|
||||
// This is necessary because some systems don't properly notify on file changes
|
||||
watchDir := filepath.Dir(resolvConfPath)
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not add /etc/resolv.conf to watcher list")
|
||||
p.Warn().Err(err).Msgf("Could not add %s to watcher list", watchDir)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-p.stopCh:
|
||||
p.Debug().Msgf("Stopping watcher for %s", resolvConfPath)
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if p.recoveryRunning.Load() {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -42,24 +63,91 @@ func watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting")
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
p.Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
|
||||
|
||||
// Convert expected nameservers to strings for comparison
|
||||
// This allows us to detect when the resolv.conf has been modified
|
||||
expectedNS := make([]string, len(ns))
|
||||
for i, addr := range ns {
|
||||
expectedNS[i] = addr.String()
|
||||
}
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
|
||||
var foundNS []string
|
||||
var err error
|
||||
|
||||
maxRetries := 1
|
||||
for retry := 0; retry < maxRetries; retry++ {
|
||||
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
|
||||
if err != nil {
|
||||
p.Error().Err(err).Msg("Failed to read resolv.conf content")
|
||||
break
|
||||
}
|
||||
|
||||
// If we found nameservers, break out of retry loop
|
||||
// This handles cases where the file is being written but not yet complete
|
||||
if len(foundNS) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Only retry if we found no nameservers
|
||||
// This handles temporary file states during updates
|
||||
if retry < maxRetries-1 {
|
||||
p.Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
p.Debug().Msg("resolv.conf remained empty after all retries")
|
||||
}
|
||||
}
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
|
||||
// If we found nameservers, check if they match what we expect
|
||||
if len(foundNS) > 0 {
|
||||
// Check if the nameservers match exactly what we expect
|
||||
matches := len(foundNS) == len(expectedNS)
|
||||
if matches {
|
||||
for i := range foundNS {
|
||||
if foundNS[i] != expectedNS[i] {
|
||||
matches = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.Debug().
|
||||
Strs("found", foundNS).
|
||||
Strs("expected", expectedNS).
|
||||
Bool("matches", matches).
|
||||
Msg("checking nameservers")
|
||||
|
||||
// Only revert if the nameservers don't match
|
||||
if !matches {
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
p.Error().Err(err).Msg("Failed to pause watcher")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
p.Error().Err(err).Msg("Failed to revert /etc/resolv.conf changes")
|
||||
}
|
||||
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
p.Error().Err(err).Msg("Failed to continue running watcher")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf")
|
||||
p.Error().Err(err).Msg("Could not get event for /etc/resolv.conf")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,15 +3,44 @@ package cli
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
|
||||
)
|
||||
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
servers := make([]string, len(ns))
|
||||
for i := range ns {
|
||||
servers[i] = ns[i].String()
|
||||
}
|
||||
return setDNS(iface, servers)
|
||||
if err := setDNS(iface, servers); err != nil {
|
||||
return err
|
||||
}
|
||||
slices.Sort(servers)
|
||||
curNs := currentDNS(iface)
|
||||
slices.Sort(curNs)
|
||||
if !slices.Equal(curNs, servers) {
|
||||
c, err := resolvconffile.ParseFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Nameservers = ns
|
||||
f, err := os.Create(resolvConfPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := c.Write(f); err != nil {
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
|
||||
@@ -6,14 +6,16 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo") // interface name does not matter.
|
||||
// setResolvConf sets the content of the resolv.conf file using the given nameservers list.
|
||||
func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -22,12 +24,17 @@ func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
oc.SearchDomains = sds
|
||||
} else {
|
||||
p.Debug().Err(err).Msg("Failed to get search domains list when reverting resolv.conf file")
|
||||
}
|
||||
return r.SetDNS(oc)
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo") // interface name does not matter.
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -38,3 +45,8 @@ func shouldWatchResolvconf() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// newLoopbackOSConfigurator creates an OSConfigurator for DNS management using the "lo" interface.
|
||||
func newLoopbackOSConfigurator() (dns.OSConfigurator, error) {
|
||||
return dns.NewOSConfigurator(noopLogf, &health.Tracker{}, &controlknobs.Knobs{}, "lo")
|
||||
}
|
||||
|
||||
51
cmd/cli/resolvconf_test.go
Normal file
51
cmd/cli/resolvconf_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
|
||||
)
|
||||
|
||||
func oldParseResolvConfNameservers(path string) ([]string, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the file for "nameserver" lines
|
||||
var currentNS []string
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "nameserver") {
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) >= 2 {
|
||||
currentNS = append(currentNS, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentNS, nil
|
||||
}
|
||||
|
||||
// Test_prog_parseResolvConfNameservers tests the parsing of nameservers from resolv.conf content.
|
||||
// Note: The previous implementation was removed to reduce code duplication and consolidate
|
||||
// the resolv.conf handling logic into a single unified approach. All resolv.conf parsing
|
||||
// is now handled by the resolvconffile package, which provides a consistent interface
|
||||
// for both reading and modifying resolv.conf files across different platforms.
|
||||
func Test_prog_parseResolvConfNameservers(t *testing.T) {
|
||||
oldNss, _ := oldParseResolvConfNameservers(resolvconffile.Path)
|
||||
p := &prog{}
|
||||
nss, _ := p.parseResolvConfNameservers(resolvconffile.Path)
|
||||
slices.Sort(oldNss)
|
||||
slices.Sort(nss)
|
||||
if !slices.Equal(oldNss, nss) {
|
||||
t.Errorf("result mismatched, old: %v, new: %v", oldNss, nss)
|
||||
}
|
||||
t.Logf("result: %v", nss)
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(_ *net.Interface, _ []netip.Addr) error {
|
||||
func (p *prog) setResolvConf(_ *net.Interface, _ []netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
14
cmd/cli/search_domains_unix.go
Normal file
14
cmd/cli/search_domains_unix.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
return resolvconffile.SearchDomains()
|
||||
}
|
||||
43
cmd/cli/search_domains_windows.go
Normal file
43
cmd/cli/search_domains_windows.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
flags := winipcfg.GAAFlagIncludeGateways |
|
||||
winipcfg.GAAFlagIncludePrefix
|
||||
|
||||
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("winipcfg.GetAdaptersAddresses: %w", err)
|
||||
}
|
||||
|
||||
var sds []dnsname.FQDN
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
continue
|
||||
}
|
||||
|
||||
for a := aa.FirstDNSSuffix; a != nil; a = a.Next {
|
||||
d, err := dnsname.ToFQDN(a.String())
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("Failed to parse domain: %s", a.String())
|
||||
continue
|
||||
}
|
||||
sds = append(sds, d)
|
||||
}
|
||||
}
|
||||
return sds, nil
|
||||
}
|
||||
8
cmd/cli/self_delete_others.go
Normal file
8
cmd/cli/self_delete_others.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
var supportedSelfDelete = true
|
||||
|
||||
// selfDeleteExe performs self-deletion on non-Windows platforms
|
||||
func selfDeleteExe() error { return nil }
|
||||
138
cmd/cli/self_delete_windows.go
Normal file
138
cmd/cli/self_delete_windows.go
Normal file
@@ -0,0 +1,138 @@
|
||||
// Copied from https://github.com/secur30nly/go-self-delete
|
||||
// with modification to suitable for ctrld usage.
|
||||
|
||||
/*
|
||||
License: MIT Licence
|
||||
|
||||
References:
|
||||
- https://github.com/LloydLabs/delete-self-poc
|
||||
- https://twitter.com/jonasLyk/status/1350401461985955840
|
||||
*/
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var supportedSelfDelete = false
|
||||
|
||||
type FILE_RENAME_INFO struct {
|
||||
Union struct {
|
||||
ReplaceIfExists bool
|
||||
Flags uint32
|
||||
}
|
||||
RootDirectory windows.Handle
|
||||
FileNameLength uint32
|
||||
FileName [1]uint16
|
||||
}
|
||||
|
||||
type FILE_DISPOSITION_INFO struct {
|
||||
DeleteFile bool
|
||||
}
|
||||
|
||||
// dsOpenHandle opens a handle to the specified file with DELETE access
|
||||
func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
|
||||
handle, err := windows.CreateFile(
|
||||
pwPath,
|
||||
windows.DELETE,
|
||||
0,
|
||||
nil,
|
||||
windows.OPEN_EXISTING,
|
||||
windows.FILE_ATTRIBUTE_NORMAL,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
// dsRenameHandle renames a file handle to a stream name
|
||||
func dsRenameHandle(hHandle windows.Handle) error {
|
||||
var fRename FILE_RENAME_INFO
|
||||
DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lpwStream := &DS_STREAM_RENAME[0]
|
||||
fRename.FileNameLength = uint32(unsafe.Sizeof(lpwStream))
|
||||
|
||||
windows.NewLazyDLL("kernel32.dll").NewProc("RtlCopyMemory").Call(
|
||||
uintptr(unsafe.Pointer(&fRename.FileName[0])),
|
||||
uintptr(unsafe.Pointer(lpwStream)),
|
||||
unsafe.Sizeof(lpwStream),
|
||||
)
|
||||
|
||||
err = windows.SetFileInformationByHandle(
|
||||
hHandle,
|
||||
windows.FileRenameInfo,
|
||||
(*byte)(unsafe.Pointer(&fRename)),
|
||||
uint32(unsafe.Sizeof(fRename)+unsafe.Sizeof(lpwStream)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dsDepositeHandle marks a file handle for deletion
|
||||
func dsDepositeHandle(hHandle windows.Handle) error {
|
||||
var fDelete FILE_DISPOSITION_INFO
|
||||
fDelete.DeleteFile = true
|
||||
|
||||
err := windows.SetFileInformationByHandle(
|
||||
hHandle,
|
||||
windows.FileDispositionInfo,
|
||||
(*byte)(unsafe.Pointer(&fDelete)),
|
||||
uint32(unsafe.Sizeof(fDelete)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// selfDeleteExe performs self-deletion on Windows platforms
|
||||
func selfDeleteExe() error {
|
||||
var wcPath [windows.MAX_PATH + 1]uint16
|
||||
var hCurrent windows.Handle
|
||||
|
||||
_, err := windows.GetModuleFileName(0, &wcPath[0], windows.MAX_PATH)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hCurrent, err = dsOpenHandle(&wcPath[0])
|
||||
if err != nil || hCurrent == windows.InvalidHandle {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := dsRenameHandle(hCurrent); err != nil {
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
return err
|
||||
}
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
|
||||
hCurrent, err = dsOpenHandle(&wcPath[0])
|
||||
if err != nil || hCurrent == windows.InvalidHandle {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := dsDepositeHandle(hCurrent); err != nil {
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
return err
|
||||
}
|
||||
|
||||
return windows.CloseHandle(hCurrent)
|
||||
}
|
||||
17
cmd/cli/self_kill_others.go
Normal file
17
cmd/cli/self_kill_others.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build !unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// selfUninstall performs self-uninstallation on non-Unix platforms
|
||||
func selfUninstall(p *prog, logger *ctrld.Logger) {
|
||||
if uninstallInvalidCdUID(p, logger, false) {
|
||||
logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
47
cmd/cli/self_kill_unix.go
Normal file
47
cmd/cli/self_kill_unix.go
Normal file
@@ -0,0 +1,47 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// selfUninstall performs self-uninstallation on Unix platforms
|
||||
func selfUninstall(p *prog, logger *ctrld.Logger) {
|
||||
if runtime.GOOS == "linux" {
|
||||
selfUninstallLinux(p, logger)
|
||||
}
|
||||
|
||||
bin, err := os.Executable()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("Could not determine executable")
|
||||
}
|
||||
args := []string{"uninstall"}
|
||||
if deactivationPinSet() {
|
||||
args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin.Load()))
|
||||
}
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
if err := cmd.Start(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("Could not start self uninstall command")
|
||||
}
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID)
|
||||
_ = cmd.Wait()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// selfUninstallLinux performs self-uninstallation on Linux platforms
|
||||
func selfUninstallLinux(p *prog, logger *ctrld.Logger) {
|
||||
if uninstallInvalidCdUID(p, logger, true) {
|
||||
logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
12
cmd/cli/self_upgrade_others.go
Normal file
12
cmd/cli/self_upgrade_others.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running a detached child command.
|
||||
func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{Setsid: true}
|
||||
}
|
||||
18
cmd/cli/self_upgrade_windows.go
Normal file
18
cmd/cli/self_upgrade_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// From: https://learn.microsoft.com/en-us/windows/win32/procthread/process-creation-flags?redirectedfrom=MSDN
|
||||
|
||||
// SYSCALL_CREATE_NO_WINDOW set flag to run process without a console window.
|
||||
const SYSCALL_CREATE_NO_WINDOW = 0x08000000
|
||||
|
||||
// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running self-upgrade command.
|
||||
func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | SYSCALL_CREATE_NO_WINDOW,
|
||||
HideWindow: true,
|
||||
}
|
||||
}
|
||||
@@ -1,24 +1,31 @@
|
||||
package cli
|
||||
|
||||
// semaphore provides a simple synchronization mechanism
|
||||
type semaphore interface {
|
||||
acquire()
|
||||
release()
|
||||
}
|
||||
|
||||
// noopSemaphore is a no-operation implementation of semaphore
|
||||
type noopSemaphore struct{}
|
||||
|
||||
// acquire performs a no-operation for the noop semaphore
|
||||
func (n noopSemaphore) acquire() {}
|
||||
|
||||
// release performs a no-operation for the noop semaphore
|
||||
func (n noopSemaphore) release() {}
|
||||
|
||||
// chanSemaphore is a channel-based implementation of semaphore
|
||||
type chanSemaphore struct {
|
||||
ready chan struct{}
|
||||
}
|
||||
|
||||
// acquire blocks until a slot is available in the semaphore
|
||||
func (c *chanSemaphore) acquire() {
|
||||
c.ready <- struct{}{}
|
||||
}
|
||||
|
||||
// release signals that a slot has been freed in the semaphore
|
||||
func (c *chanSemaphore) release() {
|
||||
<-c.ready
|
||||
}
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/coreos/go-systemd/v22/unit"
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
// newService wraps service.New call to return service.Service
|
||||
@@ -20,14 +21,13 @@ func newService(i service.Interface, c *service.Config) (service.Service, error)
|
||||
return nil, err
|
||||
}
|
||||
switch {
|
||||
case router.IsOldOpenwrt(), router.IsNetGearOrbi():
|
||||
return &procd{&sysV{s}}, nil
|
||||
case router.IsGLiNet():
|
||||
return &sysV{s}, nil
|
||||
case s.Platform() == "unix-systemv":
|
||||
return &sysV{s}, nil
|
||||
case s.Platform() == "linux-systemd":
|
||||
return &systemd{s}, nil
|
||||
case s.Platform() == "darwin-launchd":
|
||||
return newLaunchd(s), nil
|
||||
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
@@ -35,7 +35,7 @@ func newService(i service.Interface, c *service.Config) (service.Service, error)
|
||||
// sysV wraps a service.Service, and provide start/stop/status command
|
||||
// base on "/etc/init.d/<service_name>".
|
||||
//
|
||||
// Use this on system where "service" command is not available, like GL.iNET router.
|
||||
// Use this on system where "service" command is not available.
|
||||
type sysV struct {
|
||||
service.Service
|
||||
}
|
||||
@@ -82,32 +82,7 @@ func (s *sysV) Status() (service.Status, error) {
|
||||
return unixSystemVServiceStatus()
|
||||
}
|
||||
|
||||
// procd wraps a service.Service, and provide start/stop command
|
||||
// base on "/etc/init.d/<service_name>", status command base on parsing "ps" command output.
|
||||
//
|
||||
// Use this on system where "/etc/init.d/<service_name> status" command is not available,
|
||||
// like old GL.iNET Opal router.
|
||||
type procd struct {
|
||||
*sysV
|
||||
}
|
||||
|
||||
func (s *procd) Status() (service.Status, error) {
|
||||
if !s.installed() {
|
||||
return service.StatusUnknown, service.ErrNotInstalled
|
||||
}
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, nil
|
||||
}
|
||||
// Looking for something like "/sbin/ctrld run ".
|
||||
shellCmd := fmt.Sprintf("ps | grep -q %q", exe+" [r]un ")
|
||||
if err := exec.Command("sh", "-c", shellCmd).Run(); err != nil {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
return service.StatusRunning, nil
|
||||
}
|
||||
|
||||
// procd wraps a service.Service, and provide status command to
|
||||
// systemd wraps a service.Service, and provide status command to
|
||||
// report the status correctly.
|
||||
type systemd struct {
|
||||
service.Service
|
||||
@@ -121,29 +96,113 @@ func (s *systemd) Status() (service.Status, error) {
|
||||
return s.Service.Status()
|
||||
}
|
||||
|
||||
func (s *systemd) Start() error {
|
||||
const systemdUnitFile = "/etc/systemd/system/ctrld.service"
|
||||
f, err := os.Open(systemdUnitFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if opts, change := ensureSystemdKillMode(f); change {
|
||||
mode := os.FileMode(0644)
|
||||
buf, err := io.ReadAll(unit.Serialize(opts))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(systemdUnitFile, buf, mode); err != nil {
|
||||
return err
|
||||
}
|
||||
if out, err := exec.Command("systemctl", "daemon-reload").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("systemctl daemon-reload failed: %w\n%s", err, string(out))
|
||||
}
|
||||
mainLog.Load().Debug().Msg("Set KillMode=process successfully")
|
||||
}
|
||||
return s.Service.Start()
|
||||
}
|
||||
|
||||
// ensureSystemdKillMode ensure systemd unit file is configured with KillMode=process.
|
||||
// This is necessary for running self-upgrade flow.
|
||||
func ensureSystemdKillMode(r io.Reader) (opts []*unit.UnitOption, change bool) {
|
||||
opts, err := unit.DeserializeOptions(r)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to deserialize options")
|
||||
return
|
||||
}
|
||||
change = true
|
||||
needKillModeOpt := true
|
||||
killModeOpt := unit.NewUnitOption("Service", "KillMode", "process")
|
||||
for _, opt := range opts {
|
||||
if opt.Match(killModeOpt) {
|
||||
needKillModeOpt = false
|
||||
change = false
|
||||
break
|
||||
}
|
||||
if opt.Section == killModeOpt.Section && opt.Name == killModeOpt.Name {
|
||||
opt.Value = killModeOpt.Value
|
||||
needKillModeOpt = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if needKillModeOpt {
|
||||
opts = append(opts, killModeOpt)
|
||||
}
|
||||
return opts, change
|
||||
}
|
||||
|
||||
// newLaunchd creates a new launchd service wrapper
|
||||
func newLaunchd(s service.Service) *launchd {
|
||||
return &launchd{
|
||||
Service: s,
|
||||
statusErrMsg: "Permission denied",
|
||||
}
|
||||
}
|
||||
|
||||
// launchd wraps a service.Service, and provide status command to
|
||||
// report the status correctly when not running as root on Darwin.
|
||||
//
|
||||
// TODO: remove this wrapper once https://github.com/kardianos/service/issues/400 fixed.
|
||||
type launchd struct {
|
||||
service.Service
|
||||
statusErrMsg string
|
||||
}
|
||||
|
||||
func (l *launchd) Status() (service.Status, error) {
|
||||
if os.Geteuid() != 0 {
|
||||
return service.StatusUnknown, errors.New(l.statusErrMsg)
|
||||
}
|
||||
return l.Service.Status()
|
||||
}
|
||||
|
||||
type task struct {
|
||||
f func() error
|
||||
abortOnError bool
|
||||
Name string
|
||||
}
|
||||
|
||||
// doTasks executes a list of tasks and returns success status
|
||||
func doTasks(tasks []task) bool {
|
||||
var prevErr error
|
||||
for _, task := range tasks {
|
||||
mainLog.Load().Debug().Msgf("Running task %s", task.Name)
|
||||
if err := task.f(); err != nil {
|
||||
if task.abortOnError {
|
||||
mainLog.Load().Error().Msg(errors.Join(prevErr, err).Error())
|
||||
mainLog.Load().Error().Msgf("Error running task %s: %v", task.Name, err)
|
||||
return false
|
||||
}
|
||||
prevErr = err
|
||||
// if this is darwin stop command, dont print debug
|
||||
// since launchctl complains on every start
|
||||
if runtime.GOOS != "darwin" || task.Name != "Stop" {
|
||||
mainLog.Load().Debug().Msgf("Error running task %s: %v", task.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkHasElevatedPrivilege checks if the process has elevated privileges and exits if not
|
||||
func checkHasElevatedPrivilege() {
|
||||
ok, err := hasElevatedPrivilege()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msgf("could not detect user privilege: %v", err)
|
||||
mainLog.Load().Error().Msgf("Could not detect user privilege: %v", err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
@@ -152,6 +211,7 @@ func checkHasElevatedPrivilege() {
|
||||
}
|
||||
}
|
||||
|
||||
// unixSystemVServiceStatus checks the status of a Unix System V service
|
||||
func unixSystemVServiceStatus() (service.Status, error) {
|
||||
out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput()
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,15 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// hasElevatedPrivilege checks if the current process has elevated privileges
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
return os.Geteuid() == 0, nil
|
||||
}
|
||||
|
||||
// openLogFile opens a log file with the specified flags
|
||||
func openLogFile(path string, flags int) (*os.File, error) {
|
||||
return os.OpenFile(path, flags, os.FileMode(0o600))
|
||||
}
|
||||
|
||||
// ConfigureWindowsServiceFailureActions is a no-op on non-Windows platforms
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
|
||||
|
||||
28
cmd/cli/service_test.go
Normal file
28
cmd/cli/service_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_ensureSystemdKillMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
unitFile string
|
||||
wantChange bool
|
||||
}{
|
||||
{"no KillMode", "[Service]\nExecStart=/bin/sleep 1", true},
|
||||
{"not KillMode=process", "[Service]\nExecStart=/bin/sleep 1\nKillMode=mixed", true},
|
||||
{"KillMode=process", "[Service]\nExecStart=/bin/sleep 1\nKillMode=process", false},
|
||||
{"invalid unit file", "[Service\nExecStart=/bin/sleep 1\nKillMode=process", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if _, change := ensureSystemdKillMode(strings.NewReader(tc.unitFile)); tc.wantChange != change {
|
||||
t.Errorf("ensureSystemdKillMode(%q) = %v, want %v", tc.unitFile, change, tc.wantChange)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,17 @@
|
||||
package cli
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
// hasElevatedPrivilege checks if the current process has elevated privileges on Windows
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
var sid *windows.SID
|
||||
if err := windows.AllocateAndInitializeSid(
|
||||
@@ -22,3 +32,117 @@ func hasElevatedPrivilege() (bool, error) {
|
||||
token := windows.Token(0)
|
||||
return token.IsMember(sid)
|
||||
}
|
||||
|
||||
// ConfigureWindowsServiceFailureActions checks if the given service
|
||||
// has the correct failure actions configured, and updates them if not.
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil // no-op on non-Windows
|
||||
}
|
||||
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(serviceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// 1. Retrieve the current config
|
||||
cfg, err := s.Config()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Update the Description
|
||||
cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy"
|
||||
|
||||
// 3. Apply the updated config
|
||||
if err := s.UpdateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then proceed with existing actions, e.g. setting failure actions
|
||||
actions := []mgr.RecoveryAction{
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
}
|
||||
|
||||
// Set the recovery actions (3 restarts, reset period = 120).
|
||||
err = s.SetRecoveryActions(actions, 120)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that failure actions are NOT triggered on user-initiated stops.
|
||||
var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG
|
||||
failureActionsFlag.FailureActionsOnNonCrashFailures = 0
|
||||
|
||||
if err := windows.ChangeServiceConfig2(
|
||||
s.Handle,
|
||||
windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG,
|
||||
(*byte)(unsafe.Pointer(&failureActionsFlag)),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// openLogFile opens a log file with the specified mode on Windows
|
||||
func openLogFile(path string, mode int) (*os.File, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
|
||||
}
|
||||
|
||||
pathP, err := syscall.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var access uint32
|
||||
switch mode & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR) {
|
||||
case os.O_RDONLY:
|
||||
access = windows.GENERIC_READ
|
||||
case os.O_WRONLY:
|
||||
access = windows.GENERIC_WRITE
|
||||
case os.O_RDWR:
|
||||
access = windows.GENERIC_READ | windows.GENERIC_WRITE
|
||||
}
|
||||
if mode&os.O_CREATE != 0 {
|
||||
access |= windows.GENERIC_WRITE
|
||||
}
|
||||
if mode&os.O_APPEND != 0 {
|
||||
access &^= windows.GENERIC_WRITE
|
||||
access |= windows.FILE_APPEND_DATA
|
||||
}
|
||||
|
||||
shareMode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
|
||||
|
||||
var sa *syscall.SecurityAttributes
|
||||
|
||||
var createMode uint32
|
||||
switch {
|
||||
case mode&(os.O_CREATE|os.O_EXCL) == (os.O_CREATE | os.O_EXCL):
|
||||
createMode = windows.CREATE_NEW
|
||||
case mode&(os.O_CREATE|os.O_TRUNC) == (os.O_CREATE | os.O_TRUNC):
|
||||
createMode = windows.CREATE_ALWAYS
|
||||
case mode&os.O_CREATE == os.O_CREATE:
|
||||
createMode = windows.OPEN_ALWAYS
|
||||
case mode&os.O_TRUNC == os.O_TRUNC:
|
||||
createMode = windows.TRUNCATE_EXISTING
|
||||
default:
|
||||
createMode = windows.OPEN_EXISTING
|
||||
}
|
||||
|
||||
handle, err := syscall.CreateFile(pathP, access, shareMode, sa, createMode, syscall.FILE_ATTRIBUTE_NORMAL, 0)
|
||||
if err != nil {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: err}
|
||||
}
|
||||
|
||||
return os.NewFile(uintptr(handle), path), nil
|
||||
}
|
||||
|
||||
@@ -1,39 +1,46 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down.
|
||||
maxFailureRequest = 100
|
||||
maxFailureRequest = 50
|
||||
// checkUpstreamBackoffSleep is the time interval between each upstream checks.
|
||||
checkUpstreamBackoffSleep = 2 * time.Second
|
||||
)
|
||||
|
||||
// upstreamMonitor performs monitoring upstreams health.
|
||||
type upstreamMonitor struct {
|
||||
cfg *ctrld.Config
|
||||
cfg *ctrld.Config
|
||||
logger atomic.Pointer[ctrld.Logger]
|
||||
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
checking map[string]bool
|
||||
down map[string]bool
|
||||
failureReq map[string]uint64
|
||||
recovered map[string]bool
|
||||
|
||||
// failureTimerActive tracks if a timer is already running for a given upstream.
|
||||
failureTimerActive map[string]bool
|
||||
}
|
||||
|
||||
func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
// newUpstreamMonitor creates a new upstream monitor instance
|
||||
func newUpstreamMonitor(cfg *ctrld.Config, logger *ctrld.Logger) *upstreamMonitor {
|
||||
um := &upstreamMonitor{
|
||||
cfg: cfg,
|
||||
checking: make(map[string]bool),
|
||||
down: make(map[string]bool),
|
||||
failureReq: make(map[string]uint64),
|
||||
cfg: cfg,
|
||||
checking: make(map[string]bool),
|
||||
down: make(map[string]bool),
|
||||
failureReq: make(map[string]uint64),
|
||||
recovered: make(map[string]bool),
|
||||
failureTimerActive: make(map[string]bool),
|
||||
}
|
||||
um.logger.Store(logger)
|
||||
for n := range cfg.Upstream {
|
||||
upstream := upstreamPrefix + n
|
||||
um.reset(upstream)
|
||||
@@ -42,14 +49,47 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
return um
|
||||
}
|
||||
|
||||
// increaseFailureCount increase failed queries count for an upstream by 1.
|
||||
// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information.
|
||||
// It uses a timer to debounce failure detection, ensuring that an upstream is marked as down
|
||||
// within 10 seconds if failures persist, without spawning duplicate goroutines.
|
||||
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
if um.recovered[upstream] {
|
||||
um.logger.Load().Debug().Msgf("Upstream %q is recovered, skipping failure count increase", upstream)
|
||||
return
|
||||
}
|
||||
|
||||
um.failureReq[upstream] += 1
|
||||
failedCount := um.failureReq[upstream]
|
||||
um.down[upstream] = failedCount >= maxFailureRequest
|
||||
|
||||
// Log the updated failure count.
|
||||
um.logger.Load().Debug().Msgf("Upstream %q failure count updated to %d", upstream, failedCount)
|
||||
|
||||
// If this is the first failure and no timer is running, start a 10-second timer.
|
||||
if failedCount == 1 && !um.failureTimerActive[upstream] {
|
||||
um.failureTimerActive[upstream] = true
|
||||
go func(upstream string) {
|
||||
time.Sleep(10 * time.Second)
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
// If no success occurred during the 10-second window (i.e. counter remains > 0)
|
||||
// and the upstream is not in a recovered state, mark it as down.
|
||||
if um.failureReq[upstream] > 0 && !um.recovered[upstream] {
|
||||
um.down[upstream] = true
|
||||
um.logger.Load().Warn().Msgf("Upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream])
|
||||
}
|
||||
// Reset the timer flag so that a new timer can be spawned if needed.
|
||||
um.failureTimerActive[upstream] = false
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
// If the failure count quickly reaches the threshold, mark the upstream as down immediately.
|
||||
if failedCount >= maxFailureRequest {
|
||||
um.down[upstream] = true
|
||||
um.logger.Load().Warn().Msgf("Upstream %q marked as down immediately (failure count: %d)", upstream, failedCount)
|
||||
}
|
||||
}
|
||||
|
||||
// isDown reports whether the given upstream is being marked as down.
|
||||
@@ -63,50 +103,28 @@ func (um *upstreamMonitor) isDown(upstream string) bool {
|
||||
// reset marks an upstream as up and set failed queries counter to zero.
|
||||
func (um *upstreamMonitor) reset(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
um.failureReq[upstream] = 0
|
||||
um.down[upstream] = false
|
||||
}
|
||||
|
||||
// checkUpstream checks the given upstream status, periodically sending query to upstream
|
||||
// until successfully. An upstream status/counter will be reset once it becomes reachable.
|
||||
func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) {
|
||||
um.mu.Lock()
|
||||
isChecking := um.checking[upstream]
|
||||
if isChecking {
|
||||
um.mu.Unlock()
|
||||
return
|
||||
}
|
||||
um.checking[upstream] = true
|
||||
um.recovered[upstream] = true
|
||||
um.mu.Unlock()
|
||||
defer func() {
|
||||
go func() {
|
||||
// debounce the recovery to avoid incrementing failure counts already in flight
|
||||
time.Sleep(1 * time.Second)
|
||||
um.mu.Lock()
|
||||
um.checking[upstream] = false
|
||||
um.recovered[upstream] = false
|
||||
um.mu.Unlock()
|
||||
}()
|
||||
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not check upstream")
|
||||
return
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
|
||||
check := func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
uc.ReBootstrap()
|
||||
_, err := resolver.Resolve(ctx, msg)
|
||||
return err
|
||||
}
|
||||
for {
|
||||
if err := check(); err == nil {
|
||||
mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint)
|
||||
um.reset(upstream)
|
||||
return
|
||||
}
|
||||
time.Sleep(checkUpstreamBackoffSleep)
|
||||
}
|
||||
}
|
||||
|
||||
// countHealthy returns the number of upstreams in the provided map that are considered healthy.
|
||||
func (um *upstreamMonitor) countHealthy(upstreams []string) int {
|
||||
var count int
|
||||
um.mu.RLock()
|
||||
for _, upstream := range upstreams {
|
||||
if !um.down[upstream] {
|
||||
count++
|
||||
}
|
||||
}
|
||||
um.mu.RUnlock()
|
||||
return count
|
||||
}
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
package main
|
||||
|
||||
import "github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cli.Main()
|
||||
// make sure we exit with 0 if there are no errors
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -28,22 +28,24 @@ type AppCallback interface {
|
||||
// Start configures utility with config.toml from provided directory.
|
||||
// This function will block until Stop is called
|
||||
// Check port availability prior to calling it.
|
||||
func (c *Controller) Start(CdUID string, HomeDir string, UpstreamProto string, logLevel int, logPath string) {
|
||||
func (c *Controller) Start(CdUID string, ProvisionID string, CustomHostname string, HomeDir string, UpstreamProto string, logLevel int, logPath string) {
|
||||
if c.stopCh == nil {
|
||||
c.stopCh = make(chan struct{})
|
||||
c.Config = cli.AppConfig{
|
||||
CdUID: CdUID,
|
||||
HomeDir: HomeDir,
|
||||
UpstreamProto: UpstreamProto,
|
||||
Verbose: logLevel,
|
||||
LogPath: logPath,
|
||||
CdUID: CdUID,
|
||||
ProvisionID: ProvisionID,
|
||||
CustomHostname: CustomHostname,
|
||||
HomeDir: HomeDir,
|
||||
UpstreamProto: UpstreamProto,
|
||||
Verbose: logLevel,
|
||||
LogPath: logPath,
|
||||
}
|
||||
appCallback := mapCallback(c.AppCallback)
|
||||
cli.RunMobile(&c.Config, &appCallback, c.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
// As workaround to avoid circular dependency between cli and ctrld_library module
|
||||
// mapCallback maps the AppCallback interface to cli.AppCallback to avoid circular dependency
|
||||
func mapCallback(callback AppCallback) cli.AppCallback {
|
||||
return cli.AppCallback{
|
||||
HostName: func() string {
|
||||
|
||||
394
config.go
394
config.go
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
@@ -22,9 +23,11 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ameshkov/dnsstamps"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/tsaddr"
|
||||
@@ -50,14 +53,36 @@ const (
|
||||
FreeDnsDomain = "freedns.controld.com"
|
||||
// FreeDNSBoostrapIP is the IP address of freedns.controld.com.
|
||||
FreeDNSBoostrapIP = "76.76.2.11"
|
||||
// FreeDNSBoostrapIPv6 is the IPv6 address of freedns.controld.com.
|
||||
FreeDNSBoostrapIPv6 = "2606:1a40::11"
|
||||
// PremiumDnsDomain is the domain name of premium ControlD service.
|
||||
PremiumDnsDomain = "dns.controld.com"
|
||||
// PremiumDNSBoostrapIP is the IP address of dns.controld.com.
|
||||
PremiumDNSBoostrapIP = "76.76.2.22"
|
||||
// PremiumDNSBoostrapIPv6 is the IPv6 address of dns.controld.com.
|
||||
PremiumDNSBoostrapIPv6 = "2606:1a40::22"
|
||||
|
||||
// freeDnsDomainDev is the domain name of free ControlD service on dev env.
|
||||
freeDnsDomainDev = "freedns.controld.dev"
|
||||
// freeDNSBoostrapIP is the IP address of freedns.controld.dev.
|
||||
freeDNSBoostrapIP = "176.125.239.11"
|
||||
// freeDNSBoostrapIPv6 is the IPv6 address of freedns.controld.com.
|
||||
freeDNSBoostrapIPv6 = "2606:1a40:f000::11"
|
||||
// premiumDnsDomainDev is the domain name of premium ControlD service on dev env.
|
||||
premiumDnsDomainDev = "dns.controld.dev"
|
||||
// premiumDNSBoostrapIP is the IP address of dns.controld.dev.
|
||||
premiumDNSBoostrapIP = "176.125.239.22"
|
||||
// premiumDNSBoostrapIPv6 is the IPv6 address of dns.controld.dev.
|
||||
premiumDNSBoostrapIPv6 = "2606:1a40:f000::22"
|
||||
|
||||
controlDComDomain = "controld.com"
|
||||
controlDNetDomain = "controld.net"
|
||||
controlDDevDomain = "controld.dev"
|
||||
|
||||
endpointPrefixHTTPS = "https://"
|
||||
endpointPrefixQUIC = "quic://"
|
||||
endpointPrefixH3 = "h3://"
|
||||
endpointPrefixSdns = "sdns://"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -89,6 +114,10 @@ func SetConfigNameWithPath(v *viper.Viper, name, configPath string) {
|
||||
|
||||
// InitConfig initializes default config values for given *viper.Viper instance.
|
||||
func InitConfig(v *viper.Viper, name string) {
|
||||
ctx := context.Background()
|
||||
logger := LoggerFromCtx(ctx)
|
||||
Log(ctx, logger.Debug(), "Config initialization started")
|
||||
|
||||
v.SetDefault("listener", map[string]*ListenerConfig{
|
||||
"0": {
|
||||
IP: "",
|
||||
@@ -127,6 +156,8 @@ func InitConfig(v *viper.Viper, name string) {
|
||||
Timeout: 3000,
|
||||
},
|
||||
})
|
||||
|
||||
Log(ctx, logger.Debug(), "Config initialization completed")
|
||||
}
|
||||
|
||||
// Config represents ctrld supported configuration.
|
||||
@@ -188,27 +219,32 @@ func (c *Config) FirstUpstream() *UpstreamConfig {
|
||||
|
||||
// ServiceConfig specifies the general ctrld config.
|
||||
type ServiceConfig struct {
|
||||
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
|
||||
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
|
||||
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
|
||||
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
|
||||
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
|
||||
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
|
||||
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
|
||||
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
||||
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
|
||||
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
||||
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
|
||||
DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"`
|
||||
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
|
||||
MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"`
|
||||
MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"`
|
||||
Daemon bool `mapstructure:"-" toml:"-"`
|
||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
|
||||
LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"`
|
||||
CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"`
|
||||
CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"`
|
||||
CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"`
|
||||
CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"`
|
||||
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
|
||||
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
||||
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp kea-dhcp4"`
|
||||
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
DiscoverPtr *bool `mapstructure:"discover_ptr" toml:"discover_ptr,omitempty"`
|
||||
DiscoverHosts *bool `mapstructure:"discover_hosts" toml:"discover_hosts,omitempty"`
|
||||
DiscoverRefreshInterval int `mapstructure:"discover_refresh_interval" toml:"discover_refresh_interval,omitempty"`
|
||||
ClientIDPref string `mapstructure:"client_id_preference" toml:"client_id_preference,omitempty" validate:"omitempty,oneof=host mac"`
|
||||
MetricsQueryStats bool `mapstructure:"metrics_query_stats" toml:"metrics_query_stats,omitempty"`
|
||||
MetricsListener string `mapstructure:"metrics_listener" toml:"metrics_listener,omitempty"`
|
||||
DnsWatchdogEnabled *bool `mapstructure:"dns_watchdog_enabled" toml:"dns_watchdog_enabled,omitempty"`
|
||||
DnsWatchdogInvterval *time.Duration `mapstructure:"dns_watchdog_interval" toml:"dns_watchdog_interval,omitempty"`
|
||||
RefetchTime *int `mapstructure:"refetch_time" toml:"refetch_time,omitempty"`
|
||||
ForceRefetchWaitTime *int `mapstructure:"force_refetch_wait_time" toml:"force_refetch_wait_time,omitempty"`
|
||||
LeakOnUpstreamFailure *bool `mapstructure:"leak_on_upstream_failure" toml:"leak_on_upstream_failure,omitempty"`
|
||||
Daemon bool `mapstructure:"-" toml:"-"`
|
||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
|
||||
// NetworkConfig specifies configuration for networks where ctrld will handle requests.
|
||||
@@ -221,7 +257,7 @@ type NetworkConfig struct {
|
||||
// UpstreamConfig specifies configuration for upstreams that ctrld will forward requests to.
|
||||
type UpstreamConfig struct {
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy"`
|
||||
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy sdns ''"`
|
||||
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty"`
|
||||
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"`
|
||||
Domain string `mapstructure:"-" toml:"-"`
|
||||
@@ -248,6 +284,7 @@ type UpstreamConfig struct {
|
||||
http3RoundTripper6 http.RoundTripper
|
||||
certPool *x509.CertPool
|
||||
u *url.URL
|
||||
fallbackOnce sync.Once
|
||||
uid string
|
||||
}
|
||||
|
||||
@@ -278,14 +315,20 @@ func (lc *ListenerConfig) IsDirectDnsListener() bool {
|
||||
}
|
||||
}
|
||||
|
||||
// MatchingConfig defines the configuration for rule matching behavior
|
||||
type MatchingConfig struct {
|
||||
Order []string `mapstructure:"order" toml:"order,omitempty"`
|
||||
}
|
||||
|
||||
// ListenerPolicyConfig specifies the policy rules for ctrld to filter incoming requests.
|
||||
type ListenerPolicyConfig struct {
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"`
|
||||
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"`
|
||||
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
|
||||
Matching *MatchingConfig `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
|
||||
// Rule is a map from source to list of upstreams.
|
||||
@@ -294,11 +337,15 @@ type ListenerPolicyConfig struct {
|
||||
type Rule map[string][]string
|
||||
|
||||
// Init initialized necessary values for an UpstreamConfig.
|
||||
func (uc *UpstreamConfig) Init() {
|
||||
func (uc *UpstreamConfig) Init(ctx context.Context) {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
if err := uc.initDnsStamps(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("Invalid dns stamps")
|
||||
}
|
||||
uc.initDoHScheme()
|
||||
uc.uid = upstreamUID()
|
||||
uc.uid = upstreamUID(ctx)
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
uc.Domain = u.Host
|
||||
uc.Domain = u.Hostname()
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
uc.u = u
|
||||
@@ -316,7 +363,10 @@ func (uc *UpstreamConfig) Init() {
|
||||
}
|
||||
}
|
||||
if uc.IPStack == "" {
|
||||
if uc.isControlD() {
|
||||
// Set default IP stack based on upstream type
|
||||
// Control-D upstreams use split stack for better IPv4/IPv6 handling,
|
||||
// while other upstreams use both stacks for maximum compatibility
|
||||
if uc.IsControlD() {
|
||||
uc.IPStack = IpStackSplit
|
||||
} else {
|
||||
uc.IPStack = IpStackBoth
|
||||
@@ -324,6 +374,15 @@ func (uc *UpstreamConfig) Init() {
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyMsg creates and returns a new DNS message could be used for testing upstream health.
|
||||
func (uc *UpstreamConfig) VerifyMsg() *dns.Msg {
|
||||
msg := new(dns.Msg)
|
||||
msg.RecursionDesired = true
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
msg.SetEdns0(4096, false) // ensure handling of large DNS response
|
||||
return msg
|
||||
}
|
||||
|
||||
// VerifyDomain returns the domain name that could be resolved by the upstream endpoint.
|
||||
// It returns empty for non-ControlD upstream endpoint.
|
||||
func (uc *UpstreamConfig) VerifyDomain() string {
|
||||
@@ -354,7 +413,7 @@ func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
if uc.isControlD() || uc.isNextDNS() {
|
||||
if uc.IsControlD() || uc.isNextDNS() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -386,24 +445,23 @@ func (uc *UpstreamConfig) SetCertPool(cp *x509.CertPool) {
|
||||
uc.certPool = cp
|
||||
}
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
uc.setupBootstrapIP(true)
|
||||
}
|
||||
|
||||
// UID returns the unique identifier of the upstream.
|
||||
func (uc *UpstreamConfig) UID() string {
|
||||
return uc.uid
|
||||
}
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
// SetupBootstrapIP sets up bootstrap IPs for the upstream config.
|
||||
// The setup process will block until there's usable IPs found.
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
Log(ctx, logger.Debug(), "Setting up bootstrap IPs for upstream: %s", uc.Name)
|
||||
|
||||
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second)
|
||||
isControlD := uc.isControlD()
|
||||
isControlD := uc.IsControlD()
|
||||
nss := initDefaultOsResolver(ctx)
|
||||
for {
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
|
||||
Log(ctx, logger.Debug(), "Looking up bootstrap IPs for domain: %s", uc.Domain)
|
||||
uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, nss)
|
||||
// For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses,
|
||||
// filtering them out here to prevent weird behavior.
|
||||
if isControlD {
|
||||
@@ -416,11 +474,20 @@ func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
}
|
||||
}
|
||||
uc.bootstrapIPs = uc.bootstrapIPs[:n]
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
uc.bootstrapIPs = bootstrapIPsFromControlDDomain(uc.Domain)
|
||||
logger.Warn().Msgf("No record found for %q, lookup from direct ip table", uc.Domain)
|
||||
}
|
||||
}
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
logger.Warn().Msgf("No record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP)
|
||||
uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")})
|
||||
|
||||
}
|
||||
if len(uc.bootstrapIPs) > 0 {
|
||||
break
|
||||
}
|
||||
ProxyLogger.Load().Warn().Msg("could not resolve bootstrap IPs, retrying...")
|
||||
logger.Warn().Msg("Could not resolve bootstrap ips, retrying...")
|
||||
b.BackOff(context.Background(), errors.New("no bootstrap IPs"))
|
||||
}
|
||||
for _, ip := range uc.bootstrapIPs {
|
||||
@@ -430,11 +497,12 @@ func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
uc.bootstrapIPs4 = append(uc.bootstrapIPs4, ip)
|
||||
}
|
||||
}
|
||||
ProxyLogger.Load().Debug().Msgf("bootstrap IPs: %v", uc.bootstrapIPs)
|
||||
logger.Debug().Msgf("Bootstrap ips: %v", uc.bootstrapIPs)
|
||||
Log(ctx, logger.Debug(), "Bootstrap IP setup completed for upstream: %s", uc.Name)
|
||||
}
|
||||
|
||||
// ReBootstrap re-setup the bootstrap IP and the transport.
|
||||
func (uc *UpstreamConfig) ReBootstrap() {
|
||||
func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
default:
|
||||
@@ -442,7 +510,8 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
}
|
||||
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
|
||||
if uc.rebootstrap.CompareAndSwap(false, true) {
|
||||
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
|
||||
logger := LoggerFromCtx(ctx)
|
||||
Log(ctx, logger.Debug(), "Re-bootstrapping upstream: %s", uc.Name)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
@@ -450,35 +519,35 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
|
||||
// SetupTransport initializes the network transport used to connect to upstream server.
|
||||
// For now, only DoH upstream is supported.
|
||||
func (uc *UpstreamConfig) SetupTransport() {
|
||||
func (uc *UpstreamConfig) SetupTransport(ctx context.Context) {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH:
|
||||
uc.setupDOHTransport()
|
||||
uc.setupDOHTransport(ctx)
|
||||
case ResolverTypeDOH3:
|
||||
uc.setupDOH3Transport()
|
||||
uc.setupDOH3Transport(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) setupDOHTransport() {
|
||||
func (uc *UpstreamConfig) setupDOHTransport(ctx context.Context) {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, "":
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs4)
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
|
||||
if hasIPv6() {
|
||||
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4)
|
||||
if HasIPv6(ctx) {
|
||||
uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.transport6 = uc.transport4
|
||||
}
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
func (uc *UpstreamConfig) newDOHTransport(ctx context.Context, addrs []string) *http.Transport {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.MaxIdleConnsPerHost = 100
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
@@ -486,17 +555,25 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
ClientSessionCache: tls.NewLRUClientSessionCache(0),
|
||||
}
|
||||
|
||||
// Prevent bad tcp connection hanging the requests for too long.
|
||||
// See: https://github.com/golang/go/issues/36026
|
||||
if t2, err := http2.ConfigureTransports(transport); err == nil {
|
||||
t2.ReadIdleTimeout = 10 * time.Second
|
||||
t2.PingTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
dialerTimeoutMs := 2000
|
||||
if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs {
|
||||
dialerTimeoutMs = uc.Timeout
|
||||
}
|
||||
dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond
|
||||
logger := LoggerFromCtx(ctx)
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
if uc.BootstrapIP != "" {
|
||||
dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout}
|
||||
addr := net.JoinHostPort(uc.BootstrapIP, port)
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", addr)
|
||||
Log(ctx, logger.Debug(), "Sending doh request to: %s", addr)
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
pd := &ctrldnet.ParallelDialer{}
|
||||
@@ -506,11 +583,11 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
for i := range addrs {
|
||||
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
||||
}
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs)
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs, logger.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", conn.RemoteAddr())
|
||||
Log(ctx, logger.Debug(), "Sending doh request to: %s", conn.RemoteAddr())
|
||||
return conn, nil
|
||||
}
|
||||
runtime.SetFinalizer(transport, func(transport *http.Transport) {
|
||||
@@ -520,16 +597,20 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
}
|
||||
|
||||
// Ping warms up the connection to DoH/DoH3 upstream.
|
||||
func (uc *UpstreamConfig) Ping() {
|
||||
_ = uc.ping()
|
||||
func (uc *UpstreamConfig) Ping(ctx context.Context) {
|
||||
if err := uc.ping(ctx); err != nil {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Debug().Err(err).Msgf("Upstream ping failed: %s", uc.Endpoint)
|
||||
_ = uc.FallbackToDirectIP(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorPing is like Ping, but return an error if any.
|
||||
func (uc *UpstreamConfig) ErrorPing() error {
|
||||
return uc.ping()
|
||||
func (uc *UpstreamConfig) ErrorPing(ctx context.Context) error {
|
||||
return uc.ping(ctx)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) ping() error {
|
||||
func (uc *UpstreamConfig) ping(ctx context.Context) error {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
default:
|
||||
@@ -558,12 +639,11 @@ func (uc *UpstreamConfig) ping() error {
|
||||
for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH:
|
||||
|
||||
if err := ping(uc.dohTransport(typ)); err != nil {
|
||||
if err := ping(uc.dohTransport(ctx, typ)); err != nil {
|
||||
return err
|
||||
}
|
||||
case ResolverTypeDOH3:
|
||||
if err := ping(uc.doh3Transport(typ)); err != nil {
|
||||
if err := ping(uc.doh3Transport(ctx, typ)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -572,7 +652,8 @@ func (uc *UpstreamConfig) ping() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) isControlD() bool {
|
||||
// IsControlD reports whether this is a ControlD upstream.
|
||||
func (uc *UpstreamConfig) IsControlD() bool {
|
||||
domain := uc.Domain
|
||||
if domain == "" {
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
@@ -597,12 +678,12 @@ func (uc *UpstreamConfig) isNextDNS() bool {
|
||||
return domain == "dns.nextdns.io"
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
|
||||
func (uc *UpstreamConfig) dohTransport(ctx context.Context, dnsType uint16) http.RoundTripper {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport()
|
||||
uc.SetupTransport(ctx)
|
||||
})
|
||||
if uc.rebootstrap.CompareAndSwap(true, false) {
|
||||
uc.SetupTransport()
|
||||
uc.SetupTransport(ctx)
|
||||
}
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, IpStackV4, IpStackV6:
|
||||
@@ -618,7 +699,7 @@ func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
|
||||
return uc.transport
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
|
||||
func (uc *UpstreamConfig) bootstrapIPForDNSType(ctx context.Context, dnsType uint16) string {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth:
|
||||
return pick(uc.bootstrapIPs)
|
||||
@@ -631,7 +712,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
|
||||
case dns.TypeA:
|
||||
return pick(uc.bootstrapIPs4)
|
||||
default:
|
||||
if hasIPv6() {
|
||||
if HasIPv6(ctx) {
|
||||
return pick(uc.bootstrapIPs6)
|
||||
}
|
||||
return pick(uc.bootstrapIPs4)
|
||||
@@ -640,7 +721,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
|
||||
return pick(uc.bootstrapIPs)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
|
||||
func (uc *UpstreamConfig) netForDNSType(ctx context.Context, dnsType uint16) (string, string) {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth:
|
||||
return "tcp-tls", "udp"
|
||||
@@ -653,7 +734,7 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
|
||||
case dns.TypeA:
|
||||
return "tcp4-tls", "udp4"
|
||||
default:
|
||||
if hasIPv6() {
|
||||
if HasIPv6(ctx) {
|
||||
return "tcp6-tls", "udp6"
|
||||
}
|
||||
return "tcp4-tls", "udp4"
|
||||
@@ -664,24 +745,117 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
|
||||
|
||||
// initDoHScheme initializes the endpoint scheme for DoH/DoH3 upstream if not present.
|
||||
func (uc *UpstreamConfig) initDoHScheme() {
|
||||
if strings.HasPrefix(uc.Endpoint, endpointPrefixH3) && uc.Type == "" {
|
||||
uc.Type = ResolverTypeDOH3
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
case ResolverTypeDOH:
|
||||
case ResolverTypeDOH3:
|
||||
if after, found := strings.CutPrefix(uc.Endpoint, endpointPrefixH3); found {
|
||||
uc.Endpoint = endpointPrefixHTTPS + after
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(uc.Endpoint, "https://") {
|
||||
uc.Endpoint = "https://" + uc.Endpoint
|
||||
if !strings.HasPrefix(uc.Endpoint, endpointPrefixHTTPS) {
|
||||
uc.Endpoint = endpointPrefixHTTPS + uc.Endpoint
|
||||
}
|
||||
}
|
||||
|
||||
// initDnsStamps initializes upstream config based on encoded DNS Stamps Endpoint.
|
||||
func (uc *UpstreamConfig) initDnsStamps() error {
|
||||
if strings.HasPrefix(uc.Endpoint, endpointPrefixSdns) && uc.Type == "" {
|
||||
uc.Type = ResolverTypeSDNS
|
||||
}
|
||||
if uc.Type != ResolverTypeSDNS {
|
||||
return nil
|
||||
}
|
||||
sdns, err := dnsstamps.NewServerStampFromString(uc.Endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ip, port, _ := net.SplitHostPort(sdns.ServerAddrStr)
|
||||
providerName, port2, _ := net.SplitHostPort(sdns.ProviderName)
|
||||
if port2 != "" {
|
||||
port = port2
|
||||
}
|
||||
if providerName == "" {
|
||||
providerName = sdns.ProviderName
|
||||
}
|
||||
switch sdns.Proto {
|
||||
case dnsstamps.StampProtoTypeDoH:
|
||||
uc.Type = ResolverTypeDOH
|
||||
host := sdns.ProviderName
|
||||
if port != "" && port != defaultPortFor(uc.Type) {
|
||||
host = net.JoinHostPort(providerName, port)
|
||||
}
|
||||
uc.Endpoint = "https://" + host + sdns.Path
|
||||
case dnsstamps.StampProtoTypeTLS:
|
||||
uc.Type = ResolverTypeDOT
|
||||
uc.Endpoint = net.JoinHostPort(providerName, port)
|
||||
case dnsstamps.StampProtoTypeDoQ:
|
||||
uc.Type = ResolverTypeDOQ
|
||||
uc.Endpoint = net.JoinHostPort(providerName, port)
|
||||
case dnsstamps.StampProtoTypePlain:
|
||||
uc.Type = ResolverTypeLegacy
|
||||
uc.Endpoint = sdns.ServerAddrStr
|
||||
default:
|
||||
return fmt.Errorf("unsupported stamp protocol %q", sdns.Proto)
|
||||
}
|
||||
uc.BootstrapIP = ip
|
||||
return nil
|
||||
}
|
||||
|
||||
// Context returns a new context with timeout set from upstream config.
|
||||
func (uc *UpstreamConfig) Context(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if uc.Timeout > 0 {
|
||||
return context.WithTimeout(ctx, time.Millisecond*time.Duration(uc.Timeout))
|
||||
}
|
||||
return context.WithCancel(ctx)
|
||||
}
|
||||
|
||||
// FallbackToDirectIP changes ControlD upstream endpoint to use direct IP instead of domain.
|
||||
func (uc *UpstreamConfig) FallbackToDirectIP(ctx context.Context) bool {
|
||||
if !uc.IsControlD() {
|
||||
return false
|
||||
}
|
||||
if uc.u == nil || uc.Domain == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
done := false
|
||||
uc.fallbackOnce.Do(func() {
|
||||
var ip string
|
||||
switch {
|
||||
case dns.IsSubDomain(PremiumDnsDomain, uc.Domain):
|
||||
ip = PremiumDNSBoostrapIP
|
||||
case dns.IsSubDomain(FreeDnsDomain, uc.Domain):
|
||||
ip = FreeDNSBoostrapIP
|
||||
default:
|
||||
return
|
||||
}
|
||||
logger := LoggerFromCtx(ctx)
|
||||
Log(ctx, logger.Warn(), "Using direct IP for %q: %s", uc.Endpoint, ip)
|
||||
uc.u.Host = ip
|
||||
done = true
|
||||
})
|
||||
return done
|
||||
}
|
||||
|
||||
// Init initialized necessary values for an ListenerConfig.
|
||||
func (lc *ListenerConfig) Init() {
|
||||
logger := LoggerFromCtx(context.Background())
|
||||
Log(context.Background(), logger.Debug(), "Initializing listener config")
|
||||
|
||||
if lc.Policy != nil {
|
||||
lc.Policy.FailoverRcodeNumbers = make([]int, len(lc.Policy.FailoverRcodes))
|
||||
for i, rcode := range lc.Policy.FailoverRcodes {
|
||||
lc.Policy.FailoverRcodeNumbers[i] = dnsrcode.FromString(rcode)
|
||||
}
|
||||
Log(context.Background(), logger.Debug(), "Listener policy initialized with %d failover rcodes", len(lc.Policy.FailoverRcodes))
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(), "Listener config initialization completed")
|
||||
}
|
||||
|
||||
// ValidateConfig validates the given config.
|
||||
@@ -726,6 +900,23 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) {
|
||||
return
|
||||
}
|
||||
|
||||
// Empty type is ok only for endpoints starts with "h3://" and "sdns://".
|
||||
if uc.Type == "" && !strings.HasPrefix(uc.Endpoint, endpointPrefixH3) && !strings.HasPrefix(uc.Endpoint, endpointPrefixSdns) {
|
||||
sl.ReportError(uc.Endpoint, "type", "type", "oneof", "doh doh3 dot doq os legacy sdns")
|
||||
return
|
||||
}
|
||||
|
||||
// initDoHScheme/initDnsStamps may change upstreams information,
|
||||
// so restoring changed values after validation to keep original one.
|
||||
defer func(ep, typ string) {
|
||||
uc.Endpoint = ep
|
||||
uc.Type = typ
|
||||
}(uc.Endpoint, uc.Type)
|
||||
|
||||
if err := uc.initDnsStamps(); err != nil {
|
||||
sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "")
|
||||
return
|
||||
}
|
||||
uc.initDoHScheme()
|
||||
// DoH/DoH3 requires endpoint is an HTTP url.
|
||||
if uc.Type == ResolverTypeDOH || uc.Type == ResolverTypeDOH3 {
|
||||
@@ -755,13 +946,19 @@ func defaultPortFor(typ string) string {
|
||||
// - If endpoint is an IP address -> ResolverTypeLegacy
|
||||
// - If endpoint starts with "https://" -> ResolverTypeDOH
|
||||
// - If endpoint starts with "quic://" -> ResolverTypeDOQ
|
||||
// - If endpoint starts with "h3://" -> ResolverTypeDOH3
|
||||
// - If endpoint starts with "sdns://" -> ResolverTypeSDNS
|
||||
// - For anything else -> ResolverTypeDOT
|
||||
func ResolverTypeFromEndpoint(endpoint string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(endpoint, "https://"):
|
||||
case strings.HasPrefix(endpoint, endpointPrefixHTTPS):
|
||||
return ResolverTypeDOH
|
||||
case strings.HasPrefix(endpoint, "quic://"):
|
||||
case strings.HasPrefix(endpoint, endpointPrefixQUIC):
|
||||
return ResolverTypeDOQ
|
||||
case strings.HasPrefix(endpoint, endpointPrefixH3):
|
||||
return ResolverTypeDOH3
|
||||
case strings.HasPrefix(endpoint, endpointPrefixSdns):
|
||||
return ResolverTypeSDNS
|
||||
}
|
||||
host := endpoint
|
||||
if strings.Contains(endpoint, ":") {
|
||||
@@ -778,13 +975,38 @@ func pick(s []string) string {
|
||||
}
|
||||
|
||||
// upstreamUID generates an unique identifier for an upstream.
|
||||
func upstreamUID() string {
|
||||
func upstreamUID(ctx context.Context) string {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
b := make([]byte, 4)
|
||||
for {
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...")
|
||||
logger.Warn().Err(err).Msg("Could not generate uid for upstream, retrying...")
|
||||
continue
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the UpstreamConfig for logging.
|
||||
func (uc *UpstreamConfig) String() string {
|
||||
if uc == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
return fmt.Sprintf("{name: %q, type: %q, endpoint: %q, bootstrap_ip: %q, domain: %q, ip_stack: %q}",
|
||||
uc.Name, uc.Type, uc.Endpoint, uc.BootstrapIP, uc.Domain, uc.IPStack)
|
||||
}
|
||||
|
||||
// bootstrapIPsFromControlDDomain returns bootstrap IPs for ControlD domain.
|
||||
func bootstrapIPsFromControlDDomain(domain string) []string {
|
||||
switch {
|
||||
case dns.IsSubDomain(PremiumDnsDomain, domain):
|
||||
return []string{PremiumDNSBoostrapIP, PremiumDNSBoostrapIPv6}
|
||||
case dns.IsSubDomain(FreeDnsDomain, domain):
|
||||
return []string{FreeDNSBoostrapIP, FreeDNSBoostrapIPv6}
|
||||
case dns.IsSubDomain(premiumDnsDomainDev, domain):
|
||||
return []string{premiumDNSBoostrapIP, premiumDNSBoostrapIPv6}
|
||||
case dns.IsSubDomain(freeDnsDomainDev, domain):
|
||||
return []string{freeDNSBoostrapIP, freeDNSBoostrapIPv6}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
@@ -8,24 +9,49 @@ import (
|
||||
)
|
||||
|
||||
func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
|
||||
uc := &UpstreamConfig{
|
||||
Name: "test",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: "https://freedns.controld.com/p2",
|
||||
Timeout: 5000,
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
}{
|
||||
{
|
||||
name: "doh/doh3",
|
||||
uc: &UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: "https://freedns.controld.com/p2",
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "doq/dot",
|
||||
uc: &UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: ResolverTypeDOT,
|
||||
Endpoint: "p2.freedns.controld.com",
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
}
|
||||
uc.Init()
|
||||
uc.setupBootstrapIP(false)
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
t.Log(nameservers())
|
||||
t.Fatal("could not bootstrap ip without bootstrap DNS")
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Enable parallel tests once https://github.com/microsoft/wmi/issues/165 fixed.
|
||||
// t.Parallel()
|
||||
tc.uc.Init(context.Background())
|
||||
tc.uc.SetupBootstrapIP(context.Background())
|
||||
if len(tc.uc.bootstrapIPs) == 0 {
|
||||
t.Log(defaultNameservers(context.Background()))
|
||||
t.Fatalf("could not bootstrap ip: %s", tc.uc.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
t.Log(uc)
|
||||
|
||||
}
|
||||
|
||||
func TestUpstreamConfig_Init(t *testing.T) {
|
||||
u1, _ := url.Parse("https://example.com")
|
||||
u2, _ := url.Parse("https://example.com?k=v")
|
||||
u3, _ := url.Parse("https://freedns.controld.com/p1")
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
@@ -178,13 +204,159 @@ func TestUpstreamConfig_Init(t *testing.T) {
|
||||
u: u2,
|
||||
},
|
||||
},
|
||||
{
|
||||
"h3",
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Type: "doh3",
|
||||
Endpoint: "h3://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Type: "doh3",
|
||||
Endpoint: "https://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "example.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
u: u1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"h3 without type",
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Endpoint: "h3://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Type: "doh3",
|
||||
Endpoint: "https://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "example.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
u: u1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns -> doh",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "sdns",
|
||||
Endpoint: "sdns://AgMAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29tAy9wMQ",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "doh",
|
||||
Endpoint: "https://freedns.controld.com/p1",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
u: u3,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns -> dot",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "sdns",
|
||||
Endpoint: "sdns://AwcAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29t",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "dot",
|
||||
Endpoint: "freedns.controld.com:843",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns -> doq",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "sdns",
|
||||
Endpoint: "sdns://BAcAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29t",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "doq",
|
||||
Endpoint: "freedns.controld.com:784",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns -> legacy",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "sdns",
|
||||
Endpoint: "sdns://AAcAAAAAAAAACjc2Ljc2LjIuMTE",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "legacy",
|
||||
Endpoint: "76.76.2.11:53",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "76.76.2.11",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns without type",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Endpoint: "sdns://AAcAAAAAAAAACjc2Ljc2LjIuMTE",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "legacy",
|
||||
Endpoint: "76.76.2.11:53",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "76.76.2.11",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.Init(context.Background())
|
||||
tc.uc.uid = "" // we don't care about the uid.
|
||||
assert.Equal(t, tc.expected, tc.uc)
|
||||
})
|
||||
@@ -326,7 +498,7 @@ func TestUpstreamConfig_IsDiscoverable(t *testing.T) {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.Init(context.Background())
|
||||
if got := tc.uc.IsDiscoverable(); got != tc.discoverable {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got)
|
||||
}
|
||||
|
||||
@@ -14,34 +14,35 @@ import (
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
func (uc *UpstreamConfig) setupDOH3Transport(ctx context.Context) {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, "":
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
if hasIPv6() {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
|
||||
if HasIPv6(ctx) {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.http3RoundTripper6 = uc.http3RoundTripper4
|
||||
}
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
rt := &http3.RoundTripper{}
|
||||
func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper {
|
||||
rt := &http3.Transport{}
|
||||
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
||||
if uc.BootstrapIP != "" {
|
||||
addr = net.JoinHostPort(uc.BootstrapIP, port)
|
||||
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", addr)
|
||||
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", addr)
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -61,21 +62,21 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
|
||||
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", conn.RemoteAddr())
|
||||
return conn, err
|
||||
}
|
||||
runtime.SetFinalizer(rt, func(rt *http3.RoundTripper) {
|
||||
runtime.SetFinalizer(rt, func(rt *http3.Transport) {
|
||||
rt.CloseIdleConnections()
|
||||
})
|
||||
return rt
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
|
||||
func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport()
|
||||
uc.SetupTransport(ctx)
|
||||
})
|
||||
if uc.rebootstrap.CompareAndSwap(true, false) {
|
||||
uc.SetupTransport()
|
||||
uc.SetupTransport(ctx)
|
||||
}
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, IpStackV4, IpStackV6:
|
||||
@@ -96,14 +97,14 @@ func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
|
||||
// - quic dialer is different with net.Dialer
|
||||
// - simplification for quic free version
|
||||
type parallelDialerResult struct {
|
||||
conn quic.EarlyConnection
|
||||
conn *quic.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
type quicParallelDialer struct{}
|
||||
|
||||
// Dial performs parallel dialing to the given address list.
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
||||
if len(addrs) == 0 {
|
||||
return nil, errors.New("empty addresses")
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/spf13/viper"
|
||||
@@ -22,6 +23,8 @@ func TestLoadConfig(t *testing.T) {
|
||||
|
||||
assert.Equal(t, "info", cfg.Service.LogLevel)
|
||||
assert.Equal(t, "/path/to/log.log", cfg.Service.LogPath)
|
||||
assert.Equal(t, false, *cfg.Service.DnsWatchdogEnabled)
|
||||
assert.Equal(t, time.Duration(20*time.Second), *cfg.Service.DnsWatchdogInvterval)
|
||||
|
||||
assert.Len(t, cfg.Network, 2)
|
||||
assert.Contains(t, cfg.Network, "0")
|
||||
@@ -104,7 +107,11 @@ func TestConfigValidation(t *testing.T) {
|
||||
{"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true},
|
||||
{"invalid client id pref", configWithInvalidClientIDPref(t), true},
|
||||
{"doh endpoint without scheme", dohUpstreamEndpointWithoutScheme(t), false},
|
||||
{"doh endpoint without type", dohUpstreamEndpointWithoutType(t), true},
|
||||
{"doh3 endpoint without type", doh3UpstreamEndpointWithoutType(t), false},
|
||||
{"sdns endpoint without type", sdnsUpstreamEndpointWithoutType(t), false},
|
||||
{"maximum number of flush cache domains", configWithInvalidFlushCacheDomain(t), true},
|
||||
{"kea dhcp4 format", configWithDhcp4KeaFormat(t), false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -124,6 +131,21 @@ func TestConfigValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidationDoNotChangeEndpoint(t *testing.T) {
|
||||
cfg := configWithInvalidDoHEndpoint(t)
|
||||
endpointMap := map[string]struct{}{}
|
||||
for _, uc := range cfg.Upstream {
|
||||
endpointMap[uc.Endpoint] = struct{}{}
|
||||
}
|
||||
validate := validator.New()
|
||||
_ = ctrld.ValidateConfig(validate, cfg)
|
||||
for _, uc := range cfg.Upstream {
|
||||
if _, ok := endpointMap[uc.Endpoint]; !ok {
|
||||
t.Fatalf("expected endpoint '%s' to exist", uc.Endpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigDiscoverOverride(t *testing.T) {
|
||||
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
ctrld.InitConfig(v, "test_config_discover_override")
|
||||
@@ -176,6 +198,27 @@ func dohUpstreamEndpointWithoutScheme(t *testing.T) *ctrld.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func dohUpstreamEndpointWithoutType(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "https://freedns.controld.com/p1"
|
||||
cfg.Upstream["0"].Type = ""
|
||||
return cfg
|
||||
}
|
||||
|
||||
func doh3UpstreamEndpointWithoutType(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "h3://freedns.controld.com/p1"
|
||||
cfg.Upstream["0"].Type = ""
|
||||
return cfg
|
||||
}
|
||||
|
||||
func sdnsUpstreamEndpointWithoutType(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "sdns://AgMAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29tAy9wMQ"
|
||||
cfg.Upstream["0"].Type = ""
|
||||
return cfg
|
||||
}
|
||||
|
||||
func invalidUpstreamTimeout(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Timeout = -1
|
||||
@@ -265,6 +308,12 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithDhcp4KeaFormat(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Service.DHCPLeaseFileFormat = "kea-dhcp4"
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "/1.1.1.1"
|
||||
|
||||
10
desktop_darwin.go
Normal file
10
desktop_darwin.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool { return true }
|
||||
12
desktop_others.go
Normal file
12
desktop_others.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool { return false }
|
||||
22
desktop_windows.go
Normal file
22
desktop_windows.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package ctrld
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return isWindowsWorkStation()
|
||||
}
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool {
|
||||
return isWindowsWorkStation()
|
||||
}
|
||||
|
||||
// isWindowsWorkStation reports whether ctrld was run on a Windows workstation machine.
|
||||
func isWindowsWorkStation() bool {
|
||||
// From https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-osversioninfoexa
|
||||
const VER_NT_WORKSTATION = 0x0000001
|
||||
osvi := windows.RtlGetVersion()
|
||||
return osvi.ProductType == VER_NT_WORKSTATION
|
||||
}
|
||||
30
dns.go
Normal file
30
dns.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// SetCacheReply extracts and stores the necessary data from the message for a cached answer.
|
||||
func SetCacheReply(answer, msg *dns.Msg, code int) {
|
||||
answer.SetRcode(msg, code)
|
||||
cCookie := getEdns0Cookie(msg.IsEdns0())
|
||||
sCookie := getEdns0Cookie(answer.IsEdns0())
|
||||
if cCookie != nil && sCookie != nil {
|
||||
// Client cookie is fixed size 8 bytes, Server cookie is variable size 8 -> 32 bytes.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc7873#section-4
|
||||
sCookie.Cookie = cCookie.Cookie[:16] + sCookie.Cookie[16:]
|
||||
}
|
||||
}
|
||||
|
||||
// getEdns0Cookie returns Edns0 cookie from *dns.OPT if present.
|
||||
func getEdns0Cookie(opt *dns.OPT) *dns.EDNS0_COOKIE {
|
||||
if opt == nil {
|
||||
return nil
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
if e, ok := o.(*dns.EDNS0_COOKIE); ok {
|
||||
return e
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -8,7 +8,7 @@
|
||||
# - Non-cgo ctrld binary.
|
||||
#
|
||||
# CI_COMMIT_TAG is used to set the version of ctrld binary.
|
||||
FROM golang:1.20-bullseye as base
|
||||
FROM golang:bullseye as base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -18,10 +18,6 @@ The config file allows for advanced configuration of the `ctrld` utility to cove
|
||||
|
||||
- `/etc/controld` on *nix.
|
||||
- User's home directory on Windows.
|
||||
- Same directory with `ctrld` binary on these routers:
|
||||
- `ddwrt`
|
||||
- `merlin`
|
||||
- `freshtomato`
|
||||
- Current directory.
|
||||
|
||||
The user can choose to override default value using command line `--config` or `-c`:
|
||||
@@ -166,7 +162,6 @@ before serving the query.
|
||||
|
||||
### max_concurrent_requests
|
||||
The number of concurrent requests that will be handled, must be a non-negative integer.
|
||||
Tweaking this value depends on the capacity of your system.
|
||||
|
||||
- Type: number
|
||||
- Required: no
|
||||
@@ -179,6 +174,8 @@ Perform LAN client discovery using mDNS. This will spawn a listener on port 5353
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_arp
|
||||
Perform LAN client discovery using ARP.
|
||||
|
||||
@@ -186,6 +183,8 @@ Perform LAN client discovery using ARP.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_dhcp
|
||||
Perform LAN client discovery using DHCP leases files. Common file locations are auto-discovered.
|
||||
|
||||
@@ -193,6 +192,8 @@ Perform LAN client discovery using DHCP leases files. Common file locations are
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_ptr
|
||||
Perform LAN client discovery using PTR queries.
|
||||
|
||||
@@ -200,6 +201,8 @@ Perform LAN client discovery using PTR queries.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_hosts
|
||||
Perform LAN client discovery using hosts file.
|
||||
|
||||
@@ -207,6 +210,8 @@ Perform LAN client discovery using hosts file.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_refresh_interval
|
||||
Time in seconds between each discovery refresh loop to update new client information data.
|
||||
The default value is 120 seconds, lower this value to make the discovery process run more aggressively.
|
||||
@@ -252,6 +257,40 @@ Specifying the `ip` and `port` of the Prometheus metrics server. The Prometheus
|
||||
- Required: no
|
||||
- Default: ""
|
||||
|
||||
### dns_watchdog_enabled
|
||||
Watches all physical interfaces for DNS changes and reverts them to ctrld's settings.The DNS watchdog process only runs on Windows and MacOS.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
### dns_watchdog_interval
|
||||
Time duration between each DNS watchdog iteration.
|
||||
|
||||
A duration string is a possibly signed sequence of decimal numbers, each with optional fraction and a unit suffix,
|
||||
such as "300ms", "-1.5h" or "2h45m". Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
|
||||
|
||||
If the time duration is non-positive, default value will be used.
|
||||
|
||||
- Type: time duration string
|
||||
- Required: no
|
||||
- Default: 20s
|
||||
|
||||
### refetch_time
|
||||
Time in seconds between each iteration that reloads custom config from the API.
|
||||
|
||||
The value must be a positive number, any invalid value will be ignored and default value will be used.
|
||||
- Type: number
|
||||
- Required: no
|
||||
- Default: 3600
|
||||
|
||||
### leak_on_upstream_failure
|
||||
If a remote upstream fails to resolve a query or is unreachable, `ctrld` will forward the queries to the default DNS resolver on the network. If failures persist, `ctrld` will remove itself from all networking interfaces until connectivity is restored.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default: true on Windows, MacOS and Linux.
|
||||
|
||||
## Upstream
|
||||
The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.
|
||||
|
||||
@@ -336,7 +375,7 @@ The protocol that `ctrld` will use to send DNS requests to upstream.
|
||||
|
||||
- Type: string
|
||||
- Required: yes
|
||||
- Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`, `os`
|
||||
- Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`
|
||||
|
||||
### ip_stack
|
||||
Specifying what kind of ip stack that `ctrld` will use to connect to upstream.
|
||||
@@ -495,6 +534,15 @@ rules = [
|
||||
]
|
||||
```
|
||||
|
||||
If there is no explicitly defined rules, LAN queries will be handled solely by the OS resolver.
|
||||
|
||||
These following domains are considered LAN queries:
|
||||
|
||||
- Queries does not have dot `.` in domain name, like `machine1`, `example`, ... (1)
|
||||
- Queries have domain ends with: `.domain`, `.lan`, `.local`. (2)
|
||||
- All `SRV` queries of LAN hostname (1) + (2).
|
||||
- `PTR` queries with private IPs.
|
||||
|
||||
---
|
||||
|
||||
Note that the order of matching preference:
|
||||
@@ -528,6 +576,12 @@ And within each policy, the rules are processed from top to bottom.
|
||||
- Required: no
|
||||
- Default: []
|
||||
|
||||
---
|
||||
|
||||
Note that the domain comparisons are done in case in-sensitive manner following [RFC 1034](https://datatracker.ietf.org/doc/html/rfc1034#section-3.1)
|
||||
|
||||
---
|
||||
|
||||
### macs:
|
||||
`macs` is the list of mac rules within the policy. Mac address value is case-insensitive.
|
||||
|
||||
|
||||
BIN
docs/ctrldsplash.png
Normal file
BIN
docs/ctrldsplash.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 458 KiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user